package generator import ( "github.com/edmand46/arpack/parser" "fmt" "go/format" "strings" ) func GenerateGo(messages []parser.Message, pkgName string) ([]byte, error) { return GenerateGoSchema(parser.Schema{Messages: messages}, pkgName) } func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) { messages := schema.Messages var b strings.Builder b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n") fmt.Fprintf(&b, "package %s\n\n", pkgName) b.WriteString("import (\n") b.WriteString("\t\"encoding/binary\"\n") b.WriteString("\t\"errors\"\n") if needsMathImport(messages) { b.WriteString("\t\"math\"\n") } b.WriteString(")\n\n") for _, msg := range messages { if err := writeGoMessage(&b, msg); err != nil { return nil, fmt.Errorf("message %s: %w", msg.Name, err) } } src := b.String() formatted, err := format.Source([]byte(src)) if err != nil { return []byte(src), fmt.Errorf("go/format: %w\n\nSource:\n%s", err, src) } return formatted, nil } func writeGoMessage(b *strings.Builder, msg parser.Message) error { segs := segmentFields(msg.Fields) fmt.Fprintf(b, "func (m *%s) Marshal(buf []byte) []byte {\n", msg.Name) for i, seg := range segs { if seg.single != nil { if err := writeGoMarshalField(b, "m", *seg.single, "\t"); err != nil { return err } } else { writeGoBoolGroupMarshal(b, "m", seg.bools, i, "\t") } } b.WriteString("\treturn buf\n}\n\n") fmt.Fprintf(b, "func (m *%s) Unmarshal(data []byte) (int, error) {\n", msg.Name) minSize := packedMinWireSize(msg.Fields) fmt.Fprintf(b, "\tif len(data) < %d {\n", minSize) fmt.Fprintf(b, "\t\treturn 0, errors.New(\"arpack: buffer too short for %s\")\n", msg.Name) b.WriteString("\t}\n") b.WriteString("\toffset := 0\n") for i, seg := range segs { if seg.single != nil { if err := writeGoUnmarshalField(b, "m", *seg.single, "\t"); err != nil { return err } } else { writeGoBoolGroupUnmarshal(b, "m", seg.bools, i, "\t") } } b.WriteString("\treturn offset, nil\n}\n\n") return nil } func writeGoBoolGroupMarshal(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) { varName := fmt.Sprintf("_boolByte%d", groupIdx) fmt.Fprintf(b, "%svar %s uint8\n", indent, varName) for bit, f := range bools { fmt.Fprintf(b, "%sif %s.%s { %s |= 1 << %d }\n", indent, recv, f.Name, varName, bit) } fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, varName) } func writeGoBoolGroupUnmarshal(b *strings.Builder, recv string, bools []parser.Field, groupIdx int, indent string) { varName := fmt.Sprintf("_boolByte%d", groupIdx) fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s := data[offset]; offset++\n", indent, varName) for bit, f := range bools { expr := fmt.Sprintf("%s&(1<<%d) != 0", varName, bit) fmt.Fprintf(b, "%s%s.%s = %s\n", indent, recv, f.Name, goUnmarshalValueExpr(expr, f)) } } func writeGoMarshalField(b *strings.Builder, recv string, f parser.Field, indent string) error { access := recv + "." + f.Name switch f.Kind { case parser.KindPrimitive: return writeGoMarshalPrimitive(b, access, f, indent) case parser.KindNested: fmt.Fprintf(b, "%sbuf = %s.Marshal(buf)\n", indent, access) case parser.KindFixedArray: fmt.Fprintf(b, "%sfor _i%s := 0; _i%s < %d; _i%s++ {\n", indent, f.Name, f.Name, f.FixedLen, f.Name) elemField := parser.Field{ Name: f.Name + "[_i" + f.Name + "]", Kind: f.Elem.Kind, Primitive: f.Elem.Primitive, NamedType: f.Elem.NamedType, Quant: f.Elem.Quant, TypeName: f.Elem.TypeName, Elem: f.Elem.Elem, FixedLen: f.Elem.FixedLen, } if err := writeGoMarshalField(b, recv, elemField, indent+"\t"); err != nil { return err } fmt.Fprintf(b, "%s}\n", indent) case parser.KindSlice: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, access) fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access) elemField := parser.Field{ Name: f.Name + "[_i" + f.Name + "]", Kind: f.Elem.Kind, Primitive: f.Elem.Primitive, NamedType: f.Elem.NamedType, Quant: f.Elem.Quant, TypeName: f.Elem.TypeName, Elem: f.Elem.Elem, FixedLen: f.Elem.FixedLen, } if err := writeGoMarshalField(b, recv, elemField, indent+"\t"); err != nil { return err } fmt.Fprintf(b, "%s}\n", indent) } return nil } func writeGoMarshalPrimitive(b *strings.Builder, access string, f parser.Field, indent string) error { if f.Quant != nil { return writeGoMarshalQuant(b, access, f, indent) } valueExpr := goMarshalValueExpr(access, f) switch f.Primitive { case parser.KindFloat32: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, math.Float32bits(%s))\n", indent, valueExpr) case parser.KindFloat64: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, math.Float64bits(%s))\n", indent, valueExpr) case parser.KindInt8: fmt.Fprintf(b, "%sbuf = append(buf, uint8(%s))\n", indent, valueExpr) case parser.KindUint8: fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, valueExpr) case parser.KindBool: fmt.Fprintf(b, "%sif %s { buf = append(buf, 1) } else { buf = append(buf, 0) }\n", indent, valueExpr) case parser.KindInt16: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(%s))\n", indent, valueExpr) case parser.KindUint16: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, %s)\n", indent, valueExpr) case parser.KindInt32: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, uint32(%s))\n", indent, valueExpr) case parser.KindUint32: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint32(buf, %s)\n", indent, valueExpr) case parser.KindInt64: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, uint64(%s))\n", indent, valueExpr) case parser.KindUint64: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr) case parser.KindString: fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, valueExpr) fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr) } return nil } func writeGoMarshalQuant(b *strings.Builder, access string, f parser.Field, indent string) error { q := f.Quant varName := "_q" + sanitizeVarName(access) valueExpr := goMarshalValueExpr(access, f) if q.Bits == 8 { fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n", indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint()) fmt.Fprintf(b, "%sbuf = append(buf, %s)\n", indent, varName) } else { fmt.Fprintf(b, "%s%s := uint16((%s - (%g)) / (%g - (%g)) * %g)\n", indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint()) fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, %s)\n", indent, varName) } return nil } func writeGoUnmarshalField(b *strings.Builder, recv string, f parser.Field, indent string) error { access := recv + "." + f.Name switch f.Kind { case parser.KindPrimitive: return writeGoUnmarshalPrimitive(b, access, f, indent) case parser.KindNested: nVar := "_n" + sanitizeVarName(f.Name) fmt.Fprintf(b, "%s%s, _err := %s.Unmarshal(data[offset:])\n", indent, nVar, access) fmt.Fprintf(b, "%sif _err != nil { return 0, _err }\n", indent) fmt.Fprintf(b, "%soffset += %s\n", indent, nVar) case parser.KindFixedArray: iVar := "_i" + f.Name fmt.Fprintf(b, "%sfor %s := 0; %s < %d; %s++ {\n", indent, iVar, iVar, f.FixedLen, iVar) elemField := parser.Field{ Name: f.Name + "[" + iVar + "]", Kind: f.Elem.Kind, Primitive: f.Elem.Primitive, NamedType: f.Elem.NamedType, Quant: f.Elem.Quant, TypeName: f.Elem.TypeName, Elem: f.Elem.Elem, FixedLen: f.Elem.FixedLen, } if err := writeGoUnmarshalField(b, recv, elemField, indent+"\t"); err != nil { return err } fmt.Fprintf(b, "%s}\n", indent) case parser.KindSlice: lenVar := "_len" + sanitizeVarName(f.Name) fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s := int(binary.LittleEndian.Uint16(data[offset:]))\n", indent, lenVar) fmt.Fprintf(b, "%soffset += 2\n", indent) fmt.Fprintf(b, "%s%s = make(%s, %s)\n", indent, access, f.GoTypeName(), lenVar) iVar := "_i" + sanitizeVarName(f.Name) fmt.Fprintf(b, "%sfor %s := 0; %s < %s; %s++ {\n", indent, iVar, iVar, lenVar, iVar) elemField := parser.Field{ Name: f.Name + "[" + iVar + "]", Kind: f.Elem.Kind, Primitive: f.Elem.Primitive, NamedType: f.Elem.NamedType, Quant: f.Elem.Quant, TypeName: f.Elem.TypeName, Elem: f.Elem.Elem, FixedLen: f.Elem.FixedLen, } if err := writeGoUnmarshalField(b, recv, elemField, indent+"\t"); err != nil { return err } fmt.Fprintf(b, "%s}\n", indent) } return nil } func writeGoUnmarshalPrimitive(b *strings.Builder, access string, f parser.Field, indent string) error { if f.Quant != nil { return writeGoUnmarshalQuant(b, access, f, indent) } switch f.Primitive { case parser.KindFloat32: fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("math.Float32frombits(binary.LittleEndian.Uint32(data[offset:]))", f)) fmt.Fprintf(b, "%soffset += 4\n", indent) case parser.KindFloat64: fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("math.Float64frombits(binary.LittleEndian.Uint64(data[offset:]))", f)) fmt.Fprintf(b, "%soffset += 8\n", indent) case parser.KindInt8: fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int8(data[offset])", f)) fmt.Fprintf(b, "%soffset += 1\n", indent) case parser.KindUint8: fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("data[offset]", f)) fmt.Fprintf(b, "%soffset += 1\n", indent) case parser.KindBool: fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("data[offset] != 0", f)) fmt.Fprintf(b, "%soffset += 1\n", indent) case parser.KindInt16: fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int16(binary.LittleEndian.Uint16(data[offset:]))", f)) fmt.Fprintf(b, "%soffset += 2\n", indent) case parser.KindUint16: fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint16(data[offset:])", f)) fmt.Fprintf(b, "%soffset += 2\n", indent) case parser.KindInt32: fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int32(binary.LittleEndian.Uint32(data[offset:]))", f)) fmt.Fprintf(b, "%soffset += 4\n", indent) case parser.KindUint32: fmt.Fprintf(b, "%sif len(data) < offset+4 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint32(data[offset:])", f)) fmt.Fprintf(b, "%soffset += 4\n", indent) case parser.KindInt64: fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("int64(binary.LittleEndian.Uint64(data[offset:]))", f)) fmt.Fprintf(b, "%soffset += 8\n", indent) case parser.KindUint64: fmt.Fprintf(b, "%sif len(data) < offset+8 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("binary.LittleEndian.Uint64(data[offset:])", f)) fmt.Fprintf(b, "%soffset += 8\n", indent) case parser.KindString: lenVar := "_slen" + sanitizeVarName(access) fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s := int(binary.LittleEndian.Uint16(data[offset:]))\n", indent, lenVar) fmt.Fprintf(b, "%soffset += 2\n", indent) fmt.Fprintf(b, "%sif len(data) < offset+%s { return 0, errors.New(\"arpack: buffer too short\") }\n", indent, lenVar) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr("string(data[offset : offset+"+lenVar+"])", f)) fmt.Fprintf(b, "%soffset += %s\n", indent, lenVar) } return nil } func writeGoUnmarshalQuant(b *strings.Builder, access string, f parser.Field, indent string) error { q := f.Quant varName := "_q" + sanitizeVarName(access) maxUint := q.MaxUint() if q.Bits == 8 { fmt.Fprintf(b, "%sif len(data) < offset+1 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s := data[offset]\n", indent, varName) fmt.Fprintf(b, "%soffset += 1\n", indent) if f.Primitive == parser.KindFloat32 { expr := fmt.Sprintf("float32(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f)) } else { expr := fmt.Sprintf("float64(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f)) } } else { fmt.Fprintf(b, "%sif len(data) < offset+2 { return 0, errors.New(\"arpack: buffer too short\") }\n", indent) fmt.Fprintf(b, "%s%s := binary.LittleEndian.Uint16(data[offset:])\n", indent, varName) fmt.Fprintf(b, "%soffset += 2\n", indent) if f.Primitive == parser.KindFloat32 { expr := fmt.Sprintf("float32(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f)) } else { expr := fmt.Sprintf("float64(%s) / %g * (%g - (%g)) + (%g)", varName, maxUint, q.Max, q.Min, q.Min) fmt.Fprintf(b, "%s%s = %s\n", indent, access, goUnmarshalValueExpr(expr, f)) } } return nil } func goMarshalValueExpr(access string, f parser.Field) string { if f.NamedType == "" { return access } return f.GoPrimitiveTypeName() + "(" + access + ")" } func goUnmarshalValueExpr(expr string, f parser.Field) string { if f.NamedType == "" { return expr } return f.NamedType + "(" + expr + ")" } func sanitizeVarName(s string) string { var b strings.Builder for _, c := range s { if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') { b.WriteRune(c) } else { b.WriteRune('_') } } return b.String() } func needsMathImport(messages []parser.Message) bool { for _, msg := range messages { for _, f := range msg.Fields { if needsMathField(f) { return true } } } return false } func needsMathField(f parser.Field) bool { switch f.Kind { case parser.KindPrimitive: return f.Quant == nil && (f.Primitive == parser.KindFloat32 || f.Primitive == parser.KindFloat64) case parser.KindFixedArray, parser.KindSlice: if f.Elem != nil { return needsMathField(*f.Elem) } } return false }