Files
arpack/generator/go.go
T

389 lines
15 KiB
Go
Raw Permalink Normal View History

2026-03-19 14:52:12 +03:00
package generator
import (
2026-03-23 09:47:14 +03:00
"github.com/edmand46/arpack/parser"
2026-03-19 14:52:12 +03:00
"fmt"
"go/format"
"strings"
)
func GenerateGo(messages []parser.Message, pkgName string) ([]byte, error) {
return GenerateGoSchema(parser.Schema{Messages: messages}, pkgName)
}
func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
messages := schema.Messages
var b strings.Builder
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
fmt.Fprintf(&b, "package %s\n\n", pkgName)
b.WriteString("import (\n")
b.WriteString("\t\"encoding/binary\"\n")
b.WriteString("\t\"errors\"\n")
if needsMathImport(messages) {
b.WriteString("\t\"math\"\n")
}
b.WriteString(")\n\n")
for _, msg := range messages {
if err := writeGoMessage(&b, msg); err != nil {
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
}
}
src := b.String()
formatted, err := format.Source([]byte(src))
if err != nil {
return []byte(src), fmt.Errorf("go/format: %w\n\nSource:\n%s", err, src)
}
return formatted, nil
}
func writeGoMessage(b *strings.Builder, msg parser.Message) error {
segs := segmentFields(msg.Fields)
fmt.Fprintf(b, "func (m *%s) Marshal(buf []byte) []byte {\n", msg.Name)
for i, seg := range segs {
if seg.single != nil {
if err := writeGoMarshalField(b, "m", *seg.single, "\t"); err != nil {
return err
}
} else {
writeGoBoolGroupMarshal(b, "m", seg.bools, i, "\t")
}
}
b.WriteString("\treturn buf\n}\n\n")
fmt.Fprintf(b, "func (m *%s) Unmarshal(data []byte) (int, error) {\n", msg.Name)
minSize := packedMinWireSize(msg.Fields)
fmt.Fprintf(b, "\tif len(data) < %d {\n", minSize)
fmt.Fprintf(b, "\t\treturn 0, errors.New(\"arpack: buffer too short for %s\")\n", msg.Name)
b.WriteString("\t}\n")
b.WriteString("\toffset := 0\n")
for i, seg := range segs {
if seg.single != nil {
if err := writeGoUnmarshalField(b, "m", *seg.single, "\t"); err != nil {
return err
}
} else {
writeGoBoolGroupUnmarshal(b, "m", seg.bools, i, "\t")
}
}
b.WriteString("\treturn offset, nil\n}\n\n")
return nil
}
func writeGoBoolGroupMarshal(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_boolByte%d", groupIdx)
fmt.Fprintf(b, "%svar %s uint8\n", indent, varName)
for bit, f := range bools {
fmt.Fprintf(b, "%sif %s.%s { %s |= 1 << %d }\n", indent, recv, f.Name, varName, bit)
}
fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, varName)
}
func writeGoBoolGroupUnmarshal(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) {
varName := fmt.Sprintf("_boolByte%d", groupIdx)
fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s := data[offset]; offset++\n", indent, varName)
for bit, f := range bools {
expr := fmt.Sprintf("%s&(1<<%d) != 0", varName, bit)
fmt.Fprintf(b, "%s%s.%s = %s\n", indent, recv, f.Name, goUnmarshalValueExpr(expr, f))
}
}
func writeGoMarshalField(b *strings.Builder, recv string, f parser.Field, indent string) error {
access := recv + "." + f.Name
switch f.Kind {
case parser.KindPrimitive:
return writeGoMarshalPrimitive(b, access, f, indent)
case parser.KindNested:
fmt.Fprintf(b, "%sbuf = %s.Marshal(buf)\n", indent, access)
case parser.KindFixedArray:
fmt.Fprintf(b, "%sfor _i%s := 0; _i%s < %d; _i%s++ {\n",
indent, f.Name, f.Name, f.FixedLen, f.Name)
elemField := parser.Field{
Name: f.Name + "[_i" + f.Name + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeGoMarshalField(b, recv, elemField, indent+"\t"); err != nil {
return err
}
fmt.Fprintf(b, "%s}\n", indent)
case parser.KindSlice:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, access)
fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access)
elemField := parser.Field{
Name: f.Name + "[_i" + f.Name + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeGoMarshalField(b, recv, elemField, indent+"\t"); err != nil {
return err
}
fmt.Fprintf(b, "%s}\n", indent)
}
return nil
}
func writeGoMarshalPrimitive(b *strings.Builder, access string, f parser.Field, indent string) error {
if f.Quant != nil {
return writeGoMarshalQuant(b, access, f, indent)
}
valueExpr := goMarshalValueExpr(access, f)
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, math.Float32bits(%s))\n", indent, valueExpr)
case parser.KindFloat64:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(%s))\n", indent, valueExpr)
case parser.KindInt8:
fmt.Fprintf(b, "%sbuf = append(buf, uint8(%s))\n", indent, valueExpr)
case parser.KindUint8:
fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, valueExpr)
case parser.KindBool:
fmt.Fprintf(b, "%sif %s { buf = append(buf, 1) } else { buf = append(buf, 0) }\n", indent, valueExpr)
case parser.KindInt16:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(%s))\n", indent, valueExpr)
case parser.KindUint16:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, %s)\n", indent, valueExpr)
case parser.KindInt32:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, uint32(%s))\n", indent, valueExpr)
case parser.KindUint32:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, %s)\n", indent, valueExpr)
case parser.KindInt64:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, uint64(%s))\n", indent, valueExpr)
case parser.KindUint64:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr)
case parser.KindString:
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, valueExpr)
fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr)
}
return nil
}
func writeGoMarshalQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
q := f.Quant
varName := "_q" + sanitizeVarName(access)
valueExpr := goMarshalValueExpr(access, f)
if q.Bits == 8 {
fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n",
indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint())
fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, varName)
} else {
fmt.Fprintf(b, "%s%s := uint16((%s - (%g)) / (%g - (%g)) * %g)\n",
indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint())
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, %s)\n", indent, varName)
}
return nil
}
func writeGoUnmarshalField(b *strings.Builder, recv string, f parser.Field, indent string) error {
access := recv + "." + f.Name
switch f.Kind {
case parser.KindPrimitive:
return writeGoUnmarshalPrimitive(b, access, f, indent)
case parser.KindNested:
nVar := "_n" + sanitizeVarName(f.Name)
fmt.Fprintf(b, "%s%s, _err := %s.Unmarshal(data[offset:])\n", indent, nVar, access)
fmt.Fprintf(b, "%sif _err != nil { return 0, _err }\n", indent)
fmt.Fprintf(b, "%soffset += %s\n", indent, nVar)
case parser.KindFixedArray:
iVar := "_i" + f.Name
fmt.Fprintf(b, "%sfor %s := 0; %s < %d; %s++ {\n", indent, iVar, iVar, f.FixedLen, iVar)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeGoUnmarshalField(b, recv, elemField, indent+"\t"); err != nil {
return err
}
fmt.Fprintf(b, "%s}\n", indent)
case parser.KindSlice:
lenVar := "_len" + sanitizeVarName(f.Name)
fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s := int(binary.LittleEndian.Uint16(data[offset:]))\n", indent, lenVar)
fmt.Fprintf(b, "%soffset += 2\n", indent)
fmt.Fprintf(b, "%s%s = make(%s, %s)\n", indent, access, f.GoTypeName(), lenVar)
iVar := "_i" + sanitizeVarName(f.Name)
fmt.Fprintf(b, "%sfor %s := 0; %s < %s; %s++ {\n", indent, iVar, iVar, lenVar, iVar)
elemField := parser.Field{
Name: f.Name + "[" + iVar + "]",
Kind: f.Elem.Kind,
Primitive: f.Elem.Primitive,
NamedType: f.Elem.NamedType,
Quant: f.Elem.Quant,
TypeName: f.Elem.TypeName,
Elem: f.Elem.Elem,
FixedLen: f.Elem.FixedLen,
}
if err := writeGoUnmarshalField(b, recv, elemField, indent+"\t"); err != nil {
return err
}
fmt.Fprintf(b, "%s}\n", indent)
}
return nil
}
func writeGoUnmarshalPrimitive(b *strings.Builder, access string, f parser.Field, indent string) error {
if f.Quant != nil {
return writeGoUnmarshalQuant(b, access, f, indent)
}
switch f.Primitive {
case parser.KindFloat32:
fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("math.Float32frombits(binary.LittleEndian.Uint32(data[offset:]))", f))
fmt.Fprintf(b, "%soffset += 4\n", indent)
case parser.KindFloat64:
fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("math.Float64frombits(binary.LittleEndian.Uint64(data[offset:]))", f))
fmt.Fprintf(b, "%soffset += 8\n", indent)
case parser.KindInt8:
fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int8(data[offset])", f))
fmt.Fprintf(b, "%soffset += 1\n", indent)
case parser.KindUint8:
fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("data[offset]", f))
fmt.Fprintf(b, "%soffset += 1\n", indent)
case parser.KindBool:
fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("data[offset] != 0", f))
fmt.Fprintf(b, "%soffset += 1\n", indent)
case parser.KindInt16:
fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int16(binary.LittleEndian.Uint16(data[offset:]))", f))
fmt.Fprintf(b, "%soffset += 2\n", indent)
case parser.KindUint16:
fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint16(data[offset:])", f))
fmt.Fprintf(b, "%soffset += 2\n", indent)
case parser.KindInt32:
fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int32(binary.LittleEndian.Uint32(data[offset:]))", f))
fmt.Fprintf(b, "%soffset += 4\n", indent)
case parser.KindUint32:
fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint32(data[offset:])", f))
fmt.Fprintf(b, "%soffset += 4\n", indent)
case parser.KindInt64:
fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int64(binary.LittleEndian.Uint64(data[offset:]))", f))
fmt.Fprintf(b, "%soffset += 8\n", indent)
case parser.KindUint64:
fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint64(data[offset:])", f))
fmt.Fprintf(b, "%soffset += 8\n", indent)
case parser.KindString:
lenVar := "_slen" + sanitizeVarName(access)
fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s := int(binary.LittleEndian.Uint16(data[offset:]))\n", indent, lenVar)
fmt.Fprintf(b, "%soffset += 2\n", indent)
fmt.Fprintf(b, "%sif len(data) < offset+%s { return 0, errors.New(\"arpack: buffer too short\") }\n", indent, lenVar)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("string(data[offset : offset+"+lenVar+"])",
f))
fmt.Fprintf(b, "%soffset += %s\n", indent, lenVar)
}
return nil
}
func writeGoUnmarshalQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
q := f.Quant
varName := "_q" + sanitizeVarName(access)
maxUint := q.MaxUint()
if q.Bits == 8 {
fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s := data[offset]\n", indent, varName)
fmt.Fprintf(b, "%soffset += 1\n", indent)
if f.Primitive == parser.KindFloat32 {
expr := fmt.Sprintf("float32(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f))
} else {
expr := fmt.Sprintf("float64(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f))
}
} else {
fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent)
fmt.Fprintf(b, "%s%s := binary.LittleEndian.Uint16(data[offset:])\n", indent, varName)
fmt.Fprintf(b, "%soffset += 2\n", indent)
if f.Primitive == parser.KindFloat32 {
expr := fmt.Sprintf("float32(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f))
} else {
expr := fmt.Sprintf("float64(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min)
fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f))
}
}
return nil
}
func goMarshalValueExpr(access string, f parser.Field) string {
if f.NamedType == "" {
return access
}
return f.GoPrimitiveTypeName() + "(" + access + ")"
}
func goUnmarshalValueExpr(expr string, f parser.Field) string {
if f.NamedType == "" {
return expr
}
return f.NamedType + "(" + expr + ")"
}
func sanitizeVarName(s string) string {
var b strings.Builder
for _, c := range s {
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') {
b.WriteRune(c)
} else {
b.WriteRune('_')
}
}
return b.String()
}
func needsMathImport(messages []parser.Message) bool {
for _, msg := range messages {
for _, f := range msg.Fields {
if needsMathField(f) {
return true
}
}
}
return false
}
func needsMathField(f parser.Field) bool {
switch f.Kind {
case parser.KindPrimitive:
return f.Quant == nil && (f.Primitive == parser.KindFloat32 || f.Primitive == parser.KindFloat64)
case parser.KindFixedArray, parser.KindSlice:
if f.Elem != nil {
return needsMathField(*f.Elem)
}
}
return false
}