v1.0.0
This commit is contained in:
+33
-2
@@ -14,6 +14,8 @@ func GenerateCSharp(messages []parser.Message, namespace string) ([]byte, error)
|
||||
func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
||||
messages := schema.Messages
|
||||
var b strings.Builder
|
||||
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||
|
||||
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n")
|
||||
@@ -27,6 +29,30 @@ func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error
|
||||
|
||||
fmt.Fprintf(&b, "namespace %s\n{\n", namespace)
|
||||
|
||||
if needsLengthGuards || needsQuantGuards {
|
||||
b.WriteString(" internal static class ArpackGenerated\n {\n")
|
||||
if needsLengthGuards {
|
||||
b.WriteString(" internal static ushort EnsureU16Length(int length, string context)\n")
|
||||
b.WriteString(" {\n")
|
||||
b.WriteString(" if (length > 65535)\n")
|
||||
b.WriteString(" {\n")
|
||||
b.WriteString(" throw new InvalidOperationException(\"arpack: \" + context + \" exceeds uint16 limit\");\n")
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" return (ushort)length;\n")
|
||||
b.WriteString(" }\n\n")
|
||||
}
|
||||
if needsQuantGuards {
|
||||
b.WriteString(" internal static void EnsureQuantizedRange(double value, double min, double max, string context)\n")
|
||||
b.WriteString(" {\n")
|
||||
b.WriteString(" if (double.IsNaN(value) || value < min || value > max)\n")
|
||||
b.WriteString(" {\n")
|
||||
b.WriteString(" throw new ArgumentOutOfRangeException(context, \"arpack: quantized value out of range for \" + context);\n")
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" }\n")
|
||||
}
|
||||
b.WriteString(" }\n\n")
|
||||
}
|
||||
|
||||
enumNames := make(map[string]struct{}, len(schema.Enums))
|
||||
for _, enum := range schema.Enums {
|
||||
enumNames[enum.Name] = struct{}{}
|
||||
@@ -151,7 +177,9 @@ func writeCSharpSerializeField(b *strings.Builder, f parser.Field, indent string
|
||||
}
|
||||
fmt.Fprintf(b, "%s}\n", indent)
|
||||
case parser.KindSlice:
|
||||
fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)(%s?.Length ?? 0); ptr += 2;\n", indent, f.Name)
|
||||
lenVar := "_len" + sanitizeVarName(f.Name)
|
||||
fmt.Fprintf(b, "%sushort %s = ArpackGenerated.EnsureU16Length(%s?.Length ?? 0, %q); *(ushort*)ptr = %s; ptr += 2;\n",
|
||||
indent, lenVar, f.Name, lengthContext(f), lenVar)
|
||||
fmt.Fprintf(b, "%sif (%s != null)\n%s{\n", indent, f.Name, indent)
|
||||
iVar := "_i" + f.Name
|
||||
fmt.Fprintf(b, "%s for (int %s = 0; %s < %s.Length; %s++)\n%s {\n",
|
||||
@@ -212,7 +240,8 @@ func writeCSharpSerializePrimitive(
|
||||
lenVar := "_slen" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sint %s = %s != null ? Encoding.UTF8.GetByteCount(%s) : 0;\n",
|
||||
indent, lenVar, valueExpr, valueExpr)
|
||||
fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)%s; ptr += 2;\n", indent, lenVar)
|
||||
fmt.Fprintf(b, "%s*(ushort*)ptr = ArpackGenerated.EnsureU16Length(%s, %q); ptr += 2;\n",
|
||||
indent, lenVar, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sif (%s != null && %s > 0)\n%s{\n", indent, valueExpr, lenVar, indent)
|
||||
fmt.Fprintf(b, "%s fixed (char* _chars%s = %s)\n%s {\n",
|
||||
indent, sanitizeVarName(access), valueExpr, indent)
|
||||
@@ -228,6 +257,8 @@ func writeCSharpSerializePrimitive(
|
||||
func writeCSharpSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
|
||||
q := f.Quant
|
||||
maxUint := q.MaxUint()
|
||||
fmt.Fprintf(b, "%sArpackGenerated.EnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||
indent, access, q.Min, q.Max, quantContext(f))
|
||||
if q.Bits == 8 {
|
||||
fmt.Fprintf(b, "%s*ptr = (byte)((%s - (%gf)) / (%gf - (%gf)) * %gf); ptr += 1;\n",
|
||||
indent, access, q.Min, q.Max, q.Min, maxUint)
|
||||
|
||||
@@ -254,10 +254,109 @@ func TestGenerateCSharp_Output(t *testing.T) {
|
||||
if !strings.Contains(code, "public Opcode Code;") {
|
||||
t.Error("EnvelopeMessage.Code should use generated enum type")
|
||||
}
|
||||
if !strings.Contains(code, "internal static class ArpackGenerated") {
|
||||
t.Error("missing shared ArpackGenerated helper class")
|
||||
}
|
||||
if !strings.Contains(code, "EnsureU16Length") {
|
||||
t.Error("missing uint16 length guard helper")
|
||||
}
|
||||
if !strings.Contains(code, "EnsureQuantizedRange") {
|
||||
t.Error("missing quantized range guard helper")
|
||||
}
|
||||
|
||||
t.Logf("Generated C# (%d bytes):\n%s", len(src), code)
|
||||
}
|
||||
|
||||
func TestGenerateGo_RuntimeGuards(t *testing.T) {
|
||||
schemaSrc := `package messages
|
||||
|
||||
type Quantized struct {
|
||||
Value float32 ` + "`" + `pack:"min=0,max=1,bits=8"` + "`" + `
|
||||
}
|
||||
|
||||
type LengthLimited struct {
|
||||
Name string
|
||||
Items []uint8
|
||||
}
|
||||
`
|
||||
|
||||
schema, err := parser.ParseSchemaSource(schemaSrc)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseSchemaSource: %v", err)
|
||||
}
|
||||
|
||||
src, err := GenerateGoSchema(schema, "messages")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateGoSchema: %v", err)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "messages.go"), []byte(schemaSrc), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(dir, "messages_arpack.go"), src, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
runtimeTests := `package messages
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func expectPanicContaining(t *testing.T, want string, fn func()) {
|
||||
t.Helper()
|
||||
defer func() {
|
||||
r := recover()
|
||||
if r == nil {
|
||||
t.Fatalf("expected panic containing %q, got nil", want)
|
||||
}
|
||||
if !strings.Contains(r.(string), want) {
|
||||
t.Fatalf("expected panic containing %q, got %v", want, r)
|
||||
}
|
||||
}()
|
||||
fn()
|
||||
}
|
||||
|
||||
func TestLengthGuard_String(t *testing.T) {
|
||||
expectPanicContaining(t, "string length for Name exceeds uint16 limit", func() {
|
||||
msg := LengthLimited{Name: strings.Repeat("a", 65536)}
|
||||
_ = msg.Marshal(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLengthGuard_Slice(t *testing.T) {
|
||||
expectPanicContaining(t, "slice length for Items exceeds uint16 limit", func() {
|
||||
msg := LengthLimited{Items: make([]uint8, 65536)}
|
||||
_ = msg.Marshal(nil)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQuantizedRangeGuard(t *testing.T) {
|
||||
expectPanicContaining(t, "quantized value out of range for Value", func() {
|
||||
msg := Quantized{Value: 1.5}
|
||||
_ = msg.Marshal(nil)
|
||||
})
|
||||
}
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(dir, "guards_test.go"), []byte(runtimeTests), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
goMod := "module messages\n\ngo 1.21\n"
|
||||
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(goMod), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "test", "./...")
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("go test failed:\n%s", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBoolPacking_GoCode(t *testing.T) {
|
||||
msgs, err := parser.ParseFile(samplePath)
|
||||
if err != nil {
|
||||
|
||||
+25
-2
@@ -15,6 +15,8 @@ func GenerateGo(messages []parser.Message, pkgName string) ([]byte, error) {
|
||||
func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
|
||||
messages := schema.Messages
|
||||
var b strings.Builder
|
||||
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||
|
||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
||||
fmt.Fprintf(&b, "package %s\n\n", pkgName)
|
||||
@@ -27,6 +29,23 @@ func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
|
||||
}
|
||||
b.WriteString(")\n\n")
|
||||
|
||||
if needsLengthGuards {
|
||||
b.WriteString("func arpackEnsureUint16Length(length int, context string) uint16 {\n")
|
||||
b.WriteString("\tif length > 65535 {\n")
|
||||
b.WriteString("\t\tpanic(\"arpack: \" + context + \" exceeds uint16 limit\")\n")
|
||||
b.WriteString("\t}\n")
|
||||
b.WriteString("\treturn uint16(length)\n")
|
||||
b.WriteString("}\n\n")
|
||||
}
|
||||
|
||||
if needsQuantGuards {
|
||||
b.WriteString("func arpackEnsureQuantizedRange(value float64, min float64, max float64, context string) {\n")
|
||||
b.WriteString("\tif value != value || value < min || value > max {\n")
|
||||
b.WriteString("\t\tpanic(\"arpack: quantized value out of range for \" + context)\n")
|
||||
b.WriteString("\t}\n")
|
||||
b.WriteString("}\n\n")
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if err := writeGoMessage(&b, msg); err != nil {
|
||||
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
|
||||
@@ -120,7 +139,8 @@ func writeGoMarshalField(b *strings.Builder, recv string, f parser.Field, indent
|
||||
}
|
||||
fmt.Fprintf(b, "%s}\n", indent)
|
||||
case parser.KindSlice:
|
||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, access)
|
||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n",
|
||||
indent, access, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access)
|
||||
elemField := parser.Field{
|
||||
Name: f.Name + "[_i" + f.Name + "]",
|
||||
@@ -169,7 +189,8 @@ func writeGoMarshalPrimitive(b *strings.Builder, access string, f parser.Field,
|
||||
case parser.KindUint64:
|
||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr)
|
||||
case parser.KindString:
|
||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, valueExpr)
|
||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n",
|
||||
indent, valueExpr, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr)
|
||||
}
|
||||
return nil
|
||||
@@ -179,6 +200,8 @@ func writeGoMarshalQuant(b *strings.Builder, access string, f parser.Field, inde
|
||||
q := f.Quant
|
||||
varName := "_q" + sanitizeVarName(access)
|
||||
valueExpr := goMarshalValueExpr(access, f)
|
||||
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(float64(%s), %g, %g, %q)\n",
|
||||
indent, valueExpr, q.Min, q.Max, quantContext(f))
|
||||
if q.Bits == 8 {
|
||||
fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n",
|
||||
indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint())
|
||||
|
||||
+11
-1
@@ -76,6 +76,13 @@ func writeLuaHelpers(b *strings.Builder) {
|
||||
b.WriteString(" return n\n")
|
||||
b.WriteString("end\n\n")
|
||||
|
||||
b.WriteString("local function ensure_quant_range(value, min, max, context)\n")
|
||||
b.WriteString(" if value ~= value or value < min or value > max then\n")
|
||||
b.WriteString(" error(string.format(\"arpack: quantized value out of range for %s\", context))\n")
|
||||
b.WriteString(" end\n")
|
||||
b.WriteString(" return value\n")
|
||||
b.WriteString("end\n\n")
|
||||
|
||||
b.WriteString("local function read_u8(data, offset)\n")
|
||||
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
|
||||
b.WriteString(" return string.byte(data, offset), 1\n")
|
||||
@@ -506,8 +513,11 @@ func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, i
|
||||
q := f.Quant
|
||||
maxUint := q.MaxUint()
|
||||
varName := "_q_" + sanitizeLuaVarName(access)
|
||||
valueVar := "_quant_value_" + sanitizeLuaVarName(access)
|
||||
fmt.Fprintf(b, "%slocal %s = ensure_quant_range(%s, %g, %g, %q)\n",
|
||||
indent, valueVar, access, q.Min, q.Max, quantContext(f))
|
||||
fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g)\n",
|
||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||
indent, varName, valueVar, q.Min, q.Max, q.Min, maxUint)
|
||||
if q.Bits == 8 {
|
||||
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
|
||||
} else {
|
||||
|
||||
+63
-1
@@ -281,9 +281,12 @@ func TestGenerateLua_QuantizedFloat(t *testing.T) {
|
||||
|
||||
luaStr := string(lua)
|
||||
|
||||
if !strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535)") {
|
||||
if !strings.Contains(luaStr, "math.floor(((_quant_value_msg_position - (-500)) / (500 - (-500))) * 65535)") {
|
||||
t.Error("Missing truncating quantization code for Lua")
|
||||
}
|
||||
if !strings.Contains(luaStr, `ensure_quant_range(msg.position, -500, 500, "Position")`) {
|
||||
t.Error("Missing quantized range guard for Lua")
|
||||
}
|
||||
if strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535 + 0.5)") {
|
||||
t.Error("Lua quantization should not round to nearest")
|
||||
}
|
||||
@@ -338,6 +341,7 @@ func TestLuaHelpersGenerated(t *testing.T) {
|
||||
"buffer too short for u8",
|
||||
"buffer too short for bool",
|
||||
"local function ensure_u16_limit(n, context)",
|
||||
"local function ensure_quant_range(value, min, max, context)",
|
||||
"local function write_u8(n)",
|
||||
"buffer too short for u16",
|
||||
"local function write_u16_le(n)",
|
||||
@@ -703,3 +707,61 @@ print(bytes_to_hex(messages.serialize_float_edges(msg)))
|
||||
t.Fatalf("subnormal roundtrip mismatch: %s", lines[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLua_RuntimeQuantizedRangeGuard(t *testing.T) {
|
||||
if _, err := exec.LookPath("luajit"); err != nil {
|
||||
t.Skip("luajit not found")
|
||||
}
|
||||
|
||||
schema := parser.Schema{
|
||||
Messages: []parser.Message{
|
||||
{
|
||||
Name: "WithQuantized",
|
||||
Fields: []parser.Field{
|
||||
{
|
||||
Name: "Position",
|
||||
Kind: parser.KindPrimitive,
|
||||
Primitive: parser.KindFloat32,
|
||||
Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lua, err := GenerateLuaSchema(schema, "messages")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateLuaSchema failed: %v", err)
|
||||
}
|
||||
|
||||
dir := t.TempDir()
|
||||
if err := os.WriteFile(filepath.Join(dir, "messages_gen.lua"), lua, 0o600); err != nil {
|
||||
t.Fatalf("write module: %v", err)
|
||||
}
|
||||
|
||||
script := `local messages = require("messages_gen")
|
||||
local msg = messages.new_with_quantized()
|
||||
msg.position = 501
|
||||
local ok, res = pcall(messages.serialize_with_quantized, msg)
|
||||
if ok then
|
||||
print("OK")
|
||||
else
|
||||
print(res)
|
||||
end
|
||||
`
|
||||
if err := os.WriteFile(filepath.Join(dir, "check.lua"), []byte(script), 0o600); err != nil {
|
||||
t.Fatalf("write script: %v", err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("luajit", "check.lua")
|
||||
cmd.Dir = dir
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("luajit failed: %v\n%s", err, out)
|
||||
}
|
||||
|
||||
got := strings.TrimSpace(string(out))
|
||||
if !strings.Contains(got, "quantized value out of range for Position") {
|
||||
t.Fatalf("expected quantized range guard, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package generator
|
||||
|
||||
import "github.com/edmand46/arpack/parser"
|
||||
|
||||
const maxUint16Len = 65535
|
||||
|
||||
func lengthContext(f parser.Field) string {
|
||||
switch {
|
||||
case f.Kind == parser.KindSlice:
|
||||
if f.Name != "" {
|
||||
return "slice length for " + f.Name
|
||||
}
|
||||
return "slice length"
|
||||
case f.Kind == parser.KindPrimitive && f.Primitive == parser.KindString:
|
||||
if f.Name != "" {
|
||||
return "string length for " + f.Name
|
||||
}
|
||||
return "string length"
|
||||
default:
|
||||
return "length"
|
||||
}
|
||||
}
|
||||
|
||||
func quantContext(f parser.Field) string {
|
||||
if f.Name != "" {
|
||||
return f.Name
|
||||
}
|
||||
return "value"
|
||||
}
|
||||
|
||||
func schemaNeedsLengthGuards(messages []parser.Message) bool {
|
||||
for _, msg := range messages {
|
||||
for _, f := range msg.Fields {
|
||||
if fieldNeedsLengthGuard(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fieldNeedsLengthGuard(f parser.Field) bool {
|
||||
switch f.Kind {
|
||||
case parser.KindPrimitive:
|
||||
return f.Primitive == parser.KindString
|
||||
case parser.KindFixedArray, parser.KindSlice:
|
||||
if f.Kind == parser.KindSlice {
|
||||
return true
|
||||
}
|
||||
if f.Elem != nil {
|
||||
return fieldNeedsLengthGuard(*f.Elem)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func schemaNeedsQuantRangeGuards(messages []parser.Message) bool {
|
||||
for _, msg := range messages {
|
||||
for _, f := range msg.Fields {
|
||||
if fieldNeedsQuantRangeGuard(f) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func fieldNeedsQuantRangeGuard(f parser.Field) bool {
|
||||
switch f.Kind {
|
||||
case parser.KindPrimitive:
|
||||
return f.Quant != nil
|
||||
case parser.KindFixedArray, parser.KindSlice:
|
||||
if f.Elem != nil {
|
||||
return fieldNeedsQuantRangeGuard(*f.Elem)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
+41
-7
@@ -16,10 +16,31 @@ func GenerateTypeScript(messages []parser.Message, namespace string) ([]byte, er
|
||||
func GenerateTypeScriptSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
||||
messages := schema.Messages
|
||||
var b strings.Builder
|
||||
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||
|
||||
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
||||
|
||||
if needsLengthGuards {
|
||||
b.WriteString("const arpackTextEncoder = new TextEncoder();\n")
|
||||
b.WriteString("const arpackTextDecoder = new TextDecoder();\n\n")
|
||||
b.WriteString("function arpackEnsureUint16Length(length: number, context: string): number {\n")
|
||||
b.WriteString(" if (length > 65535) {\n")
|
||||
b.WriteString(" throw new RangeError(\"arpack: \" + context + \" exceeds uint16 limit\");\n")
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" return length;\n")
|
||||
b.WriteString("}\n\n")
|
||||
}
|
||||
|
||||
if needsQuantGuards {
|
||||
b.WriteString("function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void {\n")
|
||||
b.WriteString(" if (Number.isNaN(value) || value < min || value > max) {\n")
|
||||
b.WriteString(" throw new RangeError(\"arpack: quantized value out of range for \" + context);\n")
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString("}\n\n")
|
||||
}
|
||||
|
||||
enumNames := make(map[string]struct{}, len(schema.Enums))
|
||||
for _, enum := range schema.Enums {
|
||||
enumNames[enum.Name] = struct{}{}
|
||||
@@ -148,7 +169,10 @@ func writeTSSerializeField(b *strings.Builder, recv string, f parser.Field, inde
|
||||
}
|
||||
fmt.Fprintf(b, "%s}\n", indent)
|
||||
case parser.KindSlice:
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, access)
|
||||
lenVar := "_len" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||
indent, lenVar, access, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, lenVar)
|
||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||
iVar := "_i" + f.Name
|
||||
fmt.Fprintf(b, "%sfor (const %s of %s) {\n", indent, iVar, access)
|
||||
@@ -222,8 +246,11 @@ func writeTSSerializePrimitiveElement(b *strings.Builder, access string, f parse
|
||||
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
||||
case parser.KindString:
|
||||
lenVar := "_slen" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr)
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar)
|
||||
guardVar := "_slenChecked" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr)
|
||||
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||
indent, guardVar, lenVar, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar)
|
||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
||||
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
||||
@@ -273,8 +300,11 @@ func writeTSSerializePrimitive(b *strings.Builder, access string, f parser.Field
|
||||
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
||||
case parser.KindString:
|
||||
lenVar := "_slen" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr)
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar)
|
||||
guardVar := "_slenChecked" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr)
|
||||
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||
indent, guardVar, lenVar, lengthContext(f))
|
||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar)
|
||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
||||
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
||||
@@ -286,6 +316,8 @@ func writeTSSerializeQuant(b *strings.Builder, access string, f parser.Field, in
|
||||
q := f.Quant
|
||||
maxUint := q.MaxUint()
|
||||
varName := "_q" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||
indent, access, q.Min, q.Max, quantContext(f))
|
||||
if q.Bits == 8 {
|
||||
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||
@@ -304,6 +336,8 @@ func writeTSSerializeQuantElement(b *strings.Builder, access string, f parser.Fi
|
||||
q := f.Quant
|
||||
maxUint := q.MaxUint()
|
||||
varName := "_q" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||
indent, access, q.Min, q.Max, quantContext(f))
|
||||
if q.Bits == 8 {
|
||||
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||
@@ -438,7 +472,7 @@ func writeTSDeserializePrimitiveElement(b *strings.Builder, access string, f par
|
||||
lenVar := "_slen" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||
expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||
expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
||||
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
||||
}
|
||||
@@ -499,7 +533,7 @@ func writeTSDeserializePrimitive(b *strings.Builder, access string, f parser.Fie
|
||||
lenVar := "_slen" + sanitizeVarName(access)
|
||||
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||
expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||
expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
||||
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
||||
}
|
||||
|
||||
+72
-6
@@ -101,11 +101,17 @@ func TestGenerateTypeScript_QuantizedFloats(t *testing.T) {
|
||||
if !strings.Contains(code, "Math.trunc((this.q8 - (0)) / (100 - (0)) * 255)") {
|
||||
t.Error("Missing 8-bit quantization code")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q8, 0, 100, "Q8");`) {
|
||||
t.Error("Missing 8-bit quantized range guard")
|
||||
}
|
||||
|
||||
// Check 16-bit quantization (using camelCase field names)
|
||||
if !strings.Contains(code, "Math.trunc((this.q16 - (-500)) / (500 - (-500)) * 65535)") {
|
||||
t.Error("Missing 16-bit quantization code")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q16, -500, 500, "Q16");`) {
|
||||
t.Error("Missing 16-bit quantized range guard")
|
||||
}
|
||||
|
||||
// Check deserialization with dequantization
|
||||
if !strings.Contains(code, "/ 255 * (100 - (0)) + (0)") {
|
||||
@@ -286,8 +292,11 @@ func TestGenerateTypeScript_Slices(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check length prefix in serialize (using camelCase field name)
|
||||
if !strings.Contains(code, "view.setUint16(pos, this.items.length, true);") {
|
||||
t.Error("Missing slice length prefix in serialize")
|
||||
if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) {
|
||||
t.Error("Missing slice length guard in serialize")
|
||||
}
|
||||
if !strings.Contains(code, "view.setUint16(pos, _lenthis_items, true);") {
|
||||
t.Error("Missing guarded slice length prefix in serialize")
|
||||
}
|
||||
|
||||
// Check length reading in deserialize
|
||||
@@ -385,17 +394,74 @@ func TestGenerateTypeScript_Strings(t *testing.T) {
|
||||
code := string(src)
|
||||
|
||||
// Check TextEncoder usage
|
||||
if !strings.Contains(code, "new TextEncoder().encode(") {
|
||||
t.Error("Missing TextEncoder in serialize")
|
||||
if !strings.Contains(code, "const arpackTextEncoder = new TextEncoder();") {
|
||||
t.Error("Missing shared TextEncoder helper")
|
||||
}
|
||||
|
||||
// Check length prefix
|
||||
if !strings.Contains(code, "view.setUint16(pos, _slen") {
|
||||
t.Error("Missing string length prefix in serialize")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) {
|
||||
t.Error("Missing string length guard in serialize")
|
||||
}
|
||||
|
||||
// Check TextDecoder usage
|
||||
if !strings.Contains(code, "new TextDecoder().decode(") {
|
||||
t.Error("Missing TextDecoder in deserialize")
|
||||
if !strings.Contains(code, "const arpackTextDecoder = new TextDecoder();") {
|
||||
t.Error("Missing shared TextDecoder helper")
|
||||
}
|
||||
if !strings.Contains(code, "arpackTextDecoder.decode(") {
|
||||
t.Error("Missing shared TextDecoder in deserialize")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTypeScript_LengthAndRangeHelpers(t *testing.T) {
|
||||
schema := parser.Schema{
|
||||
Messages: []parser.Message{
|
||||
{
|
||||
PackageName: "test",
|
||||
Name: "LengthAndQuant",
|
||||
Fields: []parser.Field{
|
||||
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
|
||||
{
|
||||
Name: "Items",
|
||||
Kind: parser.KindSlice,
|
||||
Elem: &parser.Field{
|
||||
Kind: parser.KindPrimitive,
|
||||
Primitive: parser.KindUint8,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Ratio",
|
||||
Kind: parser.KindPrimitive,
|
||||
Primitive: parser.KindFloat32,
|
||||
Quant: &parser.QuantInfo{Min: 0, Max: 1, Bits: 8},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
src, err := GenerateTypeScriptSchema(schema, "Test")
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTypeScriptSchema: %v", err)
|
||||
}
|
||||
|
||||
code := string(src)
|
||||
|
||||
if !strings.Contains(code, "function arpackEnsureUint16Length(length: number, context: string): number") {
|
||||
t.Error("Missing uint16 length helper")
|
||||
}
|
||||
if !strings.Contains(code, "function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void") {
|
||||
t.Error("Missing quantized range helper")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) {
|
||||
t.Error("Missing slice length guard")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) {
|
||||
t.Error("Missing string length helper call")
|
||||
}
|
||||
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.ratio, 0, 1, "Ratio");`) {
|
||||
t.Error("Missing quantized range helper call")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user