feat: added lua

This commit is contained in:
2026-03-25 13:02:08 +03:00
parent 57f3d9e976
commit cf2e095fbe
5 changed files with 1633 additions and 15 deletions
+776
View File
@@ -0,0 +1,776 @@
package generator
import (
"fmt"
"strings"
"github.com/edmand46/arpack/parser"
)
func GenerateLuaSchema(schema parser.Schema, moduleName string) ([]byte, error) {
if err := checkLuaUnsupportedTypes(schema); err != nil {
return nil, err
}
messages := schema.Messages
var b strings.Builder
b.WriteString("-- <auto-generated> arpack </auto-generated>\n")
b.WriteString("-- Code generated by arpack. DO NOT EDIT.\n\n")
b.WriteString("local M = {}\n\n")
b.WriteString("-- Load BitOp library for bit operations (Defold/LuaJIT)\n")
b.WriteString("local bit = require('bit')\n\n")
writeLuaHelpers(&b)
enumNames := make(map[string]struct{}, len(schema.Enums))
for _, enum := range schema.Enums {
enumNames[enum.Name] = struct{}{}
}
for _, enum := range schema.Enums {
writeLuaEnum(&b, enum)
b.WriteString("\n")
}
for _, msg := range messages {
writeLuaConstructor(&b, msg, enumNames)
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaSerializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaDeserializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
b.WriteString("return M\n")
return []byte(b.String()), nil
}
func writeLuaHelpers(b *strings.Builder) {
b.WriteString("-- Inline helpers for little-endian byte operations\n\n")
b.WriteString("-- Error handling for truncated data\n")
b.WriteString("local function check_bounds(data, offset, needed, context)\n")
b.WriteString(" local available = #data - offset + 1\n")
b.WriteString(" if available < needed then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for %s: need %d bytes, have %d\", context, needed, available))\n")
b.WriteString(" end\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
b.WriteString(" return string.byte(data, offset), 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u8(n)\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u16_le(data, offset)\n")
b.WriteString(" local b1, b2 = string.byte(data, offset, offset + 1)\n")
b.WriteString(" if not b2 then error(\"arpack: buffer too short for u16\") end\n")
b.WriteString(" return b1 + b2 * 256, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u16_le(n)\n")
b.WriteString(" return string.char(n % 256, math.floor(n / 256))\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u32_le(data, offset)\n")
b.WriteString(" local b1, b2, b3, b4 = string.byte(data, offset, offset + 3)\n")
b.WriteString(" if not b4 then error(\"arpack: buffer too short for u32\") end\n")
b.WriteString(" return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u32_le(n)\n")
b.WriteString(" return string.char(\n")
b.WriteString(" n % 256,\n")
b.WriteString(" math.floor(n / 256) % 256,\n")
b.WriteString(" math.floor(n / 65536) % 256,\n")
b.WriteString(" math.floor(n / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for i8\") end\n")
b.WriteString(" local v = string.byte(data, offset)\n")
b.WriteString(" if v >= 128 then v = v - 256 end\n")
b.WriteString(" return v, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i8(n)\n")
b.WriteString(" if n < 0 then n = n + 256 end\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i16_le(data, offset)\n")
b.WriteString(" local v = read_u16_le(data, offset)\n")
b.WriteString(" if v >= 32768 then v = v - 65536 end\n")
b.WriteString(" return v, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i16_le(n)\n")
b.WriteString(" if n < 0 then n = n + 65536 end\n")
b.WriteString(" return write_u16_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i32_le(data, offset)\n")
b.WriteString(" local v = read_u32_le(data, offset)\n")
b.WriteString(" if v >= 2147483648 then v = v - 4294967296 end\n")
b.WriteString(" return v, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i32_le(n)\n")
b.WriteString(" if n < 0 then n = n + 4294967296 end\n")
b.WriteString(" return write_u32_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f32_le(data, offset)\n")
b.WriteString(" local u32 = read_u32_le(data, offset)\n")
b.WriteString(" if u32 == 0 then return 0.0, 4 end\n")
b.WriteString(" local sign = (u32 >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then u32 = u32 - 2147483648 end\n")
b.WriteString(" local exp = math.floor(u32 / 8388608) % 256\n")
b.WriteString(" local mant = u32 % 8388608\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 4\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 8388608) * math.pow(2, -126), 4\n")
b.WriteString(" elseif exp == 255 then\n")
b.WriteString(" if mant == 0 then return sign * math.huge, 4\n")
b.WriteString(" else return 0.0 / 0.0, 4 end\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (1 + mant / 8388608) * math.pow(2, exp - 127), 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f32_le(n)\n")
b.WriteString(" if n ~= n then return write_u32_le(2143289344) end\n")
b.WriteString(" if n == math.huge then return write_u32_le(2139095040) end\n")
b.WriteString(" if n == -math.huge then return write_u32_le(4286578688) end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then return write_u32_le(2147483648) end\n")
b.WriteString(" return write_u32_le(0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -126) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -149) + 0.5)\n")
b.WriteString(" if mant <= 0 then return write_u32_le(sign) end\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = math.floor((n / math.pow(2, exp) - 1) * 8388608 + 0.5)\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 127\n")
b.WriteString(" if exp >= 255 then\n")
b.WriteString(" return write_u32_le(sign + 2139095040)\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" return write_u32_le(sign + exp * 8388608 + mant)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f64_le(data, offset)\n")
b.WriteString(" -- Read 8 bytes directly to avoid precision loss from 64-bit arithmetic\n")
b.WriteString(" local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(data, offset, offset + 7)\n")
b.WriteString(" if not b8 then error(\"arpack: buffer too short for f64\") end\n")
b.WriteString(" local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216\n")
b.WriteString(" local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216\n")
b.WriteString(" -- Decode IEEE 754 double from low/high parts separately\n")
b.WriteString(" if low == 0 and high == 0 then return 0.0, 8 end\n")
b.WriteString(" local sign = (high >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then high = high - 2147483648 end\n")
b.WriteString(" local exp = math.floor(high / 1048576) % 2048\n")
b.WriteString(" local high_mant = high % 1048576\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 8\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 4503599627370496) * math.pow(2, -1022), 8\n")
b.WriteString(" elseif exp == 2047 then\n")
b.WriteString(" if high_mant == 0 and low == 0 then return sign * math.huge, 8\n")
b.WriteString(" else return 0.0 / 0.0, 8 end\n")
b.WriteString(" end\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" return sign * (1 + mant / 4503599627370496) * math.pow(2, exp - 1023), 8\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f64_le(n)\n")
b.WriteString(" -- Handle special values\n")
b.WriteString(" if n ~= n then -- NaN\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 248, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 255)\n")
b.WriteString(" end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 128)\n")
b.WriteString(" end\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -1022) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -1074) + 0.5)\n")
b.WriteString(" if mant <= 0 then\n")
b.WriteString(" local high = sign\n")
b.WriteString(" return string.char(0, 0, 0, 0, high % 256, math.floor(high / 256) % 256, math.floor(high / 65536) % 256, math.floor(high / 16777216) % 256)\n")
b.WriteString(" end\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = (n / math.pow(2, exp) - 1) * 4503599627370496\n")
b.WriteString(" mant = math.floor(mant + 0.5)\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 1023\n")
b.WriteString(" if exp >= 2047 then\n")
b.WriteString(" exp = 2047\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" local high_mant = math.floor(mant / 4294967296)\n")
b.WriteString(" local low = mant % 4294967296\n")
b.WriteString(" local high = sign + exp * 1048576 + high_mant\n")
b.WriteString(" return string.char(\n")
b.WriteString(" low % 256,\n")
b.WriteString(" math.floor(low / 256) % 256,\n")
b.WriteString(" math.floor(low / 65536) % 256,\n")
b.WriteString(" math.floor(low / 16777216) % 256,\n")
b.WriteString(" high % 256,\n")
b.WriteString(" math.floor(high / 256) % 256,\n")
b.WriteString(" math.floor(high / 65536) % 256,\n")
b.WriteString(" math.floor(high / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_bool(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for bool\") end\n")
b.WriteString(" return string.byte(data, offset) ~= 0, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_bool(v)\n")
b.WriteString(" return string.char(v and 1 or 0)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_string(data, offset)\n")
b.WriteString(" local len, header_bytes = read_u16_le(data, offset)\n")
b.WriteString(" if len == 0 then return '', header_bytes end\n")
b.WriteString(" local available = #data - offset - header_bytes + 1\n")
b.WriteString(" if available < len then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for string: need %d bytes, have %d\", len, available))\n")
b.WriteString(" end\n")
b.WriteString(" -- string.sub is 1-based; data starts at offset + header_bytes\n")
b.WriteString(" return string.sub(data, offset + header_bytes, offset + header_bytes + len - 1), header_bytes + len\n")
b.WriteString("end\n\n")
b.WriteString("local function write_string(s)\n")
b.WriteString(" local len = #s\n")
b.WriteString(" return write_u16_le(len) .. s\n")
b.WriteString("end\n\n")
}
func writeLuaEnum(b *strings.Builder, enum parser.Enum) {
fmt.Fprintf(b, "M.%s = {\n", enum.Name)
for i, value := range enum.Values {
fmt.Fprintf(b, " %s = %s", value.Name, value.Value)
if i < len(enum.Values)-1 {
b.WriteString(",")
}
b.WriteString("\n")
}
b.WriteString("}\n")
}
func writeLuaConstructor(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) {
fmt.Fprintf(b, "function M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" return {\n")
for _, f := range msg.Fields {
defaultValue := luaDefaultValue(f, enumNames)
fmt.Fprintf(b, " %s = %s,\n", luaFieldName(f.Name), defaultValue)
}
b.WriteString(" }\n")
b.WriteString("end\n")
}
func writeLuaSerializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
fmt.Fprintf(b, "function M.serialize_%s(msg)\n", toSnakeCase(msg.Name))
b.WriteString(" local parts = {}\n")
b.WriteString(" local part_idx = 0\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaSerializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupSerialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return table.concat(parts)\n")
b.WriteString("end\n")
return nil
}
func writeLuaDeserializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
minSize := packedMinWireSize(msg.Fields)
fmt.Fprintf(b, "function M.deserialize_%s(data, offset)\n", toSnakeCase(msg.Name))
b.WriteString(" offset = offset or 1\n")
fmt.Fprintf(b, " local msg = M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" local start_offset = offset\n")
b.WriteString(" local bytes_read = 0\n")
fmt.Fprintf(b, " if #data < offset + %d - 1 then\n", minSize)
fmt.Fprintf(b, " error(\"arpack: buffer too short for %s\")\n", msg.Name)
b.WriteString(" end\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaDeserializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupDeserialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return msg, offset - start_offset\n")
b.WriteString("end\n")
return nil
}
func writeLuaBoolGroupSerialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_bool_byte_%d", groupIdx)
fmt.Fprintf(b, "%slocal %s = 0\n", indent, varName)
for bit, f := range bools {
fmt.Fprintf(b, "%sif %s.%s then %s = bit.bor(%s, %d) end\n",
indent, recv, luaFieldName(f.Name), varName, varName, 1<<bit)
}
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
}
func writeLuaBoolGroupDeserialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_bool_byte_%d", groupIdx)
fmt.Fprintf(b, "%sif #data < offset then error(\"arpack: buffer too short for bool group\") end\n", indent)
fmt.Fprintf(b, "%slocal %s = string.byte(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 1\n", indent)
for bit, f := range bools {
fmt.Fprintf(b, "%s%s.%s = bit.band(%s, %d) ~= 0\n",
indent, recv, luaFieldName(f.Name), varName, 1<<bit)
}
}
func writeLuaSerializeField(b *strings.Builder, recv string, f parser.Field, indent string, enumNames map[string]struct{}) error {
access := recv + "." + luaFieldName(f.Name)
switch f.Kind {
case parser.KindPrimitive:
return writeLuaSerializePrimitive(b, access, f, indent, enumNames)
case parser.KindNested:
fmt.Fprintf(b, "%slocal _nested_%s = M.serialize_%s(%s)\n", indent, f.Name, toSnakeCase(f.TypeName), access)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = _nested_%s\n", indent, f.Name)
case parser.KindFixedArray:
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %d do\n", indent, iVar, f.FixedLen)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaSerializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
fmt.Fprintf(b, "%send\n", indent)
case parser.KindSlice:
lenVar := "_len_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%slocal %s = #(%s or {})\n", indent, lenVar, access)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, lenVar)
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
if f.Elem.Kind == parser.KindNested {
// For nested types in slices, serialize directly
fmt.Fprintf(b, "%slocal _nested_%s = M.serialize_%s(%s[%s])\n",
indent, f.Name, toSnakeCase(f.Elem.TypeName), access, iVar)
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = _nested_%s\n",
indent, f.Name)
} else {
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaSerializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
}
fmt.Fprintf(b, "%send\n", indent)
}
return nil
}
func writeLuaSerializePrimitive(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
if f.Quant != nil {
return writeLuaSerializeQuant(b, access, f, indent)
}
valueExpr := luaSerializeValueExpr(access, f, enumNames)
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_f32_le(%s)\n", indent, valueExpr)
case parser.KindFloat64:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_f64_le(%s)\n", indent, valueExpr)
case parser.KindInt8:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i8(%s)\n", indent, valueExpr)
case parser.KindUint8:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, valueExpr)
case parser.KindBool:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_bool(%s)\n", indent, valueExpr)
case parser.KindInt16:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i16_le(%s)\n", indent, valueExpr)
case parser.KindUint16:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, valueExpr)
case parser.KindInt32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_i32_le(%s)\n", indent, valueExpr)
case parser.KindUint32:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u32_le(%s)\n", indent, valueExpr)
case parser.KindInt64, parser.KindUint64:
return fmt.Errorf("int64/uint64 serialization not supported in Lua")
case parser.KindString:
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_string(%s or '')\n", indent, valueExpr)
}
return nil
}
func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
q := f.Quant
maxUint := q.MaxUint()
varName := "_q_" + sanitizeLuaVarName(access)
fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g + 0.5)\n",
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
if q.Bits == 8 {
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
} else {
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, varName)
}
return nil
}
func writeLuaDeserializeField(b *strings.Builder, recv string, f parser.Field, indent string, enumNames map[string]struct{}) error {
access := recv + "." + luaFieldName(f.Name)
switch f.Kind {
case parser.KindPrimitive:
return writeLuaDeserializePrimitive(b, access, f, indent, enumNames)
case parser.KindNested:
fmt.Fprintf(b, "%slocal _nested_%s, _n_%s = M.deserialize_%s(data, offset)\n", indent, f.Name, f.Name, toSnakeCase(f.TypeName))
fmt.Fprintf(b, "%s%s = _nested_%s\n", indent, access, f.Name)
fmt.Fprintf(b, "%soffset = offset + _n_%s\n", indent, f.Name)
case parser.KindFixedArray:
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%s%s = {}\n", indent, access)
fmt.Fprintf(b, "%sfor %s = 1, %d do\n", indent, iVar, f.FixedLen)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaDeserializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
fmt.Fprintf(b, "%send\n", indent)
case parser.KindSlice:
lenVar := "_len_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%slocal %s = read_u16_le(data, offset)\n", indent, lenVar)
fmt.Fprintf(b, "%soffset = offset + 2\n", indent)
fmt.Fprintf(b, "%s%s = {}\n", indent, access)
iVar := "_i_" + strings.ToLower(f.Name)
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
if f.Elem.Kind == parser.KindNested {
// For nested types in slices, deserialize directly
fmt.Fprintf(b, "%slocal _nested_%s, _n_%s = M.deserialize_%s(data, offset)\n",
indent, f.Name, f.Name, toSnakeCase(f.Elem.TypeName))
fmt.Fprintf(b, "%s%s[%s] = _nested_%s\n",
indent, access, iVar, f.Name)
fmt.Fprintf(b, "%soffset = offset + _n_%s\n",
indent, f.Name)
} else {
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeLuaDeserializeField(b, recv, elemField, indent+" ", enumNames); err != nil {
return err
}
}
fmt.Fprintf(b, "%send\n", indent)
}
return nil
}
func writeLuaDeserializePrimitive(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
if f.Quant != nil {
return writeLuaDeserializeQuant(b, access, f, indent, enumNames)
}
varName := "_v_" + sanitizeLuaVarName(access)
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%slocal %s, _n = read_f32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindFloat64:
fmt.Fprintf(b, "%slocal %s, _n = read_f64_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt8:
fmt.Fprintf(b, "%slocal %s, _n = read_i8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint8:
fmt.Fprintf(b, "%slocal %s, _n = read_u8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindBool:
fmt.Fprintf(b, "%slocal %s, _n = read_bool(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt16:
fmt.Fprintf(b, "%slocal %s, _n = read_i16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint16:
fmt.Fprintf(b, "%slocal %s, _n = read_u16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt32:
fmt.Fprintf(b, "%slocal %s, _n = read_i32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindUint32:
fmt.Fprintf(b, "%slocal %s, _n = read_u32_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
case parser.KindInt64, parser.KindUint64:
return fmt.Errorf("int64/uint64 deserialization not supported in Lua")
case parser.KindString:
fmt.Fprintf(b, "%slocal %s, _n = read_string(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(varName, f, enumNames))
fmt.Fprintf(b, "%soffset = offset + _n\n", indent)
}
return nil
}
func writeLuaDeserializeQuant(b *strings.Builder, access string, f parser.Field, indent string, enumNames map[string]struct{}) error {
q := f.Quant
maxUint := q.MaxUint()
varName := "_q_" + sanitizeLuaVarName(access)
if q.Bits == 8 {
fmt.Fprintf(b, "%slocal %s = read_u8(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 1\n", indent)
} else {
fmt.Fprintf(b, "%slocal %s = read_u16_le(data, offset)\n", indent, varName)
fmt.Fprintf(b, "%soffset = offset + 2\n", indent)
}
expr := fmt.Sprintf("%s / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, luaDeserializeValueExpr(expr, f, enumNames))
return nil
}
func luaFieldName(name string) string {
return toSnakeCase(name)
}
func toSnakeCase(s string) string {
if s == "" {
return ""
}
var b strings.Builder
var prevUpper bool
for i, c := range s {
isUpper := c >= 'A' && c <= 'Z'
// Add underscore before uppercase letter if:
// - It's not the first character
// - Previous character was lowercase, OR
// - Previous character was uppercase AND next character (if exists) is lowercase
// (this handles cases like "PlayerID" -> "player_id", not "player_i_d")
if i > 0 && isUpper {
nextLower := false
if i+1 < len(s) {
nextChar := rune(s[i+1])
nextLower = nextChar >= 'a' && nextChar <= 'z'
}
if !prevUpper || nextLower {
b.WriteByte('_')
}
}
b.WriteRune(c)
prevUpper = isUpper
}
return strings.ToLower(b.String())
}
func luaDefaultValue(f parser.Field, enumNames map[string]struct{}) string {
switch f.Kind {
case parser.KindPrimitive:
if luaIsEnumType(f, enumNames) {
return "0"
}
switch f.Primitive {
case parser.KindFloat32, parser.KindFloat64, parser.KindInt8, parser.KindInt16, parser.KindInt32,
parser.KindUint8, parser.KindUint16, parser.KindUint32, parser.KindInt64, parser.KindUint64:
return "0"
case parser.KindBool:
return "false"
case parser.KindString:
return "''"
}
case parser.KindNested:
return fmt.Sprintf("M.new_%s()", toSnakeCase(f.TypeName))
case parser.KindFixedArray, parser.KindSlice:
return "{}"
}
return "nil"
}
func luaTypeName(f parser.Field, enumNames map[string]struct{}) string {
switch f.Kind {
case parser.KindPrimitive:
if luaIsEnumType(f, enumNames) {
return "number"
}
return "number"
case parser.KindNested:
return f.TypeName
case parser.KindFixedArray, parser.KindSlice:
return "table"
}
return "any"
}
func luaSerializeValueExpr(access string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return access
}
return access
}
func luaDeserializeValueExpr(expr string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return expr
}
return expr
}
func luaIsEnumType(f parser.Field, enumNames map[string]struct{}) bool {
if f.NamedType == "" || enumNames == nil {
return false
}
_, ok := enumNames[f.NamedType]
return ok
}
func sanitizeLuaVarName(s string) string {
var b strings.Builder
for _, c := range s {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
b.WriteRune(c)
} else {
b.WriteRune('_')
}
}
return strings.ToLower(b.String())
}
func checkLuaUnsupportedTypes(schema parser.Schema) error {
for _, msg := range schema.Messages {
for _, f := range msg.Fields {
if err := checkFieldLuaSupport(msg.Name, f); err != nil {
return err
}
}
}
return nil
}
func checkFieldLuaSupport(msgName string, f parser.Field) error {
switch f.Kind {
case parser.KindPrimitive:
if f.Primitive == parser.KindInt64 || f.Primitive == parser.KindUint64 {
return fmt.Errorf("Lua target does not support int64/uint64: field %s in message %s (LuaJIT/Defold uses double-precision floats which can only safely represent integers up to 2^53)", f.Name, msgName)
}
case parser.KindFixedArray, parser.KindSlice:
if f.Elem != nil {
return checkFieldLuaSupport(msgName, *f.Elem)
}
}
return nil
}
+557
View File
@@ -0,0 +1,557 @@
package generator
import (
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/edmand46/arpack/parser"
)
func TestGenerateLua_BasicTypes(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "BasicTypes",
Fields: []parser.Field{
{Name: "Int8Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt8},
{Name: "Int16Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt16},
{Name: "Int32Field", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
{Name: "Uint8Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint8},
{Name: "Uint16Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint16},
{Name: "Uint32Field", Kind: parser.KindPrimitive, Primitive: parser.KindUint32},
{Name: "Float32Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Float64Field", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64},
{Name: "BoolField", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "StringField", Kind: parser.KindPrimitive, Primitive: parser.KindString},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "function M.new_basic_types()") {
t.Error("Missing constructor for BasicTypes")
}
if !strings.Contains(luaStr, "function M.serialize_basic_types(msg)") {
t.Error("Missing serializer for BasicTypes")
}
if !strings.Contains(luaStr, "function M.deserialize_basic_types(data, offset)") {
t.Error("Missing deserializer for BasicTypes")
}
if !strings.Contains(luaStr, "int8_field = 0") {
t.Error("Missing int8_field in constructor")
}
if !strings.Contains(luaStr, "string_field = ''") {
t.Error("Missing string_field default value")
}
if !strings.Contains(luaStr, "bool_field = false") {
t.Error("Missing bool_field default value")
}
}
func TestGenerateLua_Enum(t *testing.T) {
schema := parser.Schema{
Enums: []parser.Enum{
{
Name: "Opcode",
Primitive: parser.KindUint16,
Values: []parser.EnumValue{
{Name: "Unknown", Value: "0"},
{Name: "Join", Value: "1"},
{Name: "Leave", Value: "2"},
},
},
},
Messages: []parser.Message{
{
Name: "MessageWithEnum",
Fields: []parser.Field{
{Name: "Op", Kind: parser.KindPrimitive, Primitive: parser.KindUint16, NamedType: "Opcode"},
},
},
},
}
enumNames := map[string]struct{}{"Opcode": {}}
_ = enumNames
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "M.Opcode = {") {
t.Error("Missing Opcode enum table")
}
if !strings.Contains(luaStr, "Unknown = 0") {
t.Error("Missing Unknown enum value")
}
if !strings.Contains(luaStr, "Join = 1") {
t.Error("Missing Join enum value")
}
}
func TestGenerateLua_NestedMessage(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "Vector3",
Fields: []parser.Field{
{Name: "X", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Y", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "Z", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
},
},
{
Name: "Player",
Fields: []parser.Field{
{Name: "Position", Kind: parser.KindNested, TypeName: "Vector3"},
{Name: "Health", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "function M.new_vector3()") {
t.Error("Missing constructor for Vector3")
}
if !strings.Contains(luaStr, "function M.new_player()") {
t.Error("Missing constructor for Player")
}
if !strings.Contains(luaStr, "position = M.new_vector3()") {
t.Error("Missing nested initialization in Player constructor")
}
if !strings.Contains(luaStr, "M.serialize_vector3") {
t.Error("Missing Vector3 serializer call")
}
}
func TestGenerateLua_FixedArray(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithFixedArray",
Fields: []parser.Field{
{
Name: "Values",
Kind: parser.KindFixedArray,
FixedLen: 3,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindFloat32,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "values = {}") {
t.Error("Missing values array initialization")
}
if !strings.Contains(luaStr, "for _i_values = 1, 3 do") {
t.Error("Missing fixed array loop in serializer")
}
}
func TestGenerateLua_Slice(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithSlice",
Fields: []parser.Field{
{
Name: "Items",
Kind: parser.KindSlice,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindInt32,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "items = {}") {
t.Error("Missing items slice initialization")
}
if !strings.Contains(luaStr, "local _len_items = #(msg.items or {})") {
t.Error("Missing slice length serialization")
}
if !strings.Contains(luaStr, "for _i_items = 1, _len_items do") {
t.Error("Missing slice iteration in serializer")
}
}
func TestGenerateLua_BoolPacking(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithBools",
Fields: []parser.Field{
{Name: "A", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "B", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "C", Kind: parser.KindPrimitive, Primitive: parser.KindBool},
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt32},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "local _bool_byte_0 = 0") {
t.Error("Missing bool byte packing variable")
}
if !strings.Contains(luaStr, "if msg.a then _bool_byte_0 = bit.bor(_bool_byte_0, 1) end") {
t.Error("Missing first bool packing check with bit.bor")
}
if !strings.Contains(luaStr, "if msg.b then _bool_byte_0 = bit.bor(_bool_byte_0, 2) end") {
t.Error("Missing second bool packing check with bit.bor")
}
if !strings.Contains(luaStr, "msg.a = bit.band(_bool_byte_0, 1) ~= 0") {
t.Error("Missing bit.band for bool deserialization")
}
}
func TestGenerateLua_QuantizedFloat(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithQuantized",
Fields: []parser.Field{
{
Name: "Position",
Kind: parser.KindPrimitive,
Primitive: parser.KindFloat32,
Quant: &parser.QuantInfo{
Min: -500,
Max: 500,
Bits: 16,
},
},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
if !strings.Contains(luaStr, "math.floor") {
t.Error("Missing math.floor for quantization")
}
if !strings.Contains(luaStr, "write_u16_le") {
t.Error("Missing u16 write for 16-bit quantization")
}
}
func TestToSnakeCase(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{"A", "a"},
{"AB", "ab"},
{"AbCd", "ab_cd"},
{"ABC", "abc"},
{"PlayerID", "player_id"},
{"HTTPResponse", "http_response"},
{"XMLHttpRequest", "xml_http_request"},
{"getHTTPResponse", "get_http_response"},
}
for _, tt := range tests {
result := toSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("toSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
}
}
func TestLuaHelpersGenerated(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "Empty",
Fields: []parser.Field{},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
helpers := []string{
"local bit = require('bit')",
"buffer too short for u8",
"buffer too short for bool",
"local function write_u8(n)",
"buffer too short for u16",
"local function write_u16_le(n)",
"buffer too short for u32",
"local function write_u32_le(n)",
"local function read_f32_le(data, offset)",
"local function write_f32_le(n)",
"local function read_f64_le(data, offset)",
"local function write_f64_le(n)",
"local function write_bool(v)",
"buffer too short for string",
"local function write_string(s)",
}
for _, helper := range helpers {
if !strings.Contains(luaStr, helper) {
t.Errorf("Missing helper: %s", helper)
}
}
}
func TestGenerateLua_Int64NotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithInt64",
Fields: []parser.Field{
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindInt64},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for int64 field, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
if !strings.Contains(err.Error(), "LuaJIT/Defold") {
t.Errorf("Expected error mentioning LuaJIT/Defold, got: %v", err)
}
}
func TestGenerateLua_Uint64NotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithUint64",
Fields: []parser.Field{
{Name: "Value", Kind: parser.KindPrimitive, Primitive: parser.KindUint64},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for uint64 field, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
}
func TestGenerateLua_Int64InSliceNotSupported(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "WithInt64Slice",
Fields: []parser.Field{
{
Name: "Values",
Kind: parser.KindSlice,
Elem: &parser.Field{
Kind: parser.KindPrimitive,
Primitive: parser.KindInt64,
},
},
},
},
},
}
_, err := GenerateLuaSchema(schema, "test")
if err == nil {
t.Fatal("Expected error for int64 in slice, got nil")
}
if !strings.Contains(err.Error(), "int64/uint64") {
t.Errorf("Expected error mentioning int64/uint64, got: %v", err)
}
}
func TestGenerateLua_BoundsChecks(t *testing.T) {
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "SimpleMessage",
Fields: []parser.Field{
{Name: "ID", Kind: parser.KindPrimitive, Primitive: parser.KindUint32},
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "test")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
luaStr := string(lua)
// Check that bounds check function exists
if !strings.Contains(luaStr, "check_bounds") {
t.Error("Missing check_bounds function")
}
// Check that read_u16_le has bounds check
if !strings.Contains(luaStr, "buffer too short for u16") {
t.Error("Missing bounds check in read_u16_le")
}
// Check that read_u32_le has bounds check
if !strings.Contains(luaStr, "buffer too short for u32") {
t.Error("Missing bounds check in read_u32_le")
}
// Check that read_string has bounds check
if !strings.Contains(luaStr, "buffer too short for string") {
t.Error("Missing bounds check in read_string")
}
// Check that deserialize function has min size check (message name is preserved in error)
if !strings.Contains(luaStr, "buffer too short for SimpleMessage") {
t.Error("Missing min size check in deserialize function")
}
// Check that read_u8 has bounds check
if !strings.Contains(luaStr, "buffer too short for u8") {
t.Error("Missing bounds check in read_u8")
}
// Check that read_bool has bounds check
if !strings.Contains(luaStr, "buffer too short for bool") {
t.Error("Missing bounds check in read_bool")
}
// Check that read_i8 has bounds check
if !strings.Contains(luaStr, "buffer too short for i8") {
t.Error("Missing bounds check in read_i8")
}
}
func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) {
if _, err := exec.LookPath("luajit"); err != nil {
t.Skip("luajit not found")
}
schema := parser.Schema{
Messages: []parser.Message{
{
Name: "FloatEdges",
Fields: []parser.Field{
{Name: "F32", Kind: parser.KindPrimitive, Primitive: parser.KindFloat32},
{Name: "F64", Kind: parser.KindPrimitive, Primitive: parser.KindFloat64},
},
},
},
}
lua, err := GenerateLuaSchema(schema, "messages")
if err != nil {
t.Fatalf("GenerateLuaSchema failed: %v", err)
}
dir := t.TempDir()
modulePath := filepath.Join(dir, "messages_gen.lua")
if err := os.WriteFile(modulePath, lua, 0o600); err != nil {
t.Fatalf("write module: %v", err)
}
scriptPath := filepath.Join(dir, "check.lua")
script := `local messages = require("messages_gen")
local function bytes_to_hex(s)
return (s:gsub(".", function(c) return string.format("%02x", string.byte(c)) end))
end
local neg_zero = string.char(0, 0, 0, 128, 0, 0, 0, 0, 0, 0, 0, 128)
local msg = messages.deserialize_float_edges(neg_zero, 1)
print(bytes_to_hex(messages.serialize_float_edges(msg)))
local subnormal = string.char(1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0)
msg = messages.deserialize_float_edges(subnormal, 1)
print(bytes_to_hex(messages.serialize_float_edges(msg)))
`
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
t.Fatalf("write script: %v", err)
}
cmd := exec.Command("luajit", "check.lua")
cmd.Dir = dir
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("luajit failed: %v\n%s", err, out)
}
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
if len(lines) != 2 {
t.Fatalf("expected 2 output lines, got %d: %q", len(lines), string(out))
}
if lines[0] != "000000800000000000000080" {
t.Fatalf("negative zero roundtrip mismatch: %s", lines[0])
}
if lines[1] != "010000000100000000000000" {
t.Fatalf("subnormal roundtrip mismatch: %s", lines[1])
}
}