diff --git a/README.md b/README.md index 9f46458..ea8d0f0 100644 --- a/README.md +++ b/README.md @@ -9,16 +9,16 @@ ![GitHub License](https://img.shields.io/github/license/edmand46/arpack) -Binary serialization code generator for Go, C#, and TypeScript. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, and `DataView`-based serialization for TypeScript/browser. +Binary serialization code generator for Go, C#, TypeScript, and Lua. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, `DataView`-based serialization for TypeScript/browser, and pure Lua implementation for Defold/LuaJIT. ## Features -- **Single source of truth** — define messages in Go, generate code for Go, C#, and TypeScript +- **Single source of truth** — define messages in Go, generate code for Go, C#, TypeScript, and Lua - **Float quantization** — compress `float32`/`float64` to 8 or 16 bits with a `pack` struct tag - **Boolean packing** — consecutive `bool` fields are packed into single bytes (up to 8 per byte) - **Enums** — `type Opcode uint16` + `const` block becomes C#/TypeScript enums - **Nested types, fixed arrays, slices** — full support for complex message structures -- **Cross-language binary compatibility** — Go, C#, and TypeScript produce identical wire formats +- **Cross-language binary compatibility** — Go, C#, TypeScript, and Lua produce identical wire formats - **Browser support** — TypeScript target uses native DataView API for zero-dependency serialization ## When to use @@ -31,6 +31,7 @@ Typical setups: - **Custom Go game server + Unity** — roll your own server without pulling in a serialization framework. ArPack generates plain `Marshal`/`Unmarshal` methods with zero allocations on the hot path. - **Any Go service + .NET client** — works anywhere you control both ends and want a compact binary protocol without Protobuf's runtime overhead or code-gen complexity. - **Go backend + Browser/WebSocket** — generate TypeScript classes for browser-based clients. Uses native DataView API with zero dependencies. +- **Go backend + Defold/Lua** — generate Lua modules for Defold game engine. Pure Lua implementation compatible with LuaJIT. ## Installation @@ -46,6 +47,9 @@ arpack -in messages.go -out-go ./gen -out-cs ../Unity/Assets/Scripts -out-ts ./w # Generate only TypeScript arpack -in messages.go -out-ts ./web/src/messages + +# Generate only Lua (for Defold) +arpack -in messages.go -out-lua ./defold/scripts/messages ``` | Flag | Description | @@ -54,12 +58,14 @@ arpack -in messages.go -out-ts ./web/src/messages | `-out-go` | Output directory for generated Go code | | `-out-cs` | Output directory for generated C# code | | `-out-ts` | Output directory for generated TypeScript code | +| `-out-lua` | Output directory for generated Lua code | | `-cs-namespace` | C# namespace (default: `Arpack.Messages`) | **Output files:** - Go: `{name}_gen.go` - C#: `{Name}.gen.cs` - TypeScript: `{Name}.gen.ts` +- Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility) ## Schema Definition @@ -98,16 +104,19 @@ type MoveMessage struct { ### Supported Types -| Type | Wire Size | -|---|---| -| `bool` (packed) | 1 bit (up to 8 per byte) | -| `int8`, `uint8` | 1 byte | -| `int16`, `uint16` | 2 bytes | -| `int32`, `uint32`, `float32` | 4 bytes | -| `int64`, `uint64`, `float64` | 8 bytes | -| `string` | 2-byte length prefix + UTF-8 | -| `[N]T` | N × sizeof(T) | -| `[]T` | 2-byte length prefix + N × sizeof(T) | +| Type | Wire Size | Lua Support | +|---|---|---| +| `bool` (packed) | 1 bit (up to 8 per byte) | ✓ (uses BitOp library) | +| `int8`, `uint8` | 1 byte | ✓ | +| `int16`, `uint16` | 2 bytes | ✓ | +| `int32`, `uint32`, `float32` | 4 bytes | ✓ | +| `int64`, `uint64` | 8 bytes | ✗ (LuaJIT limitation) | +| `float64` | 8 bytes | ✓ | +| `string` | 2-byte length prefix + UTF-8 | ✓ | +| `[N]T` | N × sizeof(T) | ✓ | +| `[]T` | 2-byte length prefix + N × sizeof(T) | ✓ | + +**Note:** `int64`/`uint64` are not supported in Lua target. LuaJIT (used by Defold) represents numbers as double-precision floats, which can only safely represent integers up to 2^53. Use `int32`/`uint32` instead. ### Float Quantization @@ -168,6 +177,31 @@ Uses native DataView API for browser-compatible serialization with zero dependen **Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID` → `playerId`). +### Lua + +```lua +local messages = require("messages.messages_gen") + +-- Create message +local msg = messages.new_move_message() +msg.player_id = 123 +msg.active = true + +-- Serialize +local data = messages.serialize_move_message(msg) + +-- Deserialize +local decoded, bytes_read = messages.deserialize_move_message(data, 1) +``` + +Uses pure Lua with inline helper functions for byte manipulation. Compatible with LuaJIT (Defold). All identifiers use snake_case (e.g., `MoveMessage` → `move_message`, `PlayerID` → `player_id`). + +**Requirements:** The generated Lua code requires the [BitOp library](https://bitop.luajit.org/) for bit manipulation. This library is included in LuaJIT (used by Defold). + +**Limitations:** +- Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53. +- Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution. + ## Wire Format - Little-endian byte order diff --git a/cmd/arpack/main.go b/cmd/arpack/main.go index 70016e7..ff6cbbd 100644 --- a/cmd/arpack/main.go +++ b/cmd/arpack/main.go @@ -17,6 +17,7 @@ func main() { outGo := flag.String("out-go", "", "output directory for generated Go code") outCS := flag.String("out-cs", "", "output directory for generated C# code") outTS := flag.String("out-ts", "", "output directory for generated TypeScript code") + outLua := flag.String("out-lua", "", "output directory for generated Lua code") namespace := flag.String("cs-namespace", "Arpack.Messages", "C# namespace") flag.Parse() @@ -24,8 +25,8 @@ func main() { log.Fatal("arpack: -in is required") } - if *outGo == "" && *outCS == "" && *outTS == "" { - log.Fatal("arpack: at least one of -out-go, -out-cs, or -out-ts is required") + if *outGo == "" && *outCS == "" && *outTS == "" && *outLua == "" { + log.Fatal("arpack: at least one of -out-go, -out-cs, -out-ts, or -out-lua is required") } schema, err := parser.ParseSchemaFile(*in) @@ -96,8 +97,56 @@ func main() { fmt.Printf("arpack: wrote %s\n", outPath) } + + if *outLua != "" { + src, err := generator.GenerateLuaSchema(schema, baseName) + if err != nil { + log.Fatalf("arpack: Lua generation error: %v", err) + } + + // Use snake_case filename for Lua require() compatibility + outPath := filepath.Join(*outLua, toSnakeCase(baseName)+"_gen.lua") + if err := os.MkdirAll(*outLua, 0755); err != nil { + log.Fatalf("arpack: mkdir %s: %v", *outLua, err) + } + if err := os.WriteFile(outPath, src, 0644); err != nil { + log.Fatalf("arpack: write %s: %v", outPath, err) + } + + fmt.Printf("arpack: wrote %s\n", outPath) + } } func toTitle(s string) string { return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) } + +func toSnakeCase(s string) string { + if s == "" { + return "" + } + + var b strings.Builder + var prevUpper bool + + for i, c := range s { + isUpper := c >= 'A' && c <= 'Z' + + if i > 0 && isUpper { + nextLower := false + if i+1 < len(s) { + nextChar := rune(s[i+1]) + nextLower = nextChar >= 'a' && nextChar <= 'z' + } + + if !prevUpper || nextLower { + b.WriteByte('_') + } + } + + b.WriteRune(c) + prevUpper = isUpper + } + + return strings.ToLower(b.String()) +} diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 3d1bf53..ddfe9ba 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -86,6 +86,84 @@ func TestE2E_CrossLanguage(t *testing.T) { } else { t.Log("node not found, skipping TypeScript cross-language e2e tests") } + + if _, err := exec.LookPath("luajit"); err == nil { + // Use a simpler test schema without int64/uint64 for Lua + luaSchema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Vector3", + Fields: []parser.Field{ + {Name: "X", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}}, + {Name: "Y", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}}, + {Name: "Z", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}}, + }, + }, + { + Name: "MoveMessage", + Fields: []parser.Field{ + {Name: "Position", Kind: parser.KindNested, TypeName: "Vector3"}, + {Name: "Velocity", Kind: parser.KindFixedArray, FixedLen: 3, Elem: &parser.Field{Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}}, + {Name: "Waypoints", Kind: parser.KindSlice, Elem: &parser.Field{Kind: parser.KindNested, TypeName: "Vector3"}}, + {Name: "PlayerID", Kind: parser.KindPrimitive, Primitive: parser.KindUint32}, + {Name: "Active", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Visible", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Ghost", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + }, + }, + { + Name: "EnvelopeMessage", + Fields: []parser.Field{ + {Name: "Code", Kind: parser.KindPrimitive, Primitive: parser.KindUint16}, + {Name: "Counter", Kind: parser.KindPrimitive, Primitive: parser.KindUint8}, + }, + }, + }, + Enums: []parser.Enum{ + { + Name: "Opcode", + Primitive: parser.KindUint16, + Values: []parser.EnumValue{ + {Name: "Unknown", Value: "0"}, + {Name: "Join", Value: "1"}, + {Name: "Leave", Value: "2"}, + }, + }, + }, + } + + luaSrc, err := generator.GenerateLuaSchema(luaSchema, "messages") + if err != nil { + t.Fatalf("GenerateLuaSchema: %v", err) + } + luaDir := buildLuaHarness(t, luaSrc) + + luaCases := []struct { + name string + typ string + epsilon float64 + }{ + {"Vector3", "Vector3", 0.02}, + {"MoveMessage", "MoveMessage", 0.02}, + {"EnvelopeMessage", "EnvelopeMessage", 0}, + } + + for _, tc := range luaCases { + t.Run("Go_to_Lua/"+tc.name, func(t *testing.T) { + hex := runHarness(t, goDir, "go", "ser", tc.typ, "") + out := runHarness(t, luaDir, "lua", "deser", tc.typ, hex) + checkOutput(t, tc.typ, out, tc.epsilon) + }) + t.Run("Lua_to_Go/"+tc.name, func(t *testing.T) { + hex := runHarness(t, luaDir, "lua", "ser", tc.typ, "") + out := runHarness(t, goDir, "go", "deser", tc.typ, hex) + checkOutput(t, tc.typ, out, tc.epsilon) + }) + } + } else { + t.Log("luajit not found, skipping Lua cross-language e2e tests") + } } func buildGoHarness(t *testing.T, generatedSrc []byte) string { @@ -139,6 +217,16 @@ func buildTSHarness(t *testing.T, generatedSrc []byte) string { return dir } +func buildLuaHarness(t *testing.T, generatedSrc []byte) string { + t.Helper() + dir := t.TempDir() + + write(t, filepath.Join(dir, "messages_gen.lua"), generatedSrc) + write(t, filepath.Join(dir, "harness.lua"), []byte(luaHarnessSource)) + + return dir +} + func runHarness(t *testing.T, dir, lang, op, typ, hexInput string) string { t.Helper() var cmd *exec.Cmd @@ -161,6 +249,12 @@ func runHarness(t *testing.T, dir, lang, op, typ, hexInput string) string { args = append(args, hexInput) } cmd = exec.Command("node", append([]string{filepath.Join(dir, "dist", "harness.js")}, args...)...) + case "lua": + args := []string{filepath.Join(dir, "harness.lua"), op, typ} + if hexInput != "" { + args = append(args, hexInput) + } + cmd = exec.Command("luajit", args...) } cmd.Dir = dir out, err := cmd.CombinedOutput() @@ -740,3 +834,111 @@ function main() { main(); ` + +const luaHarnessSource = `-- Lua E2E Harness +-- Usage: luajit harness.lua [hex_input] +-- op: 'ser' or 'deser' +-- type: message type name + +local messages = require("messages_gen") + +local function hexToBytes(hex) + local bytes = {} + for i = 1, #hex, 2 do + local byte = tonumber(hex:sub(i, i+1), 16) + table.insert(bytes, string.char(byte)) + end + return table.concat(bytes) +end + +local function bytesToHex(data) + local hex = {} + for i = 1, #data do + table.insert(hex, string.format("%02x", string.byte(data, i))) + end + return table.concat(hex) +end + +local function serializeVector3() + local msg = messages.new_vector3() + msg.x = 123.45 + msg.y = -200.0 + msg.z = 0.0 + return bytesToHex(messages.serialize_vector3(msg)) +end + +local function deserializeVector3(hex) + local data = hexToBytes(hex) + local msg = messages.deserialize_vector3(data, 1) + print(string.format("X=%.10g", msg.x)) + print(string.format("Y=%.10g", msg.y)) + print(string.format("Z=%.10g", msg.z)) +end + +local function serializeMoveMessage() + local msg = messages.new_move_message() + msg.position = messages.new_vector3() + msg.position.x = 10.0 + msg.position.y = 20.0 + msg.position.z = 30.0 + msg.velocity = {1.0, 2.0, 3.0} + msg.waypoints = {} + local wp = messages.new_vector3() + wp.x = 10.0 + wp.y = 20.0 + wp.z = 0.0 + table.insert(msg.waypoints, wp) + msg.player_id = 777 + msg.active = true + msg.visible = false + msg.ghost = true + msg.name = "TestPlayer" + return bytesToHex(messages.serialize_move_message(msg)) +end + +local function deserializeMoveMessage(hex) + local data = hexToBytes(hex) + local msg = messages.deserialize_move_message(data, 1) + print(string.format("PlayerID=%d", msg.player_id)) + print(string.format("Active=%s", tostring(msg.active))) + print(string.format("Visible=%s", tostring(msg.visible))) + print(string.format("Ghost=%s", tostring(msg.ghost))) + print(string.format("Name=%s", msg.name)) +end + +local function serializeEnvelopeMessage() + local msg = messages.new_envelope_message() + msg.code = 2 -- Join + msg.counter = 7 + return bytesToHex(messages.serialize_envelope_message(msg)) +end + +local function deserializeEnvelopeMessage(hex) + local data = hexToBytes(hex) + local msg = messages.deserialize_envelope_message(data, 1) + print(string.format("Code=%d", msg.code)) + print(string.format("Counter=%d", msg.counter)) +end + +local op = arg[1] +local typ = arg[2] +local hexInput = arg[3] + +local key = op .. ":" .. typ + +if key == "ser:Vector3" then + print(serializeVector3()) +elseif key == "deser:Vector3" then + deserializeVector3(hexInput) +elseif key == "ser:MoveMessage" then + print(serializeMoveMessage()) +elseif key == "deser:MoveMessage" then + deserializeMoveMessage(hexInput) +elseif key == "ser:EnvelopeMessage" then + print(serializeEnvelopeMessage()) +elseif key == "deser:EnvelopeMessage" then + deserializeEnvelopeMessage(hexInput) +else + error("Unknown op:type " .. key) +end +` diff --git a/generator/lua.go b/generator/lua.go new file mode 100644 index 0000000..53322f5 --- /dev/null +++ b/generator/lua.go @@ -0,0 +1,776 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/edmand46/arpack/parser" +) + +func GenerateLuaSchema(schema parser.Schema, moduleName string) ([]byte, error) { + if err := checkLuaUnsupportedTypes(schema); err != nil { + return nil, err + } + + messages := schema.Messages + var b strings.Builder + + b.WriteString("-- arpack \n") + b.WriteString("-- Code generated by arpack. DO NOT EDIT.\n\n") + + b.WriteString("local M = {}\n\n") + b.WriteString("-- Load BitOp library for bit operations (Defold/LuaJIT)\n") + b.WriteString("local bit = require('bit')\n\n") + + writeLuaHelpers(&b) + + enumNames := make(map[string]struct{}, len(schema.Enums)) + for _, enum := range schema.Enums { + enumNames[enum.Name] = struct{}{} + } + + for _, enum := range schema.Enums { + writeLuaEnum(&b, enum) + b.WriteString("\n") + } + + for _, msg := range messages { + writeLuaConstructor(&b, msg, enumNames) + b.WriteString("\n") + } + + for _, msg := range messages { + if err := writeLuaSerializer(&b, msg, enumNames); err != nil { + return nil, fmt.Errorf("message %s: %w", msg.Name, err) + } + b.WriteString("\n") + } + + for _, msg := range messages { + if err := writeLuaDeserializer(&b, msg, enumNames); err != nil { + return nil, fmt.Errorf("message %s: %w", msg.Name, err) + } + b.WriteString("\n") + } + + b.WriteString("return M\n") + + return []byte(b.String()), nil +} + +func writeLuaHelpers(b *strings.Builder) { + b.WriteString("-- Inline helpers for little-endian byte operations\n\n") + + b.WriteString("-- Error handling for truncated data\n") + b.WriteString("local function check_bounds(data, offset, needed, context)\n") + b.WriteString(" local available = #data - offset + 1\n") + b.WriteString(" if available < needed then\n") + b.WriteString(" error(string.format(\"arpack: buffer too short for %s: need %d bytes, have %d\", context, needed, available))\n") + b.WriteString(" end\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_u8(data, offset)\n") + b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n") + b.WriteString(" return string.byte(data, offset), 1\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_u8(n)\n") + b.WriteString(" return string.char(n)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_u16_le(data, offset)\n") + b.WriteString(" local b1, b2 = string.byte(data, offset, offset + 1)\n") + b.WriteString(" if not b2 then error(\"arpack: buffer too short for u16\") end\n") + b.WriteString(" return b1 + b2 * 256, 2\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_u16_le(n)\n") + b.WriteString(" return string.char(n % 256, math.floor(n / 256))\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_u32_le(data, offset)\n") + b.WriteString(" local b1, b2, b3, b4 = string.byte(data, offset, offset + 3)\n") + b.WriteString(" if not b4 then error(\"arpack: buffer too short for u32\") end\n") + b.WriteString(" return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216, 4\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_u32_le(n)\n") + b.WriteString(" return string.char(\n") + b.WriteString(" n % 256,\n") + b.WriteString(" math.floor(n / 256) % 256,\n") + b.WriteString(" math.floor(n / 65536) % 256,\n") + b.WriteString(" math.floor(n / 16777216) % 256\n") + b.WriteString(" )\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_i8(data, offset)\n") + b.WriteString(" if offset > #data then error(\"arpack: buffer too short for i8\") end\n") + b.WriteString(" local v = string.byte(data, offset)\n") + b.WriteString(" if v >= 128 then v = v - 256 end\n") + b.WriteString(" return v, 1\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_i8(n)\n") + b.WriteString(" if n < 0 then n = n + 256 end\n") + b.WriteString(" return string.char(n)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_i16_le(data, offset)\n") + b.WriteString(" local v = read_u16_le(data, offset)\n") + b.WriteString(" if v >= 32768 then v = v - 65536 end\n") + b.WriteString(" return v, 2\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_i16_le(n)\n") + b.WriteString(" if n < 0 then n = n + 65536 end\n") + b.WriteString(" return write_u16_le(n)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_i32_le(data, offset)\n") + b.WriteString(" local v = read_u32_le(data, offset)\n") + b.WriteString(" if v >= 2147483648 then v = v - 4294967296 end\n") + b.WriteString(" return v, 4\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_i32_le(n)\n") + b.WriteString(" if n < 0 then n = n + 4294967296 end\n") + b.WriteString(" return write_u32_le(n)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_f32_le(data, offset)\n") + b.WriteString(" local u32 = read_u32_le(data, offset)\n") + b.WriteString(" if u32 == 0 then return 0.0, 4 end\n") + b.WriteString(" local sign = (u32 >= 2147483648) and -1 or 1\n") + b.WriteString(" if sign < 0 then u32 = u32 - 2147483648 end\n") + b.WriteString(" local exp = math.floor(u32 / 8388608) % 256\n") + b.WriteString(" local mant = u32 % 8388608\n") + b.WriteString(" if exp == 0 then\n") + b.WriteString(" if mant == 0 then\n") + b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 4\n") + b.WriteString(" end\n") + b.WriteString(" return sign * (mant / 8388608) * math.pow(2, -126), 4\n") + b.WriteString(" elseif exp == 255 then\n") + b.WriteString(" if mant == 0 then return sign * math.huge, 4\n") + b.WriteString(" else return 0.0 / 0.0, 4 end\n") + b.WriteString(" end\n") + b.WriteString(" return sign * (1 + mant / 8388608) * math.pow(2, exp - 127), 4\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_f32_le(n)\n") + b.WriteString(" if n ~= n then return write_u32_le(2143289344) end\n") + b.WriteString(" if n == math.huge then return write_u32_le(2139095040) end\n") + b.WriteString(" if n == -math.huge then return write_u32_le(4286578688) end\n") + b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n") + b.WriteString(" if n == 0 then\n") + b.WriteString(" if 1/n == -math.huge then return write_u32_le(2147483648) end\n") + b.WriteString(" return write_u32_le(0)\n") + b.WriteString(" end\n") + b.WriteString(" local sign = 0\n") + b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n") + b.WriteString(" local exp\n") + b.WriteString(" local mant\n") + b.WriteString(" if n < math.pow(2, -126) then\n") + b.WriteString(" mant = math.floor(n / math.pow(2, -149) + 0.5)\n") + b.WriteString(" if mant <= 0 then return write_u32_le(sign) end\n") + b.WriteString(" if mant >= 8388608 then\n") + b.WriteString(" exp = 1\n") + b.WriteString(" mant = 0\n") + b.WriteString(" else\n") + b.WriteString(" exp = 0\n") + b.WriteString(" end\n") + b.WriteString(" else\n") + b.WriteString(" exp = math.floor(math.log(n, 2))\n") + b.WriteString(" mant = math.floor((n / math.pow(2, exp) - 1) * 8388608 + 0.5)\n") + b.WriteString(" if mant >= 8388608 then\n") + b.WriteString(" exp = exp + 1\n") + b.WriteString(" mant = 0\n") + b.WriteString(" end\n") + b.WriteString(" exp = exp + 127\n") + b.WriteString(" if exp >= 255 then\n") + b.WriteString(" return write_u32_le(sign + 2139095040)\n") + b.WriteString(" end\n") + b.WriteString(" end\n") + b.WriteString(" return write_u32_le(sign + exp * 8388608 + mant)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_f64_le(data, offset)\n") + b.WriteString(" -- Read 8 bytes directly to avoid precision loss from 64-bit arithmetic\n") + b.WriteString(" local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(data, offset, offset + 7)\n") + b.WriteString(" if not b8 then error(\"arpack: buffer too short for f64\") end\n") + b.WriteString(" local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216\n") + b.WriteString(" local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216\n") + b.WriteString(" -- Decode IEEE 754 double from low/high parts separately\n") + b.WriteString(" if low == 0 and high == 0 then return 0.0, 8 end\n") + b.WriteString(" local sign = (high >= 2147483648) and -1 or 1\n") + b.WriteString(" if sign < 0 then high = high - 2147483648 end\n") + b.WriteString(" local exp = math.floor(high / 1048576) % 2048\n") + b.WriteString(" local high_mant = high % 1048576\n") + b.WriteString(" if exp == 0 then\n") + b.WriteString(" local mant = high_mant * 4294967296 + low\n") + b.WriteString(" if mant == 0 then\n") + b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 8\n") + b.WriteString(" end\n") + b.WriteString(" return sign * (mant / 4503599627370496) * math.pow(2, -1022), 8\n") + b.WriteString(" elseif exp == 2047 then\n") + b.WriteString(" if high_mant == 0 and low == 0 then return sign * math.huge, 8\n") + b.WriteString(" else return 0.0 / 0.0, 8 end\n") + b.WriteString(" end\n") + b.WriteString(" local mant = high_mant * 4294967296 + low\n") + b.WriteString(" return sign * (1 + mant / 4503599627370496) * math.pow(2, exp - 1023), 8\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_f64_le(n)\n") + b.WriteString(" -- Handle special values\n") + b.WriteString(" if n ~= n then -- NaN\n") + b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 248, 127)\n") + b.WriteString(" end\n") + b.WriteString(" if n == math.huge then\n") + b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 127)\n") + b.WriteString(" end\n") + b.WriteString(" if n == -math.huge then\n") + b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 255)\n") + b.WriteString(" end\n") + b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n") + b.WriteString(" if n == 0 then\n") + b.WriteString(" if 1/n == -math.huge then\n") + b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 128)\n") + b.WriteString(" end\n") + b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 0)\n") + b.WriteString(" end\n") + b.WriteString(" local sign = 0\n") + b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n") + b.WriteString(" local exp\n") + b.WriteString(" local mant\n") + b.WriteString(" if n < math.pow(2, -1022) then\n") + b.WriteString(" mant = math.floor(n / math.pow(2, -1074) + 0.5)\n") + b.WriteString(" if mant <= 0 then\n") + b.WriteString(" local high = sign\n") + b.WriteString(" return string.char(0, 0, 0, 0, high % 256, math.floor(high / 256) % 256, math.floor(high / 65536) % 256, math.floor(high / 16777216) % 256)\n") + b.WriteString(" end\n") + b.WriteString(" if mant >= 4503599627370496 then\n") + b.WriteString(" exp = 1\n") + b.WriteString(" mant = 0\n") + b.WriteString(" else\n") + b.WriteString(" exp = 0\n") + b.WriteString(" end\n") + b.WriteString(" else\n") + b.WriteString(" exp = math.floor(math.log(n, 2))\n") + b.WriteString(" mant = (n / math.pow(2, exp) - 1) * 4503599627370496\n") + b.WriteString(" mant = math.floor(mant + 0.5)\n") + b.WriteString(" if mant >= 4503599627370496 then\n") + b.WriteString(" exp = exp + 1\n") + b.WriteString(" mant = 0\n") + b.WriteString(" end\n") + b.WriteString(" exp = exp + 1023\n") + b.WriteString(" if exp >= 2047 then\n") + b.WriteString(" exp = 2047\n") + b.WriteString(" mant = 0\n") + b.WriteString(" end\n") + b.WriteString(" end\n") + b.WriteString(" local high_mant = math.floor(mant / 4294967296)\n") + b.WriteString(" local low = mant % 4294967296\n") + b.WriteString(" local high = sign + exp * 1048576 + high_mant\n") + b.WriteString(" return string.char(\n") + b.WriteString(" low % 256,\n") + b.WriteString(" math.floor(low / 256) % 256,\n") + b.WriteString(" math.floor(low / 65536) % 256,\n") + b.WriteString(" math.floor(low / 16777216) % 256,\n") + b.WriteString(" high % 256,\n") + b.WriteString(" math.floor(high / 256) % 256,\n") + b.WriteString(" math.floor(high / 65536) % 256,\n") + b.WriteString(" math.floor(high / 16777216) % 256\n") + b.WriteString(" )\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_bool(data, offset)\n") + b.WriteString(" if offset > #data then error(\"arpack: buffer too short for bool\") end\n") + b.WriteString(" return string.byte(data, offset) ~= 0, 1\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_bool(v)\n") + b.WriteString(" return string.char(v and 1 or 0)\n") + b.WriteString("end\n\n") + + b.WriteString("local function read_string(data, offset)\n") + b.WriteString(" local len, header_bytes = read_u16_le(data, offset)\n") + b.WriteString(" if len == 0 then return '', header_bytes end\n") + b.WriteString(" local available = #data - offset - header_bytes + 1\n") + b.WriteString(" if available < len then\n") + b.WriteString(" error(string.format(\"arpack: buffer too short for string: need %d bytes, have %d\", len, available))\n") + b.WriteString(" end\n") + b.WriteString(" -- string.sub is 1-based; data starts at offset + header_bytes\n") + b.WriteString(" return string.sub(data, offset + header_bytes, offset + header_bytes + len - 1), header_bytes + len\n") + b.WriteString("end\n\n") + + b.WriteString("local function write_string(s)\n") + b.WriteString(" local len = #s\n") + b.WriteString(" return write_u16_le(len) .. s\n") + b.WriteString("end\n\n") +} + +func writeLuaEnum(b *strings.Builder, enum parser.Enum) { + fmt.Fprintf(b, "M.%s = {\n", enum.Name) + for i, value := range enum.Values { + fmt.Fprintf(b, " %s = %s", value.Name, value.Value) + if i < len(enum.Values)-1 { + b.WriteString(",") + } + b.WriteString("\n") + } + b.WriteString("}\n") +} + +func writeLuaConstructor(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) { + fmt.Fprintf(b, "function M.new_%s()\n", toSnakeCase(msg.Name)) + b.WriteString(" return {\n") + for _, f := range msg.Fields { + defaultValue := luaDefaultValue(f, enumNames) + fmt.Fprintf(b, " %s = %s,\n", luaFieldName(f.Name), defaultValue) + } + b.WriteString(" }\n") + b.WriteString("end\n") +} + +func writeLuaSerializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error { + segs := segmentFields(msg.Fields) + + fmt.Fprintf(b, "function M.serialize_%s(msg)\n", toSnakeCase(msg.Name)) + b.WriteString(" local parts = {}\n") + b.WriteString(" local part_idx = 0\n") + + for i, seg := range segs { + if seg.single != nil { + if err := writeLuaSerializeField(b, "msg", *seg.single, " ", enumNames); err != nil { + return err + } + } else { + writeLuaBoolGroupSerialize(b, "msg", seg.bools, i, " ") + } + } + + b.WriteString(" return table.concat(parts)\n") + b.WriteString("end\n") + return nil +} + +func writeLuaDeserializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error { + segs := segmentFields(msg.Fields) + minSize := packedMinWireSize(msg.Fields) + + fmt.Fprintf(b, "function M.deserialize_%s(data, offset)\n", toSnakeCase(msg.Name)) + b.WriteString(" offset = offset or 1\n") + fmt.Fprintf(b, " local msg = M.new_%s()\n", toSnakeCase(msg.Name)) + b.WriteString(" local start_offset = offset\n") + b.WriteString(" local bytes_read = 0\n") + fmt.Fprintf(b, " if #data < offset + %d - 1 then\n", minSize) + fmt.Fprintf(b, " error(\"arpack: buffer too short for %s\")\n", msg.Name) + b.WriteString(" end\n") + + for i, seg := range segs { + if seg.single != nil { + if err := writeLuaDeserializeField(b, "msg", *seg.single, " ", enumNames); err != nil { + return err + } + } else { + writeLuaBoolGroupDeserialize(b, "msg", seg.bools, i, " ") + } + } + + b.WriteString(" return msg, offset - start_offset\n") + b.WriteString("end\n") + return nil +} + +func writeLuaBoolGroupSerialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) { + varName := fmt.Sprintf("_bool_byte_%d", groupIdx) + fmt.Fprintf(b, "%slocal %s = 0\n", indent, varName) + for bit, f := range bools { + fmt.Fprintf(b, "%sif %s.%s then %s = bit.bor(%s, %d) end\n", + indent, recv, luaFieldName(f.Name), varName, varName, 1<= 'A' && c <= 'Z' + + // Add underscore before uppercase letter if: + // - It's not the first character + // - Previous character was lowercase, OR + // - Previous character was uppercase AND next character (if exists) is lowercase + // (this handles cases like "PlayerID" -> "player_id", not "player_i_d") + if i > 0 && isUpper { + nextLower := false + if i+1 < len(s) { + nextChar := rune(s[i+1]) + nextLower = nextChar >= 'a' && nextChar <= 'z' + } + + if !prevUpper || nextLower { + b.WriteByte('_') + } + } + + b.WriteRune(c) + prevUpper = isUpper + } + + return strings.ToLower(b.String()) +} + +func luaDefaultValue(f parser.Field, enumNames map[string]struct{}) string { + switch f.Kind { + case parser.KindPrimitive: + if luaIsEnumType(f, enumNames) { + return "0" + } + switch f.Primitive { + case parser.KindFloat32, parser.KindFloat64, parser.KindInt8, parser.KindInt16, parser.KindInt32, + parser.KindUint8, parser.KindUint16, parser.KindUint32, parser.KindInt64, parser.KindUint64: + return "0" + case parser.KindBool: + return "false" + case parser.KindString: + return "''" + } + case parser.KindNested: + return fmt.Sprintf("M.new_%s()", toSnakeCase(f.TypeName)) + case parser.KindFixedArray, parser.KindSlice: + return "{}" + } + return "nil" +} + +func luaTypeName(f parser.Field, enumNames map[string]struct{}) string { + switch f.Kind { + case parser.KindPrimitive: + if luaIsEnumType(f, enumNames) { + return "number" + } + return "number" + case parser.KindNested: + return f.TypeName + case parser.KindFixedArray, parser.KindSlice: + return "table" + } + return "any" +} + +func luaSerializeValueExpr(access string, f parser.Field, enumNames map[string]struct{}) string { + if !luaIsEnumType(f, enumNames) { + return access + } + return access +} + +func luaDeserializeValueExpr(expr string, f parser.Field, enumNames map[string]struct{}) string { + if !luaIsEnumType(f, enumNames) { + return expr + } + return expr +} + +func luaIsEnumType(f parser.Field, enumNames map[string]struct{}) bool { + if f.NamedType == "" || enumNames == nil { + return false + } + _, ok := enumNames[f.NamedType] + return ok +} + +func sanitizeLuaVarName(s string) string { + var b strings.Builder + for _, c := range s { + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') { + b.WriteRune(c) + } else { + b.WriteRune('_') + } + } + return strings.ToLower(b.String()) +} + +func checkLuaUnsupportedTypes(schema parser.Schema) error { + for _, msg := range schema.Messages { + for _, f := range msg.Fields { + if err := checkFieldLuaSupport(msg.Name, f); err != nil { + return err + } + } + } + return nil +} + +func checkFieldLuaSupport(msgName string, f parser.Field) error { + switch f.Kind { + case parser.KindPrimitive: + if f.Primitive == parser.KindInt64 || f.Primitive == parser.KindUint64 { + return fmt.Errorf("Lua target does not support int64/uint64: field %s in message %s (LuaJIT/Defold uses double-precision floats which can only safely represent integers up to 2^53)", f.Name, msgName) + } + case parser.KindFixedArray, parser.KindSlice: + if f.Elem != nil { + return checkFieldLuaSupport(msgName, *f.Elem) + } + } + return nil +} diff --git a/generator/lua_test.go b/generator/lua_test.go new file mode 100644 index 0000000..737a852 --- /dev/null +++ b/generator/lua_test.go @@ -0,0 +1,557 @@ +package generator + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/edmand46/arpack/parser" +) + +func TestGenerateLua_BasicTypes(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "BasicTypes", + Fields: []parser.Field{ + {Name: "Int8Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt8}, + {Name: "Int16Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt16}, + {Name: "Int32Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + {Name: "Uint8Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint8}, + {Name: "Uint16Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint16}, + {Name: "Uint32Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint32}, + {Name: "Float32Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + {Name: "Float64Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64}, + {Name: "BoolField", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "StringField", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "function M.new_basic_types()") { + t.Error("Missing constructor for BasicTypes") + } + if !strings.Contains(luaStr, "function M.serialize_basic_types(msg)") { + t.Error("Missing serializer for BasicTypes") + } + if !strings.Contains(luaStr, "function M.deserialize_basic_types(data, offset)") { + t.Error("Missing deserializer for BasicTypes") + } + + if !strings.Contains(luaStr, "int8_field = 0") { + t.Error("Missing int8_field in constructor") + } + if !strings.Contains(luaStr, "string_field = ''") { + t.Error("Missing string_field default value") + } + if !strings.Contains(luaStr, "bool_field = false") { + t.Error("Missing bool_field default value") + } +} + +func TestGenerateLua_Enum(t *testing.T) { + schema := parser.Schema{ + Enums: []parser.Enum{ + { + Name: "Opcode", + Primitive: parser.KindUint16, + Values: []parser.EnumValue{ + {Name: "Unknown", Value: "0"}, + {Name: "Join", Value: "1"}, + {Name: "Leave", Value: "2"}, + }, + }, + }, + Messages: []parser.Message{ + { + Name: "MessageWithEnum", + Fields: []parser.Field{ + {Name: "Op", Kind: parser.KindPrimitive, Primitive: parser.KindUint16, NamedType: "Opcode"}, + }, + }, + }, + } + + enumNames := map[string]struct{}{"Opcode": {}} + _ = enumNames + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "M.Opcode = {") { + t.Error("Missing Opcode enum table") + } + if !strings.Contains(luaStr, "Unknown = 0") { + t.Error("Missing Unknown enum value") + } + if !strings.Contains(luaStr, "Join = 1") { + t.Error("Missing Join enum value") + } +} + +func TestGenerateLua_NestedMessage(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Vector3", + Fields: []parser.Field{ + {Name: "X", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + {Name: "Y", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + {Name: "Z", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + }, + }, + { + Name: "Player", + Fields: []parser.Field{ + {Name: "Position", Kind: parser.KindNested, TypeName: "Vector3"}, + {Name: "Health", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "function M.new_vector3()") { + t.Error("Missing constructor for Vector3") + } + if !strings.Contains(luaStr, "function M.new_player()") { + t.Error("Missing constructor for Player") + } + if !strings.Contains(luaStr, "position = M.new_vector3()") { + t.Error("Missing nested initialization in Player constructor") + } + if !strings.Contains(luaStr, "M.serialize_vector3") { + t.Error("Missing Vector3 serializer call") + } +} + +func TestGenerateLua_FixedArray(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithFixedArray", + Fields: []parser.Field{ + { + Name: "Values", + Kind: parser.KindFixedArray, + FixedLen: 3, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + }, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "values = {}") { + t.Error("Missing values array initialization") + } + if !strings.Contains(luaStr, "for _i_values = 1, 3 do") { + t.Error("Missing fixed array loop in serializer") + } +} + +func TestGenerateLua_Slice(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithSlice", + Fields: []parser.Field{ + { + Name: "Items", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindInt32, + }, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "items = {}") { + t.Error("Missing items slice initialization") + } + if !strings.Contains(luaStr, "local _len_items = #(msg.items or {})") { + t.Error("Missing slice length serialization") + } + if !strings.Contains(luaStr, "for _i_items = 1, _len_items do") { + t.Error("Missing slice iteration in serializer") + } +} + +func TestGenerateLua_BoolPacking(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithBools", + Fields: []parser.Field{ + {Name: "A", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "B", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "C", Kind: parser.KindPrimitive, Primitive: parser.KindBool}, + {Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt32}, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "local _bool_byte_0 = 0") { + t.Error("Missing bool byte packing variable") + } + if !strings.Contains(luaStr, "if msg.a then _bool_byte_0 = bit.bor(_bool_byte_0, 1) end") { + t.Error("Missing first bool packing check with bit.bor") + } + if !strings.Contains(luaStr, "if msg.b then _bool_byte_0 = bit.bor(_bool_byte_0, 2) end") { + t.Error("Missing second bool packing check with bit.bor") + } + if !strings.Contains(luaStr, "msg.a = bit.band(_bool_byte_0, 1) ~= 0") { + t.Error("Missing bit.band for bool deserialization") + } +} + +func TestGenerateLua_QuantizedFloat(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithQuantized", + Fields: []parser.Field{ + { + Name: "Position", + Kind: parser.KindPrimitive, + Primitive: parser.KindFloat32, + Quant: &parser.QuantInfo{ + Min: -500, + Max: 500, + Bits: 16, + }, + }, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + if !strings.Contains(luaStr, "math.floor") { + t.Error("Missing math.floor for quantization") + } + if !strings.Contains(luaStr, "write_u16_le") { + t.Error("Missing u16 write for 16-bit quantization") + } +} + +func TestToSnakeCase(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", ""}, + {"A", "a"}, + {"AB", "ab"}, + {"AbCd", "ab_cd"}, + {"ABC", "abc"}, + {"PlayerID", "player_id"}, + {"HTTPResponse", "http_response"}, + {"XMLHttpRequest", "xml_http_request"}, + {"getHTTPResponse", "get_http_response"}, + } + + for _, tt := range tests { + result := toSnakeCase(tt.input) + if result != tt.expected { + t.Errorf("toSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected) + } + } +} + +func TestLuaHelpersGenerated(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "Empty", + Fields: []parser.Field{}, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + helpers := []string{ + "local bit = require('bit')", + "buffer too short for u8", + "buffer too short for bool", + "local function write_u8(n)", + "buffer too short for u16", + "local function write_u16_le(n)", + "buffer too short for u32", + "local function write_u32_le(n)", + "local function read_f32_le(data, offset)", + "local function write_f32_le(n)", + "local function read_f64_le(data, offset)", + "local function write_f64_le(n)", + "local function write_bool(v)", + "buffer too short for string", + "local function write_string(s)", + } + + for _, helper := range helpers { + if !strings.Contains(luaStr, helper) { + t.Errorf("Missing helper: %s", helper) + } + } +} + +func TestGenerateLua_Int64NotSupported(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithInt64", + Fields: []parser.Field{ + {Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt64}, + }, + }, + }, + } + + _, err := GenerateLuaSchema(schema, "test") + if err == nil { + t.Fatal("Expected error for int64 field, got nil") + } + if !strings.Contains(err.Error(), "int64/uint64") { + t.Errorf("Expected error mentioning int64/uint64, got: %v", err) + } + if !strings.Contains(err.Error(), "LuaJIT/Defold") { + t.Errorf("Expected error mentioning LuaJIT/Defold, got: %v", err) + } +} + +func TestGenerateLua_Uint64NotSupported(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithUint64", + Fields: []parser.Field{ + {Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindUint64}, + }, + }, + }, + } + + _, err := GenerateLuaSchema(schema, "test") + if err == nil { + t.Fatal("Expected error for uint64 field, got nil") + } + if !strings.Contains(err.Error(), "int64/uint64") { + t.Errorf("Expected error mentioning int64/uint64, got: %v", err) + } +} + +func TestGenerateLua_Int64InSliceNotSupported(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "WithInt64Slice", + Fields: []parser.Field{ + { + Name: "Values", + Kind: parser.KindSlice, + Elem: &parser.Field{ + Kind: parser.KindPrimitive, + Primitive: parser.KindInt64, + }, + }, + }, + }, + }, + } + + _, err := GenerateLuaSchema(schema, "test") + if err == nil { + t.Fatal("Expected error for int64 in slice, got nil") + } + if !strings.Contains(err.Error(), "int64/uint64") { + t.Errorf("Expected error mentioning int64/uint64, got: %v", err) + } +} + +func TestGenerateLua_BoundsChecks(t *testing.T) { + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "SimpleMessage", + Fields: []parser.Field{ + {Name: "ID", Kind: parser.KindPrimitive, Primitive: parser.KindUint32}, + {Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString}, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "test") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + luaStr := string(lua) + + // Check that bounds check function exists + if !strings.Contains(luaStr, "check_bounds") { + t.Error("Missing check_bounds function") + } + + // Check that read_u16_le has bounds check + if !strings.Contains(luaStr, "buffer too short for u16") { + t.Error("Missing bounds check in read_u16_le") + } + + // Check that read_u32_le has bounds check + if !strings.Contains(luaStr, "buffer too short for u32") { + t.Error("Missing bounds check in read_u32_le") + } + + // Check that read_string has bounds check + if !strings.Contains(luaStr, "buffer too short for string") { + t.Error("Missing bounds check in read_string") + } + + // Check that deserialize function has min size check (message name is preserved in error) + if !strings.Contains(luaStr, "buffer too short for SimpleMessage") { + t.Error("Missing min size check in deserialize function") + } + + // Check that read_u8 has bounds check + if !strings.Contains(luaStr, "buffer too short for u8") { + t.Error("Missing bounds check in read_u8") + } + + // Check that read_bool has bounds check + if !strings.Contains(luaStr, "buffer too short for bool") { + t.Error("Missing bounds check in read_bool") + } + + // Check that read_i8 has bounds check + if !strings.Contains(luaStr, "buffer too short for i8") { + t.Error("Missing bounds check in read_i8") + } +} + +func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) { + if _, err := exec.LookPath("luajit"); err != nil { + t.Skip("luajit not found") + } + + schema := parser.Schema{ + Messages: []parser.Message{ + { + Name: "FloatEdges", + Fields: []parser.Field{ + {Name: "F32", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}, + {Name: "F64", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64}, + }, + }, + }, + } + + lua, err := GenerateLuaSchema(schema, "messages") + if err != nil { + t.Fatalf("GenerateLuaSchema failed: %v", err) + } + + dir := t.TempDir() + modulePath := filepath.Join(dir, "messages_gen.lua") + if err := os.WriteFile(modulePath, lua, 0o600); err != nil { + t.Fatalf("write module: %v", err) + } + + scriptPath := filepath.Join(dir, "check.lua") + script := `local messages = require("messages_gen") + +local function bytes_to_hex(s) + return (s:gsub(".", function(c) return string.format("%02x", string.byte(c)) end)) +end + +local neg_zero = string.char(0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 128) +local msg = messages.deserialize_float_edges(neg_zero, 1) +print(bytes_to_hex(messages.serialize_float_edges(msg))) + +local subnormal = string.char(1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0) +msg = messages.deserialize_float_edges(subnormal, 1) +print(bytes_to_hex(messages.serialize_float_edges(msg))) +` + if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil { + t.Fatalf("write script: %v", err) + } + + cmd := exec.Command("luajit", "check.lua") + cmd.Dir = dir + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("luajit failed: %v\n%s", err, out) + } + + lines := strings.Split(strings.TrimSpace(string(out)), "\n") + if len(lines) != 2 { + t.Fatalf("expected 2 output lines, got %d: %q", len(lines), string(out)) + } + + if lines[0] != "000000800000000000000080" { + t.Fatalf("negative zero roundtrip mismatch: %s", lines[0]) + } + if lines[1] != "010000000100000000000000" { + t.Fatalf("subnormal roundtrip mismatch: %s", lines[1]) + } +}