package generator import ( "fmt" "strings" "github.com/edmand46/arpack/parser" ) func GenerateLuaSchema(schema parser.Schema, moduleName string) ([]byte, error) { if err := checkLuaUnsupportedTypes(schema); err != nil { return nil, err } messages := schema.Messages var b strings.Builder b.WriteString("-- arpack \n") b.WriteString("-- Code generated by arpack. DO NOT EDIT.\n\n") b.WriteString("local M = {}\n\n") b.WriteString("-- Load BitOp library for bit operations (Defold/LuaJIT)\n") b.WriteString("local bit = require('bit')\n\n") writeLuaHelpers(&b) enumNames := make(map[string]struct{}, len(schema.Enums)) for _, enum := range schema.Enums { enumNames[enum.Name] = struct{}{} } for _, enum := range schema.Enums { writeLuaEnum(&b, enum) b.WriteString("\n") } for _, msg := range messages { writeLuaConstructor(&b, msg, enumNames) b.WriteString("\n") } for _, msg := range messages { if err := writeLuaSerializer(&b, msg, enumNames); err != nil { return nil, fmt.Errorf("message %s: %w", msg.Name, err) } b.WriteString("\n") } for _, msg := range messages { if err := writeLuaDeserializer(&b, msg, enumNames); err != nil { return nil, fmt.Errorf("message %s: %w", msg.Name, err) } b.WriteString("\n") } b.WriteString("return M\n") return []byte(b.String()), nil } func writeLuaHelpers(b *strings.Builder) { b.WriteString("-- Inline helpers for little-endian byte operations\n\n") b.WriteString("-- Error handling for truncated data\n") b.WriteString("local function check_bounds(data, offset, needed, context)\n") b.WriteString(" local available = #data - offset + 1\n") b.WriteString(" if available < needed then\n") b.WriteString(" error(string.format(\"arpack: buffer too short for %s: need %d bytes, have %d\", context, needed, available))\n") b.WriteString(" end\n") b.WriteString("end\n\n") b.WriteString("local function read_u8(data, offset)\n") b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n") b.WriteString(" return string.byte(data, offset), 1\n") b.WriteString("end\n\n") b.WriteString("local function write_u8(n)\n") b.WriteString(" return string.char(n)\n") b.WriteString("end\n\n") b.WriteString("local function read_u16_le(data, offset)\n") b.WriteString(" local b1, b2 = string.byte(data, offset, offset + 1)\n") b.WriteString(" if not b2 then error(\"arpack: buffer too short for u16\") end\n") b.WriteString(" return b1 + b2 * 256, 2\n") b.WriteString("end\n\n") b.WriteString("local function write_u16_le(n)\n") b.WriteString(" return string.char(n % 256, math.floor(n / 256))\n") b.WriteString("end\n\n") b.WriteString("local function read_u32_le(data, offset)\n") b.WriteString(" local b1, b2, b3, b4 = string.byte(data, offset, offset + 3)\n") b.WriteString(" if not b4 then error(\"arpack: buffer too short for u32\") end\n") b.WriteString(" return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216, 4\n") b.WriteString("end\n\n") b.WriteString("local function write_u32_le(n)\n") b.WriteString(" return string.char(\n") b.WriteString(" n % 256,\n") b.WriteString(" math.floor(n / 256) % 256,\n") b.WriteString(" math.floor(n / 65536) % 256,\n") b.WriteString(" math.floor(n / 16777216) % 256\n") b.WriteString(" )\n") b.WriteString("end\n\n") b.WriteString("local function read_i8(data, offset)\n") b.WriteString(" if offset > #data then error(\"arpack: buffer too short for i8\") end\n") b.WriteString(" local v = string.byte(data, offset)\n") b.WriteString(" if v >= 128 then v = v - 256 end\n") b.WriteString(" return v, 1\n") b.WriteString("end\n\n") b.WriteString("local function write_i8(n)\n") b.WriteString(" if n < 0 then n = n + 256 end\n") b.WriteString(" return string.char(n)\n") b.WriteString("end\n\n") b.WriteString("local function read_i16_le(data, offset)\n") b.WriteString(" local v = read_u16_le(data, offset)\n") b.WriteString(" if v >= 32768 then v = v - 65536 end\n") b.WriteString(" return v, 2\n") b.WriteString("end\n\n") b.WriteString("local function write_i16_le(n)\n") b.WriteString(" if n < 0 then n = n + 65536 end\n") b.WriteString(" return write_u16_le(n)\n") b.WriteString("end\n\n") b.WriteString("local function read_i32_le(data, offset)\n") b.WriteString(" local v = read_u32_le(data, offset)\n") b.WriteString(" if v >= 2147483648 then v = v - 4294967296 end\n") b.WriteString(" return v, 4\n") b.WriteString("end\n\n") b.WriteString("local function write_i32_le(n)\n") b.WriteString(" if n < 0 then n = n + 4294967296 end\n") b.WriteString(" return write_u32_le(n)\n") b.WriteString("end\n\n") b.WriteString("local function read_f32_le(data, offset)\n") b.WriteString(" local u32 = read_u32_le(data, offset)\n") b.WriteString(" if u32 == 0 then return 0.0, 4 end\n") b.WriteString(" local sign = (u32 >= 2147483648) and -1 or 1\n") b.WriteString(" if sign < 0 then u32 = u32 - 2147483648 end\n") b.WriteString(" local exp = math.floor(u32 / 8388608) % 256\n") b.WriteString(" local mant = u32 % 8388608\n") b.WriteString(" if exp == 0 then\n") b.WriteString(" if mant == 0 then\n") b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 4\n") b.WriteString(" end\n") b.WriteString(" return sign * (mant / 8388608) * math.pow(2, -126), 4\n") b.WriteString(" elseif exp == 255 then\n") b.WriteString(" if mant == 0 then return sign * math.huge, 4\n") b.WriteString(" else return 0.0 / 0.0, 4 end\n") b.WriteString(" end\n") b.WriteString(" return sign * (1 + mant / 8388608) * math.pow(2, exp - 127), 4\n") b.WriteString("end\n\n") b.WriteString("local function write_f32_le(n)\n") b.WriteString(" if n ~= n then return write_u32_le(2143289344) end\n") b.WriteString(" if n == math.huge then return write_u32_le(2139095040) end\n") b.WriteString(" if n == -math.huge then return write_u32_le(4286578688) end\n") b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n") b.WriteString(" if n == 0 then\n") b.WriteString(" if 1/n == -math.huge then return write_u32_le(2147483648) end\n") b.WriteString(" return write_u32_le(0)\n") b.WriteString(" end\n") b.WriteString(" local sign = 0\n") b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n") b.WriteString(" local exp\n") b.WriteString(" local mant\n") b.WriteString(" if n < math.pow(2, -126) then\n") b.WriteString(" mant = math.floor(n / math.pow(2, -149) + 0.5)\n") b.WriteString(" if mant <= 0 then return write_u32_le(sign) end\n") b.WriteString(" if mant >= 8388608 then\n") b.WriteString(" exp = 1\n") b.WriteString(" mant = 0\n") b.WriteString(" else\n") b.WriteString(" exp = 0\n") b.WriteString(" end\n") b.WriteString(" else\n") b.WriteString(" exp = math.floor(math.log(n, 2))\n") b.WriteString(" mant = math.floor((n / math.pow(2, exp) - 1) * 8388608 + 0.5)\n") b.WriteString(" if mant >= 8388608 then\n") b.WriteString(" exp = exp + 1\n") b.WriteString(" mant = 0\n") b.WriteString(" end\n") b.WriteString(" exp = exp + 127\n") b.WriteString(" if exp >= 255 then\n") b.WriteString(" return write_u32_le(sign + 2139095040)\n") b.WriteString(" end\n") b.WriteString(" end\n") b.WriteString(" return write_u32_le(sign + exp * 8388608 + mant)\n") b.WriteString("end\n\n") b.WriteString("local function read_f64_le(data, offset)\n") b.WriteString(" -- Read 8 bytes directly to avoid precision loss from 64-bit arithmetic\n") b.WriteString(" local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(data, offset, offset + 7)\n") b.WriteString(" if not b8 then error(\"arpack: buffer too short for f64\") end\n") b.WriteString(" local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216\n") b.WriteString(" local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216\n") b.WriteString(" -- Decode IEEE 754 double from low/high parts separately\n") b.WriteString(" if low == 0 and high == 0 then return 0.0, 8 end\n") b.WriteString(" local sign = (high >= 2147483648) and -1 or 1\n") b.WriteString(" if sign < 0 then high = high - 2147483648 end\n") b.WriteString(" local exp = math.floor(high / 1048576) % 2048\n") b.WriteString(" local high_mant = high % 1048576\n") b.WriteString(" if exp == 0 then\n") b.WriteString(" local mant = high_mant * 4294967296 + low\n") b.WriteString(" if mant == 0 then\n") b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 8\n") b.WriteString(" end\n") b.WriteString(" return sign * (mant / 4503599627370496) * math.pow(2, -1022), 8\n") b.WriteString(" elseif exp == 2047 then\n") b.WriteString(" if high_mant == 0 and low == 0 then return sign * math.huge, 8\n") b.WriteString(" else return 0.0 / 0.0, 8 end\n") b.WriteString(" end\n") b.WriteString(" local mant = high_mant * 4294967296 + low\n") b.WriteString(" return sign * (1 + mant / 4503599627370496) * math.pow(2, exp - 1023), 8\n") b.WriteString("end\n\n") b.WriteString("local function write_f64_le(n)\n") b.WriteString(" -- Handle special values\n") b.WriteString(" if n ~= n then -- NaN\n") b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 248, 127)\n") b.WriteString(" end\n") b.WriteString(" if n == math.huge then\n") b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 127)\n") b.WriteString(" end\n") b.WriteString(" if n == -math.huge then\n") b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 255)\n") b.WriteString(" end\n") b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n") b.WriteString(" if n == 0 then\n") b.WriteString(" if 1/n == -math.huge then\n") b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 128)\n") b.WriteString(" end\n") b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 0)\n") b.WriteString(" end\n") b.WriteString(" local sign = 0\n") b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n") b.WriteString(" local exp\n") b.WriteString(" local mant\n") b.WriteString(" if n < math.pow(2, -1022) then\n") b.WriteString(" mant = math.floor(n / math.pow(2, -1074) + 0.5)\n") b.WriteString(" if mant <= 0 then\n") b.WriteString(" local high = sign\n") b.WriteString(" return string.char(0, 0, 0, 0, high % 256, math.floor(high / 256) % 256, math.floor(high / 65536) % 256, math.floor(high / 16777216) % 256)\n") b.WriteString(" end\n") b.WriteString(" if mant >= 4503599627370496 then\n") b.WriteString(" exp = 1\n") b.WriteString(" mant = 0\n") b.WriteString(" else\n") b.WriteString(" exp = 0\n") b.WriteString(" end\n") b.WriteString(" else\n") b.WriteString(" exp = math.floor(math.log(n, 2))\n") b.WriteString(" mant = (n / math.pow(2, exp) - 1) * 4503599627370496\n") b.WriteString(" mant = math.floor(mant + 0.5)\n") b.WriteString(" if mant >= 4503599627370496 then\n") b.WriteString(" exp = exp + 1\n") b.WriteString(" mant = 0\n") b.WriteString(" end\n") b.WriteString(" exp = exp + 1023\n") b.WriteString(" if exp >= 2047 then\n") b.WriteString(" exp = 2047\n") b.WriteString(" mant = 0\n") b.WriteString(" end\n") b.WriteString(" end\n") b.WriteString(" local high_mant = math.floor(mant / 4294967296)\n") b.WriteString(" local low = mant % 4294967296\n") b.WriteString(" local high = sign + exp * 1048576 + high_mant\n") b.WriteString(" return string.char(\n") b.WriteString(" low % 256,\n") b.WriteString(" math.floor(low / 256) % 256,\n") b.WriteString(" math.floor(low / 65536) % 256,\n") b.WriteString(" math.floor(low / 16777216) % 256,\n") b.WriteString(" high % 256,\n") b.WriteString(" math.floor(high / 256) % 256,\n") b.WriteString(" math.floor(high / 65536) % 256,\n") b.WriteString(" math.floor(high / 16777216) % 256\n") b.WriteString(" )\n") b.WriteString("end\n\n") b.WriteString("local function read_bool(data, offset)\n") b.WriteString(" if offset > #data then error(\"arpack: buffer too short for bool\") end\n") b.WriteString(" return string.byte(data, offset) ~= 0, 1\n") b.WriteString("end\n\n") b.WriteString("local function write_bool(v)\n") b.WriteString(" return string.char(v and 1 or 0)\n") b.WriteString("end\n\n") b.WriteString("local function read_string(data, offset)\n") b.WriteString(" local len, header_bytes = read_u16_le(data, offset)\n") b.WriteString(" if len == 0 then return '', header_bytes end\n") b.WriteString(" local available = #data - offset - header_bytes + 1\n") b.WriteString(" if available < len then\n") b.WriteString(" error(string.format(\"arpack: buffer too short for string: need %d bytes, have %d\", len, available))\n") b.WriteString(" end\n") b.WriteString(" -- string.sub is 1-based; data starts at offset + header_bytes\n") b.WriteString(" return string.sub(data, offset + header_bytes, offset + header_bytes + len - 1), header_bytes + len\n") b.WriteString("end\n\n") b.WriteString("local function write_string(s)\n") b.WriteString(" local len = #s\n") b.WriteString(" return write_u16_le(len) .. s\n") b.WriteString("end\n\n") } func writeLuaEnum(b *strings.Builder, enum parser.Enum) { fmt.Fprintf(b, "M.%s = {\n", enum.Name) for i, value := range enum.Values { fmt.Fprintf(b, " %s = %s", value.Name, value.Value) if i < len(enum.Values)-1 { b.WriteString(",") } b.WriteString("\n") } b.WriteString("}\n") } func writeLuaConstructor(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) { fmt.Fprintf(b, "function M.new_%s()\n", toSnakeCase(msg.Name)) b.WriteString(" return {\n") for _, f := range msg.Fields { defaultValue := luaDefaultValue(f, enumNames) fmt.Fprintf(b, " %s = %s,\n", luaFieldName(f.Name), defaultValue) } b.WriteString(" }\n") b.WriteString("end\n") } func writeLuaSerializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error { segs := segmentFields(msg.Fields) fmt.Fprintf(b, "function M.serialize_%s(msg)\n", toSnakeCase(msg.Name)) b.WriteString(" local parts = {}\n") b.WriteString(" local part_idx = 0\n") for i, seg := range segs { if seg.single != nil { if err := writeLuaSerializeField(b, "msg", *seg.single, " ", enumNames); err != nil { return err } } else { writeLuaBoolGroupSerialize(b, "msg", seg.bools, i, " ") } } b.WriteString(" return table.concat(parts)\n") b.WriteString("end\n") return nil } func writeLuaDeserializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error { segs := segmentFields(msg.Fields) minSize := packedMinWireSize(msg.Fields) fmt.Fprintf(b, "function M.deserialize_%s(data, offset)\n", toSnakeCase(msg.Name)) b.WriteString(" offset = offset or 1\n") fmt.Fprintf(b, " local msg = M.new_%s()\n", toSnakeCase(msg.Name)) b.WriteString(" local start_offset = offset\n") b.WriteString(" local bytes_read = 0\n") fmt.Fprintf(b, " if #data < offset + %d - 1 then\n", minSize) fmt.Fprintf(b, " error(\"arpack: buffer too short for %s\")\n", msg.Name) b.WriteString(" end\n") for i, seg := range segs { if seg.single != nil { if err := writeLuaDeserializeField(b, "msg", *seg.single, " ", enumNames); err != nil { return err } } else { writeLuaBoolGroupDeserialize(b, "msg", seg.bools, i, " ") } } b.WriteString(" return msg, offset - start_offset\n") b.WriteString("end\n") return nil } func writeLuaBoolGroupSerialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) { varName := fmt.Sprintf("_bool_byte_%d", groupIdx) fmt.Fprintf(b, "%slocal %s = 0\n", indent, varName) for bit, f := range bools { fmt.Fprintf(b, "%sif %s.%s then %s = bit.bor(%s, %d) end\n", indent, recv, luaFieldName(f.Name), varName, varName, 1<= 'A' && c <= 'Z' // Add underscore before uppercase letter if: // - It's not the first character // - Previous character was lowercase, OR // - Previous character was uppercase AND next character (if exists) is lowercase // (this handles cases like "PlayerID" -> "player_id", not "player_i_d") if i > 0 && isUpper { nextLower := false if i+1 < len(s) { nextChar := rune(s[i+1]) nextLower = nextChar >= 'a' && nextChar <= 'z' } if !prevUpper || nextLower { b.WriteByte('_') } } b.WriteRune(c) prevUpper = isUpper } return strings.ToLower(b.String()) } func luaDefaultValue(f parser.Field, enumNames map[string]struct{}) string { switch f.Kind { case parser.KindPrimitive: if luaIsEnumType(f, enumNames) { return "0" } switch f.Primitive { case parser.KindFloat32, parser.KindFloat64, parser.KindInt8, parser.KindInt16, parser.KindInt32, parser.KindUint8, parser.KindUint16, parser.KindUint32, parser.KindInt64, parser.KindUint64: return "0" case parser.KindBool: return "false" case parser.KindString: return "''" } case parser.KindNested: return fmt.Sprintf("M.new_%s()", toSnakeCase(f.TypeName)) case parser.KindFixedArray, parser.KindSlice: return "{}" } return "nil" } func luaSerializeValueExpr(access string, f parser.Field, enumNames map[string]struct{}) string { if !luaIsEnumType(f, enumNames) { return access } return access } func luaDeserializeValueExpr(expr string, f parser.Field, enumNames map[string]struct{}) string { if !luaIsEnumType(f, enumNames) { return expr } return expr } func luaIsEnumType(f parser.Field, enumNames map[string]struct{}) bool { if f.NamedType == "" || enumNames == nil { return false } _, ok := enumNames[f.NamedType] return ok } func sanitizeLuaVarName(s string) string { var b strings.Builder for _, c := range s { if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') { b.WriteRune(c) } else { b.WriteRune('_') } } return strings.ToLower(b.String()) } func checkLuaUnsupportedTypes(schema parser.Schema) error { for _, msg := range schema.Messages { for _, f := range msg.Fields { if err := checkFieldLuaSupport(msg.Name, f); err != nil { return err } } } return nil } func checkFieldLuaSupport(msgName string, f parser.Field) error { switch f.Kind { case parser.KindPrimitive: if f.Primitive == parser.KindInt64 || f.Primitive == parser.KindUint64 { return fmt.Errorf("lua target does not support int64/uint64: field %s in message %s (LuaJIT/Defold uses double-precision floats which can only safely represent integers up to 2^53)", f.Name, msgName) } case parser.KindFixedArray, parser.KindSlice: if f.Elem != nil { return checkFieldLuaSupport(msgName, *f.Elem) } } return nil }