feat: added lua

This commit is contained in:
2026-03-25 13:02:08 +03:00
parent 57f3d9e976
commit cf2e095fbe
5 changed files with 1633 additions and 15 deletions
+47 -13
View File
@@ -9,16 +9,16 @@
![GitHub License](https://img.shields.io/github/license/edmand46/arpack) ![GitHub License](https://img.shields.io/github/license/edmand46/arpack)
Binary serialization code generator for Go, C#, and TypeScript. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, and `DataView`-based serialization for TypeScript/browser. Binary serialization code generator for Go, C#, TypeScript, and Lua. Define messages once as Go structs — get zero-allocation `Marshal`/`Unmarshal` for Go, `unsafe` pointer-based `Serialize`/`Deserialize` for C#, `DataView`-based serialization for TypeScript/browser, and pure Lua implementation for Defold/LuaJIT.
## Features ## Features
- **Single source of truth** — define messages in Go, generate code for Go, C#, and TypeScript - **Single source of truth** — define messages in Go, generate code for Go, C#, TypeScript, and Lua
- **Float quantization** — compress `float32`/`float64` to 8 or 16 bits with a `pack` struct tag - **Float quantization** — compress `float32`/`float64` to 8 or 16 bits with a `pack` struct tag
- **Boolean packing** — consecutive `bool` fields are packed into single bytes (up to 8 per byte) - **Boolean packing** — consecutive `bool` fields are packed into single bytes (up to 8 per byte)
- **Enums** — `type Opcode uint16` + `const` block becomes C#/TypeScript enums - **Enums** — `type Opcode uint16` + `const` block becomes C#/TypeScript enums
- **Nested types, fixed arrays, slices** — full support for complex message structures - **Nested types, fixed arrays, slices** — full support for complex message structures
- **Cross-language binary compatibility** — Go, C#, and TypeScript produce identical wire formats - **Cross-language binary compatibility** — Go, C#, TypeScript, and Lua produce identical wire formats
- **Browser support** — TypeScript target uses native DataView API for zero-dependency serialization - **Browser support** — TypeScript target uses native DataView API for zero-dependency serialization
## When to use ## When to use
@@ -31,6 +31,7 @@ Typical setups:
- **Custom Go game server + Unity** — roll your own server without pulling in a serialization framework. ArPack generates plain `Marshal`/`Unmarshal` methods with zero allocations on the hot path. - **Custom Go game server + Unity** — roll your own server without pulling in a serialization framework. ArPack generates plain `Marshal`/`Unmarshal` methods with zero allocations on the hot path.
- **Any Go service + .NET client** — works anywhere you control both ends and want a compact binary protocol without Protobuf's runtime overhead or code-gen complexity. - **Any Go service + .NET client** — works anywhere you control both ends and want a compact binary protocol without Protobuf's runtime overhead or code-gen complexity.
- **Go backend + Browser/WebSocket** — generate TypeScript classes for browser-based clients. Uses native DataView API with zero dependencies. - **Go backend + Browser/WebSocket** — generate TypeScript classes for browser-based clients. Uses native DataView API with zero dependencies.
- **Go backend + Defold/Lua** — generate Lua modules for Defold game engine. Pure Lua implementation compatible with LuaJIT.
## Installation ## Installation
@@ -46,6 +47,9 @@ arpack -in messages.go -out-go ./gen -out-cs ../Unity/Assets/Scripts -out-ts ./w
# Generate only TypeScript # Generate only TypeScript
arpack -in messages.go -out-ts ./web/src/messages arpack -in messages.go -out-ts ./web/src/messages
# Generate only Lua (for Defold)
arpack -in messages.go -out-lua ./defold/scripts/messages
``` ```
| Flag | Description | | Flag | Description |
@@ -54,12 +58,14 @@ arpack -in messages.go -out-ts ./web/src/messages
| `-out-go` | Output directory for generated Go code | | `-out-go` | Output directory for generated Go code |
| `-out-cs` | Output directory for generated C# code | | `-out-cs` | Output directory for generated C# code |
| `-out-ts` | Output directory for generated TypeScript code | | `-out-ts` | Output directory for generated TypeScript code |
| `-out-lua` | Output directory for generated Lua code |
| `-cs-namespace` | C# namespace (default: `Arpack.Messages`) | | `-cs-namespace` | C# namespace (default: `Arpack.Messages`) |
**Output files:** **Output files:**
- Go: `{name}_gen.go` - Go: `{name}_gen.go`
- C#: `{Name}.gen.cs` - C#: `{Name}.gen.cs`
- TypeScript: `{Name}.gen.ts` - TypeScript: `{Name}.gen.ts`
- Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility)
## Schema Definition ## Schema Definition
@@ -98,16 +104,19 @@ type MoveMessage struct {
### Supported Types ### Supported Types
| Type | Wire Size | | Type | Wire Size | Lua Support |
|---|---| |---|---|---|
| `bool` (packed) | 1 bit (up to 8 per byte) | | `bool` (packed) | 1 bit (up to 8 per byte) | ✓ (uses BitOp library) |
| `int8`, `uint8` | 1 byte | | `int8`, `uint8` | 1 byte | ✓ |
| `int16`, `uint16` | 2 bytes | | `int16`, `uint16` | 2 bytes | ✓ |
| `int32`, `uint32`, `float32` | 4 bytes | | `int32`, `uint32`, `float32` | 4 bytes | ✓ |
| `int64`, `uint64`, `float64` | 8 bytes | | `int64`, `uint64` | 8 bytes | ✗ (LuaJIT limitation) |
| `string` | 2-byte length prefix + UTF-8 | | `float64` | 8 bytes | ✓ |
| `[N]T` | N × sizeof(T) | | `string` | 2-byte length prefix + UTF-8 | ✓ |
| `[]T` | 2-byte length prefix + N × sizeof(T) | | `[N]T` | N × sizeof(T) | ✓ |
| `[]T` | 2-byte length prefix + N × sizeof(T) | ✓ |
**Note:** `int64`/`uint64` are not supported in Lua target. LuaJIT (used by Defold) represents numbers as double-precision floats, which can only safely represent integers up to 2^53. Use `int32`/`uint32` instead.
### Float Quantization ### Float Quantization
@@ -168,6 +177,31 @@ Uses native DataView API for browser-compatible serialization with zero dependen
**Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID``playerId`). **Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID``playerId`).
### Lua
```lua
local messages = require("messages.messages_gen")
-- Create message
local msg = messages.new_move_message()
msg.player_id = 123
msg.active = true
-- Serialize
local data = messages.serialize_move_message(msg)
-- Deserialize
local decoded, bytes_read = messages.deserialize_move_message(data, 1)
```
Uses pure Lua with inline helper functions for byte manipulation. Compatible with LuaJIT (Defold). All identifiers use snake_case (e.g., `MoveMessage``move_message`, `PlayerID``player_id`).
**Requirements:** The generated Lua code requires the [BitOp library](https://bitop.luajit.org/) for bit manipulation. This library is included in LuaJIT (used by Defold).
**Limitations:**
- Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53.
- Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution.
## Wire Format ## Wire Format
- Little-endian byte order - Little-endian byte order
+51 -2
View File
@@ -17,6 +17,7 @@ func main() {
outGo := flag.String("out-go", "", "output directory for generated Go code") outGo := flag.String("out-go", "", "output directory for generated Go code")
outCS := flag.String("out-cs", "", "output directory for generated C# code") outCS := flag.String("out-cs", "", "output directory for generated C# code")
outTS := flag.String("out-ts", "", "output directory for generated TypeScript code") outTS := flag.String("out-ts", "", "output directory for generated TypeScript code")
outLua := flag.String("out-lua", "", "output directory for generated Lua code")
namespace := flag.String("cs-namespace", "Arpack.Messages", "C# namespace") namespace := flag.String("cs-namespace", "Arpack.Messages", "C# namespace")
flag.Parse() flag.Parse()
@@ -24,8 +25,8 @@ func main() {
log.Fatal("arpack: -in is required") log.Fatal("arpack: -in is required")
} }
if *outGo == "" && *outCS == "" && *outTS == "" { if *outGo == "" && *outCS == "" && *outTS == "" && *outLua == "" {
log.Fatal("arpack: at least one of -out-go, -out-cs, or -out-ts is required") log.Fatal("arpack: at least one of -out-go, -out-cs, -out-ts, or -out-lua is required")
} }
schema, err := parser.ParseSchemaFile(*in) schema, err := parser.ParseSchemaFile(*in)
@@ -96,8 +97,56 @@ func main() {
fmt.Printf("arpack: wrote %s\n", outPath) fmt.Printf("arpack: wrote %s\n", outPath)
} }
if *outLua != "" {
src, err := generator.GenerateLuaSchema(schema, baseName)
if err != nil {
log.Fatalf("arpack: Lua generation error: %v", err)
}
// Use snake_case filename for Lua require() compatibility
outPath := filepath.Join(*outLua, toSnakeCase(baseName)+"_gen.lua")
if err := os.MkdirAll(*outLua, 0755); err != nil {
log.Fatalf("arpack: mkdir %s: %v", *outLua, err)
}
if err := os.WriteFile(outPath, src, 0644); err != nil {
log.Fatalf("arpack: write %s: %v", outPath, err)
}
fmt.Printf("arpack: wrote %s\n", outPath)
}
} }
func toTitle(s string) string { func toTitle(s string) string {
return strings.ToUpper(s[:1]) + strings.ToLower(s[1:]) return strings.ToUpper(s[:1]) + strings.ToLower(s[1:])
} }
func toSnakeCase(s string) string {
if s == "" {
return ""
}
var b strings.Builder
var prevUpper bool
for i, c := range s {
isUpper := c >= 'A' && c <= 'Z'
if i > 0 && isUpper {
nextLower := false
if i+1 < len(s) {
nextChar := rune(s[i+1])
nextLower = nextChar >= 'a' && nextChar <= 'z'
}
if !prevUpper || nextLower {
b.WriteByte('_')
}
}
b.WriteRune(c)
prevUpper = isUpper
}
return strings.ToLower(b.String())
}
+202
View File
@@ -86,6 +86,84 @@ func TestE2E_CrossLanguage(t *testing.T) {
} else { } else {
t.Log("node not found, skipping TypeScript cross-language e2e tests") t.Log("node not found, skipping TypeScript cross-language e2e tests")
} }
if _, err := exec.LookPath("luajit"); err == nil {
// Use a simpler test schema without int64/uint64 for Lua
luaSchema := parser.Schema{
Messages: []parser.Message{
{
Name: "Vector3",
Fields: []parser.Field{
{Name: "X", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}},
{Name: "Y", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}},
{Name: "Z", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32, Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16}},
},
},
{
Name: "MoveMessage",
Fields: []parser.Field{
{Name: "Position", Kind: parser.KindNested, TypeName: "Vector3"},
{Name: "Velocity", Kind: parser.KindFixedArray, FixedLen: 3, Elem: &parser.Field{Kind: parser.KindPrimitive, Primitive: parser.KindFloat32}},
{Name: "Waypoints", Kind: parser.KindSlice, Elem: &parser.Field{Kind: parser.KindNested, TypeName: "Vector3"}},
{Name: "PlayerID", Kind: parser.KindPrimitive, Primitive: parser.KindUint32},
{Name: "Active", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "Visible", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "Ghost", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
},
},
{
Name: "EnvelopeMessage",
Fields: []parser.Field{
{Name: "Code", Kind: parser.KindPrimitive, Primitive: parser.KindUint16},
{Name: "Counter", Kind: parser.KindPrimitive, Primitive: parser.KindUint8},
},
},
},
Enums: []parser.Enum{
{
Name: "Opcode",
Primitive: parser.KindUint16,
Values: []parser.EnumValue{
{Name: "Unknown", Value: "0"},
{Name: "Join", Value: "1"},
{Name: "Leave", Value: "2"},
},
},
},
}
luaSrc, err := generator.GenerateLuaSchema(luaSchema, "messages")
if err != nil {
t.Fatalf("GenerateLuaSchema: %v", err)
}
luaDir := buildLuaHarness(t, luaSrc)
luaCases := []struct {
name string
typ string
epsilon float64
}{
{"Vector3", "Vector3", 0.02},
{"MoveMessage", "MoveMessage", 0.02},
{"EnvelopeMessage", "EnvelopeMessage", 0},
}
for _, tc := range luaCases {
t.Run("Go_to_Lua/"+tc.name, func(t *testing.T) {
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
out := runHarness(t, luaDir, "lua", "deser", tc.typ, hex)
checkOutput(t, tc.typ, out, tc.epsilon)
})
t.Run("Lua_to_Go/"+tc.name, func(t *testing.T) {
hex := runHarness(t, luaDir, "lua", "ser", tc.typ, "")
out := runHarness(t, goDir, "go", "deser", tc.typ, hex)
checkOutput(t, tc.typ, out, tc.epsilon)
})
}
} else {
t.Log("luajit not found, skipping Lua cross-language e2e tests")
}
} }
func buildGoHarness(t *testing.T, generatedSrc []byte) string { func buildGoHarness(t *testing.T, generatedSrc []byte) string {
@@ -139,6 +217,16 @@ func buildTSHarness(t *testing.T, generatedSrc []byte) string {
return dir return dir
} }
func buildLuaHarness(t *testing.T, generatedSrc []byte) string {
t.Helper()
dir := t.TempDir()
write(t, filepath.Join(dir, "messages_gen.lua"), generatedSrc)
write(t, filepath.Join(dir, "harness.lua"), []byte(luaHarnessSource))
return dir
}
func runHarness(t *testing.T, dir, lang, op, typ, hexInput string) string { func runHarness(t *testing.T, dir, lang, op, typ, hexInput string) string {
t.Helper() t.Helper()
var cmd *exec.Cmd var cmd *exec.Cmd
@@ -161,6 +249,12 @@ func runHarness(t *testing.T, dir, lang, op, typ, hexInput string) string {
args = append(args, hexInput) args = append(args, hexInput)
} }
cmd = exec.Command("node", append([]string{filepath.Join(dir, "dist", "harness.js")}, args...)...) cmd = exec.Command("node", append([]string{filepath.Join(dir, "dist", "harness.js")}, args...)...)
case "lua":
args := []string{filepath.Join(dir, "harness.lua"), op, typ}
if hexInput != "" {
args = append(args, hexInput)
}
cmd = exec.Command("luajit", args...)
} }
cmd.Dir = dir cmd.Dir = dir
out, err := cmd.CombinedOutput() out, err := cmd.CombinedOutput()
@@ -740,3 +834,111 @@ function main() {
main(); main();
` `
const luaHarnessSource = `-- Lua E2E Harness
-- Usage: luajit harness.lua <op> <type> [hex_input]
-- op: 'ser' or 'deser'
-- type: message type name
local messages = require("messages_gen")
local function hexToBytes(hex)
local bytes = {}
for i = 1, #hex, 2 do
local byte = tonumber(hex:sub(i, i+1), 16)
table.insert(bytes, string.char(byte))
end
return table.concat(bytes)
end
local function bytesToHex(data)
local hex = {}
for i = 1, #data do
table.insert(hex, string.format("%02x", string.byte(data, i)))
end
return table.concat(hex)
end
local function serializeVector3()
local msg = messages.new_vector3()
msg.x = 123.45
msg.y = -200.0
msg.z = 0.0
return bytesToHex(messages.serialize_vector3(msg))
end
local function deserializeVector3(hex)
local data = hexToBytes(hex)
local msg = messages.deserialize_vector3(data, 1)
print(string.format("X=%.10g", msg.x))
print(string.format("Y=%.10g", msg.y))
print(string.format("Z=%.10g", msg.z))
end
local function serializeMoveMessage()
local msg = messages.new_move_message()
msg.position = messages.new_vector3()
msg.position.x = 10.0
msg.position.y = 20.0
msg.position.z = 30.0
msg.velocity = {1.0, 2.0, 3.0}
msg.waypoints = {}
local wp = messages.new_vector3()
wp.x = 10.0
wp.y = 20.0
wp.z = 0.0
table.insert(msg.waypoints, wp)
msg.player_id = 777
msg.active = true
msg.visible = false
msg.ghost = true
msg.name = "TestPlayer"
return bytesToHex(messages.serialize_move_message(msg))
end
local function deserializeMoveMessage(hex)
local data = hexToBytes(hex)
local msg = messages.deserialize_move_message(data, 1)
print(string.format("PlayerID=%d", msg.player_id))
print(string.format("Active=%s", tostring(msg.active)))
print(string.format("Visible=%s", tostring(msg.visible)))
print(string.format("Ghost=%s", tostring(msg.ghost)))
print(string.format("Name=%s", msg.name))
end
local function serializeEnvelopeMessage()
local msg = messages.new_envelope_message()
msg.code = 2 -- Join
msg.counter = 7
return bytesToHex(messages.serialize_envelope_message(msg))
end
local function deserializeEnvelopeMessage(hex)
local data = hexToBytes(hex)
local msg = messages.deserialize_envelope_message(data, 1)
print(string.format("Code=%d", msg.code))
print(string.format("Counter=%d", msg.counter))
end
local op = arg[1]
local typ = arg[2]
local hexInput = arg[3]
local key = op .. ":" .. typ
if key == "ser:Vector3" then
print(serializeVector3())
elseif key == "deser:Vector3" then
deserializeVector3(hexInput)
elseif key == "ser:MoveMessage" then
print(serializeMoveMessage())
elseif key == "deser:MoveMessage" then
deserializeMoveMessage(hexInput)
elseif key == "ser:EnvelopeMessage" then
print(serializeEnvelopeMessage())
elseif key == "deser:EnvelopeMessage" then
deserializeEnvelopeMessage(hexInput)
else
error("Unknown op:type " .. key)
end
`
+776
View File
@@ -0,0 +1,776 @@
package generator
import (
"fmt"
"strings"
"github.com/edmand46/arpack/parser"
)
func GenerateLuaSchema(schema parser.Schema, moduleName string) ([]byte, error) {
if err := checkLuaUnsupportedTypes(schema); err != nil {
return nil, err
}
messages := schema.Messages
var b strings.Builder
b.WriteString("-- <auto-generated> arpack </auto-generated>\n")
b.WriteString("-- Code generated by arpack. DO NOT EDIT.\n\n")
b.WriteString("local M = {}\n\n")
b.WriteString("-- Load BitOp library for bit operations (Defold/LuaJIT)\n")
b.WriteString("local bit = require('bit')\n\n")
writeLuaHelpers(&b)
enumNames := make(map[string]struct{}, len(schema.Enums))
for _, enum := range schema.Enums {
enumNames[enum.Name] = struct{}{}
}
for _, enum := range schema.Enums {
writeLuaEnum(&b, enum)
b.WriteString("\n")
}
for _, msg := range messages {
writeLuaConstructor(&b, msg, enumNames)
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaSerializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaDeserializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
b.WriteString("return M\n")
return []byte(b.String()), nil
}
func writeLuaHelpers(b *strings.Builder) {
b.WriteString("-- Inline helpers for little-endian byte operations\n\n")
b.WriteString("-- Error handling for truncated data\n")
b.WriteString("local function check_bounds(data, offset, needed, context)\n")
b.WriteString(" local available = #data - offset + 1\n")
b.WriteString(" if available < needed then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for %s: need %d bytes, have %d\", context, needed, available))\n")
b.WriteString(" end\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
b.WriteString(" return string.byte(data, offset), 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u8(n)\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u16_le(data, offset)\n")
b.WriteString(" local b1, b2 = string.byte(data, offset, offset + 1)\n")
b.WriteString(" if not b2 then error(\"arpack: buffer too short for u16\") end\n")
b.WriteString(" return b1 + b2 * 256, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u16_le(n)\n")
b.WriteString(" return string.char(n % 256, math.floor(n / 256))\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u32_le(data, offset)\n")
b.WriteString(" local b1, b2, b3, b4 = string.byte(data, offset, offset + 3)\n")
b.WriteString(" if not b4 then error(\"arpack: buffer too short for u32\") end\n")
b.WriteString(" return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u32_le(n)\n")
b.WriteString(" return string.char(\n")
b.WriteString(" n % 256,\n")
b.WriteString(" math.floor(n / 256) % 256,\n")
b.WriteString(" math.floor(n / 65536) % 256,\n")
b.WriteString(" math.floor(n / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for i8\") end\n")
b.WriteString(" local v = string.byte(data, offset)\n")
b.WriteString(" if v >= 128 then v = v - 256 end\n")
b.WriteString(" return v, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i8(n)\n")
b.WriteString(" if n < 0 then n = n + 256 end\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i16_le(data, offset)\n")
b.WriteString(" local v = read_u16_le(data, offset)\n")
b.WriteString(" if v >= 32768 then v = v - 65536 end\n")
b.WriteString(" return v, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i16_le(n)\n")
b.WriteString(" if n < 0 then n = n + 65536 end\n")
b.WriteString(" return write_u16_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i32_le(data, offset)\n")
b.WriteString(" local v = read_u32_le(data, offset)\n")
b.WriteString(" if v >= 2147483648 then v = v - 4294967296 end\n")
b.WriteString(" return v, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i32_le(n)\n")
b.WriteString(" if n < 0 then n = n + 4294967296 end\n")
b.WriteString(" return write_u32_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f32_le(data, offset)\n")
b.WriteString(" local u32 = read_u32_le(data, offset)\n")
b.WriteString(" if u32 == 0 then return 0.0, 4 end\n")
b.WriteString(" local sign = (u32 >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then u32 = u32 - 2147483648 end\n")
b.WriteString(" local exp = math.floor(u32 / 8388608) % 256\n")
b.WriteString(" local mant = u32 % 8388608\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 4\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 8388608) * math.pow(2, -126), 4\n")
b.WriteString(" elseif exp == 255 then\n")
b.WriteString(" if mant == 0 then return sign * math.huge, 4\n")
b.WriteString(" else return 0.0 / 0.0, 4 end\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (1 + mant / 8388608) * math.pow(2, exp - 127), 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f32_le(n)\n")
b.WriteString(" if n ~= n then return write_u32_le(2143289344) end\n")
b.WriteString(" if n == math.huge then return write_u32_le(2139095040) end\n")
b.WriteString(" if n == -math.huge then return write_u32_le(4286578688) end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then return write_u32_le(2147483648) end\n")
b.WriteString(" return write_u32_le(0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -126) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -149) + 0.5)\n")
b.WriteString(" if mant <= 0 then return write_u32_le(sign) end\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = math.floor((n / math.pow(2, exp) - 1) * 8388608 + 0.5)\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 127\n")
b.WriteString(" if exp >= 255 then\n")
b.WriteString(" return write_u32_le(sign + 2139095040)\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" return write_u32_le(sign + exp * 8388608 + mant)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f64_le(data, offset)\n")
b.WriteString(" -- Read 8 bytes directly to avoid precision loss from 64-bit arithmetic\n")
b.WriteString(" local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(data, offset, offset + 7)\n")
b.WriteString(" if not b8 then error(\"arpack: buffer too short for f64\") end\n")
b.WriteString(" local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216\n")
b.WriteString(" local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216\n")
b.WriteString(" -- Decode IEEE 754 double from low/high parts separately\n")
b.WriteString(" if low == 0 and high == 0 then return 0.0, 8 end\n")
b.WriteString(" local sign = (high >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then high = high - 2147483648 end\n")
b.WriteString(" local exp = math.floor(high / 1048576) % 2048\n")
b.WriteString(" local high_mant = high % 1048576\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 8\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 4503599627370496) * math.pow(2, -1022), 8\n")
b.WriteString(" elseif exp == 2047 then\n")
b.WriteString(" if high_mant == 0 and low == 0 then return sign * math.huge, 8\n")
b.WriteString(" else return 0.0 / 0.0, 8 end\n")
b.WriteString(" end\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" return sign * (1 + mant / 4503599627370496) * math.pow(2, exp - 1023), 8\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f64_le(n)\n")
b.WriteString(" -- Handle special values\n")
b.WriteString(" if n ~= n then -- NaN\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 248, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 255)\n")
b.WriteString(" end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 128)\n")
b.WriteString(" end\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -1022) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -1074) + 0.5)\n")
b.WriteString(" if mant <= 0 then\n")
b.WriteString(" local high = sign\n")
b.WriteString(" return string.char(0, 0, 0, 0, high % 256, math.floor(high / 256) % 256, math.floor(high / 65536) % 256, math.floor(high / 16777216) % 256)\n")
b.WriteString(" end\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = (n / math.pow(2, exp) - 1) * 4503599627370496\n")
b.WriteString(" mant = math.floor(mant + 0.5)\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 1023\n")
b.WriteString(" if exp >= 2047 then\n")
b.WriteString(" exp = 2047\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" local high_mant = math.floor(mant / 4294967296)\n")
b.WriteString(" local low = mant % 4294967296\n")
b.WriteString(" local high = sign + exp * 1048576 + high_mant\n")
b.WriteString(" return string.char(\n")
b.WriteString(" low % 256,\n")
b.WriteString(" math.floor(low / 256) % 256,\n")
b.WriteString(" math.floor(low / 65536) % 256,\n")
b.WriteString(" math.floor(low / 16777216) % 256,\n")
b.WriteString(" high % 256,\n")
b.WriteString(" math.floor(high / 256) % 256,\n")
b.WriteString(" math.floor(high / 65536) % 256,\n")
b.WriteString(" math.floor(high / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_bool(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for bool\") end\n")
b.WriteString(" return string.byte(data, offset) ~= 0, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_bool(v)\n")
b.WriteString(" return string.char(v and 1 or 0)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_string(data, offset)\n")
b.WriteString(" local len, header_bytes = read_u16_le(data, offset)\n")
b.WriteString(" if len == 0 then return '', header_bytes end\n")
b.WriteString(" local available = #data - offset - header_bytes + 1\n")
b.WriteString(" if available < len then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for string: need %d bytes, have %d\", len, available))\n")
b.WriteString(" end\n")
b.WriteString(" -- string.sub is 1-based; data starts at offset + header_bytes\n")
b.WriteString(" return string.sub(data, offset + header_bytes, offset + header_bytes + len - 1), header_bytes + len\n")
b.WriteString("end\n\n")
b.WriteString("local function write_string(s)\n")
b.WriteString(" local len = #s\n")
b.WriteString(" return write_u16_le(len) .. s\n")
b.WriteString("end\n\n")
}
func writeLuaEnum(b *strings.Builder, enum parser.Enum) {
fmt.Fprintf(b, "M.%s = {\n", enum.Name)
for i, value := range enum.Values {
fmt.Fprintf(b, " %s = %s", value.Name, value.Value)
if i < len(enum.Values)-1 {
b.WriteString(",")
}
b.WriteString("\n")
}
b.WriteString("}\n")
}
func writeLuaConstructor(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) {
fmt.Fprintf(b, "function M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" return {\n")
for _, f := range msg.Fields {
defaultValue := luaDefaultValue(f, enumNames)
fmt.Fprintf(b, " %s = %s,\n", luaFieldName(f.Name), defaultValue)
}
b.WriteString(" }\n")
b.WriteString("end\n")
}
func writeLuaSerializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
fmt.Fprintf(b, "function M.serialize_%s(msg)\n", toSnakeCase(msg.Name))
b.WriteString(" local parts = {}\n")
b.WriteString(" local part_idx = 0\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaSerializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupSerialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return table.concat(parts)\n")
b.WriteString("end\n")
return nil
}
func writeLuaDeserializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
minSize := packedMinWireSize(msg.Fields)
fmt.Fprintf(b, "function M.deserialize_%s(data, offset)\n", toSnakeCase(msg.Name))
b.WriteString(" offset = offset or 1\n")
fmt.Fprintf(b, " local msg = M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" local start_offset = offset\n")
b.WriteString(" local bytes_read = 0\n")
fmt.Fprintf(b, " if #data < offset + %d - 1 then\n", minSize)
fmt.Fprintf(b, " error(\"arpack: buffer too short for %s\")\n", msg.Name)
b.WriteString(" end\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaDeserializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupDeserialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return msg, offset - start_offset\n")
b.WriteString("end\n")
return nil
}
func writeLuaBoolGroupSerialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_bool_byte_%d", groupIdx)
fmt.Fprintf(b, "%slocal %s = 0\n", indent, varName)
for bit, f := range bools {
fmt.Fprintf(b, "%sif %s.%s then %s = bit.bor(%s, %d) end\n",
indent, recv, luaFieldName(f.Name), varName, varName, 1<<bit)
}
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
}
func writeLuaBoolGroupDeserialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_bool_byte_%d", groupIdx)
fmt.Fprintf(b, "%sif #data < offset then error(\"arpack: buffer too short for bool group\") end\n", indent)
fmt.Fprintf(b, "%slocal %s = string.byte(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 1\n", indent)
for bit, f := range bools {
fmt.Fprintf(b, "%s%s.%s = bit.band(%s, %d) ~= 0\n",
indent, recv, luaFieldName(f.Name), varName, 1<<bit)
}
}
func writeLuaSerializeField(b *strings.Builder, recv string, f parser.Field, indent string, enumNames map[string]struct{}) error {
access := recv + "." + luaFieldName(f.Name)
switch f.Kind {
case parser.KindPrimitive:
return writeLuaSerializePrimitive(b, access, f, indent, enumNames)
case parser.KindNested:
fmt.Fprintf(b, "%slocal _nested_%s = M.serialize_%s(%s)\n", indent, f.Name, toSnakeCase(f.TypeName), access)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = _nested_%s\n", indent, f.Name)
case parser.KindFixedArray:
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %d do\n", indent, iVar, f.FixedLen)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaSerializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
fmt.Fprintf(b, "%send\n", indent)
case parser.KindSlice:
lenVar := "_len_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%slocal %s = #(%s or {})\n", indent, lenVar, access)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, lenVar)
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
if f.Elem.Kind == parser.KindNested {
// For nested types in slices, serialize directly
fmt.Fprintf(b, "%slocal _nested_%s = M.serialize_%s(%s[%s])\n",
indent, f.Name, toSnakeCase(f.Elem.TypeName), access, iVar)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = _nested_%s\n",
indent, f.Name)
} else {
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaSerializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
}
fmt.Fprintf(b, "%send\n", indent)
}
return nil
}
func writeLuaSerializePrimitive(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
if f.Quant != nil {
return writeLuaSerializeQuant(b, access, f, indent)
}
valueExpr := luaSerializeValueExpr(access, f, enumNames)
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_f32_le(%s)\n", indent, valueExpr)
case parser.KindFloat64:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_f64_le(%s)\n", indent, valueExpr)
case parser.KindInt8:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i8(%s)\n", indent, valueExpr)
case parser.KindUint8:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, valueExpr)
case parser.KindBool:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_bool(%s)\n", indent, valueExpr)
case parser.KindInt16:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i16_le(%s)\n", indent, valueExpr)
case parser.KindUint16:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, valueExpr)
case parser.KindInt32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i32_le(%s)\n", indent, valueExpr)
case parser.KindUint32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u32_le(%s)\n", indent, valueExpr)
case parser.KindInt64, parser.KindUint64:
return fmt.Errorf("int64/uint64 serialization not supported in Lua")
case parser.KindString:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_string(%s or '')\n", indent, valueExpr)
}
return nil
}
func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
q := f.Quant
maxUint := q.MaxUint()
varName := "_q_" + sanitizeLuaVarName(access)
fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g + 0.5)\n",
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
if q.Bits == 8 {
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
} else {
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, varName)
}
return nil
}
func writeLuaDeserializeField(b *strings.Builder, recv string, f parser.Field, indent string, enumNames map[string]struct{}) error {
access := recv + "." + luaFieldName(f.Name)
switch f.Kind {
case parser.KindPrimitive:
return writeLuaDeserializePrimitive(b, access, f, indent, enumNames)
case parser.KindNested:
fmt.Fprintf(b, "%slocal _nested_%s, _n_%s = M.deserialize_%s(data, offset)\n", indent, f.Name, f.Name, toSnakeCase(f.TypeName))
fmt.Fprintf(b, "%s%s = _nested_%s\n", indent, access, f.Name)
fmt.Fprintf(b, "%soffset = offset + _n_%s\n", indent, f.Name)
case parser.KindFixedArray:
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%s%s = {}\n", indent, access)
fmt.Fprintf(b, "%sfor %s = 1, %d do\n", indent, iVar, f.FixedLen)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaDeserializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
fmt.Fprintf(b, "%send\n", indent)
case parser.KindSlice:
lenVar := "_len_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%slocal %s = read_u16_le(data, offset)\n", indent, lenVar)
fmt.Fprintf(b, "%soffset = offset + 2\n", indent)
fmt.Fprintf(b, "%s%s = {}\n", indent, access)
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
if f.Elem.Kind == parser.KindNested {
// For nested types in slices, deserialize directly
fmt.Fprintf(b, "%slocal _nested_%s, _n_%s = M.deserialize_%s(data, offset)\n",
indent, f.Name, f.Name, toSnakeCase(f.Elem.TypeName))
fmt.Fprintf(b, "%s%s[%s] = _nested_%s\n",
indent, access, iVar, f.Name)
fmt.Fprintf(b, "%soffset = offset + _n_%s\n",
indent, f.Name)
} else {
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaDeserializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
}
fmt.Fprintf(b, "%send\n", indent)
}
return nil
}
func writeLuaDeserializePrimitive(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
if f.Quant != nil {
return writeLuaDeserializeQuant(b, access, f, indent, enumNames)
}
varName := "_v_" + sanitizeLuaVarName(access)
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%slocal %s, _n = read_f32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindFloat64:
fmt.Fprintf(b, "%slocal %s, _n = read_f64_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt8:
fmt.Fprintf(b, "%slocal %s, _n = read_i8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint8:
fmt.Fprintf(b, "%slocal %s, _n = read_u8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindBool:
fmt.Fprintf(b, "%slocal %s, _n = read_bool(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt16:
fmt.Fprintf(b, "%slocal %s, _n = read_i16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint16:
fmt.Fprintf(b, "%slocal %s, _n = read_u16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt32:
fmt.Fprintf(b, "%slocal %s, _n = read_i32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint32:
fmt.Fprintf(b, "%slocal %s, _n = read_u32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt64, parser.KindUint64:
return fmt.Errorf("int64/uint64 deserialization not supported in Lua")
case parser.KindString:
fmt.Fprintf(b, "%slocal %s, _n = read_string(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
}
return nil
}
func writeLuaDeserializeQuant(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
q := f.Quant
maxUint := q.MaxUint()
varName := "_q_" + sanitizeLuaVarName(access)
if q.Bits == 8 {
fmt.Fprintf(b, "%slocal %s = read_u8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 1\n", indent)
} else {
fmt.Fprintf(b, "%slocal %s = read_u16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 2\n", indent)
}
expr := fmt.Sprintf("%s / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(expr, f, enumNames))
return nil
}
func luaFieldName(name string) string {
return toSnakeCase(name)
}
func toSnakeCase(s string) string {
if s == "" {
return ""
}
var b strings.Builder
var prevUpper bool
for i, c := range s {
isUpper := c >= 'A' && c <= 'Z'
// Add underscore before uppercase letter if:
// - It's not the first character
// - Previous character was lowercase, OR
// - Previous character was uppercase AND next character (if exists) is lowercase
// (this handles cases like "PlayerID" -> "player_id", not "player_i_d")
if i > 0 && isUpper {
nextLower := false
if i+1 < len(s) {
nextChar := rune(s[i+1])
nextLower = nextChar >= 'a' && nextChar <= 'z'
}
if !prevUpper || nextLower {
b.WriteByte('_')
}
}
b.WriteRune(c)
prevUpper = isUpper
}
return strings.ToLower(b.String())
}
func luaDefaultValue(f parser.Field, enumNames map[string]struct{}) string {
switch f.Kind {
case parser.KindPrimitive:
if luaIsEnumType(f, enumNames) {
return "0"
}
switch f.Primitive {
case parser.KindFloat32, parser.KindFloat64, parser.KindInt8, parser.KindInt16, parser.KindInt32,
parser.KindUint8, parser.KindUint16, parser.KindUint32, parser.KindInt64, parser.KindUint64:
return "0"
case parser.KindBool:
return "false"
case parser.KindString:
return "''"
}
case parser.KindNested:
return fmt.Sprintf("M.new_%s()", toSnakeCase(f.TypeName))
case parser.KindFixedArray, parser.KindSlice:
return "{}"
}
return "nil"
}
func luaTypeName(f parser.Field, enumNames map[string]struct{}) string {
switch f.Kind {
case parser.KindPrimitive:
if luaIsEnumType(f, enumNames) {
return "number"
}
return "number"
case parser.KindNested:
return f.TypeName
case parser.KindFixedArray, parser.KindSlice:
return "table"
}
return "any"
}
func luaSerializeValueExpr(access string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return access
}
return access
}
func luaDeserializeValueExpr(expr string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return expr
}
return expr
}
func luaIsEnumType(f parser.Field, enumNames map[string]struct{}) bool {
if f.NamedType == "" || enumNames == nil {
return false
}
_, ok := enumNames[f.NamedType]
return ok
}
func sanitizeLuaVarName(s string) string {
var b strings.Builder
for _, c := range s {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
b.WriteRune(c)
} else {
b.WriteRune('_')
}
}
return strings.ToLower(b.String())
}
func checkLuaUnsupportedTypes(schema parser.Schema) error {
for _, msg := range schema.Messages {
for _, f := range msg.Fields {
if err := checkFieldLuaSupport(msg.Name, f); err != nil {
return err
}
}
}
return nil
}
func checkFieldLuaSupport(msgName string, f parser.Field) error {
switch f.Kind {
case parser.KindPrimitive:
if f.Primitive == parser.KindInt64 || f.Primitive == parser.KindUint64 {
return fmt.Errorf("Lua target does not support int64/uint64: field %s in message %s (LuaJIT/Defold uses double-precision floats which can only safely represent integers up to 2^53)", f.Name, msgName)
}
case parser.KindFixedArray, parser.KindSlice:
if f.Elem != nil {
return checkFieldLuaSupport(msgName, *f.Elem)
}
}
return nil
}
+557
View File
@@ -0,0 +1,557 @@
package generator
import (
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/edmand46/arpack/parser"
)
func TestGenerateLua_BasicTypes(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "BasicTypes",
Fields: []parser.Field{
{Name: "Int8Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt8},
{Name: "Int16Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt16},
{Name: "Int32Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
{Name: "Uint8Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint8},
{Name: "Uint16Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint16},
{Name: "Uint32Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint32},
{Name: "Float32Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Float64Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64},
{Name: "BoolField", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "StringField", Kind: parser.KindPrimitive, Primitive: parser.KindString},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "function M.new_basic_types()") {
t.Error("Missing constructor for BasicTypes")
}
if !strings.Contains(luaStr, "function M.serialize_basic_types(msg)") {
t.Error("Missing serializer for BasicTypes")
}
if !strings.Contains(luaStr, "function M.deserialize_basic_types(data, offset)") {
t.Error("Missing deserializer for BasicTypes")
}
if !strings.Contains(luaStr, "int8_field = 0") {
t.Error("Missing int8_field in constructor")
}
if !strings.Contains(luaStr, "string_field = ''") {
t.Error("Missing string_field default value")
}
if !strings.Contains(luaStr, "bool_field = false") {
t.Error("Missing bool_field default value")
}
}
func TestGenerateLua_Enum(t *testing.T) {
schema := parser.Schema{
Enums: []parser.Enum{
{
Name: "Opcode",
Primitive: parser.KindUint16,
Values: []parser.EnumValue{
{Name: "Unknown", Value: "0"},
{Name: "Join", Value: "1"},
{Name: "Leave", Value: "2"},
},
},
},
Messages: []parser.Message{
{
Name: "MessageWithEnum",
Fields: []parser.Field{
{Name: "Op", Kind: parser.KindPrimitive, Primitive: parser.KindUint16, NamedType: "Opcode"},
},
},
},
}
enumNames := map[string]struct{}{"Opcode": {}}
_ = enumNames
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "M.Opcode = {") {
t.Error("Missing Opcode enum table")
}
if !strings.Contains(luaStr, "Unknown = 0") {
t.Error("Missing Unknown enum value")
}
if !strings.Contains(luaStr, "Join = 1") {
t.Error("Missing Join enum value")
}
}
func TestGenerateLua_NestedMessage(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "Vector3",
Fields: []parser.Field{
{Name: "X", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Y", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Z", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
},
},
{
Name: "Player",
Fields: []parser.Field{
{Name: "Position", Kind: parser.KindNested, TypeName: "Vector3"},
{Name: "Health", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "function M.new_vector3()") {
t.Error("Missing constructor for Vector3")
}
if !strings.Contains(luaStr, "function M.new_player()") {
t.Error("Missing constructor for Player")
}
if !strings.Contains(luaStr, "position = M.new_vector3()") {
t.Error("Missing nested initialization in Player constructor")
}
if !strings.Contains(luaStr, "M.serialize_vector3") {
t.Error("Missing Vector3 serializer call")
}
}
func TestGenerateLua_FixedArray(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithFixedArray",
Fields: []parser.Field{
{
Name: "Values",
Kind: parser.KindFixedArray,
FixedLen: 3,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindFloat32,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "values = {}") {
t.Error("Missing values array initialization")
}
if !strings.Contains(luaStr, "for _i_values = 1, 3 do") {
t.Error("Missing fixed array loop in serializer")
}
}
func TestGenerateLua_Slice(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithSlice",
Fields: []parser.Field{
{
Name: "Items",
Kind: parser.KindSlice,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindInt32,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "items = {}") {
t.Error("Missing items slice initialization")
}
if !strings.Contains(luaStr, "local _len_items = #(msg.items or {})") {
t.Error("Missing slice length serialization")
}
if !strings.Contains(luaStr, "for _i_items = 1, _len_items do") {
t.Error("Missing slice iteration in serializer")
}
}
func TestGenerateLua_BoolPacking(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithBools",
Fields: []parser.Field{
{Name: "A", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "B", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "C", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "local _bool_byte_0 = 0") {
t.Error("Missing bool byte packing variable")
}
if !strings.Contains(luaStr, "if msg.a then _bool_byte_0 = bit.bor(_bool_byte_0, 1) end") {
t.Error("Missing first bool packing check with bit.bor")
}
if !strings.Contains(luaStr, "if msg.b then _bool_byte_0 = bit.bor(_bool_byte_0, 2) end") {
t.Error("Missing second bool packing check with bit.bor")
}
if !strings.Contains(luaStr, "msg.a = bit.band(_bool_byte_0, 1) ~= 0") {
t.Error("Missing bit.band for bool deserialization")
}
}
func TestGenerateLua_QuantizedFloat(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithQuantized",
Fields: []parser.Field{
{
Name: "Position",
Kind: parser.KindPrimitive,
Primitive: parser.KindFloat32,
Quant: &parser.QuantInfo{
Min: -500,
Max: 500,
Bits: 16,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "math.floor") {
t.Error("Missing math.floor for quantization")
}
if !strings.Contains(luaStr, "write_u16_le") {
t.Error("Missing u16 write for 16-bit quantization")
}
}
func TestToSnakeCase(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{"A", "a"},
{"AB", "ab"},
{"AbCd", "ab_cd"},
{"ABC", "abc"},
{"PlayerID", "player_id"},
{"HTTPResponse", "http_response"},
{"XMLHttpRequest", "xml_http_request"},
{"getHTTPResponse", "get_http_response"},
}
for _, tt := range tests {
result := toSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("toSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestLuaHelpersGenerated(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "Empty",
Fields: []parser.Field{},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
helpers := []string{
"local bit = require('bit')",
"buffer too short for u8",
"buffer too short for bool",
"local function write_u8(n)",
"buffer too short for u16",
"local function write_u16_le(n)",
"buffer too short for u32",
"local function write_u32_le(n)",
"local function read_f32_le(data, offset)",
"local function write_f32_le(n)",
"local function read_f64_le(data, offset)",
"local function write_f64_le(n)",
"local function write_bool(v)",
"buffer too short for string",
"local function write_string(s)",
}
for _, helper := range helpers {
if !strings.Contains(luaStr, helper) {
t.Errorf("Missing helper: %s", helper)
}
}
}
func TestGenerateLua_Int64NotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithInt64",
Fields: []parser.Field{
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt64},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for int64 field, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
if !strings.Contains(err.Error(), "LuaJIT/Defold") {
t.Errorf("Expected error mentioning LuaJIT/Defold, got: %v", err)
}
}
func TestGenerateLua_Uint64NotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithUint64",
Fields: []parser.Field{
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindUint64},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for uint64 field, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
}
func TestGenerateLua_Int64InSliceNotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithInt64Slice",
Fields: []parser.Field{
{
Name: "Values",
Kind: parser.KindSlice,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindInt64,
},
},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for int64 in slice, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
}
func TestGenerateLua_BoundsChecks(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "SimpleMessage",
Fields: []parser.Field{
{Name: "ID", Kind: parser.KindPrimitive, Primitive: parser.KindUint32},
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
// Check that bounds check function exists
if !strings.Contains(luaStr, "check_bounds") {
t.Error("Missing check_bounds function")
}
// Check that read_u16_le has bounds check
if !strings.Contains(luaStr, "buffer too short for u16") {
t.Error("Missing bounds check in read_u16_le")
}
// Check that read_u32_le has bounds check
if !strings.Contains(luaStr, "buffer too short for u32") {
t.Error("Missing bounds check in read_u32_le")
}
// Check that read_string has bounds check
if !strings.Contains(luaStr, "buffer too short for string") {
t.Error("Missing bounds check in read_string")
}
// Check that deserialize function has min size check (message name is preserved in error)
if !strings.Contains(luaStr, "buffer too short for SimpleMessage") {
t.Error("Missing min size check in deserialize function")
}
// Check that read_u8 has bounds check
if !strings.Contains(luaStr, "buffer too short for u8") {
t.Error("Missing bounds check in read_u8")
}
// Check that read_bool has bounds check
if !strings.Contains(luaStr, "buffer too short for bool") {
t.Error("Missing bounds check in read_bool")
}
// Check that read_i8 has bounds check
if !strings.Contains(luaStr, "buffer too short for i8") {
t.Error("Missing bounds check in read_i8")
}
}
func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) {
if _, err := exec.LookPath("luajit"); err != nil {
t.Skip("luajit not found")
}
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "FloatEdges",
Fields: []parser.Field{
{Name: "F32", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "F64", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "messages")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
dir := t.TempDir()
modulePath := filepath.Join(dir, "messages_gen.lua")
if err := os.WriteFile(modulePath, lua, 0o600); err != nil {
t.Fatalf("write module: %v", err)
}
scriptPath := filepath.Join(dir, "check.lua")
script := `local messages = require("messages_gen")
local function bytes_to_hex(s)
return (s:gsub(".", function(c) return string.format("%02x", string.byte(c)) end))
end
local neg_zero = string.char(0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 128)
local msg = messages.deserialize_float_edges(neg_zero, 1)
print(bytes_to_hex(messages.serialize_float_edges(msg)))
local subnormal = string.char(1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)
msg = messages.deserialize_float_edges(subnormal, 1)
print(bytes_to_hex(messages.serialize_float_edges(msg)))
`
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
t.Fatalf("write script: %v", err)
}
cmd := exec.Command("luajit", "check.lua")
cmd.Dir = dir
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("luajit failed: %v\n%s", err, out)
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) != 2 {
t.Fatalf("expected 2 output lines, got %d: %q", len(lines), string(out))
}
if lines[0] != "000000800000000000000080" {
t.Fatalf("negative zero roundtrip mismatch: %s", lines[0])
}
if lines[1] != "010000000100000000000000" {
t.Fatalf("subnormal roundtrip mismatch: %s", lines[1])
}
}