From c1890216c5bdaff171f5ceb740f5f16cac6b4d6f Mon Sep 17 00:00:00 2001 From: edmand46 Date: Wed, 1 Apr 2026 10:53:51 +0300 Subject: [PATCH] v1.0.0 --- .github/workflows/tests.yml | 4 +- Makefile | 5 +- README.md | 59 ++++++++++- benchmarks/arpackmsg/messages_gen.go | 26 ++++- .../unity/Assets/Benchmarks/Messages.gen.cs | 33 ++++++- e2e/e2e_test.go | 52 ++++++---- generator/csharp.go | 35 ++++++- generator/generator_test.go | 99 +++++++++++++++++++ generator/go.go | 27 ++++- generator/lua.go | 12 ++- generator/lua_test.go | 64 +++++++++++- generator/policy.go | 78 +++++++++++++++ generator/ts.go | 48 +++++++-- generator/ts_test.go | 78 +++++++++++++-- parser/parser.go | 74 ++++++++++++-- parser/parser_test.go | 90 +++++++++++++++++ 16 files changed, 722 insertions(+), 62 deletions(-) create mode 100644 generator/policy.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad4785b..59e3e60 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,8 +32,8 @@ jobs: - name: Download dependencies run: go mod download - - name: Run unit tests - run: go test -v ./parser/... ./generator/... + - name: Run full Go test suite + run: make test - name: Run benchmarks (short) run: go test -bench=. -benchtime=100ms -run=^$ ./benchmarks/... diff --git a/Makefile b/Makefile index 01f4f2e..e78b676 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,12 @@ UNITY_ASSETS := benchmarks/unity/Assets -.PHONY: bench-image generate bench size gen-unity +.PHONY: test bench-image generate bench size gen-unity IMAGE := arpack-bench +test: + go test ./... + bench-image: docker build -f Dockerfile.bench -t $(IMAGE) . diff --git a/README.md b/README.md index 86019f5..d8b797c 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,17 @@ arpack -in messages.go -out-lua ./defold/scripts/messages - TypeScript: `{Name}.gen.ts` - Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility) +## v1 Contract + +ArPack `v1` intentionally supports a narrow schema model: + +- Input is a single Go source file. +- Message types must be defined in that same file. +- External package types, pointers, and platform-dependent integer aliases (`int`, `uint`, `uintptr`) are not supported. +- Wire format is stable within `v1.x` for unchanged schemas. + +This is a deliberate product boundary for predictable code generation and cross-language compatibility. + ## Schema Definition Messages are defined as Go structs in a single `.go` file: @@ -116,6 +127,8 @@ type MoveMessage struct { | `[N]T` | N × sizeof(T) | ✓ | | `[]T` | 2-byte length prefix + N × sizeof(T) | ✓ | +**Note:** platform-dependent `int`, `uint`, and `uintptr` are not supported. Use explicit widths like `int32`, `uint32`, `int64`, or `uint64`. + **Note:** `int64`/`uint64` are not supported in Lua target. LuaJIT (used by Defold) represents numbers as double-precision floats, which can only safely represent integers up to 2^53. Use `int32`/`uint32` instead. ### Float Quantization @@ -135,6 +148,8 @@ Y float32 `pack:"min=0,max=1,bits=8"` // 1 byte instead of 4 Values are linearly mapped: `encoded = (value - min) / (max - min) * maxUint`. +Quantized values must stay within the declared `[min, max]` range. Generated serializers fail fast on out-of-range or `NaN` inputs instead of silently truncating them. + ## Generated Code ### Go @@ -146,6 +161,8 @@ func (m *MoveMessage) Unmarshal(data []byte) (int, error) `Marshal` appends to the buffer and returns it. `Unmarshal` reads from the buffer and returns bytes consumed. +**Failure behavior:** generated `Marshal` panics if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range. + ### C# ```csharp @@ -155,6 +172,8 @@ public static unsafe int Deserialize(byte* buffer, out MoveMessage msg) Uses unsafe pointers for zero-copy serialization. Returns bytes written/consumed. +**Failure behavior:** generated `Serialize` throws if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range. + ### TypeScript ```typescript @@ -177,6 +196,8 @@ Uses native DataView API for browser-compatible serialization with zero dependen **Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID` → `playerId`). +**Failure behavior:** generated `serialize(...)` throws `RangeError` if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range. + ### Lua ```lua @@ -201,6 +222,7 @@ Uses pure Lua with inline helper functions for byte manipulation. Compatible wit **Limitations:** - Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53. - Variable-length fields use `uint16` length prefixes, so `string` byte length and `[]T` element count must not exceed `65535`. Serialization raises an error if the limit is exceeded. +- Quantized values must stay within the declared `[min, max]` range. Serialization raises a Lua error on out-of-range or `NaN` inputs. - Deserialization raises Lua errors on malformed or truncated input. If you need a recoverable boundary, wrap decode calls in `pcall(...)`. - Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution. @@ -212,6 +234,24 @@ Uses pure Lua with inline helper functions for byte manipulation. Compatible wit - Booleans packed as bitfields (LSB first, up to 8 per byte) - Quantized floats stored as `uint8` or `uint16` +## Compatibility Guarantees + +Within `v1.x`, the following are considered compatibility guarantees for a fixed schema: + +- Same field declaration order produces the same wire layout. +- Go, C#, TypeScript, and Lua generators produce identical wire bytes for supported types. +- `string` and `[]T` always use `uint16` length prefixes. +- Consecutive `bool` fields are bit-packed in declaration order, least-significant bit first. +- Enum fields use their declared underlying integer type on the wire. + +The following are breaking changes: + +- changing field order +- changing a field type +- changing quantization parameters +- changing enum underlying types +- changing how booleans are grouped or how lengths are encoded + ## Benchmarks ### Go Results (M3 Max) @@ -254,9 +294,20 @@ make gen-unity ## Running Tests ```bash -# Unit tests -go test ./parser/... ./generator/... +# Full test suite +make test -# End-to-end cross-language tests -go test ./e2e/... +# Benchmarks +go test ./benchmarks/... -bench=. -benchmem ``` + +## Troubleshooting + +- `unknown type "..."` + The field type is not a supported primitive and is not defined in the same schema file. +- `external package types not supported` + Copy the wire-facing type definition into the schema file instead of referencing another package. +- `... exceeds uint16 limit` + A `string` encoded to more than `65535` bytes, or a slice contains more than `65535` elements. +- `quantized value out of range` + The runtime value does not satisfy the declared `pack:"min=...,max=..."` bounds. diff --git a/benchmarks/arpackmsg/messages_gen.go b/benchmarks/arpackmsg/messages_gen.go index bf19f82..a9ea510 100644 --- a/benchmarks/arpackmsg/messages_gen.go +++ b/benchmarks/arpackmsg/messages_gen.go @@ -8,11 +8,27 @@ import ( "math" ) +func arpackEnsureUint16Length(length int, context string) uint16 { + if length > 65535 { + panic("arpack: " + context + " exceeds uint16 limit") + } + return uint16(length) +} + +func arpackEnsureQuantizedRange(value float64, min float64, max float64, context string) { + if value != value || value < min || value > max { + panic("arpack: quantized value out of range for " + context) + } +} + func (m *Vector3) Marshal(buf []byte) []byte { + arpackEnsureQuantizedRange(float64(m.X), -500, 500, "X") _qm_X := uint16((m.X - (-500)) / (500 - (-500)) * 65535) buf = binary.LittleEndian.AppendUint16(buf, _qm_X) + arpackEnsureQuantizedRange(float64(m.Y), -500, 500, "Y") _qm_Y := uint16((m.Y - (-500)) / (500 - (-500)) * 65535) buf = binary.LittleEndian.AppendUint16(buf, _qm_Y) + arpackEnsureQuantizedRange(float64(m.Z), -500, 500, "Z") _qm_Z := uint16((m.Z - (-500)) / (500 - (-500)) * 65535) buf = binary.LittleEndian.AppendUint16(buf, _qm_Z) return buf @@ -49,7 +65,7 @@ func (m *MoveMessage) Marshal(buf []byte) []byte { for _iVelocity := 0; _iVelocity < 3; _iVelocity++ { buf = binary.LittleEndian.AppendUint32(buf, math.Float32bits(m.Velocity[_iVelocity])) } - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Waypoints))) + buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Waypoints), "slice length for Waypoints")) for _iWaypoints := range m.Waypoints { buf = m.Waypoints[_iWaypoints].Marshal(buf) } @@ -65,7 +81,7 @@ func (m *MoveMessage) Marshal(buf []byte) []byte { _boolByte4 |= 1 << 2 } buf = append(buf, _boolByte4) - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Name))) + buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Name), "string length for Name")) buf = append(buf, m.Name...) return buf } @@ -130,12 +146,12 @@ func (m *SpawnMessage) Marshal(buf []byte) []byte { buf = binary.LittleEndian.AppendUint64(buf, m.EntityID) buf = m.Position.Marshal(buf) buf = binary.LittleEndian.AppendUint16(buf, uint16(m.Health)) - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Tags))) + buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Tags), "slice length for Tags")) for _iTags := range m.Tags { - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Tags[_iTags]))) + buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Tags[_iTags]), "string length for Tags[_iTags]")) buf = append(buf, m.Tags[_iTags]...) } - buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Data))) + buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Data), "slice length for Data")) for _iData := range m.Data { buf = append(buf, m.Data[_iData]) } diff --git a/benchmarks/unity/Assets/Benchmarks/Messages.gen.cs b/benchmarks/unity/Assets/Benchmarks/Messages.gen.cs index 7fbdfad..1368e0a 100644 --- a/benchmarks/unity/Assets/Benchmarks/Messages.gen.cs +++ b/benchmarks/unity/Assets/Benchmarks/Messages.gen.cs @@ -7,6 +7,26 @@ using System.Text; namespace Arpack.Messages { + internal static class ArpackGenerated + { + internal static ushort EnsureU16Length(int length, string context) + { + if (length > 65535) + { + throw new InvalidOperationException("arpack: " + context + " exceeds uint16 limit"); + } + return (ushort)length; + } + + internal static void EnsureQuantizedRange(double value, double min, double max, string context) + { + if (double.IsNaN(value) || value < min || value > max) + { + throw new ArgumentOutOfRangeException(context, "arpack: quantized value out of range for " + context); + } + } + } + public enum Opcode : ushort { Unknown = 0, @@ -23,8 +43,11 @@ namespace Arpack.Messages public int Serialize(byte* buffer) { byte* ptr = buffer; + ArpackGenerated.EnsureQuantizedRange(X, -500, 500, "X"); *(ushort*)ptr = (ushort)((X - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2; + ArpackGenerated.EnsureQuantizedRange(Y, -500, 500, "Y"); *(ushort*)ptr = (ushort)((Y - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2; + ArpackGenerated.EnsureQuantizedRange(Z, -500, 500, "Z"); *(ushort*)ptr = (ushort)((Z - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2; return (int)(ptr - buffer); } @@ -59,7 +82,7 @@ namespace Arpack.Messages { *(float*)ptr = Velocity[_iVelocity]; ptr += 4; } - *(ushort*)ptr = (ushort)(Waypoints?.Length ?? 0); ptr += 2; + ushort _lenWaypoints = ArpackGenerated.EnsureU16Length(Waypoints?.Length ?? 0, "slice length for Waypoints"); *(ushort*)ptr = _lenWaypoints; ptr += 2; if (Waypoints != null) { for (int _iWaypoints = 0; _iWaypoints < Waypoints.Length; _iWaypoints++) @@ -74,7 +97,7 @@ namespace Arpack.Messages if (Ghost) _boolByte4 |= (byte)(1 << 2); *ptr = _boolByte4; ptr++; int _slenName = Name != null ? Encoding.UTF8.GetByteCount(Name) : 0; - *(ushort*)ptr = (ushort)_slenName; ptr += 2; + *(ushort*)ptr = ArpackGenerated.EnsureU16Length(_slenName, "string length for Name"); ptr += 2; if (Name != null && _slenName > 0) { fixed (char* _charsName = Name) @@ -128,13 +151,13 @@ namespace Arpack.Messages *(ulong*)ptr = EntityID; ptr += 8; ptr += Position.Serialize(ptr); *(short*)ptr = Health; ptr += 2; - *(ushort*)ptr = (ushort)(Tags?.Length ?? 0); ptr += 2; + ushort _lenTags = ArpackGenerated.EnsureU16Length(Tags?.Length ?? 0, "slice length for Tags"); *(ushort*)ptr = _lenTags; ptr += 2; if (Tags != null) { for (int _iTags = 0; _iTags < Tags.Length; _iTags++) { int _slenTags__iTags_ = Tags[_iTags] != null ? Encoding.UTF8.GetByteCount(Tags[_iTags]) : 0; - *(ushort*)ptr = (ushort)_slenTags__iTags_; ptr += 2; + *(ushort*)ptr = ArpackGenerated.EnsureU16Length(_slenTags__iTags_, "string length for Tags[_iTags]"); ptr += 2; if (Tags[_iTags] != null && _slenTags__iTags_ > 0) { fixed (char* _charsTags__iTags_ = Tags[_iTags]) @@ -145,7 +168,7 @@ namespace Arpack.Messages ptr += _slenTags__iTags_; } } - *(ushort*)ptr = (ushort)(Data?.Length ?? 0); ptr += 2; + ushort _lenData = ArpackGenerated.EnsureU16Length(Data?.Length ?? 0, "slice length for Data"); *(ushort*)ptr = _lenData; ptr += 2; if (Data != null) { for (int _iData = 0; _iData < Data.Length; _iData++) diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 9edf76b..5308587 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -47,6 +47,16 @@ func TestE2E_CrossLanguage(t *testing.T) { } csDir := buildCSHarness(t, csSrc) + for _, tc := range cases { + t.Run("Wire/Go_EQ_CS/"+tc.name, func(t *testing.T) { + goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, "")) + csHex := strings.TrimSpace(runHarness(t, csDir, "cs", "ser", tc.typ, "")) + if goHex != csHex { + t.Fatalf("wire drift between Go and C# for %s:\ngo=%s\ncs=%s", tc.typ, goHex, csHex) + } + }) + } + for _, tc := range cases { t.Run("Go_to_CS/"+tc.name, func(t *testing.T) { hex := runHarness(t, goDir, "go", "ser", tc.typ, "") @@ -70,13 +80,15 @@ func TestE2E_CrossLanguage(t *testing.T) { } tsDir := buildTSHarness(t, tsSrc) - t.Run("QuantizedWire/Go_EQ_TS/Vector3", func(t *testing.T) { - goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "Vector3", "")) - tsHex := strings.TrimSpace(runHarness(t, tsDir, "ts", "ser", "Vector3", "")) - if goHex != tsHex { - t.Fatalf("quantized wire drift between Go and TS for Vector3:\ngo=%s\nts=%s", goHex, tsHex) - } - }) + for _, tc := range cases { + t.Run("Wire/Go_EQ_TS/"+tc.name, func(t *testing.T) { + goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, "")) + tsHex := strings.TrimSpace(runHarness(t, tsDir, "ts", "ser", tc.typ, "")) + if goHex != tsHex { + t.Fatalf("wire drift between Go and TS for %s:\ngo=%s\nts=%s", tc.typ, goHex, tsHex) + } + }) + } for _, tc := range cases { t.Run("Go_to_TS/"+tc.name, func(t *testing.T) { @@ -146,14 +158,6 @@ func TestE2E_CrossLanguage(t *testing.T) { } luaDir := buildLuaHarness(t, luaSrc) - t.Run("QuantizedWire/Go_EQ_Lua/Vector3", func(t *testing.T) { - goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "Vector3", "")) - luaHex := strings.TrimSpace(runHarness(t, luaDir, "lua", "ser", "Vector3", "")) - if goHex != luaHex { - t.Fatalf("quantized wire drift between Go and Lua for Vector3:\ngo=%s\nlua=%s", goHex, luaHex) - } - }) - luaCases := []struct { name string typ string @@ -164,6 +168,16 @@ func TestE2E_CrossLanguage(t *testing.T) { {"EnvelopeMessage", "EnvelopeMessage", 0}, } + for _, tc := range luaCases { + t.Run("Wire/Go_EQ_Lua/"+tc.name, func(t *testing.T) { + goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, "")) + luaHex := strings.TrimSpace(runHarness(t, luaDir, "lua", "ser", tc.typ, "")) + if goHex != luaHex { + t.Fatalf("wire drift between Go and Lua for %s:\ngo=%s\nlua=%s", tc.typ, goHex, luaHex) + } + }) + } + for _, tc := range luaCases { t.Run("Go_to_Lua/"+tc.name, func(t *testing.T) { hex := runHarness(t, goDir, "go", "ser", tc.typ, "") @@ -893,10 +907,10 @@ end local function serializeMoveMessage() local msg = messages.new_move_message() msg.position = messages.new_vector3() - msg.position.x = 10.0 - msg.position.y = 20.0 - msg.position.z = 30.0 - msg.velocity = {1.0, 2.0, 3.0} + msg.position.x = 50.0 + msg.position.y = -100.0 + msg.position.z = 0.0 + msg.velocity = {1.5, -2.5, 0.0} msg.waypoints = {} local wp = messages.new_vector3() wp.x = 10.0 diff --git a/generator/csharp.go b/generator/csharp.go index f655e7b..335d436 100644 --- a/generator/csharp.go +++ b/generator/csharp.go @@ -14,6 +14,8 @@ func GenerateCSharp(messages []parser.Message, namespace string) ([]byte, error) func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error) { messages := schema.Messages var b strings.Builder + needsLengthGuards := schemaNeedsLengthGuards(messages) + needsQuantGuards := schemaNeedsQuantRangeGuards(messages) b.WriteString("// arpack \n") b.WriteString("// Code generated by arpack. DO NOT EDIT.\n") @@ -27,6 +29,30 @@ func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error fmt.Fprintf(&b, "namespace %s\n{\n", namespace) + if needsLengthGuards || needsQuantGuards { + b.WriteString(" internal static class ArpackGenerated\n {\n") + if needsLengthGuards { + b.WriteString(" internal static ushort EnsureU16Length(int length, string context)\n") + b.WriteString(" {\n") + b.WriteString(" if (length > 65535)\n") + b.WriteString(" {\n") + b.WriteString(" throw new InvalidOperationException(\"arpack: \" + context + \" exceeds uint16 limit\");\n") + b.WriteString(" }\n") + b.WriteString(" return (ushort)length;\n") + b.WriteString(" }\n\n") + } + if needsQuantGuards { + b.WriteString(" internal static void EnsureQuantizedRange(double value, double min, double max, string context)\n") + b.WriteString(" {\n") + b.WriteString(" if (double.IsNaN(value) || value < min || value > max)\n") + b.WriteString(" {\n") + b.WriteString(" throw new ArgumentOutOfRangeException(context, \"arpack: quantized value out of range for \" + context);\n") + b.WriteString(" }\n") + b.WriteString(" }\n") + } + b.WriteString(" }\n\n") + } + enumNames := make(map[string]struct{}, len(schema.Enums)) for _, enum := range schema.Enums { enumNames[enum.Name] = struct{}{} @@ -151,7 +177,9 @@ func writeCSharpSerializeField(b *strings.Builder, f parser.Field, indent string } fmt.Fprintf(b, "%s}\n", indent) case parser.KindSlice: - fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)(%s?.Length ?? 0); ptr += 2;\n", indent, f.Name) + lenVar := "_len" + sanitizeVarName(f.Name) + fmt.Fprintf(b, "%sushort %s = ArpackGenerated.EnsureU16Length(%s?.Length ?? 0, %q); *(ushort*)ptr = %s; ptr += 2;\n", + indent, lenVar, f.Name, lengthContext(f), lenVar) fmt.Fprintf(b, "%sif (%s != null)\n%s{\n", indent, f.Name, indent) iVar := "_i" + f.Name fmt.Fprintf(b, "%s for (int %s = 0; %s < %s.Length; %s++)\n%s {\n", @@ -212,7 +240,8 @@ func writeCSharpSerializePrimitive( lenVar := "_slen" + sanitizeVarName(access) fmt.Fprintf(b, "%sint %s = %s != null ? Encoding.UTF8.GetByteCount(%s) : 0;\n", indent, lenVar, valueExpr, valueExpr) - fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)%s; ptr += 2;\n", indent, lenVar) + fmt.Fprintf(b, "%s*(ushort*)ptr = ArpackGenerated.EnsureU16Length(%s, %q); ptr += 2;\n", + indent, lenVar, lengthContext(f)) fmt.Fprintf(b, "%sif (%s != null && %s > 0)\n%s{\n", indent, valueExpr, lenVar, indent) fmt.Fprintf(b, "%s fixed (char* _chars%s = %s)\n%s {\n", indent, sanitizeVarName(access), valueExpr, indent) @@ -228,6 +257,8 @@ func writeCSharpSerializePrimitive( func writeCSharpSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error { q := f.Quant maxUint := q.MaxUint() + fmt.Fprintf(b, "%sArpackGenerated.EnsureQuantizedRange(%s, %g, %g, %q);\n", + indent, access, q.Min, q.Max, quantContext(f)) if q.Bits == 8 { fmt.Fprintf(b, "%s*ptr = (byte)((%s - (%gf)) / (%gf - (%gf)) * %gf); ptr += 1;\n", indent, access, q.Min, q.Max, q.Min, maxUint) diff --git a/generator/generator_test.go b/generator/generator_test.go index d545af8..7f0f497 100644 --- a/generator/generator_test.go +++ b/generator/generator_test.go @@ -254,10 +254,109 @@ func TestGenerateCSharp_Output(t *testing.T) { if !strings.Contains(code, "public Opcode Code;") { t.Error("EnvelopeMessage.Code should use generated enum type") } + if !strings.Contains(code, "internal static class ArpackGenerated") { + t.Error("missing shared ArpackGenerated helper class") + } + if !strings.Contains(code, "EnsureU16Length") { + t.Error("missing uint16 length guard helper") + } + if !strings.Contains(code, "EnsureQuantizedRange") { + t.Error("missing quantized range guard helper") + } t.Logf("Generated C# (%d bytes):\n%s", len(src), code) } +func TestGenerateGo_RuntimeGuards(t *testing.T) { + schemaSrc := `package messages + +type Quantized struct { + Value float32 ` + "`" + `pack:"min=0,max=1,bits=8"` + "`" + ` +} + +type LengthLimited struct { + Name string + Items []uint8 +} +` + + schema, err := parser.ParseSchemaSource(schemaSrc) + if err != nil { + t.Fatalf("ParseSchemaSource: %v", err) + } + + src, err := GenerateGoSchema(schema, "messages") + if err != nil { + t.Fatalf("GenerateGoSchema: %v", err) + } + + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "messages.go"), []byte(schemaSrc), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(dir, "messages_arpack.go"), src, 0644); err != nil { + t.Fatal(err) + } + + runtimeTests := `package messages + +import ( + "strings" + "testing" +) + +func expectPanicContaining(t *testing.T, want string, fn func()) { + t.Helper() + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected panic containing %q, got nil", want) + } + if !strings.Contains(r.(string), want) { + t.Fatalf("expected panic containing %q, got %v", want, r) + } + }() + fn() +} + +func TestLengthGuard_String(t *testing.T) { + expectPanicContaining(t, "string length for Name exceeds uint16 limit", func() { + msg := LengthLimited{Name: strings.Repeat("a", 65536)} + _ = msg.Marshal(nil) + }) +} + +func TestLengthGuard_Slice(t *testing.T) { + expectPanicContaining(t, "slice length for Items exceeds uint16 limit", func() { + msg := LengthLimited{Items: make([]uint8, 65536)} + _ = msg.Marshal(nil) + }) +} + +func TestQuantizedRangeGuard(t *testing.T) { + expectPanicContaining(t, "quantized value out of range for Value", func() { + msg := Quantized{Value: 1.5} + _ = msg.Marshal(nil) + }) +} +` + if err := os.WriteFile(filepath.Join(dir, "guards_test.go"), []byte(runtimeTests), 0644); err != nil { + t.Fatal(err) + } + + goMod := "module messages\n\ngo 1.21\n" + if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(goMod), 0644); err != nil { + t.Fatal(err) + } + + cmd := exec.Command("go", "test", "./...") + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go test failed:\n%s", out) + } +} + func TestBoolPacking_GoCode(t *testing.T) { msgs, err := parser.ParseFile(samplePath) if err != nil { diff --git a/generator/go.go b/generator/go.go index 1e8c878..a046abe 100644 --- a/generator/go.go +++ b/generator/go.go @@ -15,6 +15,8 @@ func GenerateGo(messages []parser.Message, pkgName string) ([]byte, error) { func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) { messages := schema.Messages var b strings.Builder + needsLengthGuards := schemaNeedsLengthGuards(messages) + needsQuantGuards := schemaNeedsQuantRangeGuards(messages) b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n") fmt.Fprintf(&b, "package %s\n\n", pkgName) @@ -27,6 +29,23 @@ func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) { } b.WriteString(")\n\n") + if needsLengthGuards { + b.WriteString("func arpackEnsureUint16Length(length int, context string) uint16 {\n") + b.WriteString("\tif length > 65535 {\n") + b.WriteString("\t\tpanic(\"arpack: \" + context + \" exceeds uint16 limit\")\n") + b.WriteString("\t}\n") + b.WriteString("\treturn uint16(length)\n") + b.WriteString("}\n\n") + } + + if needsQuantGuards { + b.WriteString("func arpackEnsureQuantizedRange(value float64, min float64, max float64, context string) {\n") + b.WriteString("\tif value != value || value < min || value > max {\n") + b.WriteString("\t\tpanic(\"arpack: quantized value out of range for \" + context)\n") + b.WriteString("\t}\n") + b.WriteString("}\n\n") + } + for _, msg := range messages { if err := writeGoMessage(&b, msg); err != nil { return nil, fmt.Errorf("message %s: %w", msg.Name, err) @@ -120,7 +139,8 @@ func writeGoMarshalField(b *strings.Builder, recv string, f parser.Field, indent } fmt.Fprintf(b, "%s}\n", indent) case parser.KindSlice: - fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, access) + fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n", + indent, access, lengthContext(f)) fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access) elemField := parser.Field{ Name: f.Name + "[_i" + f.Name + "]", @@ -169,7 +189,8 @@ func writeGoMarshalPrimitive(b *strings.Builder, access string, f parser.Field, case parser.KindUint64: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr) case parser.KindString: - fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, valueExpr) + fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n", + indent, valueExpr, lengthContext(f)) fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr) } return nil @@ -179,6 +200,8 @@ func writeGoMarshalQuant(b *strings.Builder, access string, f parser.Field, inde q := f.Quant varName := "_q" + sanitizeVarName(access) valueExpr := goMarshalValueExpr(access, f) + fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(float64(%s), %g, %g, %q)\n", + indent, valueExpr, q.Min, q.Max, quantContext(f)) if q.Bits == 8 { fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n", indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint()) diff --git a/generator/lua.go b/generator/lua.go index 2cca660..34777ad 100644 --- a/generator/lua.go +++ b/generator/lua.go @@ -76,6 +76,13 @@ func writeLuaHelpers(b *strings.Builder) { b.WriteString(" return n\n") b.WriteString("end\n\n") + b.WriteString("local function ensure_quant_range(value, min, max, context)\n") + b.WriteString(" if value ~= value or value < min or value > max then\n") + b.WriteString(" error(string.format(\"arpack: quantized value out of range for %s\", context))\n") + b.WriteString(" end\n") + b.WriteString(" return value\n") + b.WriteString("end\n\n") + b.WriteString("local function read_u8(data, offset)\n") b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n") b.WriteString(" return string.byte(data, offset), 1\n") @@ -506,8 +513,11 @@ func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, i q := f.Quant maxUint := q.MaxUint() varName := "_q_" + sanitizeLuaVarName(access) + valueVar := "_quant_value_" + sanitizeLuaVarName(access) + fmt.Fprintf(b, "%slocal %s = ensure_quant_range(%s, %g, %g, %q)\n", + indent, valueVar, access, q.Min, q.Max, quantContext(f)) fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g)\n", - indent, varName, access, q.Min, q.Max, q.Min, maxUint) + indent, varName, valueVar, q.Min, q.Max, q.Min, maxUint) if q.Bits == 8 { fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName) } else { diff --git a/generator/lua_test.go b/generator/lua_test.go index 312bcd5..5b1758d 100644 --- a/generator/lua_test.go +++ b/generator/lua_test.go @@ -281,9 +281,12 @@ func TestGenerateLua_QuantizedFloat(t *testing.T) { luaStr := string(lua) - if !strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535)") { + if !strings.Contains(luaStr, "math.floor(((_quant_value_msg_position - (-500)) / (500 - (-500))) * 65535)") { t.Error("Missing truncating quantization code for Lua") } + if !strings.Contains(luaStr, `ensure_quant_range(msg.position, -500, 500, "Position")`) { + t.Error("Missing quantized range guard for Lua") + } if strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535 + 0.5)") { t.Error("Lua quantization should not round to nearest") } @@ -338,6 +341,7 @@ func TestLuaHelpersGenerated(t *testing.T) { "buffer too short for u8", "buffer too short for bool", "local function ensure_u16_limit(n, context)", + "local function ensure_quant_range(value, min, max, context)", "local function write_u8(n)", "buffer too short for u16", "local function write_u16_le(n)", @@ -703,3 +707,61 @@ print(bytes_to_hex(messages.serialize_float_edges(msg))) t.Fatalf("subnormal roundtrip mismatch: %s", lines[1]) } } + +func TestGenerateLua_RuntimeQuantizedRangeGuard(t *testing.T) { + if _, err := exec.LookPath("luajit"); err != nil { + t.Skip("luajit not found") + } + + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithQuantized", + Fields: []parser.Field{ + { + Name: "Position", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "messages") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "messages_gen.lua"), lua, 0o600); err != nil { + t.Fatalf("write module: %v", err) + } + + script := `local messages = require("messages_gen") +local msg = messages.new_with_quantized() +msg.position = 501 +local ok, res = pcall(messages.serialize_with_quantized, msg) +if ok then + print("OK") +else + print(res) +end +` + if err := os.WriteFile(filepath.Join(dir, "check.lua"), []byte(script), 0o600); err != nil { + t.Fatalf("write script: %v", err) + } + + cmd := exec.Command("luajit", "check.lua") + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("luajit failed: %v\n%s", err, out) + } + + got := strings.TrimSpace(string(out)) + if !strings.Contains(got, "quantized value out of range for Position") { + t.Fatalf("expected quantized range guard, got %q", got) + } +} diff --git a/generator/policy.go b/generator/policy.go new file mode 100644 index 0000000..d3a8561 --- /dev/null +++ b/generator/policy.go @@ -0,0 +1,78 @@ +package generator + +import "github.com/edmand46/arpack/parser" + +const maxUint16Len = 65535 + +func lengthContext(f parser.Field) string { + switch { + case f.Kind == parser.KindSlice: + if f.Name != "" { + return "slice length for " + f.Name + } + return "slice length" + case f.Kind == parser.KindPrimitive && f.Primitive == parser.KindString: + if f.Name != "" { + return "string length for " + f.Name + } + return "string length" + default: + return "length" + } +} + +func quantContext(f parser.Field) string { + if f.Name != "" { + return f.Name + } + return "value" +} + +func schemaNeedsLengthGuards(messages []parser.Message) bool { + for _, msg := range messages { + for _, f := range msg.Fields { + if fieldNeedsLengthGuard(f) { + return true + } + } + } + return false +} + +func fieldNeedsLengthGuard(f parser.Field) bool { + switch f.Kind { + case parser.KindPrimitive: + return f.Primitive == parser.KindString + case parser.KindFixedArray, parser.KindSlice: + if f.Kind == parser.KindSlice { + return true + } + if f.Elem != nil { + return fieldNeedsLengthGuard(*f.Elem) + } + } + return false +} + +func schemaNeedsQuantRangeGuards(messages []parser.Message) bool { + for _, msg := range messages { + for _, f := range msg.Fields { + if fieldNeedsQuantRangeGuard(f) { + return true + } + } + } + return false +} + +func fieldNeedsQuantRangeGuard(f parser.Field) bool { + switch f.Kind { + case parser.KindPrimitive: + return f.Quant != nil + case parser.KindFixedArray, parser.KindSlice: + if f.Elem != nil { + return fieldNeedsQuantRangeGuard(*f.Elem) + } + } + return false +} diff --git a/generator/ts.go b/generator/ts.go index 5ca32f3..e1035f1 100644 --- a/generator/ts.go +++ b/generator/ts.go @@ -16,10 +16,31 @@ func GenerateTypeScript(messages []parser.Message, namespace string) ([]byte, er func GenerateTypeScriptSchema(schema parser.Schema, namespace string) ([]byte, error) { messages := schema.Messages var b strings.Builder + needsLengthGuards := schemaNeedsLengthGuards(messages) + needsQuantGuards := schemaNeedsQuantRangeGuards(messages) b.WriteString("// arpack \n") b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n") + if needsLengthGuards { + b.WriteString("const arpackTextEncoder = new TextEncoder();\n") + b.WriteString("const arpackTextDecoder = new TextDecoder();\n\n") + b.WriteString("function arpackEnsureUint16Length(length: number, context: string): number {\n") + b.WriteString(" if (length > 65535) {\n") + b.WriteString(" throw new RangeError(\"arpack: \" + context + \" exceeds uint16 limit\");\n") + b.WriteString(" }\n") + b.WriteString(" return length;\n") + b.WriteString("}\n\n") + } + + if needsQuantGuards { + b.WriteString("function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void {\n") + b.WriteString(" if (Number.isNaN(value) || value < min || value > max) {\n") + b.WriteString(" throw new RangeError(\"arpack: quantized value out of range for \" + context);\n") + b.WriteString(" }\n") + b.WriteString("}\n\n") + } + enumNames := make(map[string]struct{}, len(schema.Enums)) for _, enum := range schema.Enums { enumNames[enum.Name] = struct{}{} @@ -148,7 +169,10 @@ func writeTSSerializeField(b *strings.Builder, recv string, f parser.Field, inde } fmt.Fprintf(b, "%s}\n", indent) case parser.KindSlice: - fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, access) + lenVar := "_len" + sanitizeVarName(access) + fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n", + indent, lenVar, access, lengthContext(f)) + fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, lenVar) fmt.Fprintf(b, "%spos += 2;\n", indent) iVar := "_i" + f.Name fmt.Fprintf(b, "%sfor (const %s of %s) {\n", indent, iVar, access) @@ -222,8 +246,11 @@ func writeTSSerializePrimitiveElement(b *strings.Builder, access string, f parse fmt.Fprintf(b, "%spos += 8;\n", indent) case parser.KindString: lenVar := "_slen" + sanitizeVarName(access) - fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr) - fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar) + guardVar := "_slenChecked" + sanitizeVarName(access) + fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr) + fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n", + indent, guardVar, lenVar, lengthContext(f)) + fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar) fmt.Fprintf(b, "%spos += 2;\n", indent) fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar) fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar) @@ -273,8 +300,11 @@ func writeTSSerializePrimitive(b *strings.Builder, access string, f parser.Field fmt.Fprintf(b, "%spos += 8;\n", indent) case parser.KindString: lenVar := "_slen" + sanitizeVarName(access) - fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr) - fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar) + guardVar := "_slenChecked" + sanitizeVarName(access) + fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr) + fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n", + indent, guardVar, lenVar, lengthContext(f)) + fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar) fmt.Fprintf(b, "%spos += 2;\n", indent) fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar) fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar) @@ -286,6 +316,8 @@ func writeTSSerializeQuant(b *strings.Builder, access string, f parser.Field, in q := f.Quant maxUint := q.MaxUint() varName := "_q" + sanitizeVarName(access) + fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n", + indent, access, q.Min, q.Max, quantContext(f)) if q.Bits == 8 { fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) @@ -304,6 +336,8 @@ func writeTSSerializeQuantElement(b *strings.Builder, access string, f parser.Fi q := f.Quant maxUint := q.MaxUint() varName := "_q" + sanitizeVarName(access) + fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n", + indent, access, q.Min, q.Max, quantContext(f)) if q.Bits == 8 { fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) @@ -438,7 +472,7 @@ func writeTSDeserializePrimitiveElement(b *strings.Builder, access string, f par lenVar := "_slen" + sanitizeVarName(access) fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar) fmt.Fprintf(b, "%spos += 2;\n", indent) - expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar) + expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar) fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames)) fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar) } @@ -499,7 +533,7 @@ func writeTSDeserializePrimitive(b *strings.Builder, access string, f parser.Fie lenVar := "_slen" + sanitizeVarName(access) fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar) fmt.Fprintf(b, "%spos += 2;\n", indent) - expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar) + expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar) fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames)) fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar) } diff --git a/generator/ts_test.go b/generator/ts_test.go index 366d5d0..2a82b5c 100644 --- a/generator/ts_test.go +++ b/generator/ts_test.go @@ -101,11 +101,17 @@ func TestGenerateTypeScript_QuantizedFloats(t *testing.T) { if !strings.Contains(code, "Math.trunc((this.q8 - (0)) / (100 - (0)) * 255)") { t.Error("Missing 8-bit quantization code") } + if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q8, 0, 100, "Q8");`) { + t.Error("Missing 8-bit quantized range guard") + } // Check 16-bit quantization (using camelCase field names) if !strings.Contains(code, "Math.trunc((this.q16 - (-500)) / (500 - (-500)) * 65535)") { t.Error("Missing 16-bit quantization code") } + if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q16, -500, 500, "Q16");`) { + t.Error("Missing 16-bit quantized range guard") + } // Check deserialization with dequantization if !strings.Contains(code, "/ 255 * (100 - (0)) + (0)") { @@ -286,8 +292,11 @@ func TestGenerateTypeScript_Slices(t *testing.T) { } // Check length prefix in serialize (using camelCase field name) - if !strings.Contains(code, "view.setUint16(pos, this.items.length, true);") { - t.Error("Missing slice length prefix in serialize") + if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) { + t.Error("Missing slice length guard in serialize") + } + if !strings.Contains(code, "view.setUint16(pos, _lenthis_items, true);") { + t.Error("Missing guarded slice length prefix in serialize") } // Check length reading in deserialize @@ -385,17 +394,74 @@ func TestGenerateTypeScript_Strings(t *testing.T) { code := string(src) // Check TextEncoder usage - if !strings.Contains(code, "new TextEncoder().encode(") { - t.Error("Missing TextEncoder in serialize") + if !strings.Contains(code, "const arpackTextEncoder = new TextEncoder();") { + t.Error("Missing shared TextEncoder helper") } // Check length prefix if !strings.Contains(code, "view.setUint16(pos, _slen") { t.Error("Missing string length prefix in serialize") } + if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) { + t.Error("Missing string length guard in serialize") + } // Check TextDecoder usage - if !strings.Contains(code, "new TextDecoder().decode(") { - t.Error("Missing TextDecoder in deserialize") + if !strings.Contains(code, "const arpackTextDecoder = new TextDecoder();") { + t.Error("Missing shared TextDecoder helper") + } + if !strings.Contains(code, "arpackTextDecoder.decode(") { + t.Error("Missing shared TextDecoder in deserialize") + } +} + +func TestGenerateTypeScript_LengthAndRangeHelpers(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + PackageName: "test", + Name: "LengthAndQuant", + Fields: []parser.Field{ + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + { + Name: "Items", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint8, + }, + }, + { + Name: "Ratio", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: 0, Max: 1, Bits: 8}, + }, + }, + }, + }, + } + + src, err := GenerateTypeScriptSchema(schema, "Test") + if err != nil { + t.Fatalf("GenerateTypeScriptSchema: %v", err) + } + + code := string(src) + + if !strings.Contains(code, "function arpackEnsureUint16Length(length: number, context: string): number") { + t.Error("Missing uint16 length helper") + } + if !strings.Contains(code, "function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void") { + t.Error("Missing quantized range helper") + } + if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) { + t.Error("Missing slice length guard") + } + if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) { + t.Error("Missing string length helper call") + } + if !strings.Contains(code, `arpackEnsureQuantizedRange(this.ratio, 0, 1, "Ratio");`) { + t.Error("Missing quantized range helper call") } } diff --git a/parser/parser.go b/parser/parser.go index fd1aada..4b6803a 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -55,6 +55,8 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) { knownStructs := map[string]bool{} namedPrimitives := map[string]PrimitiveKind{} + unsupportedNamedPrimitives := map[string]string{} + unresolvedNamedPrimitives := map[string]string{} var enumOrder []string for _, decl := range f.Decls { @@ -73,8 +75,14 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) { case *ast.StructType: knownStructs[typeSpec.Name.Name] = true case *ast.Ident: + switch t.Name { + case "int", "uint", "uintptr": + unsupportedNamedPrimitives[typeSpec.Name.Name] = t.Name + continue + } primKind, isPrimitive := goPrimitiveKind(t.Name) if !isPrimitive { + unresolvedNamedPrimitives[typeSpec.Name.Name] = t.Name continue } namedPrimitives[typeSpec.Name.Name] = primKind @@ -85,6 +93,19 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) { } } + for changed := true; changed; { + changed = false + for name, target := range unresolvedNamedPrimitives { + if _, ok := unsupportedNamedPrimitives[name]; ok { + continue + } + if baseType, ok := unsupportedNamedPrimitives[target]; ok { + unsupportedNamedPrimitives[name] = baseType + changed = true + } + } + } + info, err := typeCheckFile(fset, f) if err != nil { return Schema{}, err @@ -119,7 +140,14 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) { continue } - msg, err := parseStruct(pkgName, typeSpec.Name.Name, structType, knownStructs, namedPrimitives) + msg, err := parseStruct( + pkgName, + typeSpec.Name.Name, + structType, + knownStructs, + namedPrimitives, + unsupportedNamedPrimitives, + ) if err != nil { return Schema{}, fmt.Errorf("struct %s: %w", typeSpec.Name.Name, err) } @@ -187,6 +215,7 @@ func parseStruct( st *ast.StructType, knownStructs map[string]bool, namedPrimitives map[string]PrimitiveKind, + unsupportedNamedPrimitives map[string]string, ) (Message, error) { msg := Message{PackageName: pkg, Name: name} @@ -202,7 +231,14 @@ func parseStruct( } for _, fieldName := range astField.Names { - field, err := parseFieldType(fieldName.Name, astField.Type, rawTag, knownStructs, namedPrimitives) + field, err := parseFieldType( + fieldName.Name, + astField.Type, + rawTag, + knownStructs, + namedPrimitives, + unsupportedNamedPrimitives, + ) if err != nil { return Message{}, fmt.Errorf("field %s: %w", fieldName.Name, err) } @@ -219,14 +255,22 @@ func parseFieldType( rawTag string, knownStructs map[string]bool, namedPrimitives map[string]PrimitiveKind, + unsupportedNamedPrimitives map[string]string, ) (Field, error) { switch t := expr.(type) { case *ast.Ident: - return parsePrimitiveOrNested(name, t.Name, rawTag, knownStructs, namedPrimitives) + return parsePrimitiveOrNested( + name, + t.Name, + rawTag, + knownStructs, + namedPrimitives, + unsupportedNamedPrimitives, + ) case *ast.ArrayType: if t.Len == nil { - elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives) + elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives, unsupportedNamedPrimitives) if err != nil { return Field{}, fmt.Errorf("slice element: %w", err) } @@ -243,7 +287,7 @@ func parseFieldType( return Field{}, fmt.Errorf("array length: %w", err) } - elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives) + elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives, unsupportedNamedPrimitives) if err != nil { return Field{}, fmt.Errorf("array element: %w", err) } @@ -271,9 +315,25 @@ func parsePrimitiveOrNested( rawTag string, knownStructs map[string]bool, namedPrimitives map[string]PrimitiveKind, + unsupportedNamedPrimitives map[string]string, ) (Field, error) { + switch typeName { + case "int", "uint", "uintptr": + return Field{}, fmt.Errorf( + "platform-dependent type %q is not supported; use int32/int64, uint32/uint64, or fixed-size integer IDs instead", + typeName, + ) + } + primKind, isPrimitive := goPrimitiveKind(typeName) if !isPrimitive { + if baseType, ok := unsupportedNamedPrimitives[typeName]; ok { + return Field{}, fmt.Errorf( + "type %q aliases unsupported platform-dependent %q; use int32/int64, uint32/uint64, or fixed-size integer IDs instead", + typeName, + baseType, + ) + } if namedPrimitive, ok := namedPrimitives[typeName]; ok { return buildPrimitiveField(name, typeName, namedPrimitive, rawTag) } @@ -385,7 +445,7 @@ func goPrimitiveKind(name string) (PrimitiveKind, bool) { return KindInt8, true case "int16": return KindInt16, true - case "int32", "int": + case "int32": return KindInt32, true case "int64": return KindInt64, true @@ -393,7 +453,7 @@ func goPrimitiveKind(name string) (PrimitiveKind, bool) { return KindUint8, true case "uint16": return KindUint16, true - case "uint32", "uint": + case "uint32": return KindUint32, true case "uint64": return KindUint64, true diff --git a/parser/parser_test.go b/parser/parser_test.go index 4e46e45..de65a25 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,6 +1,7 @@ package parser import ( + "strings" "testing" ) @@ -266,3 +267,92 @@ type Msg struct { t.Fatal("expected error for unknown nested type, got nil") } } + +func TestUnsupportedPlatformDependentIntTypes(t *testing.T) { + cases := []struct { + name string + src string + wantErr string + }{ + { + name: "direct int field", + src: `package p +type Msg struct { + X int +} +`, + wantErr: `platform-dependent type "int" is not supported`, + }, + { + name: "direct uint field", + src: `package p +type Msg struct { + X uint +} +`, + wantErr: `platform-dependent type "uint" is not supported`, + }, + { + name: "direct uintptr field", + src: `package p +type Msg struct { + X uintptr +} +`, + wantErr: `platform-dependent type "uintptr" is not supported`, + }, + { + name: "alias of int", + src: `package p +type Counter int +type Msg struct { + X Counter +} +`, + wantErr: `type "Counter" aliases unsupported platform-dependent "int"`, + }, + { + name: "alias of uint", + src: `package p +type Counter uint +type Msg struct { + X Counter +} +`, + wantErr: `type "Counter" aliases unsupported platform-dependent "uint"`, + }, + { + name: "alias of uintptr", + src: `package p +type Handle uintptr +type Msg struct { + X Handle +} +`, + wantErr: `type "Handle" aliases unsupported platform-dependent "uintptr"`, + }, + { + name: "transitive alias of int", + src: `package p +type Base int +type Counter Base +type Msg struct { + X Counter +} +`, + wantErr: `type "Counter" aliases unsupported platform-dependent "int"`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + _, err := ParseSource(tc.src) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("expected error containing %q, got %v", tc.wantErr, err) + } + }) + } +}