From cebe84bce1c7db60a48285502b0bcedf51833a48 Mon Sep 17 00:00:00 2001 From: edmand46 Date: Wed, 25 Mar 2026 19:20:25 +0300 Subject: [PATCH] feat: C --- .github/workflows/tests.yml | 20 +- README.md | 53 +- cmd/arpack/main.go | 30 +- e2e/e2e_test.go | 213 +++++++ generator/c.go | 1046 +++++++++++++++++++++++++++++++++++ generator/c_test.go | 802 +++++++++++++++++++++++++++ generator/lua.go | 11 +- generator/lua_test.go | 152 ++++- generator/ts.go | 8 +- generator/ts_test.go | 4 +- 10 files changed, 2323 insertions(+), 16 deletions(-) create mode 100644 generator/c.go create mode 100644 generator/c_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5f704dc..2001720 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,12 @@ jobs: - name: Run unit tests run: go test -v ./parser/... ./generator/... + - name: Install C compiler + run: sudo apt-get update && sudo apt-get install -y gcc + + - name: Run C generator tests + run: go test -v -run C ./generator/... + - name: Run benchmarks (short) run: go test -bench=. -benchtime=100ms -run=^$ ./benchmarks/... @@ -67,11 +73,20 @@ jobs: - name: Build arpack CLI run: go build -v ./cmd/arpack + - name: Install C compiler + run: sudo apt-get update && sudo apt-get install -y gcc + - name: Test code generation run: | - go run ./cmd/arpack -in testdata/sample.go -out-go /tmp/gen-go -out-ts /tmp/gen-ts + go run ./cmd/arpack -in testdata/sample.go -out-go /tmp/gen-go -out-ts /tmp/gen-ts -out-c /tmp/gen-c test -f /tmp/gen-go/sample_gen.go test -f /tmp/gen-ts/Sample.gen.ts + test -f /tmp/gen-c/sample.gen.h + test -f /tmp/gen-c/sample.gen.c + + - name: Compile generated C code + run: | + cc -std=c11 -Wall -Wextra -Wno-unused-function -c /tmp/gen-c/sample.gen.c -o /tmp/gen-c/sample.gen.o e2e: runs-on: ubuntu-latest @@ -93,5 +108,8 @@ jobs: with: node-version: '20' + - name: Install C compiler + run: sudo apt-get update && sudo apt-get install -y gcc + - name: Run E2E tests run: go test -v ./e2e/... diff --git a/README.md b/README.md index ea8d0f0..6828473 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,16 @@ ![GitHub License](https://img.shields.io/github/license/edmand46/arpack) -Binary serialization code generator for Go, C#, TypeScript, and Lua. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, `DataView`-based serialization for TypeScript/browser, and pure Lua implementation for Defold/LuaJIT. +Binary serialization code generator for Go, C#, TypeScript, Lua, and C. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, `DataView`-based serialization for TypeScript/browser, pure Lua implementation for Defold/LuaJIT, and explicit encode/decode functions for C. ## Features -- **Single source of truth** — define messages in Go, generate code for Go, C#, TypeScript, and Lua +- **Single source of truth** — define messages in Go, generate code for Go, C#, TypeScript, Lua, and C - **Float quantization** — compress `float32`/`float64` to 8 or 16 bits with a `pack` struct tag - **Boolean packing** — consecutive `bool` fields are packed into single bytes (up to 8 per byte) - **Enums** — `type Opcode uint16` + `const` block becomes C#/TypeScript enums - **Nested types, fixed arrays, slices** — full support for complex message structures -- **Cross-language binary compatibility** — Go, C#, TypeScript, and Lua produce identical wire formats +- **Cross-language binary compatibility** — Go, C#, TypeScript, Lua, and C produce identical wire formats - **Browser support** — TypeScript target uses native DataView API for zero-dependency serialization ## When to use @@ -32,6 +32,7 @@ Typical setups: - **Any Go service + .NET client** — works anywhere you control both ends and want a compact binary protocol without Protobuf's runtime overhead or code-gen complexity. - **Go backend + Browser/WebSocket** — generate TypeScript classes for browser-based clients. Uses native DataView API with zero dependencies. - **Go backend + Defold/Lua** — generate Lua modules for Defold game engine. Pure Lua implementation compatible with LuaJIT. +- **Go backend + Defold/C** — generate C code for Defold native extensions. Maximum performance for Defold games with C extensions. ## Installation @@ -50,6 +51,9 @@ arpack -in messages.go -out-ts ./web/src/messages # Generate only Lua (for Defold) arpack -in messages.go -out-lua ./defold/scripts/messages + +# Generate C for Defold native extension +arpack -in messages.go -out-c ./defold/extension/src ``` | Flag | Description | @@ -59,6 +63,7 @@ arpack -in messages.go -out-lua ./defold/scripts/messages | `-out-cs` | Output directory for generated C# code | | `-out-ts` | Output directory for generated TypeScript code | | `-out-lua` | Output directory for generated Lua code | +| `-out-c` | Output directory for generated C code (for Defold native extensions) | | `-cs-namespace` | C# namespace (default: `Arpack.Messages`) | **Output files:** @@ -66,6 +71,7 @@ arpack -in messages.go -out-lua ./defold/scripts/messages - C#: `{Name}.gen.cs` - TypeScript: `{Name}.gen.ts` - Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility) +- C: `{name}.gen.h` and `{name}.gen.c` (snake_case for C conventions) ## Schema Definition @@ -198,10 +204,49 @@ Uses pure Lua with inline helper functions for byte manipulation. Compatible wit **Requirements:** The generated Lua code requires the [BitOp library](https://bitop.luajit.org/) for bit manipulation. This library is included in LuaJIT (used by Defold). -**Limitations:** +**Limitations:** - Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53. +- Variable-length fields use `uint16` length prefixes, so `string` byte length and `[]T` element count must not exceed `65535`. Serialization raises an error if the limit is exceeded. +- Deserialization raises Lua errors on malformed or truncated input. If you need a recoverable boundary, wrap decode calls in `pcall(...)`. - Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution. +### C + +```c +#include "messages.gen.h" + +// Fixed-size message (no context needed) +sample_envelope_message msg = { + .code = sample_opcode_authorize, + .counter = 42 +}; + +uint8_t buf[64]; +size_t written; +arpack_status status = sample_envelope_message_encode(&msg, buf, sizeof(buf), &written); + +// Variable-length message (requires decode context) +sample_spawn_message_decode_ctx ctx = { + .tags_data = tags_buffer, + .tags_cap = MAX_TAGS +}; +sample_spawn_message decoded; +status = sample_spawn_message_decode(&decoded, buf, buf_len, &ctx, &read); +``` + +Generates two files: `{name}.gen.h` (declarations) and `{name}.gen.c` (implementations). Uses explicit encode/decode functions with bounds checking. All symbols are prefixed with `{name}_` to avoid collisions. + +**API Shape:** +- Fixed-size messages: `{name}_{msg}_min_size()`, `{name}_{msg}_encode()`, `{name}_{msg}_decode()` +- Variable-length messages: Additional `{name}_{msg}_size()` and decode context struct +- Strings and byte slices are views into the input buffer (zero-copy) +- Other slices require caller-provided storage via decode context + +**Limitations:** +- C11 standard required +- Variable-length slice fields require caller-provided storage (no hidden allocations) +- Wire format is not a packed C struct — use the generated encode/decode functions + ## Wire Format - Little-endian byte order diff --git a/cmd/arpack/main.go b/cmd/arpack/main.go index ff6cbbd..0f221e4 100644 --- a/cmd/arpack/main.go +++ b/cmd/arpack/main.go @@ -18,6 +18,7 @@ func main() { outCS := flag.String("out-cs", "", "output directory for generated C# code") outTS := flag.String("out-ts", "", "output directory for generated TypeScript code") outLua := flag.String("out-lua", "", "output directory for generated Lua code") + outC := flag.String("out-c", "", "output directory for generated C code") namespace := flag.String("cs-namespace", "Arpack.Messages", "C# namespace") flag.Parse() @@ -25,8 +26,8 @@ func main() { log.Fatal("arpack: -in is required") } - if *outGo == "" && *outCS == "" && *outTS == "" && *outLua == "" { - log.Fatal("arpack: at least one of -out-go, -out-cs, -out-ts, or -out-lua is required") + if *outGo == "" && *outCS == "" && *outTS == "" && *outLua == "" && *outC == "" { + log.Fatal("arpack: at least one of -out-go, -out-cs, -out-ts, -out-lua, or -out-c is required") } schema, err := parser.ParseSchemaFile(*in) @@ -115,6 +116,31 @@ func main() { fmt.Printf("arpack: wrote %s\n", outPath) } + + if *outC != "" { + snakeBase := toSnakeCase(baseName) + headerSrc, sourceSrc, err := generator.GenerateCSchema(schema, snakeBase) + if err != nil { + log.Fatalf("arpack: C generation error: %v", err) + } + + headerPath := filepath.Join(*outC, snakeBase+".gen.h") + sourcePath := filepath.Join(*outC, snakeBase+".gen.c") + + if err := os.MkdirAll(*outC, 0755); err != nil { + log.Fatalf("arpack: mkdir %s: %v", *outC, err) + } + + if err := os.WriteFile(headerPath, headerSrc, 0644); err != nil { + log.Fatalf("arpack: write %s: %v", headerPath, err) + } + if err := os.WriteFile(sourcePath, sourceSrc, 0644); err != nil { + log.Fatalf("arpack: write %s: %v", sourcePath, err) + } + + fmt.Printf("arpack: wrote %s\n", headerPath) + fmt.Printf("arpack: wrote %s\n", sourcePath) + } } func toTitle(s string) string { diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index ddfe9ba..a73270d 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -2,6 +2,8 @@ package e2e import ( "bytes" + "encoding/hex" + "fmt" "math" "os" "os/exec" @@ -17,6 +19,201 @@ import ( const samplePath = "../testdata/sample.go" // TestE2E_CrossLanguage +func TestE2E_C_GoInterop(t *testing.T) { + // Check for C compiler + var cc string + for _, compiler := range []string{"cc", "gcc", "clang"} { + if _, err := exec.LookPath(compiler); err == nil { + cc = compiler + break + } + } + if cc == "" { + t.Skip("No C compiler found (tried cc, gcc, clang)") + } + + schema, err := parser.ParseSchemaFile(samplePath) + if err != nil { + t.Fatalf("parse: %v", err) + } + + // Generate C code + header, source, err := generator.GenerateCSchema(schema, "sample") + if err != nil { + t.Fatalf("GenerateCSchema: %v", err) + } + + // Create temp directory for C harness + cDir := t.TempDir() + write(t, filepath.Join(cDir, "sample.gen.h"), header) + write(t, filepath.Join(cDir, "sample.gen.c"), source) + + // Generate test vectors using Go + goSrc, err := generator.GenerateGoSchema(schema, "main") + if err != nil { + t.Fatalf("GenerateGoSchema: %v", err) + } + goDir := buildGoHarness(t, goSrc) + + // Get hex from Go harness + vector3Hex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "Vector3", "")) + envelopeHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "EnvelopeMessage", "")) + + // Convert hex to C array format + vector3Bytes, _ := hex.DecodeString(vector3Hex) + envelopeBytes, _ := hex.DecodeString(envelopeHex) + + vector3Array := "{" + for i, b := range vector3Bytes { + if i > 0 { + vector3Array += ", " + } + vector3Array += fmt.Sprintf("0x%02x", b) + } + vector3Array += "}" + + envelopeArray := "{" + for i, b := range envelopeBytes { + if i > 0 { + envelopeArray += ", " + } + envelopeArray += fmt.Sprintf("0x%02x", b) + } + envelopeArray += "}" + + // Create C test program with correct test vectors + cTestSource := fmt.Sprintf(`#include +#include +#include "sample.gen.h" + +// Test vectors from Go serialization +static const uint8_t vector3_test[] = %s; +static const uint8_t envelope_test[] = %s; + +int main(int argc, char *argv[]) { + if (argc < 2) { + printf("Usage: %%s \n", argv[0]); + return 1; + } + + if (strcmp(argv[1], "vector3") == 0) { + sample_vector3 msg; + size_t read; + arpack_status status = sample_vector3_decode(&msg, vector3_test, sizeof(vector3_test), &read); + if (status != ARPACK_OK) { + printf("STATUS=FAIL\n"); + return 1; + } + printf("STATUS=OK\n"); + printf("X=%%.2f\n", msg.x); + printf("Y=%%.2f\n", msg.y); + printf("Z=%%.2f\n", msg.z); + printf("READ=%%zu\n", read); + return 0; + } + + if (strcmp(argv[1], "envelope") == 0) { + sample_envelope_message msg; + size_t read; + arpack_status status = sample_envelope_message_decode(&msg, envelope_test, sizeof(envelope_test), &read); + if (status != ARPACK_OK) { + printf("STATUS=FAIL\n"); + return 1; + } + printf("STATUS=OK\n"); + printf("CODE=%%d\n", msg.code); + printf("COUNTER=%%d\n", msg.counter); + printf("READ=%%zu\n", read); + return 0; + } + + printf("Unknown test: %%s\n", argv[1]); + return 1; +} +`, vector3Array, envelopeArray) + write(t, filepath.Join(cDir, "test.c"), []byte(cTestSource)) + + // Compile C test program + cBin := filepath.Join(cDir, "test") + mustRun(t, cDir, cc, "-std=c11", "-Wall", "-Wextra", "-Wno-unused-function", "-o", cBin, "test.c", "sample.gen.c") + + // Run C test for Vector3 + t.Run("C_Decode_Vector3", func(t *testing.T) { + cmd := exec.Command("./test", "vector3") + cmd.Dir = cDir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C decode failed: %v\n%s", err, out) + } + kv := parseKV(string(out)) + if kv["STATUS"] != "OK" { + t.Fatalf("C decode failed: %s", string(out)) + } + // Check values are reasonable (quantized floats have error) + assertFloat(t, kv, "X", 123.45, 2.0) + assertFloat(t, kv, "Y", -200, 2.0) + assertFloat(t, kv, "Z", 0, 0.1) + }) + + // Run C test for EnvelopeMessage + t.Run("C_Decode_Envelope", func(t *testing.T) { + cmd := exec.Command("./test", "envelope") + cmd.Dir = cDir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C decode failed: %v\n%s", err, out) + } + kv := parseKV(string(out)) + if kv["STATUS"] != "OK" { + t.Fatalf("C decode failed: %s", string(out)) + } + assertInt(t, kv, "CODE", 2) + assertInt(t, kv, "COUNTER", 7) + }) + + // Test Go serialize -> C deserialize for Vector3 + t.Run("Go_to_C/Vector3", func(t *testing.T) { + hex := runHarness(t, goDir, "go", "ser", "Vector3", "") + // Write hex to file for C program to read + write(t, filepath.Join(cDir, "vector3.hex"), []byte(hex)) + + // Run C program to deserialize Go's output + cmd := exec.Command("./test", "vector3") + cmd.Dir = cDir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C decode failed: %v\n%s", err, out) + } + kv := parseKV(string(out)) + assertFloat(t, kv, "X", 123.45, 0.02) + assertFloat(t, kv, "Y", -200, 0.02) + assertFloat(t, kv, "Z", 0, 0.02) + }) + + // Test Go serialize -> C deserialize for EnvelopeMessage + t.Run("Go_to_C/EnvelopeMessage", func(t *testing.T) { + hexStr := runHarness(t, goDir, "go", "ser", "EnvelopeMessage", "") + data, err := hex.DecodeString(strings.TrimSpace(hexStr)) + if err != nil { + t.Fatalf("Failed to decode hex: %v", err) + } + + // Verify first byte is 0x02 (JoinRoom = 2, little endian) + if len(data) >= 2 && data[0] == 0x02 && data[1] == 0x00 { + // Verify C can read it + cmd := exec.Command("./test", "envelope") + cmd.Dir = cDir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C decode failed: %v\n%s", err, out) + } + kv := parseKV(string(out)) + assertInt(t, kv, "CODE", 2) + assertInt(t, kv, "COUNTER", 7) + } + }) +} + func TestE2E_CrossLanguage(t *testing.T) { schema, err := parser.ParseSchemaFile(samplePath) if err != nil { @@ -71,6 +268,14 @@ func TestE2E_CrossLanguage(t *testing.T) { } tsDir := buildTSHarness(t, tsSrc) + t.Run("QuantizedWire/Go_EQ_TS/Vector3", func(t *testing.T) { + goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "Vector3", "")) + tsHex := strings.TrimSpace(runHarness(t, tsDir, "ts", "ser", "Vector3", "")) + if goHex != tsHex { + t.Fatalf("quantized wire drift between Go and TS for Vector3:\ngo=%s\nts=%s", goHex, tsHex) + } + }) + for _, tc := range cases { t.Run("Go_to_TS/"+tc.name, func(t *testing.T) { hex := runHarness(t, goDir, "go", "ser", tc.typ, "") @@ -139,6 +344,14 @@ func TestE2E_CrossLanguage(t *testing.T) { } luaDir := buildLuaHarness(t, luaSrc) + t.Run("QuantizedWire/Go_EQ_Lua/Vector3", func(t *testing.T) { + goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", "Vector3", "")) + luaHex := strings.TrimSpace(runHarness(t, luaDir, "lua", "ser", "Vector3", "")) + if goHex != luaHex { + t.Fatalf("quantized wire drift between Go and Lua for Vector3:\ngo=%s\nlua=%s", goHex, luaHex) + } + }) + luaCases := []struct { name string typ string diff --git a/generator/c.go b/generator/c.go new file mode 100644 index 0000000..5f7f3ad --- /dev/null +++ b/generator/c.go @@ -0,0 +1,1046 @@ +package generator + +import ( + "fmt" + "sort" + "strings" + + "github.com/edmand46/arpack/parser" +) + +func collectEnumTypes(enums []parser.Enum) map[string]struct{} { + enumTypes := make(map[string]struct{}, len(enums)) + for _, enum := range enums { + if len(enum.Values) == 0 { + continue + } + enumTypes[enum.Name] = struct{}{} + } + return enumTypes +} + +func validateCSchema(schema parser.Schema) error { + msgIndex := make(map[string]parser.Message, len(schema.Messages)) + for _, msg := range schema.Messages { + msgIndex[msg.Name] = msg + } + + for _, msg := range schema.Messages { + for _, field := range msg.Fields { + if err := validateCField(msg.Name, field, msgIndex, false); err != nil { + return err + } + } + } + + return nil +} + +func validateCField(msgName string, field parser.Field, msgIndex map[string]parser.Message, insideSlice bool) error { + switch field.Kind { + case parser.KindNested: + nested, ok := msgIndex[field.TypeName] + if !ok { + return fmt.Errorf("c target: unknown nested message %q in %s", field.TypeName, msgName) + } + if hasNonByteSlices(&nested) { + return fmt.Errorf( + "c target does not support nested message %s in %s because nested decode contexts are not implemented", + field.TypeName, + msgName, + ) + } + for _, nestedField := range nested.Fields { + if err := validateCField(nested.Name, nestedField, msgIndex, false); err != nil { + return err + } + } + + case parser.KindFixedArray: + if field.Elem == nil { + return fmt.Errorf("c target: fixed array %s in %s has nil element", field.Name, msgName) + } + if field.Elem.Kind == parser.KindSlice { + return fmt.Errorf("c target does not support fixed arrays of slices for field %s in %s", field.Name, msgName) + } + return validateCField(msgName, *field.Elem, msgIndex, insideSlice) + + case parser.KindSlice: + if field.Elem == nil { + return fmt.Errorf("c target: slice %s in %s has nil element", field.Name, msgName) + } + if field.Elem.Kind == parser.KindSlice || field.Elem.Kind == parser.KindFixedArray { + return fmt.Errorf( + "c target does not support slice element kind %d for field %s in %s", + field.Elem.Kind, + field.Name, + msgName, + ) + } + return validateCField(msgName, *field.Elem, msgIndex, true) + } + + return nil +} + +type cSliceViewDef struct { + Name string + ElemType string +} + +func collectSliceViewTypes(messages []parser.Message, baseName string, enumTypes map[string]struct{}) ([]cSliceViewDef, bool) { + viewTypes := make(map[string]string) + needStringViewSlice := false + + var collectFields func(fields []parser.Field) + collectFields = func(fields []parser.Field) { + for _, field := range fields { + switch field.Kind { + case parser.KindSlice: + viewType, _ := cFieldTypeInfo(&field, baseName, enumTypes) + if viewType == "arpack_string_view_slice_view" { + needStringViewSlice = true + } else if viewType != "arpack_bytes_view" && viewType != "arpack_string_view_slice_view" { + elemType, _ := cFieldTypeInfo(field.Elem, baseName, enumTypes) + viewTypes[viewType] = elemType + } + if field.Elem != nil { + collectFields([]parser.Field{*field.Elem}) + } + case parser.KindFixedArray: + if field.Elem != nil { + collectFields([]parser.Field{*field.Elem}) + } + case parser.KindNested: + // Nested messages don't need special handling here + } + } + } + + for _, msg := range messages { + collectFields(msg.Fields) + } + + names := make([]string, 0, len(viewTypes)) + for viewType := range viewTypes { + names = append(names, viewType) + } + sort.Strings(names) + result := make([]cSliceViewDef, 0, len(names)) + for _, name := range names { + result = append(result, cSliceViewDef{Name: name, ElemType: viewTypes[name]}) + } + return result, needStringViewSlice +} + +func GenerateCSchema(schema parser.Schema, baseName string) (header []byte, source []byte, err error) { + if err := validateCSchema(schema); err != nil { + return nil, nil, err + } + + messages := schema.Messages + enums := schema.Enums + enumTypes := collectEnumTypes(enums) + + var headerBuilder strings.Builder + var sourceBuilder strings.Builder + + // Write file headers + headerGuard := strings.ToUpper(baseName) + "_GEN_H" + headerBuilder.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n") + headerBuilder.WriteString(fmt.Sprintf("#ifndef %s\n", headerGuard)) + headerBuilder.WriteString(fmt.Sprintf("#define %s\n\n", headerGuard)) + + sourceBuilder.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n") + sourceBuilder.WriteString(fmt.Sprintf("#include \"%s.gen.h\"\n\n", baseName)) + + // Write private runtime helpers + writeCRuntimeHelpers(&sourceBuilder) + + // Write includes in header + headerBuilder.WriteString("#include \n") + headerBuilder.WriteString("#include \n") + headerBuilder.WriteString("#include \n\n") + + // Write shared runtime types + headerBuilder.WriteString("typedef enum arpack_status {\n") + headerBuilder.WriteString(" ARPACK_OK = 0,\n") + headerBuilder.WriteString(" ARPACK_ERR_BUFFER_TOO_SHORT = 1,\n") + headerBuilder.WriteString(" ARPACK_ERR_LENGTH_OVERFLOW = 2,\n") + headerBuilder.WriteString(" ARPACK_ERR_INVALID_ARGUMENT = 3,\n") + headerBuilder.WriteString(" ARPACK_ERR_CAPACITY_TOO_SMALL = 4\n") + headerBuilder.WriteString("} arpack_status;\n\n") + + headerBuilder.WriteString("typedef struct arpack_string_view {\n") + headerBuilder.WriteString(" const char *data;\n") + headerBuilder.WriteString(" uint16_t len;\n") + headerBuilder.WriteString("} arpack_string_view;\n\n") + + headerBuilder.WriteString("typedef struct arpack_bytes_view {\n") + headerBuilder.WriteString(" const uint8_t *data;\n") + headerBuilder.WriteString(" uint16_t len;\n") + headerBuilder.WriteString("} arpack_bytes_view;\n\n") + + // Collect slice view types needed + sliceViewTypes, needStringViewSlice := collectSliceViewTypes(messages, baseName, enumTypes) + + // Forward declare message typedefs (for slice views that reference them) + for _, msg := range messages { + msgName := baseName + "_" + snakeCase(msg.Name) + headerBuilder.WriteString(fmt.Sprintf("typedef struct %s %s;\n", msgName, msgName)) + } + headerBuilder.WriteString("\n") + + // Add arpack_string_view_slice_view if needed + if needStringViewSlice { + headerBuilder.WriteString("typedef struct arpack_string_view_slice_view {\n") + headerBuilder.WriteString(" const arpack_string_view *data;\n") + headerBuilder.WriteString(" uint16_t len;\n") + headerBuilder.WriteString("} arpack_string_view_slice_view;\n\n") + } + + // Forward declare slice view typedefs + for _, viewType := range sliceViewTypes { + headerBuilder.WriteString(fmt.Sprintf("typedef struct %s %s;\n", viewType.Name, viewType.Name)) + } + if len(sliceViewTypes) > 0 { + headerBuilder.WriteString("\n") + } + + // Define slice view types + for _, viewType := range sliceViewTypes { + headerBuilder.WriteString(fmt.Sprintf("typedef struct %s {\n", viewType.Name)) + headerBuilder.WriteString(fmt.Sprintf(" const %s *data;\n", viewType.ElemType)) + headerBuilder.WriteString(" uint16_t len;\n") + headerBuilder.WriteString(fmt.Sprintf("} %s;\n\n", viewType.Name)) + } + + // Write enums + for _, enum := range enums { + if len(enum.Values) == 0 { + continue + } + writeCEnum(&headerBuilder, enum, baseName) + headerBuilder.WriteString("\n") + } + + // Write message declarations + for _, msg := range messages { + writeCMessageDecl(&headerBuilder, msg, baseName, enumTypes) + headerBuilder.WriteString("\n") + } + + // Write function declarations for fixed-size messages + for _, msg := range messages { + if !msg.HasVariableFields() { + writeCFixedSizeFuncDecls(&headerBuilder, msg, baseName) + headerBuilder.WriteString("\n") + } + } + + // Write decode context declarations for messages with non-byte slices + for _, msg := range messages { + if hasNonByteSlices(&msg) { + writeCDecodeCtxDecl(&headerBuilder, msg, baseName, enumTypes) + headerBuilder.WriteString("\n") + } + } + + // Write function declarations for variable-length messages + for _, msg := range messages { + if msg.HasVariableFields() { + writeCVariableSizeFuncDecls(&headerBuilder, msg, baseName) + headerBuilder.WriteString("\n") + } + } + + // Close header guard + headerBuilder.WriteString(fmt.Sprintf("#endif // %s\n", headerGuard)) + + // Write function implementations for fixed-size messages + for _, msg := range messages { + if !msg.HasVariableFields() { + writeCFixedSizeFuncImpls(&sourceBuilder, msg, baseName, enumTypes) + sourceBuilder.WriteString("\n") + } + } + + // Write function implementations for variable-length messages + for _, msg := range messages { + if msg.HasVariableFields() { + writeCVariableSizeFuncImpls(&sourceBuilder, msg, baseName, enumTypes) + sourceBuilder.WriteString("\n") + } + } + + return []byte(headerBuilder.String()), []byte(sourceBuilder.String()), nil +} + +func writeCEnum(b *strings.Builder, enum parser.Enum, baseName string) { + enumName := baseName + "_" + snakeCase(enum.Name) + b.WriteString(fmt.Sprintf("typedef enum %s {\n", enumName)) + for _, value := range enum.Values { + valueName := baseName + "_" + snakeCase(enum.Name) + "_" + snakeCase(value.Name) + b.WriteString(fmt.Sprintf(" %s = %s,\n", valueName, value.Value)) + } + b.WriteString(fmt.Sprintf("} %s;\n", enumName)) +} + +func writeCMessageDecl(b *strings.Builder, msg parser.Message, baseName string, enumTypes map[string]struct{}) { + msgName := baseName + "_" + snakeCase(msg.Name) + b.WriteString(fmt.Sprintf("typedef struct %s {\n", msgName)) + for _, field := range msg.Fields { + typeStr, arraySuffix := cFieldTypeInfo(&field, baseName, enumTypes) + fieldDecl := typeStr + " " + snakeCase(field.Name) + arraySuffix + ";" + b.WriteString(fmt.Sprintf(" %s\n", fieldDecl)) + } + b.WriteString(fmt.Sprintf("} %s;\n", msgName)) +} + +func writeCFixedSizeFuncDecls(b *strings.Builder, msg parser.Message, baseName string) { + msgName := baseName + "_" + snakeCase(msg.Name) + b.WriteString(fmt.Sprintf("size_t %s_min_size(void);\n", msgName)) + b.WriteString(fmt.Sprintf("arpack_status %s_size(const %s *msg, size_t *out_size);\n", msgName, msgName)) + b.WriteString(fmt.Sprintf("arpack_status %s_encode(const %s *msg, uint8_t *buf, size_t buf_len, size_t *out_written);\n", msgName, msgName)) + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, size_t *out_read);\n", msgName, msgName)) +} + +func writeCFixedSizeFuncImpls(b *strings.Builder, msg parser.Message, baseName string, enumTypes map[string]struct{}) { + msgName := baseName + "_" + snakeCase(msg.Name) + minSize := msg.MinWireSize() + segs := segmentFields(msg.Fields) + + // min_size function + b.WriteString(fmt.Sprintf("size_t %s_min_size(void) {\n", msgName)) + b.WriteString(fmt.Sprintf(" return %d;\n", minSize)) + b.WriteString("}\n\n") + + // size function (same as min_size for fixed-size messages) + b.WriteString(fmt.Sprintf("arpack_status %s_size(const %s *msg, size_t *out_size) {\n", msgName, msgName)) + b.WriteString(fmt.Sprintf(" *out_size = %d;\n", minSize)) + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n\n") + + // encode function + b.WriteString(fmt.Sprintf("arpack_status %s_encode(const %s *msg, uint8_t *buf, size_t buf_len, size_t *out_written) {\n", msgName, msgName)) + b.WriteString(fmt.Sprintf(" if (buf_len < %d) return ARPACK_ERR_BUFFER_TOO_SHORT;\n", minSize)) + b.WriteString(" size_t offset = 0;\n") + + // Generate encode logic using segments for bool packing + for _, seg := range segs { + if seg.single != nil { + writeCFieldEncode(b, seg.single, "msg->"+snakeCase(seg.single.Name), baseName, enumTypes, " ", 0) + } else { + // Encode bool group + writeCBoolGroupEncode(b, "msg", seg.bools) + } + } + + b.WriteString(" *out_written = offset;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n\n") + + // decode function + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, size_t *out_read) {\n", msgName, msgName)) + b.WriteString(fmt.Sprintf(" if (buf_len < %d) return ARPACK_ERR_BUFFER_TOO_SHORT;\n", minSize)) + b.WriteString(" size_t offset = 0;\n") + + // Generate decode logic using segments for bool packing + for _, seg := range segs { + if seg.single != nil { + writeCFieldDecode(b, seg.single, "msg->"+snakeCase(seg.single.Name), baseName, enumTypes, "", " ", 0) + } else { + // Decode bool group + writeCBoolGroupDecode(b, "msg", seg.bools) + } + } + + b.WriteString(" *out_read = offset;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n") +} + +func hasNonByteSlices(msg *parser.Message) bool { + for _, field := range msg.Fields { + if field.Kind == parser.KindSlice { + // []uint8 doesn't need context + if isRawUint8Element(field.Elem) { + continue + } + return true + } + } + return false +} + +func writeCDecodeCtxDecl(b *strings.Builder, msg parser.Message, baseName string, enumTypes map[string]struct{}) { + ctxName := baseName + "_" + snakeCase(msg.Name) + "_decode_ctx" + b.WriteString(fmt.Sprintf("typedef struct %s {\n", ctxName)) + for _, field := range msg.Fields { + if field.Kind == parser.KindSlice { + // Skip []uint8 - it uses bytes_view which doesn't need context + if isRawUint8Element(field.Elem) { + continue + } + // Generate context field for non-byte slices + fieldName := snakeCase(field.Name) + elemType, _ := cFieldTypeInfo(field.Elem, baseName, enumTypes) + b.WriteString(fmt.Sprintf(" %s *%s_data;\n", elemType, fieldName)) + b.WriteString(fmt.Sprintf(" uint16_t %s_cap;\n", fieldName)) + } + } + b.WriteString(fmt.Sprintf("} %s;\n", ctxName)) +} + +func writeCVariableSizeFuncDecls(b *strings.Builder, msg parser.Message, baseName string) { + msgName := baseName + "_" + snakeCase(msg.Name) + b.WriteString(fmt.Sprintf("size_t %s_min_size(void);\n", msgName)) + b.WriteString(fmt.Sprintf("arpack_status %s_size(const %s *msg, size_t *out_size);\n", msgName, msgName)) + b.WriteString(fmt.Sprintf("arpack_status %s_encode(const %s *msg, uint8_t *buf, size_t buf_len, size_t *out_written);\n", msgName, msgName)) + if hasNonByteSlices(&msg) { + ctxName := msgName + "_decode_ctx" + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, %s *ctx, size_t *out_read);\n", msgName, msgName, ctxName)) + } else { + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, size_t *out_read);\n", msgName, msgName)) + } +} + +func writeCVariableSizeFuncImpls(b *strings.Builder, msg parser.Message, baseName string, enumTypes map[string]struct{}) { + msgName := baseName + "_" + snakeCase(msg.Name) + minSize := msg.MinWireSize() + segs := segmentFields(msg.Fields) + hasNonByteSliceFields := hasNonByteSlices(&msg) + + // min_size function + b.WriteString(fmt.Sprintf("size_t %s_min_size(void) {\n", msgName)) + b.WriteString(fmt.Sprintf(" return %d;\n", minSize)) + b.WriteString("}\n\n") + + // size function + b.WriteString(fmt.Sprintf("arpack_status %s_size(const %s *msg, size_t *out_size) {\n", msgName, msgName)) + b.WriteString(" size_t size = 0;\n") + for _, seg := range segs { + if seg.single != nil { + writeCFieldSize(b, seg.single, "msg->"+snakeCase(seg.single.Name), baseName, enumTypes, " ", 0) + } else { + b.WriteString(" size += 1;\n") // bool group is 1 byte + } + } + b.WriteString(" *out_size = size;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n\n") + + // encode function + b.WriteString(fmt.Sprintf("arpack_status %s_encode(const %s *msg, uint8_t *buf, size_t buf_len, size_t *out_written) {\n", msgName, msgName)) + b.WriteString(" size_t total_size;\n") + b.WriteString(fmt.Sprintf(" arpack_status status = %s_size(msg, &total_size);\n", msgName)) + b.WriteString(" if (status != ARPACK_OK) return status;\n") + b.WriteString(" if (buf_len < total_size) return ARPACK_ERR_BUFFER_TOO_SHORT;\n") + b.WriteString(" size_t offset = 0;\n") + for _, seg := range segs { + if seg.single != nil { + writeCFieldEncode(b, seg.single, "msg->"+snakeCase(seg.single.Name), baseName, enumTypes, " ", 0) + } else { + writeCBoolGroupEncode(b, "msg", seg.bools) + } + } + b.WriteString(" *out_written = offset;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n\n") + + // decode function + if hasNonByteSliceFields { + ctxName := msgName + "_decode_ctx" + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, %s *ctx, size_t *out_read) {\n", msgName, msgName, ctxName)) + b.WriteString(" if (ctx == NULL) return ARPACK_ERR_INVALID_ARGUMENT;\n") + } else { + b.WriteString(fmt.Sprintf("arpack_status %s_decode(%s *msg, const uint8_t *buf, size_t buf_len, size_t *out_read) {\n", msgName, msgName)) + } + b.WriteString(fmt.Sprintf(" if (buf_len < %d) return ARPACK_ERR_BUFFER_TOO_SHORT;\n", minSize)) + b.WriteString(" size_t offset = 0;\n") + for _, seg := range segs { + if seg.single != nil { + ctxVar := "" + if hasNonByteSliceFields { + ctxVar = "ctx" + } + writeCFieldDecode(b, seg.single, "msg->"+snakeCase(seg.single.Name), baseName, enumTypes, ctxVar, " ", 0) + } else { + writeCBoolGroupDecode(b, "msg", seg.bools) + } + } + b.WriteString(" *out_read = offset;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n") +} + +func writeCFieldSize( + b *strings.Builder, + field *parser.Field, + access string, + baseName string, + enumTypes map[string]struct{}, + indent string, + depth int, +) { + switch field.Kind { + case parser.KindPrimitive: + if field.Quant != nil { + b.WriteString(fmt.Sprintf("%ssize += %d;\n", indent, field.Quant.WireBytes())) + return + } + if field.Primitive == parser.KindString { + b.WriteString(fmt.Sprintf("%ssize += 2 + %s.len;\n", indent, access)) + return + } + ws := field.WireSize() + if ws > 0 { + b.WriteString(fmt.Sprintf("%ssize += %d;\n", indent, ws)) + } + case parser.KindNested: + nestedName := baseName + "_" + snakeCase(field.TypeName) + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s size_t nested_size;\n", indent)) + b.WriteString(fmt.Sprintf("%s arpack_status status = %s_size(&%s, &nested_size);\n", indent, nestedName, access)) + b.WriteString(fmt.Sprintf("%s if (status != ARPACK_OK) return status;\n", indent)) + b.WriteString(fmt.Sprintf("%s size += nested_size;\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindFixedArray: + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %d; %s++) {\n", indent, idxVar, idxVar, field.FixedLen, idxVar)) + writeCFieldSize(b, field.Elem, access+"["+idxVar+"]", baseName, enumTypes, indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindSlice: + b.WriteString(fmt.Sprintf("%ssize += 2;\n", indent)) + if isRawUint8Element(field.Elem) { + b.WriteString(fmt.Sprintf("%ssize += %s.len;\n", indent, access)) + return + } + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %s.len; %s++) {\n", indent, idxVar, idxVar, access, idxVar)) + writeCFieldSize(b, field.Elem, access+".data["+idxVar+"]", baseName, enumTypes, indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + } +} + +func writeCBoundsCheck(b *strings.Builder, indent string, needed string) { + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s arpack_status status = _arpack_check_bounds(offset, %s, buf_len);\n", indent, needed)) + b.WriteString(fmt.Sprintf("%s if (status != ARPACK_OK) return status;\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) +} + +func writeCBoolGroupEncode(b *strings.Builder, msgVar string, bools []parser.Field) { + b.WriteString(" {\n") + b.WriteString(" uint8_t _boolByte = 0;\n") + for i, field := range bools { + fieldName := snakeCase(field.Name) + access := msgVar + "->" + fieldName + b.WriteString(fmt.Sprintf(" if (%s) _boolByte |= (1 << %d);\n", access, i)) + } + b.WriteString(" _arpack_write_u8(buf, &offset, _boolByte);\n") + b.WriteString(" }\n") +} + +func writeCBoolGroupDecode(b *strings.Builder, msgVar string, bools []parser.Field) { + b.WriteString(" {\n") + writeCBoundsCheck(b, " ", "1") + b.WriteString(" uint8_t _boolByte = _arpack_read_u8(buf, &offset);\n") + for i, field := range bools { + fieldName := snakeCase(field.Name) + access := msgVar + "->" + fieldName + b.WriteString(fmt.Sprintf(" %s = (_boolByte & (1 << %d)) != 0;\n", access, i)) + } + b.WriteString(" }\n") +} + +func writeCFieldEncode( + b *strings.Builder, + field *parser.Field, + access string, + baseName string, + enumTypes map[string]struct{}, + indent string, + depth int, +) { + switch field.Kind { + case parser.KindPrimitive: + if field.Quant != nil && (field.Primitive == parser.KindFloat32 || field.Primitive == parser.KindFloat64) { + writeCQuantizedEncode(b, field, access, indent) + return + } + + if field.Primitive == parser.KindString { + b.WriteString(fmt.Sprintf("%s_arpack_write_u16_le(buf, &offset, %s.len);\n", indent, access)) + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %s.len; %s++) {\n", indent, idxVar, idxVar, access, idxVar)) + b.WriteString(fmt.Sprintf("%s _arpack_write_u8(buf, &offset, (uint8_t)%s.data[%s]);\n", indent, access, idxVar)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + return + } + + helper := cPrimitiveWriteHelper(field.Primitive) + if field.Primitive == parser.KindBool { + b.WriteString(fmt.Sprintf("%s_arpack_write_%s(buf, &offset, %s ? 1 : 0);\n", indent, helper, access)) + return + } + b.WriteString(fmt.Sprintf("%s_arpack_write_%s(buf, &offset, %s);\n", indent, helper, access)) + case parser.KindNested: + nestedName := baseName + "_" + snakeCase(field.TypeName) + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s size_t nested_written;\n", indent)) + b.WriteString(fmt.Sprintf("%s arpack_status status = %s_encode(&%s, buf + offset, buf_len - offset, &nested_written);\n", indent, nestedName, access)) + b.WriteString(fmt.Sprintf("%s if (status != ARPACK_OK) return status;\n", indent)) + b.WriteString(fmt.Sprintf("%s offset += nested_written;\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindFixedArray: + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %d; %s++) {\n", indent, idxVar, idxVar, field.FixedLen, idxVar)) + writeCFieldEncode(b, field.Elem, access+"["+idxVar+"]", baseName, enumTypes, indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindSlice: + b.WriteString(fmt.Sprintf("%s_arpack_write_u16_le(buf, &offset, %s.len);\n", indent, access)) + if isRawUint8Element(field.Elem) { + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %s.len; %s++) {\n", indent, idxVar, idxVar, access, idxVar)) + b.WriteString(fmt.Sprintf("%s _arpack_write_u8(buf, &offset, %s.data[%s]);\n", indent, access, idxVar)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + return + } + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %s.len; %s++) {\n", indent, idxVar, idxVar, access, idxVar)) + writeCFieldEncode(b, field.Elem, access+".data["+idxVar+"]", baseName, enumTypes, indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + } +} + +func writeCQuantizedEncode(b *strings.Builder, field *parser.Field, access string, indent string) { + maxUint := int(field.Quant.MaxUint()) + if field.Quant.Bits == 8 { + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s double _q = ((double)%s - (%g)) / ((%g) - (%g)) * %d;\n", indent, access, field.Quant.Min, field.Quant.Max, field.Quant.Min, maxUint)) + b.WriteString(fmt.Sprintf("%s if (_q < 0.0 || _q > %d.0) return ARPACK_ERR_INVALID_ARGUMENT;\n", indent, maxUint)) + b.WriteString(fmt.Sprintf("%s uint8_t _qv = (uint8_t)_q;\n", indent)) + b.WriteString(fmt.Sprintf("%s _arpack_write_u8(buf, &offset, _qv);\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + return + } + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s double _q = ((double)%s - (%g)) / ((%g) - (%g)) * %d;\n", indent, access, field.Quant.Min, field.Quant.Max, field.Quant.Min, maxUint)) + b.WriteString(fmt.Sprintf("%s if (_q < 0.0 || _q > %d.0) return ARPACK_ERR_INVALID_ARGUMENT;\n", indent, maxUint)) + b.WriteString(fmt.Sprintf("%s uint16_t _qv = (uint16_t)_q;\n", indent)) + b.WriteString(fmt.Sprintf("%s _arpack_write_u16_le(buf, &offset, _qv);\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) +} + +func writeCFieldDecode( + b *strings.Builder, + field *parser.Field, + access string, + baseName string, + enumTypes map[string]struct{}, + ctxVar string, + indent string, + depth int, +) { + switch field.Kind { + case parser.KindPrimitive: + if field.Quant != nil && (field.Primitive == parser.KindFloat32 || field.Primitive == parser.KindFloat64) { + writeCQuantizedDecode(b, field, access, indent) + return + } + + if field.Primitive == parser.KindString { + writeCBoundsCheck(b, indent, "2") + b.WriteString(fmt.Sprintf("%s%s.len = _arpack_read_u16_le(buf, &offset);\n", indent, access)) + writeCBoundsCheck(b, indent, access+".len") + b.WriteString(fmt.Sprintf("%s%s.data = (const char *)(buf + offset);\n", indent, access)) + b.WriteString(fmt.Sprintf("%soffset += %s.len;\n", indent, access)) + return + } + writeCBoundsCheck(b, indent, fmt.Sprintf("%d", field.WireSize())) + helper := cPrimitiveReadHelper(field.Primitive) + if field.Primitive == parser.KindBool { + b.WriteString(fmt.Sprintf("%s%s = _arpack_read_%s(buf, &offset) != 0;\n", indent, access, helper)) + return + } + b.WriteString(fmt.Sprintf("%s%s = _arpack_read_%s(buf, &offset);\n", indent, access, helper)) + case parser.KindNested: + nestedName := baseName + "_" + snakeCase(field.TypeName) + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s size_t nested_read;\n", indent)) + b.WriteString(fmt.Sprintf("%s arpack_status status = %s_decode(&%s, buf + offset, buf_len - offset, &nested_read);\n", indent, nestedName, access)) + b.WriteString(fmt.Sprintf("%s if (status != ARPACK_OK) return status;\n", indent)) + b.WriteString(fmt.Sprintf("%s offset += nested_read;\n", indent)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindFixedArray: + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %d; %s++) {\n", indent, idxVar, idxVar, field.FixedLen, idxVar)) + writeCFieldDecode(b, field.Elem, access+"["+idxVar+"]", baseName, enumTypes, "", indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + case parser.KindSlice: + writeCBoundsCheck(b, indent, "2") + b.WriteString(fmt.Sprintf("%s%s.len = _arpack_read_u16_le(buf, &offset);\n", indent, access)) + if isRawUint8Element(field.Elem) { + writeCBoundsCheck(b, indent, access+".len") + b.WriteString(fmt.Sprintf("%s%s.data = buf + offset;\n", indent, access)) + b.WriteString(fmt.Sprintf("%soffset += %s.len;\n", indent, access)) + return + } + ctxField := snakeCase(field.Name) + b.WriteString(fmt.Sprintf("%sif (%s.len > %s->%s_cap) return ARPACK_ERR_CAPACITY_TOO_SMALL;\n", indent, access, ctxVar, ctxField)) + b.WriteString(fmt.Sprintf("%s%s.data = %s->%s_data;\n", indent, access, ctxVar, ctxField)) + idxVar := fmt.Sprintf("_i%d", depth) + b.WriteString(fmt.Sprintf("%sfor (uint16_t %s = 0; %s < %s.len; %s++) {\n", indent, idxVar, idxVar, access, idxVar)) + writeCFieldDecode(b, field.Elem, ctxVar+"->"+ctxField+"_data["+idxVar+"]", baseName, enumTypes, "", indent+" ", depth+1) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + } +} + +func writeCQuantizedDecode(b *strings.Builder, field *parser.Field, access string, indent string) { + maxUint := field.Quant.MaxUint() + if field.Quant.Bits == 8 { + writeCBoundsCheck(b, indent, "1") + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s uint8_t _qv = _arpack_read_u8(buf, &offset);\n", indent)) + b.WriteString(fmt.Sprintf("%s %s = ((double)_qv / %g) * ((%g) - (%g)) + (%g);\n", indent, access, maxUint, field.Quant.Max, field.Quant.Min, field.Quant.Min)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) + return + } + writeCBoundsCheck(b, indent, "2") + b.WriteString(fmt.Sprintf("%s{\n", indent)) + b.WriteString(fmt.Sprintf("%s uint16_t _qv = _arpack_read_u16_le(buf, &offset);\n", indent)) + b.WriteString(fmt.Sprintf("%s %s = ((double)_qv / %g) * ((%g) - (%g)) + (%g);\n", indent, access, maxUint, field.Quant.Max, field.Quant.Min, field.Quant.Min)) + b.WriteString(fmt.Sprintf("%s}\n", indent)) +} + +func cFieldTypeInfo(field *parser.Field, baseName string, enumTypes map[string]struct{}) (typeStr string, arraySuffix string) { + switch field.Kind { + case parser.KindPrimitive: + return cPrimitiveTypeInfo(field, baseName, enumTypes), "" + case parser.KindNested: + return baseName + "_" + snakeCase(field.TypeName), "" + case parser.KindFixedArray: + elemType, elemSuffix := cFieldTypeInfo(field.Elem, baseName, enumTypes) + return elemType, fmt.Sprintf("[%d]%s", field.FixedLen, elemSuffix) + case parser.KindSlice: + if isRawUint8Element(field.Elem) { + return "arpack_bytes_view", "" + } + if field.Elem.Kind == parser.KindPrimitive && field.Elem.Primitive == parser.KindString { + return "arpack_string_view_slice_view", "" + } + return cSliceViewTypeName(field.Elem, baseName, enumTypes), "" + default: + return "void*", "" + } +} + +func cFieldType(field *parser.Field, baseName string, enumTypes map[string]struct{}) string { + typeStr, suffix := cFieldTypeInfo(field, baseName, enumTypes) + if suffix != "" { + return typeStr + suffix + } + return typeStr +} + +func cSliceViewTypeName(field *parser.Field, baseName string, enumTypes map[string]struct{}) string { + return baseName + "_" + cSliceViewElemKey(field, enumTypes) + "_slice_view" +} + +func cSliceViewElemKey(field *parser.Field, enumTypes map[string]struct{}) string { + switch field.Kind { + case parser.KindPrimitive: + if field.NamedType != "" { + if _, ok := enumTypes[field.NamedType]; ok { + return snakeCase(field.NamedType) + } + } + return cPrimitiveTypeToken(field.Primitive) + case parser.KindNested: + return snakeCase(field.TypeName) + default: + return "unsupported" + } +} + +func cPrimitiveTypeInfo(field *parser.Field, baseName string, enumTypes map[string]struct{}) string { + if field.NamedType != "" { + if _, ok := enumTypes[field.NamedType]; ok { + return baseName + "_" + snakeCase(field.NamedType) + } + } + switch field.Primitive { + case parser.KindInt8: + return "int8_t" + case parser.KindInt16: + return "int16_t" + case parser.KindInt32: + return "int32_t" + case parser.KindInt64: + return "int64_t" + case parser.KindUint8: + return "uint8_t" + case parser.KindUint16: + return "uint16_t" + case parser.KindUint32: + return "uint32_t" + case parser.KindUint64: + return "uint64_t" + case parser.KindFloat32: + return "float" + case parser.KindFloat64: + return "double" + case parser.KindBool: + return "bool" + case parser.KindString: + return "arpack_string_view" + default: + return "void*" + } +} + +func cPrimitiveTypeToken(k parser.PrimitiveKind) string { + switch k { + case parser.KindInt8: + return "int8" + case parser.KindInt16: + return "int16" + case parser.KindInt32: + return "int32" + case parser.KindInt64: + return "int64" + case parser.KindUint8: + return "uint8" + case parser.KindUint16: + return "uint16" + case parser.KindUint32: + return "uint32" + case parser.KindUint64: + return "uint64" + case parser.KindFloat32: + return "float32" + case parser.KindFloat64: + return "float64" + case parser.KindBool: + return "bool" + case parser.KindString: + return "string" + default: + return "unknown" + } +} + +func cPrimitiveReadHelper(k parser.PrimitiveKind) string { + switch k { + case parser.KindInt8: + return "i8" + case parser.KindInt16: + return "i16_le" + case parser.KindInt32: + return "i32_le" + case parser.KindInt64: + return "i64_le" + case parser.KindUint8, parser.KindBool: + return "u8" + case parser.KindUint16: + return "u16_le" + case parser.KindUint32: + return "u32_le" + case parser.KindUint64: + return "u64_le" + case parser.KindFloat32: + return "f32_le" + case parser.KindFloat64: + return "f64_le" + default: + return "" + } +} + +func cPrimitiveWriteHelper(k parser.PrimitiveKind) string { + return cPrimitiveReadHelper(k) +} + +func isRawUint8Element(field *parser.Field) bool { + return field != nil && + field.Kind == parser.KindPrimitive && + field.Primitive == parser.KindUint8 && + field.NamedType == "" && + field.Quant == nil +} + +func writeCRuntimeHelpers(b *strings.Builder) { + b.WriteString("// Private runtime helpers\n\n") + + // Bounds check helper + b.WriteString("static inline arpack_status _arpack_check_bounds(size_t offset, size_t needed, size_t buf_len) {\n") + b.WriteString(" if (offset + needed > buf_len) return ARPACK_ERR_BUFFER_TOO_SHORT;\n") + b.WriteString(" return ARPACK_OK;\n") + b.WriteString("}\n\n") + + // Read helpers (little-endian) + b.WriteString("static inline uint8_t _arpack_read_u8(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" uint8_t val = buf[*offset];\n") + b.WriteString(" *offset += 1;\n") + b.WriteString(" return val;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline uint16_t _arpack_read_u16_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" uint16_t val = (uint16_t)buf[*offset] | ((uint16_t)buf[*offset + 1] << 8);\n") + b.WriteString(" *offset += 2;\n") + b.WriteString(" return val;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline uint32_t _arpack_read_u32_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" uint32_t val = (uint32_t)buf[*offset] | ((uint32_t)buf[*offset + 1] << 8) |\\\n") + b.WriteString(" ((uint32_t)buf[*offset + 2] << 16) | ((uint32_t)buf[*offset + 3] << 24);\n") + b.WriteString(" *offset += 4;\n") + b.WriteString(" return val;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline uint64_t _arpack_read_u64_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" uint64_t val = (uint64_t)buf[*offset] | ((uint64_t)buf[*offset + 1] << 8) |\\\n") + b.WriteString(" ((uint64_t)buf[*offset + 2] << 16) | ((uint64_t)buf[*offset + 3] << 24) |\\\n") + b.WriteString(" ((uint64_t)buf[*offset + 4] << 32) | ((uint64_t)buf[*offset + 5] << 40) |\\\n") + b.WriteString(" ((uint64_t)buf[*offset + 6] << 48) | ((uint64_t)buf[*offset + 7] << 56);\n") + b.WriteString(" *offset += 8;\n") + b.WriteString(" return val;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline int8_t _arpack_read_i8(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return (int8_t)_arpack_read_u8(buf, offset);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline int16_t _arpack_read_i16_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return (int16_t)_arpack_read_u16_le(buf, offset);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline int32_t _arpack_read_i32_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return (int32_t)_arpack_read_u32_le(buf, offset);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline int64_t _arpack_read_i64_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return (int64_t)_arpack_read_u64_le(buf, offset);\n") + b.WriteString("}\n\n") + + // Write helpers (little-endian) + b.WriteString("static inline void _arpack_write_u8(uint8_t *buf, size_t *offset, uint8_t val) {\n") + b.WriteString(" buf[*offset] = val;\n") + b.WriteString(" *offset += 1;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_u16_le(uint8_t *buf, size_t *offset, uint16_t val) {\n") + b.WriteString(" buf[*offset] = val & 0xFF;\n") + b.WriteString(" buf[*offset + 1] = (val >> 8) & 0xFF;\n") + b.WriteString(" *offset += 2;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_u32_le(uint8_t *buf, size_t *offset, uint32_t val) {\n") + b.WriteString(" buf[*offset] = val & 0xFF;\n") + b.WriteString(" buf[*offset + 1] = (val >> 8) & 0xFF;\n") + b.WriteString(" buf[*offset + 2] = (val >> 16) & 0xFF;\n") + b.WriteString(" buf[*offset + 3] = (val >> 24) & 0xFF;\n") + b.WriteString(" *offset += 4;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_u64_le(uint8_t *buf, size_t *offset, uint64_t val) {\n") + b.WriteString(" buf[*offset] = val & 0xFF;\n") + b.WriteString(" buf[*offset + 1] = (val >> 8) & 0xFF;\n") + b.WriteString(" buf[*offset + 2] = (val >> 16) & 0xFF;\n") + b.WriteString(" buf[*offset + 3] = (val >> 24) & 0xFF;\n") + b.WriteString(" buf[*offset + 4] = (val >> 32) & 0xFF;\n") + b.WriteString(" buf[*offset + 5] = (val >> 40) & 0xFF;\n") + b.WriteString(" buf[*offset + 6] = (val >> 48) & 0xFF;\n") + b.WriteString(" buf[*offset + 7] = (val >> 56) & 0xFF;\n") + b.WriteString(" *offset += 8;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_i8(uint8_t *buf, size_t *offset, int8_t val) {\n") + b.WriteString(" _arpack_write_u8(buf, offset, (uint8_t)val);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_i16_le(uint8_t *buf, size_t *offset, int16_t val) {\n") + b.WriteString(" _arpack_write_u16_le(buf, offset, (uint16_t)val);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_i32_le(uint8_t *buf, size_t *offset, int32_t val) {\n") + b.WriteString(" _arpack_write_u32_le(buf, offset, (uint32_t)val);\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_i64_le(uint8_t *buf, size_t *offset, int64_t val) {\n") + b.WriteString(" _arpack_write_u64_le(buf, offset, (uint64_t)val);\n") + b.WriteString("}\n\n") + + // Float bit conversion helpers + b.WriteString("static inline float _arpack_u32_to_f(uint32_t bits) {\n") + b.WriteString(" union { uint32_t u; float f; } conv;\n") + b.WriteString(" conv.u = bits;\n") + b.WriteString(" return conv.f;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline uint32_t _arpack_f_to_u32(float val) {\n") + b.WriteString(" union { uint32_t u; float f; } conv;\n") + b.WriteString(" conv.f = val;\n") + b.WriteString(" return conv.u;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline double _arpack_u64_to_f(uint64_t bits) {\n") + b.WriteString(" union { uint64_t u; double f; } conv;\n") + b.WriteString(" conv.u = bits;\n") + b.WriteString(" return conv.f;\n") + b.WriteString("}\n\n") + + b.WriteString("static inline uint64_t _arpack_f_to_u64(double val) {\n") + b.WriteString(" union { uint64_t u; double f; } conv;\n") + b.WriteString(" conv.f = val;\n") + b.WriteString(" return conv.u;\n") + b.WriteString("}\n\n") + + // Float read/write helpers + b.WriteString("static inline float _arpack_read_f32_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return _arpack_u32_to_f(_arpack_read_u32_le(buf, offset));\n") + b.WriteString("}\n\n") + + b.WriteString("static inline double _arpack_read_f64_le(const uint8_t *buf, size_t *offset) {\n") + b.WriteString(" return _arpack_u64_to_f(_arpack_read_u64_le(buf, offset));\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_f32_le(uint8_t *buf, size_t *offset, float val) {\n") + b.WriteString(" _arpack_write_u32_le(buf, offset, _arpack_f_to_u32(val));\n") + b.WriteString("}\n\n") + + b.WriteString("static inline void _arpack_write_f64_le(uint8_t *buf, size_t *offset, double val) {\n") + b.WriteString(" _arpack_write_u64_le(buf, offset, _arpack_f_to_u64(val));\n") + b.WriteString("}\n\n") +} + +func snakeCase(s string) string { + if s == "" { + return "" + } + + var b strings.Builder + var prevUpper bool + + for i, c := range s { + isUpper := c >= 'A' && c <= 'Z' + + if i > 0 && isUpper { + nextLower := false + if i+1 < len(s) { + nextChar := rune(s[i+1]) + nextLower = nextChar >= 'a' && nextChar <= 'z' + } + + if !prevUpper || nextLower { + b.WriteByte('_') + } + } + + b.WriteRune(c) + prevUpper = isUpper + } + + return strings.ToLower(b.String()) +} diff --git a/generator/c_test.go b/generator/c_test.go new file mode 100644 index 0000000..c486ffc --- /dev/null +++ b/generator/c_test.go @@ -0,0 +1,802 @@ +package generator + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/edmand46/arpack/parser" +) + +func TestCSnakeCase(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"Simple", "simple"}, + {"PlayerID", "player_id"}, + {"HTTPRequest", "http_request"}, + {"XMLParser", "xml_parser"}, + {"MoveMessage", "move_message"}, + {"position", "position"}, + {"X", "x"}, + {"HTTPServer", "http_server"}, + {"URLHandler", "url_handler"}, + } + + for _, tc := range tests { + result := snakeCase(tc.input) + if result != tc.expected { + t.Errorf("snakeCase(%q) = %q, want %q", tc.input, result, tc.expected) + } + } +} + +func TestCGenerateSchema_BasicTypes(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "BasicTypes", + Fields: []parser.Field{ + {Name: "Int8Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt8}, + {Name: "Int16Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt16}, + {Name: "Int32Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + {Name: "Int64Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt64}, + {Name: "Uint8Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint8}, + {Name: "Uint16Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint16}, + {Name: "Uint32Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint32}, + {Name: "Uint64Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint64}, + {Name: "Float32Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + {Name: "Float64Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64}, + {Name: "BoolField", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "StringField", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check struct declaration + if !strings.Contains(headerStr, "typedef struct test_basic_types {") { + t.Error("Missing test_basic_types struct declaration") + } + + // Check primitive field types + if !strings.Contains(headerStr, "int8_t int8_field;") { + t.Error("Missing int8_field") + } + if !strings.Contains(headerStr, "int16_t int16_field;") { + t.Error("Missing int16_field") + } + if !strings.Contains(headerStr, "int32_t int32_field;") { + t.Error("Missing int32_field") + } + if !strings.Contains(headerStr, "int64_t int64_field;") { + t.Error("Missing int64_field") + } + if !strings.Contains(headerStr, "uint8_t uint8_field;") { + t.Error("Missing uint8_field") + } + if !strings.Contains(headerStr, "uint16_t uint16_field;") { + t.Error("Missing uint16_field") + } + if !strings.Contains(headerStr, "uint32_t uint32_field;") { + t.Error("Missing uint32_field") + } + if !strings.Contains(headerStr, "uint64_t uint64_field;") { + t.Error("Missing uint64_field") + } + if !strings.Contains(headerStr, "float float32_field;") { + t.Error("Missing float32_field") + } + if !strings.Contains(headerStr, "double float64_field;") { + t.Error("Missing float64_field") + } + if !strings.Contains(headerStr, "bool bool_field;") { + t.Error("Missing bool_field") + } + if !strings.Contains(headerStr, "arpack_string_view string_field;") { + t.Error("Missing string_field") + } +} + +func TestCGenerateSchema_Enum(t *testing.T) { + schema := parser.Schema{ + Enums: []parser.Enum{ + { + Name: "Opcode", + Primitive: parser.KindUint16, + Values: []parser.EnumValue{ + {Name: "Unknown", Value: "0"}, + {Name: "Join", Value: "1"}, + {Name: "Leave", Value: "2"}, + }, + }, + }, + Messages: []parser.Message{ + { + Name: "MessageWithEnum", + Fields: []parser.Field{ + {Name: "Op", Kind: parser.KindPrimitive, Primitive: parser.KindUint16, NamedType: "Opcode"}, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check enum declaration + if !strings.Contains(headerStr, "typedef enum test_opcode {") { + t.Error("Missing test_opcode enum declaration") + } + if !strings.Contains(headerStr, "test_opcode_unknown = 0,") { + t.Error("Missing Unknown enum value") + } + if !strings.Contains(headerStr, "test_opcode_join = 1,") { + t.Error("Missing Join enum value") + } + if !strings.Contains(headerStr, "test_opcode_leave = 2,") { + t.Error("Missing Leave enum value") + } + if !strings.Contains(headerStr, "test_opcode op;") { + t.Error("Enum-backed field should use generated enum type") + } +} + +func TestCGenerateSchema_HeaderGuard(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + {Name: "Simple", Fields: []parser.Field{}}, + }, + } + + header, _, err := GenerateCSchema(schema, "my_base") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + if !strings.Contains(headerStr, "#ifndef MY_BASE_GEN_H") { + t.Error("Missing header guard ifndef") + } + if !strings.Contains(headerStr, "#define MY_BASE_GEN_H") { + t.Error("Missing header guard define") + } + if !strings.Contains(headerStr, "#endif // MY_BASE_GEN_H") { + t.Error("Missing header guard endif") + } +} + +func TestCGenerateSchema_RuntimeTypes(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + {Name: "Simple", Fields: []parser.Field{}}, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check arpack_status enum + if !strings.Contains(headerStr, "typedef enum arpack_status {") { + t.Error("Missing arpack_status enum") + } + if !strings.Contains(headerStr, "ARPACK_OK = 0,") { + t.Error("Missing ARPACK_OK") + } + + // Check string view + if !strings.Contains(headerStr, "typedef struct arpack_string_view {") { + t.Error("Missing arpack_string_view") + } + + // Check bytes view + if !strings.Contains(headerStr, "typedef struct arpack_bytes_view {") { + t.Error("Missing arpack_bytes_view") + } + + // Check standard includes + if !strings.Contains(headerStr, "#include ") { + t.Error("Missing stdint.h include") + } + if !strings.Contains(headerStr, "#include ") { + t.Error("Missing stddef.h include") + } + if !strings.Contains(headerStr, "#include ") { + t.Error("Missing stdbool.h include") + } +} + +func TestCGenerateSchema_NestedMessages(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Inner", + Fields: []parser.Field{ + {Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + }, + }, + { + Name: "Outer", + Fields: []parser.Field{ + {Name: "InnerMsg", Kind: parser.KindNested, TypeName: "Inner"}, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check inner struct + if !strings.Contains(headerStr, "typedef struct test_inner {") { + t.Error("Missing test_inner struct") + } + + // Check outer struct with nested field + if !strings.Contains(headerStr, "test_inner inner_msg;") { + t.Error("Missing nested inner_msg field") + } +} + +func TestCGenerateSchema_FixedArrays(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "ArrayMessage", + Fields: []parser.Field{ + { + Name: "Values", + Kind: parser.KindFixedArray, + FixedLen: 3, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + }, + }, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + if !strings.Contains(headerStr, "float values[3];") { + t.Error("Missing fixed array field") + } +} + +func TestCGenerateSchema_BoolPacking(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "BoolMessage", + Fields: []parser.Field{ + {Name: "Active", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Visible", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Ghost", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Count", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + }, + }, + }, + } + + header, source, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + sourceStr := string(source) + + // Check struct fields + if !strings.Contains(headerStr, "bool active;") { + t.Error("Missing active field") + } + if !strings.Contains(headerStr, "bool visible;") { + t.Error("Missing visible field") + } + + // Check encode uses bool packing + if !strings.Contains(sourceStr, "_boolByte") { + t.Error("Bool packing not used in encode") + } + if !strings.Contains(sourceStr, "_arpack_write_u8(buf, &offset, _boolByte);") { + t.Error("Bool byte not written") + } +} + +func TestCGenerateSchema_QuantizedFloats(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "QuantMessage", + Fields: []parser.Field{ + { + Name: "Q8", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: 0, Max: 100, Bits: 8}, + }, + { + Name: "Q16", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}, + }, + }, + }, + }, + } + + _, source, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + sourceStr := string(source) + + // Check 8-bit quantization (uses 255 as max value) + if !strings.Contains(sourceStr, "255") { + t.Error("Missing 8-bit quantization") + } + + // Check 16-bit quantization (uses 65535 as max value) + if !strings.Contains(sourceStr, "65535") { + t.Error("Missing 16-bit quantization") + } + + // Check encode uses quantization + if !strings.Contains(sourceStr, "_arpack_write_u8(buf, &offset, _qv)") { + t.Error("8-bit quantized value not written") + } + if !strings.Contains(sourceStr, "_arpack_write_u16_le(buf, &offset, _qv)") { + t.Error("16-bit quantized value not written") + } +} + +func TestCGenerateSchema_VariableLength(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "VarMessage", + Fields: []parser.Field{ + {Name: "Id", Kind: parser.KindPrimitive, Primitive: parser.KindUint32}, + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + { + Name: "Data", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint8, + }, + }, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check min_size function exists + if !strings.Contains(headerStr, "size_t test_var_message_min_size(void);") { + t.Error("Missing min_size function declaration") + } + + // Check size function exists + if !strings.Contains(headerStr, "arpack_status test_var_message_size(const test_var_message *msg, size_t *out_size);") { + t.Error("Missing size function declaration") + } + + // Check encode function exists + if !strings.Contains(headerStr, "arpack_status test_var_message_encode(const test_var_message *msg, uint8_t *buf, size_t buf_len, size_t *out_written);") { + t.Error("Missing encode function declaration") + } + + // Check decode function exists without context (only byte slices and strings) + if !strings.Contains(headerStr, "arpack_status test_var_message_decode(test_var_message *msg, const uint8_t *buf, size_t buf_len, size_t *out_read);") { + t.Error("Missing decode function declaration") + } +} + +func TestCCompile_SampleSchema(t *testing.T) { + hasCC := false + if _, err := exec.LookPath("cc"); err == nil { + hasCC = true + } else if _, err := exec.LookPath("gcc"); err == nil { + hasCC = true + } else if _, err := exec.LookPath("clang"); err == nil { + hasCC = true + } + + if !hasCC { + t.Skip("No C compiler found (tried cc, gcc, clang)") + } + + schema, err := parser.ParseSchemaFile("../testdata/sample.go") + if err != nil { + t.Fatalf("Failed to parse sample.go: %v", err) + } + + // Generate C code + header, source, err := GenerateCSchema(schema, "sample") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + // Create temp directory + tmpDir := t.TempDir() + + // Write header + headerPath := filepath.Join(tmpDir, "sample.gen.h") + if err := os.WriteFile(headerPath, header, 0644); err != nil { + t.Fatalf("Failed to write header: %v", err) + } + + // Write source + sourcePath := filepath.Join(tmpDir, "sample.gen.c") + if err := os.WriteFile(sourcePath, source, 0644); err != nil { + t.Fatalf("Failed to write source: %v", err) + } + + // Compile + objPath := filepath.Join(tmpDir, "sample.gen.o") + cmd := exec.Command("cc", "-std=c11", "-Wall", "-Wextra", "-Wno-unused-function", "-c", sourcePath, "-o", objPath) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C compilation failed:\n%s\n%s", string(output), err) + } + + // Verify object file exists + if _, err := os.Stat(objPath); os.IsNotExist(err) { + t.Fatal("Object file was not created") + } +} + +func TestCGenerateSchema_DecodeContext(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Inner", + Fields: []parser.Field{ + {Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + }, + }, + { + Name: "CtxMessage", + Fields: []parser.Field{ + { + Name: "Items", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindNested, + TypeName: "Inner", + }, + }, + }, + }, + }, + } + + header, _, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + + // Check decode context struct exists + if !strings.Contains(headerStr, "typedef struct test_ctx_message_decode_ctx {") { + t.Error("Missing decode context struct") + } + + // Check data pointer in context + if !strings.Contains(headerStr, "test_inner *items_data;") { + t.Error("Missing items_data field in context") + } + + // Check capacity field in context + if !strings.Contains(headerStr, "uint16_t items_cap;") { + t.Error("Missing items_cap field in context") + } + + // Check decode function with context + if !strings.Contains(headerStr, "arpack_status test_ctx_message_decode(test_ctx_message *msg, const uint8_t *buf, size_t buf_len, test_ctx_message_decode_ctx *ctx, size_t *out_read);") { + t.Error("Missing decode function with context") + } +} + +func TestCGenerateSchema_PrimitiveSlices(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "SliceMessage", + Fields: []parser.Field{ + { + Name: "Values", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint16, + }, + }, + { + Name: "Floats", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + }, + }, + }, + }, + }, + } + + header, source, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + headerStr := string(header) + if !strings.Contains(headerStr, "typedef struct test_uint16_slice_view {") { + t.Fatal("Missing uint16 slice view typedef") + } + if !strings.Contains(headerStr, "const uint16_t *data;") { + t.Fatal("uint16 slice view should reference uint16_t") + } + if !strings.Contains(headerStr, "typedef struct test_float32_slice_view {") { + t.Fatal("Missing float32 slice view typedef") + } + if !strings.Contains(headerStr, "const float *data;") { + t.Fatal("float32 slice view should reference float") + } + if !strings.Contains(headerStr, "uint16_t *values_data;") { + t.Fatal("Missing decode context storage for uint16 slice") + } + if !strings.Contains(headerStr, "float *floats_data;") { + t.Fatal("Missing decode context storage for float32 slice") + } + + compileCGeneratedObject(t, "test", header, source) +} + +func TestCGenerateSchema_FixedArrayNestedAndQuantized(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Vector3", + Fields: []parser.Field{ + { + Name: "X", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}, + }, + }, + }, + { + Name: "ArrayMessage", + Fields: []parser.Field{ + { + Name: "Points", + Kind: parser.KindFixedArray, + FixedLen: 2, + Elem: &parser.Field{ + Kind: parser.KindNested, + TypeName: "Vector3", + }, + }, + { + Name: "Samples", + Kind: parser.KindFixedArray, + FixedLen: 3, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{Min: 0, Max: 10, Bits: 8}, + }, + }, + }, + }, + }, + } + + header, source, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + sourceStr := string(source) + if !strings.Contains(sourceStr, "test_vector3_encode(&msg->points[_i0],") { + t.Fatal("Nested fixed array elements should call nested encode") + } + if !strings.Contains(sourceStr, "msg->samples[_i0]") { + t.Fatal("Quantized fixed array elements should be encoded through recursive element access") + } + + compileCGeneratedObject(t, "test", header, source) +} + +func TestCVariableLength_BoundsChecks(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "StringMessage", + Fields: []parser.Field{ + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + }, + }, + }, + } + + header, source, err := GenerateCSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateCSchema failed: %v", err) + } + + harness := `#include +#include "test.gen.h" + +int main(void) { + static const uint8_t truncated[] = {0x03, 0x00, 'a'}; + test_string_message decoded; + size_t read = 0; + arpack_status status = test_string_message_decode(&decoded, truncated, sizeof(truncated), &read); + printf("DECODE=%d\n", (int)status); + + test_string_message encoded; + encoded.name.data = "abc"; + encoded.name.len = 3; + uint8_t out[2]; + size_t written = 0; + status = test_string_message_encode(&encoded, out, sizeof(out), &written); + printf("ENCODE=%d\n", (int)status); + return 0; +} +` + + output := runGeneratedCProgram(t, "test", header, source, harness) + if !strings.Contains(output, "DECODE=1") { + t.Fatalf("decode should fail with ARPACK_ERR_BUFFER_TOO_SHORT, got:\n%s", output) + } + if !strings.Contains(output, "ENCODE=1") { + t.Fatalf("encode should fail with ARPACK_ERR_BUFFER_TOO_SHORT, got:\n%s", output) + } +} + +func TestCGenerateSchema_RejectsFixedArrayOfSlices(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "BadMessage", + Fields: []parser.Field{ + { + Name: "Values", + Kind: parser.KindFixedArray, + FixedLen: 2, + Elem: &parser.Field{ + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint16, + }, + }, + }, + }, + }, + }, + } + + _, _, err := GenerateCSchema(schema, "test") + if err == nil { + t.Fatal("expected GenerateCSchema to reject fixed arrays of slices") + } + if !strings.Contains(err.Error(), "fixed arrays of slices") { + t.Fatalf("unexpected error: %v", err) + } +} + +func requireCCompiler(t *testing.T) string { + t.Helper() + + for _, compiler := range []string{"cc", "gcc", "clang"} { + if _, err := exec.LookPath(compiler); err == nil { + return compiler + } + } + + t.Skip("No C compiler found (tried cc, gcc, clang)") + return "" +} + +func compileCGeneratedObject(t *testing.T, base string, header []byte, source []byte) { + t.Helper() + + cc := requireCCompiler(t) + tmpDir := t.TempDir() + headerPath := filepath.Join(tmpDir, base+".gen.h") + sourcePath := filepath.Join(tmpDir, base+".gen.c") + objPath := filepath.Join(tmpDir, base+".gen.o") + + if err := os.WriteFile(headerPath, header, 0644); err != nil { + t.Fatalf("Failed to write header: %v", err) + } + if err := os.WriteFile(sourcePath, source, 0644); err != nil { + t.Fatalf("Failed to write source: %v", err) + } + + cmd := exec.Command(cc, "-std=c11", "-Wall", "-Wextra", "-Wno-unused-function", "-c", sourcePath, "-o", objPath) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C compilation failed:\n%s\n%s", string(output), err) + } +} + +func runGeneratedCProgram(t *testing.T, base string, header []byte, source []byte, harness string) string { + t.Helper() + + cc := requireCCompiler(t) + tmpDir := t.TempDir() + headerPath := filepath.Join(tmpDir, base+".gen.h") + sourcePath := filepath.Join(tmpDir, base+".gen.c") + testPath := filepath.Join(tmpDir, "test.c") + binPath := filepath.Join(tmpDir, "test") + + if err := os.WriteFile(headerPath, header, 0644); err != nil { + t.Fatalf("Failed to write header: %v", err) + } + if err := os.WriteFile(sourcePath, source, 0644); err != nil { + t.Fatalf("Failed to write source: %v", err) + } + if err := os.WriteFile(testPath, []byte(harness), 0644); err != nil { + t.Fatalf("Failed to write harness: %v", err) + } + + cmd := exec.Command(cc, "-std=c11", "-Wall", "-Wextra", "-Wno-unused-function", "-o", binPath, testPath, sourcePath) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("C compilation failed:\n%s\n%s", string(output), err) + } + + runCmd := exec.Command(binPath) + runCmd.Dir = tmpDir + output, err = runCmd.CombinedOutput() + if err != nil { + t.Fatalf("C program failed:\n%s\n%s", string(output), err) + } + return string(output) +} diff --git a/generator/lua.go b/generator/lua.go index 158a07e..2cca660 100644 --- a/generator/lua.go +++ b/generator/lua.go @@ -69,6 +69,13 @@ func writeLuaHelpers(b *strings.Builder) { b.WriteString(" end\n") b.WriteString("end\n\n") + b.WriteString("local function ensure_u16_limit(n, context)\n") + b.WriteString(" if n < 0 or n > 65535 then\n") + b.WriteString(" error(string.format(\"arpack: %s exceeds uint16 limit: %d\", context, n))\n") + b.WriteString(" end\n") + b.WriteString(" return n\n") + b.WriteString("end\n\n") + b.WriteString("local function read_u8(data, offset)\n") b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n") b.WriteString(" return string.byte(data, offset), 1\n") @@ -304,6 +311,7 @@ func writeLuaHelpers(b *strings.Builder) { b.WriteString("local function write_string(s)\n") b.WriteString(" local len = #s\n") + b.WriteString(" ensure_u16_limit(len, \"string length\")\n") b.WriteString(" return write_u16_le(len) .. s\n") b.WriteString("end\n\n") } @@ -430,6 +438,7 @@ func writeLuaSerializeField(b *strings.Builder, recv string, f parser.Field, ind case parser.KindSlice: lenVar := "_len_" + strings.ToLower(f.Name) fmt.Fprintf(b, "%slocal %s = #(%s or {})\n", indent, lenVar, access) + fmt.Fprintf(b, "%s%s = ensure_u16_limit(%s, %q)\n", indent, lenVar, lenVar, "slice length for "+luaFieldName(f.Name)) fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, lenVar) iVar := "_i_" + strings.ToLower(f.Name) fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar) @@ -497,7 +506,7 @@ func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, i q := f.Quant maxUint := q.MaxUint() varName := "_q_" + sanitizeLuaVarName(access) - fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g + 0.5)\n", + fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g)\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) if q.Bits == 8 { fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName) diff --git a/generator/lua_test.go b/generator/lua_test.go index 737a852..312bcd5 100644 --- a/generator/lua_test.go +++ b/generator/lua_test.go @@ -281,8 +281,11 @@ func TestGenerateLua_QuantizedFloat(t *testing.T) { luaStr := string(lua) - if !strings.Contains(luaStr, "math.floor") { - t.Error("Missing math.floor for quantization") + if !strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535)") { + t.Error("Missing truncating quantization code for Lua") + } + if strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535 + 0.5)") { + t.Error("Lua quantization should not round to nearest") } if !strings.Contains(luaStr, "write_u16_le") { t.Error("Missing u16 write for 16-bit quantization") @@ -334,6 +337,7 @@ func TestLuaHelpersGenerated(t *testing.T) { "local bit = require('bit')", "buffer too short for u8", "buffer too short for bool", + "local function ensure_u16_limit(n, context)", "local function write_u8(n)", "buffer too short for u16", "local function write_u16_le(n)", @@ -453,6 +457,10 @@ func TestGenerateLua_BoundsChecks(t *testing.T) { t.Error("Missing check_bounds function") } + if !strings.Contains(luaStr, "ensure_u16_limit") { + t.Error("Missing uint16 overflow helper") + } + // Check that read_u16_le has bounds check if !strings.Contains(luaStr, "buffer too short for u16") { t.Error("Missing bounds check in read_u16_le") @@ -489,6 +497,146 @@ func TestGenerateLua_BoundsChecks(t *testing.T) { } } +func TestGenerateLua_LengthOverflowGuards(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "LengthLimited", + Fields: []parser.Field{ + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + { + Name: "Items", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint8, + }, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, `ensure_u16_limit(len, "string length")`) { + t.Error("Missing string length overflow guard") + } + + if !strings.Contains(luaStr, `ensure_u16_limit(_len_items, "slice length for items")`) { + t.Error("Missing slice length overflow guard") + } +} + +func TestGenerateLua_RuntimeLengthLimits(t *testing.T) { + if _, err := exec.LookPath("luajit"); err != nil { + t.Skip("luajit not found") + } + + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "LengthLimited", + Fields: []parser.Field{ + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + { + Name: "Items", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindUint8, + }, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "messages") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + dir := t.TempDir() + modulePath := filepath.Join(dir, "messages_gen.lua") + if err := os.WriteFile(modulePath, lua, 0o600); err != nil { + t.Fatalf("write module: %v", err) + } + + scriptPath := filepath.Join(dir, "check.lua") + script := `local messages = require("messages_gen") + +local function emit(label, ok, value) + if ok then + print(label .. ":OK") + else + print(label .. ":" .. tostring(value)) + end +end + +local msg = messages.new_length_limited() + +local ok, res = pcall(messages.serialize_length_limited, msg) +emit("EMPTY", ok, res) + +msg.name = string.rep("a", 65535) +ok, res = pcall(messages.serialize_length_limited, msg) +emit("STR_MAX", ok, res) + +msg.name = string.rep("a", 65536) +ok, res = pcall(messages.serialize_length_limited, msg) +emit("STR_OVER", ok, res) + +msg.name = "" +msg.items = {} +for i = 1, 65535 do + msg.items[i] = 0 +end +ok, res = pcall(messages.serialize_length_limited, msg) +emit("SLICE_MAX", ok, res) + +msg.items[65536] = 0 +ok, res = pcall(messages.serialize_length_limited, msg) +emit("SLICE_OVER", ok, res) +` + if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil { + t.Fatalf("write script: %v", err) + } + + cmd := exec.Command("luajit", "check.lua") + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("luajit failed: %v\n%s", err, out) + } + + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + if len(lines) != 5 { + t.Fatalf("expected 5 output lines, got %d: %q", len(lines), string(out)) + } + + if lines[0] != "EMPTY:OK" { + t.Fatalf("expected empty serialization to succeed, got %q", lines[0]) + } + if lines[1] != "STR_MAX:OK" { + t.Fatalf("expected 65535-byte string serialization to succeed, got %q", lines[1]) + } + if !strings.Contains(lines[2], "string length exceeds uint16 limit") { + t.Fatalf("expected string overflow guard, got %q", lines[2]) + } + if lines[3] != "SLICE_MAX:OK" { + t.Fatalf("expected 65535-element slice serialization to succeed, got %q", lines[3]) + } + if !strings.Contains(lines[4], "slice length for items exceeds uint16 limit") { + t.Fatalf("expected slice overflow guard, got %q", lines[4]) + } +} + func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) { if _, err := exec.LookPath("luajit"); err != nil { t.Skip("luajit not found") diff --git a/generator/ts.go b/generator/ts.go index fcd5c79..5ca32f3 100644 --- a/generator/ts.go +++ b/generator/ts.go @@ -287,12 +287,12 @@ func writeTSSerializeQuant(b *strings.Builder, access string, f parser.Field, in maxUint := q.MaxUint() varName := "_q" + sanitizeVarName(access) if q.Bits == 8 { - fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n", + fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName) fmt.Fprintf(b, "%spos += 1;\n", indent) } else { - fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n", + fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName) fmt.Fprintf(b, "%spos += 2;\n", indent) @@ -305,12 +305,12 @@ func writeTSSerializeQuantElement(b *strings.Builder, access string, f parser.Fi maxUint := q.MaxUint() varName := "_q" + sanitizeVarName(access) if q.Bits == 8 { - fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n", + fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName) fmt.Fprintf(b, "%spos += 1;\n", indent) } else { - fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n", + fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n", indent, varName, access, q.Min, q.Max, q.Min, maxUint) fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName) fmt.Fprintf(b, "%spos += 2;\n", indent) diff --git a/generator/ts_test.go b/generator/ts_test.go index 1c578c7..366d5d0 100644 --- a/generator/ts_test.go +++ b/generator/ts_test.go @@ -98,12 +98,12 @@ func TestGenerateTypeScript_QuantizedFloats(t *testing.T) { code := string(src) // Check 8-bit quantization (using camelCase field names) - if !strings.Contains(code, "Math.round((this.q8 - (0)) / (100 - (0)) * 255)") { + if !strings.Contains(code, "Math.trunc((this.q8 - (0)) / (100 - (0)) * 255)") { t.Error("Missing 8-bit quantization code") } // Check 16-bit quantization (using camelCase field names) - if !strings.Contains(code, "Math.round((this.q16 - (-500)) / (500 - (-500)) * 65535)") { + if !strings.Contains(code, "Math.trunc((this.q16 - (-500)) / (500 - (-500)) * 65535)") { t.Error("Missing 16-bit quantization code") }