package generator
import (
"fmt"
"strings"
"github.com/edmand46/arpack/parser"
)
func GenerateLuaSchema(schema parser.Schema, moduleName string) ([]byte, error) {
if err := checkLuaUnsupportedTypes(schema); err != nil {
return nil, err
}
messages := schema.Messages
var b strings.Builder
b.WriteString("-- arpack \n")
b.WriteString("-- Code generated by arpack. DO NOT EDIT.\n\n")
b.WriteString("local M = {}\n\n")
b.WriteString("-- Load BitOp library for bit operations (Defold/LuaJIT)\n")
b.WriteString("local bit = require('bit')\n\n")
writeLuaHelpers(&b)
enumNames := make(map[string]struct{}, len(schema.Enums))
for _, enum := range schema.Enums {
enumNames[enum.Name] = struct{}{}
}
for _, enum := range schema.Enums {
writeLuaEnum(&b, enum)
b.WriteString("\n")
}
for _, msg := range messages {
writeLuaConstructor(&b, msg, enumNames)
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaSerializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
for _, msg := range messages {
if err := writeLuaDeserializer(&b, msg, enumNames); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
b.WriteString("\n")
}
b.WriteString("return M\n")
return []byte(b.String()), nil
}
func writeLuaHelpers(b *strings.Builder) {
b.WriteString("-- Inline helpers for little-endian byte operations\n\n")
b.WriteString("-- Error handling for truncated data\n")
b.WriteString("local function check_bounds(data, offset, needed, context)\n")
b.WriteString(" local available = #data - offset + 1\n")
b.WriteString(" if available < needed then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for %s: need %d bytes, have %d\", context, needed, available))\n")
b.WriteString(" end\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
b.WriteString(" return string.byte(data, offset), 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u8(n)\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u16_le(data, offset)\n")
b.WriteString(" local b1, b2 = string.byte(data, offset, offset + 1)\n")
b.WriteString(" if not b2 then error(\"arpack: buffer too short for u16\") end\n")
b.WriteString(" return b1 + b2 * 256, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u16_le(n)\n")
b.WriteString(" return string.char(n % 256, math.floor(n / 256))\n")
b.WriteString("end\n\n")
b.WriteString("local function read_u32_le(data, offset)\n")
b.WriteString(" local b1, b2, b3, b4 = string.byte(data, offset, offset + 3)\n")
b.WriteString(" if not b4 then error(\"arpack: buffer too short for u32\") end\n")
b.WriteString(" return b1 + b2 * 256 + b3 * 65536 + b4 * 16777216, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_u32_le(n)\n")
b.WriteString(" return string.char(\n")
b.WriteString(" n % 256,\n")
b.WriteString(" math.floor(n / 256) % 256,\n")
b.WriteString(" math.floor(n / 65536) % 256,\n")
b.WriteString(" math.floor(n / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i8(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for i8\") end\n")
b.WriteString(" local v = string.byte(data, offset)\n")
b.WriteString(" if v >= 128 then v = v - 256 end\n")
b.WriteString(" return v, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i8(n)\n")
b.WriteString(" if n < 0 then n = n + 256 end\n")
b.WriteString(" return string.char(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i16_le(data, offset)\n")
b.WriteString(" local v = read_u16_le(data, offset)\n")
b.WriteString(" if v >= 32768 then v = v - 65536 end\n")
b.WriteString(" return v, 2\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i16_le(n)\n")
b.WriteString(" if n < 0 then n = n + 65536 end\n")
b.WriteString(" return write_u16_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_i32_le(data, offset)\n")
b.WriteString(" local v = read_u32_le(data, offset)\n")
b.WriteString(" if v >= 2147483648 then v = v - 4294967296 end\n")
b.WriteString(" return v, 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_i32_le(n)\n")
b.WriteString(" if n < 0 then n = n + 4294967296 end\n")
b.WriteString(" return write_u32_le(n)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f32_le(data, offset)\n")
b.WriteString(" local u32 = read_u32_le(data, offset)\n")
b.WriteString(" if u32 == 0 then return 0.0, 4 end\n")
b.WriteString(" local sign = (u32 >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then u32 = u32 - 2147483648 end\n")
b.WriteString(" local exp = math.floor(u32 / 8388608) % 256\n")
b.WriteString(" local mant = u32 % 8388608\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 4\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 8388608) * math.pow(2, -126), 4\n")
b.WriteString(" elseif exp == 255 then\n")
b.WriteString(" if mant == 0 then return sign * math.huge, 4\n")
b.WriteString(" else return 0.0 / 0.0, 4 end\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (1 + mant / 8388608) * math.pow(2, exp - 127), 4\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f32_le(n)\n")
b.WriteString(" if n ~= n then return write_u32_le(2143289344) end\n")
b.WriteString(" if n == math.huge then return write_u32_le(2139095040) end\n")
b.WriteString(" if n == -math.huge then return write_u32_le(4286578688) end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then return write_u32_le(2147483648) end\n")
b.WriteString(" return write_u32_le(0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -126) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -149) + 0.5)\n")
b.WriteString(" if mant <= 0 then return write_u32_le(sign) end\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = math.floor((n / math.pow(2, exp) - 1) * 8388608 + 0.5)\n")
b.WriteString(" if mant >= 8388608 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 127\n")
b.WriteString(" if exp >= 255 then\n")
b.WriteString(" return write_u32_le(sign + 2139095040)\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" return write_u32_le(sign + exp * 8388608 + mant)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_f64_le(data, offset)\n")
b.WriteString(" -- Read 8 bytes directly to avoid precision loss from 64-bit arithmetic\n")
b.WriteString(" local b1, b2, b3, b4, b5, b6, b7, b8 = string.byte(data, offset, offset + 7)\n")
b.WriteString(" if not b8 then error(\"arpack: buffer too short for f64\") end\n")
b.WriteString(" local low = b1 + b2 * 256 + b3 * 65536 + b4 * 16777216\n")
b.WriteString(" local high = b5 + b6 * 256 + b7 * 65536 + b8 * 16777216\n")
b.WriteString(" -- Decode IEEE 754 double from low/high parts separately\n")
b.WriteString(" if low == 0 and high == 0 then return 0.0, 8 end\n")
b.WriteString(" local sign = (high >= 2147483648) and -1 or 1\n")
b.WriteString(" if sign < 0 then high = high - 2147483648 end\n")
b.WriteString(" local exp = math.floor(high / 1048576) % 2048\n")
b.WriteString(" local high_mant = high % 1048576\n")
b.WriteString(" if exp == 0 then\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" if mant == 0 then\n")
b.WriteString(" return sign < 0 and (-1 / math.huge) or 0.0, 8\n")
b.WriteString(" end\n")
b.WriteString(" return sign * (mant / 4503599627370496) * math.pow(2, -1022), 8\n")
b.WriteString(" elseif exp == 2047 then\n")
b.WriteString(" if high_mant == 0 and low == 0 then return sign * math.huge, 8\n")
b.WriteString(" else return 0.0 / 0.0, 8 end\n")
b.WriteString(" end\n")
b.WriteString(" local mant = high_mant * 4294967296 + low\n")
b.WriteString(" return sign * (1 + mant / 4503599627370496) * math.pow(2, exp - 1023), 8\n")
b.WriteString("end\n\n")
b.WriteString("local function write_f64_le(n)\n")
b.WriteString(" -- Handle special values\n")
b.WriteString(" if n ~= n then -- NaN\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 248, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 127)\n")
b.WriteString(" end\n")
b.WriteString(" if n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 240, 255)\n")
b.WriteString(" end\n")
b.WriteString(" -- Check for negative zero: 1/-0.0 == -math.huge\n")
b.WriteString(" if n == 0 then\n")
b.WriteString(" if 1/n == -math.huge then\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 128)\n")
b.WriteString(" end\n")
b.WriteString(" return string.char(0, 0, 0, 0, 0, 0, 0, 0)\n")
b.WriteString(" end\n")
b.WriteString(" local sign = 0\n")
b.WriteString(" if n < 0 then sign = 2147483648; n = -n end\n")
b.WriteString(" local exp\n")
b.WriteString(" local mant\n")
b.WriteString(" if n < math.pow(2, -1022) then\n")
b.WriteString(" mant = math.floor(n / math.pow(2, -1074) + 0.5)\n")
b.WriteString(" if mant <= 0 then\n")
b.WriteString(" local high = sign\n")
b.WriteString(" return string.char(0, 0, 0, 0, high % 256, math.floor(high / 256) % 256, math.floor(high / 65536) % 256, math.floor(high / 16777216) % 256)\n")
b.WriteString(" end\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" else\n")
b.WriteString(" exp = 0\n")
b.WriteString(" end\n")
b.WriteString(" else\n")
b.WriteString(" exp = math.floor(math.log(n, 2))\n")
b.WriteString(" mant = (n / math.pow(2, exp) - 1) * 4503599627370496\n")
b.WriteString(" mant = math.floor(mant + 0.5)\n")
b.WriteString(" if mant >= 4503599627370496 then\n")
b.WriteString(" exp = exp + 1\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" exp = exp + 1023\n")
b.WriteString(" if exp >= 2047 then\n")
b.WriteString(" exp = 2047\n")
b.WriteString(" mant = 0\n")
b.WriteString(" end\n")
b.WriteString(" end\n")
b.WriteString(" local high_mant = math.floor(mant / 4294967296)\n")
b.WriteString(" local low = mant % 4294967296\n")
b.WriteString(" local high = sign + exp * 1048576 + high_mant\n")
b.WriteString(" return string.char(\n")
b.WriteString(" low % 256,\n")
b.WriteString(" math.floor(low / 256) % 256,\n")
b.WriteString(" math.floor(low / 65536) % 256,\n")
b.WriteString(" math.floor(low / 16777216) % 256,\n")
b.WriteString(" high % 256,\n")
b.WriteString(" math.floor(high / 256) % 256,\n")
b.WriteString(" math.floor(high / 65536) % 256,\n")
b.WriteString(" math.floor(high / 16777216) % 256\n")
b.WriteString(" )\n")
b.WriteString("end\n\n")
b.WriteString("local function read_bool(data, offset)\n")
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for bool\") end\n")
b.WriteString(" return string.byte(data, offset) ~= 0, 1\n")
b.WriteString("end\n\n")
b.WriteString("local function write_bool(v)\n")
b.WriteString(" return string.char(v and 1 or 0)\n")
b.WriteString("end\n\n")
b.WriteString("local function read_string(data, offset)\n")
b.WriteString(" local len, header_bytes = read_u16_le(data, offset)\n")
b.WriteString(" if len == 0 then return '', header_bytes end\n")
b.WriteString(" local available = #data - offset - header_bytes + 1\n")
b.WriteString(" if available < len then\n")
b.WriteString(" error(string.format(\"arpack: buffer too short for string: need %d bytes, have %d\", len, available))\n")
b.WriteString(" end\n")
b.WriteString(" -- string.sub is 1-based; data starts at offset + header_bytes\n")
b.WriteString(" return string.sub(data, offset + header_bytes, offset + header_bytes + len - 1), header_bytes + len\n")
b.WriteString("end\n\n")
b.WriteString("local function write_string(s)\n")
b.WriteString(" local len = #s\n")
b.WriteString(" return write_u16_le(len) .. s\n")
b.WriteString("end\n\n")
}
func writeLuaEnum(b *strings.Builder, enum parser.Enum) {
fmt.Fprintf(b, "M.%s = {\n", enum.Name)
for i, value := range enum.Values {
fmt.Fprintf(b, " %s = %s", value.Name, value.Value)
if i < len(enum.Values)-1 {
b.WriteString(",")
}
b.WriteString("\n")
}
b.WriteString("}\n")
}
func writeLuaConstructor(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) {
fmt.Fprintf(b, "function M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" return {\n")
for _, f := range msg.Fields {
defaultValue := luaDefaultValue(f, enumNames)
fmt.Fprintf(b, " %s = %s,\n", luaFieldName(f.Name), defaultValue)
}
b.WriteString(" }\n")
b.WriteString("end\n")
}
func writeLuaSerializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
fmt.Fprintf(b, "function M.serialize_%s(msg)\n", toSnakeCase(msg.Name))
b.WriteString(" local parts = {}\n")
b.WriteString(" local part_idx = 0\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaSerializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupSerialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return table.concat(parts)\n")
b.WriteString("end\n")
return nil
}
func writeLuaDeserializer(b *strings.Builder, msg parser.Message, enumNames map[string]struct{}) error {
segs := segmentFields(msg.Fields)
minSize := packedMinWireSize(msg.Fields)
fmt.Fprintf(b, "function M.deserialize_%s(data, offset)\n", toSnakeCase(msg.Name))
b.WriteString(" offset = offset or 1\n")
fmt.Fprintf(b, " local msg = M.new_%s()\n", toSnakeCase(msg.Name))
b.WriteString(" local start_offset = offset\n")
b.WriteString(" local bytes_read = 0\n")
fmt.Fprintf(b, " if #data < offset + %d - 1 then\n", minSize)
fmt.Fprintf(b, " error(\"arpack: buffer too short for %s\")\n", msg.Name)
b.WriteString(" end\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeLuaDeserializeField(b, "msg", *seg.single, " ", enumNames); err != nil {
return err
}
} else {
writeLuaBoolGroupDeserialize(b, "msg", seg.bools, i, " ")
}
}
b.WriteString(" return msg, offset - start_offset\n")
b.WriteString("end\n")
return nil
}
func writeLuaBoolGroupSerialize(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_bool_byte_%d", groupIdx)
fmt.Fprintf(b, "%slocal %s = 0\n", indent, varName)
for bit, f := range bools {
fmt.Fprintf(b, "%sif %s.%s then %s = bit.bor(%s, %d) end\n",
indent, recv, luaFieldName(f.Name), varName, varName, 1<= 'A' && c <= 'Z'
// Add underscore before uppercase letter if:
// - It's not the first character
// - Previous character was lowercase, OR
// - Previous character was uppercase AND next character (if exists) is lowercase
// (this handles cases like "PlayerID" -> "player_id", not "player_i_d")
if i > 0 && isUpper {
nextLower := false
if i+1 < len(s) {
nextChar := rune(s[i+1])
nextLower = nextChar >= 'a' && nextChar <= 'z'
}
if !prevUpper || nextLower {
b.WriteByte('_')
}
}
b.WriteRune(c)
prevUpper = isUpper
}
return strings.ToLower(b.String())
}
func luaDefaultValue(f parser.Field, enumNames map[string]struct{}) string {
switch f.Kind {
case parser.KindPrimitive:
if luaIsEnumType(f, enumNames) {
return "0"
}
switch f.Primitive {
case parser.KindFloat32, parser.KindFloat64, parser.KindInt8, parser.KindInt16, parser.KindInt32,
parser.KindUint8, parser.KindUint16, parser.KindUint32, parser.KindInt64, parser.KindUint64:
return "0"
case parser.KindBool:
return "false"
case parser.KindString:
return "''"
}
case parser.KindNested:
return fmt.Sprintf("M.new_%s()", toSnakeCase(f.TypeName))
case parser.KindFixedArray, parser.KindSlice:
return "{}"
}
return "nil"
}
func luaSerializeValueExpr(access string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return access
}
return access
}
func luaDeserializeValueExpr(expr string, f parser.Field, enumNames map[string]struct{}) string {
if !luaIsEnumType(f, enumNames) {
return expr
}
return expr
}
func luaIsEnumType(f parser.Field, enumNames map[string]struct{}) bool {
if f.NamedType == "" || enumNames == nil {
return false
}
_, ok := enumNames[f.NamedType]
return ok
}
func sanitizeLuaVarName(s string) string {
var b strings.Builder
for _, c := range s {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
b.WriteRune(c)
} else {
b.WriteRune('_')
}
}
return strings.ToLower(b.String())
}
func checkLuaUnsupportedTypes(schema parser.Schema) error {
for _, msg := range schema.Messages {
for _, f := range msg.Fields {
if err := checkFieldLuaSupport(msg.Name, f); err != nil {
return err
}
}
}
return nil
}
func checkFieldLuaSupport(msgName string, f parser.Field) error {
switch f.Kind {
case parser.KindPrimitive:
if f.Primitive == parser.KindInt64 || f.Primitive == parser.KindUint64 {
return fmt.Errorf("lua target does not support int64/uint64: field %s in message %s (LuaJIT/Defold uses double-precision floats which can only safely represent integers up to 2^53)", f.Name, msgName)
}
case parser.KindFixedArray, parser.KindSlice:
if f.Elem != nil {
return checkFieldLuaSupport(msgName, *f.Elem)
}
}
return nil
}