Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8851908207 | |||
| c1890216c5 | |||
| 281af49d27 | |||
| 3b543e9b63 | |||
| cebe84bce1 |
@@ -32,8 +32,8 @@ jobs:
|
|||||||
- name: Download dependencies
|
- name: Download dependencies
|
||||||
run: go mod download
|
run: go mod download
|
||||||
|
|
||||||
- name: Run unit tests
|
- name: Run full Go test suite
|
||||||
run: go test -v ./parser/... ./generator/...
|
run: make test
|
||||||
|
|
||||||
- name: Run benchmarks (short)
|
- name: Run benchmarks (short)
|
||||||
run: go test -bench=. -benchtime=100ms -run=^$ ./benchmarks/...
|
run: go test -bench=. -benchtime=100ms -run=^$ ./benchmarks/...
|
||||||
@@ -70,8 +70,10 @@ jobs:
|
|||||||
- name: Test code generation
|
- name: Test code generation
|
||||||
run: |
|
run: |
|
||||||
go run ./cmd/arpack -in testdata/sample.go -out-go /tmp/gen-go -out-ts /tmp/gen-ts
|
go run ./cmd/arpack -in testdata/sample.go -out-go /tmp/gen-go -out-ts /tmp/gen-ts
|
||||||
|
go run ./cmd/arpack -in testdata/lua/sample.go -out-lua /tmp/gen-lua
|
||||||
test -f /tmp/gen-go/sample_gen.go
|
test -f /tmp/gen-go/sample_gen.go
|
||||||
test -f /tmp/gen-ts/Sample.gen.ts
|
test -f /tmp/gen-ts/Sample.gen.ts
|
||||||
|
test -f /tmp/gen-lua/sample_gen.lua
|
||||||
|
|
||||||
e2e:
|
e2e:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
UNITY_ASSETS := benchmarks/unity/Assets
|
UNITY_ASSETS := benchmarks/unity/Assets
|
||||||
|
|
||||||
.PHONY: bench-image generate bench size gen-unity
|
.PHONY: test bench-image generate bench size gen-unity
|
||||||
|
|
||||||
IMAGE := arpack-bench
|
IMAGE := arpack-bench
|
||||||
|
|
||||||
|
test:
|
||||||
|
go test ./...
|
||||||
|
|
||||||
bench-image:
|
bench-image:
|
||||||
docker build -f Dockerfile.bench -t $(IMAGE) .
|
docker build -f Dockerfile.bench -t $(IMAGE) .
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,17 @@ arpack -in messages.go -out-lua ./defold/scripts/messages
|
|||||||
- TypeScript: `{Name}.gen.ts`
|
- TypeScript: `{Name}.gen.ts`
|
||||||
- Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility)
|
- Lua: `{name}_gen.lua` (snake_case for Lua `require()` compatibility)
|
||||||
|
|
||||||
|
## v1 Contract
|
||||||
|
|
||||||
|
ArPack `v1` intentionally supports a narrow schema model:
|
||||||
|
|
||||||
|
- Input is a single Go source file.
|
||||||
|
- Message types must be defined in that same file.
|
||||||
|
- External package types, pointers, and platform-dependent integer aliases (`int`, `uint`, `uintptr`) are not supported.
|
||||||
|
- Wire format is stable within `v1.x` for unchanged schemas.
|
||||||
|
|
||||||
|
This is a deliberate product boundary for predictable code generation and cross-language compatibility.
|
||||||
|
|
||||||
## Schema Definition
|
## Schema Definition
|
||||||
|
|
||||||
Messages are defined as Go structs in a single `.go` file:
|
Messages are defined as Go structs in a single `.go` file:
|
||||||
@@ -116,6 +127,8 @@ type MoveMessage struct {
|
|||||||
| `[N]T` | N × sizeof(T) | ✓ |
|
| `[N]T` | N × sizeof(T) | ✓ |
|
||||||
| `[]T` | 2-byte length prefix + N × sizeof(T) | ✓ |
|
| `[]T` | 2-byte length prefix + N × sizeof(T) | ✓ |
|
||||||
|
|
||||||
|
**Note:** platform-dependent `int`, `uint`, and `uintptr` are not supported. Use explicit widths like `int32`, `uint32`, `int64`, or `uint64`.
|
||||||
|
|
||||||
**Note:** `int64`/`uint64` are not supported in Lua target. LuaJIT (used by Defold) represents numbers as double-precision floats, which can only safely represent integers up to 2^53. Use `int32`/`uint32` instead.
|
**Note:** `int64`/`uint64` are not supported in Lua target. LuaJIT (used by Defold) represents numbers as double-precision floats, which can only safely represent integers up to 2^53. Use `int32`/`uint32` instead.
|
||||||
|
|
||||||
### Float Quantization
|
### Float Quantization
|
||||||
@@ -135,6 +148,8 @@ Y float32 `pack:"min=0,max=1,bits=8"` // 1 byte instead of 4
|
|||||||
|
|
||||||
Values are linearly mapped: `encoded = (value - min) / (max - min) * maxUint`.
|
Values are linearly mapped: `encoded = (value - min) / (max - min) * maxUint`.
|
||||||
|
|
||||||
|
Quantized values must stay within the declared `[min, max]` range. Generated serializers fail fast on out-of-range or `NaN` inputs instead of silently truncating them.
|
||||||
|
|
||||||
## Generated Code
|
## Generated Code
|
||||||
|
|
||||||
### Go
|
### Go
|
||||||
@@ -146,6 +161,8 @@ func (m *MoveMessage) Unmarshal(data []byte) (int, error)
|
|||||||
|
|
||||||
`Marshal` appends to the buffer and returns it. `Unmarshal` reads from the buffer and returns bytes consumed.
|
`Marshal` appends to the buffer and returns it. `Unmarshal` reads from the buffer and returns bytes consumed.
|
||||||
|
|
||||||
|
**Failure behavior:** generated `Marshal` panics if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range.
|
||||||
|
|
||||||
### C#
|
### C#
|
||||||
|
|
||||||
```csharp
|
```csharp
|
||||||
@@ -155,6 +172,8 @@ public static unsafe int Deserialize(byte* buffer, out MoveMessage msg)
|
|||||||
|
|
||||||
Uses unsafe pointers for zero-copy serialization. Returns bytes written/consumed.
|
Uses unsafe pointers for zero-copy serialization. Returns bytes written/consumed.
|
||||||
|
|
||||||
|
**Failure behavior:** generated `Serialize` throws if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range.
|
||||||
|
|
||||||
### TypeScript
|
### TypeScript
|
||||||
|
|
||||||
```typescript
|
```typescript
|
||||||
@@ -177,6 +196,8 @@ Uses native DataView API for browser-compatible serialization with zero dependen
|
|||||||
|
|
||||||
**Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID` → `playerId`).
|
**Note:** TypeScript field names are converted to camelCase (e.g., `PlayerID` → `playerId`).
|
||||||
|
|
||||||
|
**Failure behavior:** generated `serialize(...)` throws `RangeError` if a string/slice exceeds the `uint16` wire limit or if a quantized value is outside its declared range.
|
||||||
|
|
||||||
### Lua
|
### Lua
|
||||||
|
|
||||||
```lua
|
```lua
|
||||||
@@ -198,8 +219,11 @@ Uses pure Lua with inline helper functions for byte manipulation. Compatible wit
|
|||||||
|
|
||||||
**Requirements:** The generated Lua code requires the [BitOp library](https://bitop.luajit.org/) for bit manipulation. This library is included in LuaJIT (used by Defold).
|
**Requirements:** The generated Lua code requires the [BitOp library](https://bitop.luajit.org/) for bit manipulation. This library is included in LuaJIT (used by Defold).
|
||||||
|
|
||||||
**Limitations:**
|
**Limitations:**
|
||||||
- Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53.
|
- Lua target does not support `int64`/`uint64` types. Use `int32`/`uint32` instead. This is because LuaJIT represents numbers as double-precision floats, which can only safely represent integers up to 2^53.
|
||||||
|
- Variable-length fields use `uint16` length prefixes, so `string` byte length and `[]T` element count must not exceed `65535`. Serialization raises an error if the limit is exceeded.
|
||||||
|
- Quantized values must stay within the declared `[min, max]` range. Serialization raises a Lua error on out-of-range or `NaN` inputs.
|
||||||
|
- Deserialization raises Lua errors on malformed or truncated input. If you need a recoverable boundary, wrap decode calls in `pcall(...)`.
|
||||||
- Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution.
|
- Generated file uses snake_case naming (e.g., `messages_gen.lua`) for proper Lua `require()` resolution.
|
||||||
|
|
||||||
## Wire Format
|
## Wire Format
|
||||||
@@ -210,6 +234,24 @@ Uses pure Lua with inline helper functions for byte manipulation. Compatible wit
|
|||||||
- Booleans packed as bitfields (LSB first, up to 8 per byte)
|
- Booleans packed as bitfields (LSB first, up to 8 per byte)
|
||||||
- Quantized floats stored as `uint8` or `uint16`
|
- Quantized floats stored as `uint8` or `uint16`
|
||||||
|
|
||||||
|
## Compatibility Guarantees
|
||||||
|
|
||||||
|
Within `v1.x`, the following are considered compatibility guarantees for a fixed schema:
|
||||||
|
|
||||||
|
- Same field declaration order produces the same wire layout.
|
||||||
|
- Go, C#, TypeScript, and Lua generators produce identical wire bytes for supported types.
|
||||||
|
- `string` and `[]T` always use `uint16` length prefixes.
|
||||||
|
- Consecutive `bool` fields are bit-packed in declaration order, least-significant bit first.
|
||||||
|
- Enum fields use their declared underlying integer type on the wire.
|
||||||
|
|
||||||
|
The following are breaking changes:
|
||||||
|
|
||||||
|
- changing field order
|
||||||
|
- changing a field type
|
||||||
|
- changing quantization parameters
|
||||||
|
- changing enum underlying types
|
||||||
|
- changing how booleans are grouped or how lengths are encoded
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
### Go Results (M3 Max)
|
### Go Results (M3 Max)
|
||||||
@@ -252,9 +294,20 @@ make gen-unity
|
|||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Unit tests
|
# Full test suite
|
||||||
go test ./parser/... ./generator/...
|
make test
|
||||||
|
|
||||||
# End-to-end cross-language tests
|
# Benchmarks
|
||||||
go test ./e2e/...
|
go test ./benchmarks/... -bench=. -benchmem
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
- `unknown type "..."`
|
||||||
|
The field type is not a supported primitive and is not defined in the same schema file.
|
||||||
|
- `external package types not supported`
|
||||||
|
Copy the wire-facing type definition into the schema file instead of referencing another package.
|
||||||
|
- `... exceeds uint16 limit`
|
||||||
|
A `string` encoded to more than `65535` bytes, or a slice contains more than `65535` elements.
|
||||||
|
- `quantized value out of range`
|
||||||
|
The runtime value does not satisfy the declared `pack:"min=...,max=..."` bounds.
|
||||||
|
|||||||
@@ -8,11 +8,27 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func arpackEnsureUint16Length(length int, context string) uint16 {
|
||||||
|
if length > 65535 {
|
||||||
|
panic("arpack: " + context + " exceeds uint16 limit")
|
||||||
|
}
|
||||||
|
return uint16(length)
|
||||||
|
}
|
||||||
|
|
||||||
|
func arpackEnsureQuantizedRange(value float64, min float64, max float64, context string) {
|
||||||
|
if value != value || value < min || value > max {
|
||||||
|
panic("arpack: quantized value out of range for " + context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Vector3) Marshal(buf []byte) []byte {
|
func (m *Vector3) Marshal(buf []byte) []byte {
|
||||||
|
arpackEnsureQuantizedRange(float64(m.X), -500, 500, "X")
|
||||||
_qm_X := uint16((m.X - (-500)) / (500 - (-500)) * 65535)
|
_qm_X := uint16((m.X - (-500)) / (500 - (-500)) * 65535)
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, _qm_X)
|
buf = binary.LittleEndian.AppendUint16(buf, _qm_X)
|
||||||
|
arpackEnsureQuantizedRange(float64(m.Y), -500, 500, "Y")
|
||||||
_qm_Y := uint16((m.Y - (-500)) / (500 - (-500)) * 65535)
|
_qm_Y := uint16((m.Y - (-500)) / (500 - (-500)) * 65535)
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, _qm_Y)
|
buf = binary.LittleEndian.AppendUint16(buf, _qm_Y)
|
||||||
|
arpackEnsureQuantizedRange(float64(m.Z), -500, 500, "Z")
|
||||||
_qm_Z := uint16((m.Z - (-500)) / (500 - (-500)) * 65535)
|
_qm_Z := uint16((m.Z - (-500)) / (500 - (-500)) * 65535)
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, _qm_Z)
|
buf = binary.LittleEndian.AppendUint16(buf, _qm_Z)
|
||||||
return buf
|
return buf
|
||||||
@@ -49,7 +65,7 @@ func (m *MoveMessage) Marshal(buf []byte) []byte {
|
|||||||
for _iVelocity := 0; _iVelocity < 3; _iVelocity++ {
|
for _iVelocity := 0; _iVelocity < 3; _iVelocity++ {
|
||||||
buf = binary.LittleEndian.AppendUint32(buf, math.Float32bits(m.Velocity[_iVelocity]))
|
buf = binary.LittleEndian.AppendUint32(buf, math.Float32bits(m.Velocity[_iVelocity]))
|
||||||
}
|
}
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Waypoints)))
|
buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Waypoints), "slice length for Waypoints"))
|
||||||
for _iWaypoints := range m.Waypoints {
|
for _iWaypoints := range m.Waypoints {
|
||||||
buf = m.Waypoints[_iWaypoints].Marshal(buf)
|
buf = m.Waypoints[_iWaypoints].Marshal(buf)
|
||||||
}
|
}
|
||||||
@@ -65,7 +81,7 @@ func (m *MoveMessage) Marshal(buf []byte) []byte {
|
|||||||
_boolByte4 |= 1 << 2
|
_boolByte4 |= 1 << 2
|
||||||
}
|
}
|
||||||
buf = append(buf, _boolByte4)
|
buf = append(buf, _boolByte4)
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Name)))
|
buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Name), "string length for Name"))
|
||||||
buf = append(buf, m.Name...)
|
buf = append(buf, m.Name...)
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
@@ -130,12 +146,12 @@ func (m *SpawnMessage) Marshal(buf []byte) []byte {
|
|||||||
buf = binary.LittleEndian.AppendUint64(buf, m.EntityID)
|
buf = binary.LittleEndian.AppendUint64(buf, m.EntityID)
|
||||||
buf = m.Position.Marshal(buf)
|
buf = m.Position.Marshal(buf)
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(m.Health))
|
buf = binary.LittleEndian.AppendUint16(buf, uint16(m.Health))
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Tags)))
|
buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Tags), "slice length for Tags"))
|
||||||
for _iTags := range m.Tags {
|
for _iTags := range m.Tags {
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Tags[_iTags])))
|
buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Tags[_iTags]), "string length for Tags[_iTags]"))
|
||||||
buf = append(buf, m.Tags[_iTags]...)
|
buf = append(buf, m.Tags[_iTags]...)
|
||||||
}
|
}
|
||||||
buf = binary.LittleEndian.AppendUint16(buf, uint16(len(m.Data)))
|
buf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(m.Data), "slice length for Data"))
|
||||||
for _iData := range m.Data {
|
for _iData := range m.Data {
|
||||||
buf = append(buf, m.Data[_iData])
|
buf = append(buf, m.Data[_iData])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,26 @@ using System.Text;
|
|||||||
|
|
||||||
namespace Arpack.Messages
|
namespace Arpack.Messages
|
||||||
{
|
{
|
||||||
|
internal static class ArpackGenerated
|
||||||
|
{
|
||||||
|
internal static ushort EnsureU16Length(int length, string context)
|
||||||
|
{
|
||||||
|
if (length > 65535)
|
||||||
|
{
|
||||||
|
throw new InvalidOperationException("arpack: " + context + " exceeds uint16 limit");
|
||||||
|
}
|
||||||
|
return (ushort)length;
|
||||||
|
}
|
||||||
|
|
||||||
|
internal static void EnsureQuantizedRange(double value, double min, double max, string context)
|
||||||
|
{
|
||||||
|
if (double.IsNaN(value) || value < min || value > max)
|
||||||
|
{
|
||||||
|
throw new ArgumentOutOfRangeException(context, "arpack: quantized value out of range for " + context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public enum Opcode : ushort
|
public enum Opcode : ushort
|
||||||
{
|
{
|
||||||
Unknown = 0,
|
Unknown = 0,
|
||||||
@@ -23,8 +43,11 @@ namespace Arpack.Messages
|
|||||||
public int Serialize(byte* buffer)
|
public int Serialize(byte* buffer)
|
||||||
{
|
{
|
||||||
byte* ptr = buffer;
|
byte* ptr = buffer;
|
||||||
|
ArpackGenerated.EnsureQuantizedRange(X, -500, 500, "X");
|
||||||
*(ushort*)ptr = (ushort)((X - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
*(ushort*)ptr = (ushort)((X - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
||||||
|
ArpackGenerated.EnsureQuantizedRange(Y, -500, 500, "Y");
|
||||||
*(ushort*)ptr = (ushort)((Y - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
*(ushort*)ptr = (ushort)((Y - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
||||||
|
ArpackGenerated.EnsureQuantizedRange(Z, -500, 500, "Z");
|
||||||
*(ushort*)ptr = (ushort)((Z - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
*(ushort*)ptr = (ushort)((Z - (-500f)) / (500f - (-500f)) * 65535f); ptr += 2;
|
||||||
return (int)(ptr - buffer);
|
return (int)(ptr - buffer);
|
||||||
}
|
}
|
||||||
@@ -59,7 +82,7 @@ namespace Arpack.Messages
|
|||||||
{
|
{
|
||||||
*(float*)ptr = Velocity[_iVelocity]; ptr += 4;
|
*(float*)ptr = Velocity[_iVelocity]; ptr += 4;
|
||||||
}
|
}
|
||||||
*(ushort*)ptr = (ushort)(Waypoints?.Length ?? 0); ptr += 2;
|
ushort _lenWaypoints = ArpackGenerated.EnsureU16Length(Waypoints?.Length ?? 0, "slice length for Waypoints"); *(ushort*)ptr = _lenWaypoints; ptr += 2;
|
||||||
if (Waypoints != null)
|
if (Waypoints != null)
|
||||||
{
|
{
|
||||||
for (int _iWaypoints = 0; _iWaypoints < Waypoints.Length; _iWaypoints++)
|
for (int _iWaypoints = 0; _iWaypoints < Waypoints.Length; _iWaypoints++)
|
||||||
@@ -74,7 +97,7 @@ namespace Arpack.Messages
|
|||||||
if (Ghost) _boolByte4 |= (byte)(1 << 2);
|
if (Ghost) _boolByte4 |= (byte)(1 << 2);
|
||||||
*ptr = _boolByte4; ptr++;
|
*ptr = _boolByte4; ptr++;
|
||||||
int _slenName = Name != null ? Encoding.UTF8.GetByteCount(Name) : 0;
|
int _slenName = Name != null ? Encoding.UTF8.GetByteCount(Name) : 0;
|
||||||
*(ushort*)ptr = (ushort)_slenName; ptr += 2;
|
*(ushort*)ptr = ArpackGenerated.EnsureU16Length(_slenName, "string length for Name"); ptr += 2;
|
||||||
if (Name != null && _slenName > 0)
|
if (Name != null && _slenName > 0)
|
||||||
{
|
{
|
||||||
fixed (char* _charsName = Name)
|
fixed (char* _charsName = Name)
|
||||||
@@ -128,13 +151,13 @@ namespace Arpack.Messages
|
|||||||
*(ulong*)ptr = EntityID; ptr += 8;
|
*(ulong*)ptr = EntityID; ptr += 8;
|
||||||
ptr += Position.Serialize(ptr);
|
ptr += Position.Serialize(ptr);
|
||||||
*(short*)ptr = Health; ptr += 2;
|
*(short*)ptr = Health; ptr += 2;
|
||||||
*(ushort*)ptr = (ushort)(Tags?.Length ?? 0); ptr += 2;
|
ushort _lenTags = ArpackGenerated.EnsureU16Length(Tags?.Length ?? 0, "slice length for Tags"); *(ushort*)ptr = _lenTags; ptr += 2;
|
||||||
if (Tags != null)
|
if (Tags != null)
|
||||||
{
|
{
|
||||||
for (int _iTags = 0; _iTags < Tags.Length; _iTags++)
|
for (int _iTags = 0; _iTags < Tags.Length; _iTags++)
|
||||||
{
|
{
|
||||||
int _slenTags__iTags_ = Tags[_iTags] != null ? Encoding.UTF8.GetByteCount(Tags[_iTags]) : 0;
|
int _slenTags__iTags_ = Tags[_iTags] != null ? Encoding.UTF8.GetByteCount(Tags[_iTags]) : 0;
|
||||||
*(ushort*)ptr = (ushort)_slenTags__iTags_; ptr += 2;
|
*(ushort*)ptr = ArpackGenerated.EnsureU16Length(_slenTags__iTags_, "string length for Tags[_iTags]"); ptr += 2;
|
||||||
if (Tags[_iTags] != null && _slenTags__iTags_ > 0)
|
if (Tags[_iTags] != null && _slenTags__iTags_ > 0)
|
||||||
{
|
{
|
||||||
fixed (char* _charsTags__iTags_ = Tags[_iTags])
|
fixed (char* _charsTags__iTags_ = Tags[_iTags])
|
||||||
@@ -145,7 +168,7 @@ namespace Arpack.Messages
|
|||||||
ptr += _slenTags__iTags_;
|
ptr += _slenTags__iTags_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*(ushort*)ptr = (ushort)(Data?.Length ?? 0); ptr += 2;
|
ushort _lenData = ArpackGenerated.EnsureU16Length(Data?.Length ?? 0, "slice length for Data"); *(ushort*)ptr = _lenData; ptr += 2;
|
||||||
if (Data != null)
|
if (Data != null)
|
||||||
{
|
{
|
||||||
for (int _iData = 0; _iData < Data.Length; _iData++)
|
for (int _iData = 0; _iData < Data.Length; _iData++)
|
||||||
|
|||||||
+34
-5
@@ -16,7 +16,6 @@ import (
|
|||||||
|
|
||||||
const samplePath = "../testdata/sample.go"
|
const samplePath = "../testdata/sample.go"
|
||||||
|
|
||||||
// TestE2E_CrossLanguage
|
|
||||||
func TestE2E_CrossLanguage(t *testing.T) {
|
func TestE2E_CrossLanguage(t *testing.T) {
|
||||||
schema, err := parser.ParseSchemaFile(samplePath)
|
schema, err := parser.ParseSchemaFile(samplePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -48,6 +47,16 @@ func TestE2E_CrossLanguage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
csDir := buildCSHarness(t, csSrc)
|
csDir := buildCSHarness(t, csSrc)
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run("Wire/Go_EQ_CS/"+tc.name, func(t *testing.T) {
|
||||||
|
goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, ""))
|
||||||
|
csHex := strings.TrimSpace(runHarness(t, csDir, "cs", "ser", tc.typ, ""))
|
||||||
|
if goHex != csHex {
|
||||||
|
t.Fatalf("wire drift between Go and C# for %s:\ngo=%s\ncs=%s", tc.typ, goHex, csHex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run("Go_to_CS/"+tc.name, func(t *testing.T) {
|
t.Run("Go_to_CS/"+tc.name, func(t *testing.T) {
|
||||||
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
||||||
@@ -71,6 +80,16 @@ func TestE2E_CrossLanguage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
tsDir := buildTSHarness(t, tsSrc)
|
tsDir := buildTSHarness(t, tsSrc)
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run("Wire/Go_EQ_TS/"+tc.name, func(t *testing.T) {
|
||||||
|
goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, ""))
|
||||||
|
tsHex := strings.TrimSpace(runHarness(t, tsDir, "ts", "ser", tc.typ, ""))
|
||||||
|
if goHex != tsHex {
|
||||||
|
t.Fatalf("wire drift between Go and TS for %s:\ngo=%s\nts=%s", tc.typ, goHex, tsHex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
t.Run("Go_to_TS/"+tc.name, func(t *testing.T) {
|
t.Run("Go_to_TS/"+tc.name, func(t *testing.T) {
|
||||||
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
||||||
@@ -149,6 +168,16 @@ func TestE2E_CrossLanguage(t *testing.T) {
|
|||||||
{"EnvelopeMessage", "EnvelopeMessage", 0},
|
{"EnvelopeMessage", "EnvelopeMessage", 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, tc := range luaCases {
|
||||||
|
t.Run("Wire/Go_EQ_Lua/"+tc.name, func(t *testing.T) {
|
||||||
|
goHex := strings.TrimSpace(runHarness(t, goDir, "go", "ser", tc.typ, ""))
|
||||||
|
luaHex := strings.TrimSpace(runHarness(t, luaDir, "lua", "ser", tc.typ, ""))
|
||||||
|
if goHex != luaHex {
|
||||||
|
t.Fatalf("wire drift between Go and Lua for %s:\ngo=%s\nlua=%s", tc.typ, goHex, luaHex)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range luaCases {
|
for _, tc := range luaCases {
|
||||||
t.Run("Go_to_Lua/"+tc.name, func(t *testing.T) {
|
t.Run("Go_to_Lua/"+tc.name, func(t *testing.T) {
|
||||||
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
hex := runHarness(t, goDir, "go", "ser", tc.typ, "")
|
||||||
@@ -878,10 +907,10 @@ end
|
|||||||
local function serializeMoveMessage()
|
local function serializeMoveMessage()
|
||||||
local msg = messages.new_move_message()
|
local msg = messages.new_move_message()
|
||||||
msg.position = messages.new_vector3()
|
msg.position = messages.new_vector3()
|
||||||
msg.position.x = 10.0
|
msg.position.x = 50.0
|
||||||
msg.position.y = 20.0
|
msg.position.y = -100.0
|
||||||
msg.position.z = 30.0
|
msg.position.z = 0.0
|
||||||
msg.velocity = {1.0, 2.0, 3.0}
|
msg.velocity = {1.5, -2.5, 0.0}
|
||||||
msg.waypoints = {}
|
msg.waypoints = {}
|
||||||
local wp = messages.new_vector3()
|
local wp = messages.new_vector3()
|
||||||
wp.x = 10.0
|
wp.x = 10.0
|
||||||
|
|||||||
+33
-2
@@ -14,6 +14,8 @@ func GenerateCSharp(messages []parser.Message, namespace string) ([]byte, error)
|
|||||||
func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
||||||
messages := schema.Messages
|
messages := schema.Messages
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||||
|
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||||
|
|
||||||
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
||||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n")
|
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n")
|
||||||
@@ -27,6 +29,30 @@ func GenerateCSharpSchema(schema parser.Schema, namespace string) ([]byte, error
|
|||||||
|
|
||||||
fmt.Fprintf(&b, "namespace %s\n{\n", namespace)
|
fmt.Fprintf(&b, "namespace %s\n{\n", namespace)
|
||||||
|
|
||||||
|
if needsLengthGuards || needsQuantGuards {
|
||||||
|
b.WriteString(" internal static class ArpackGenerated\n {\n")
|
||||||
|
if needsLengthGuards {
|
||||||
|
b.WriteString(" internal static ushort EnsureU16Length(int length, string context)\n")
|
||||||
|
b.WriteString(" {\n")
|
||||||
|
b.WriteString(" if (length > 65535)\n")
|
||||||
|
b.WriteString(" {\n")
|
||||||
|
b.WriteString(" throw new InvalidOperationException(\"arpack: \" + context + \" exceeds uint16 limit\");\n")
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
b.WriteString(" return (ushort)length;\n")
|
||||||
|
b.WriteString(" }\n\n")
|
||||||
|
}
|
||||||
|
if needsQuantGuards {
|
||||||
|
b.WriteString(" internal static void EnsureQuantizedRange(double value, double min, double max, string context)\n")
|
||||||
|
b.WriteString(" {\n")
|
||||||
|
b.WriteString(" if (double.IsNaN(value) || value < min || value > max)\n")
|
||||||
|
b.WriteString(" {\n")
|
||||||
|
b.WriteString(" throw new ArgumentOutOfRangeException(context, \"arpack: quantized value out of range for \" + context);\n")
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
}
|
||||||
|
b.WriteString(" }\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
enumNames := make(map[string]struct{}, len(schema.Enums))
|
enumNames := make(map[string]struct{}, len(schema.Enums))
|
||||||
for _, enum := range schema.Enums {
|
for _, enum := range schema.Enums {
|
||||||
enumNames[enum.Name] = struct{}{}
|
enumNames[enum.Name] = struct{}{}
|
||||||
@@ -151,7 +177,9 @@ func writeCSharpSerializeField(b *strings.Builder, f parser.Field, indent string
|
|||||||
}
|
}
|
||||||
fmt.Fprintf(b, "%s}\n", indent)
|
fmt.Fprintf(b, "%s}\n", indent)
|
||||||
case parser.KindSlice:
|
case parser.KindSlice:
|
||||||
fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)(%s?.Length ?? 0); ptr += 2;\n", indent, f.Name)
|
lenVar := "_len" + sanitizeVarName(f.Name)
|
||||||
|
fmt.Fprintf(b, "%sushort %s = ArpackGenerated.EnsureU16Length(%s?.Length ?? 0, %q); *(ushort*)ptr = %s; ptr += 2;\n",
|
||||||
|
indent, lenVar, f.Name, lengthContext(f), lenVar)
|
||||||
fmt.Fprintf(b, "%sif (%s != null)\n%s{\n", indent, f.Name, indent)
|
fmt.Fprintf(b, "%sif (%s != null)\n%s{\n", indent, f.Name, indent)
|
||||||
iVar := "_i" + f.Name
|
iVar := "_i" + f.Name
|
||||||
fmt.Fprintf(b, "%s for (int %s = 0; %s < %s.Length; %s++)\n%s {\n",
|
fmt.Fprintf(b, "%s for (int %s = 0; %s < %s.Length; %s++)\n%s {\n",
|
||||||
@@ -212,7 +240,8 @@ func writeCSharpSerializePrimitive(
|
|||||||
lenVar := "_slen" + sanitizeVarName(access)
|
lenVar := "_slen" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sint %s = %s != null ? Encoding.UTF8.GetByteCount(%s) : 0;\n",
|
fmt.Fprintf(b, "%sint %s = %s != null ? Encoding.UTF8.GetByteCount(%s) : 0;\n",
|
||||||
indent, lenVar, valueExpr, valueExpr)
|
indent, lenVar, valueExpr, valueExpr)
|
||||||
fmt.Fprintf(b, "%s*(ushort*)ptr = (ushort)%s; ptr += 2;\n", indent, lenVar)
|
fmt.Fprintf(b, "%s*(ushort*)ptr = ArpackGenerated.EnsureU16Length(%s, %q); ptr += 2;\n",
|
||||||
|
indent, lenVar, lengthContext(f))
|
||||||
fmt.Fprintf(b, "%sif (%s != null && %s > 0)\n%s{\n", indent, valueExpr, lenVar, indent)
|
fmt.Fprintf(b, "%sif (%s != null && %s > 0)\n%s{\n", indent, valueExpr, lenVar, indent)
|
||||||
fmt.Fprintf(b, "%s fixed (char* _chars%s = %s)\n%s {\n",
|
fmt.Fprintf(b, "%s fixed (char* _chars%s = %s)\n%s {\n",
|
||||||
indent, sanitizeVarName(access), valueExpr, indent)
|
indent, sanitizeVarName(access), valueExpr, indent)
|
||||||
@@ -228,6 +257,8 @@ func writeCSharpSerializePrimitive(
|
|||||||
func writeCSharpSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
|
func writeCSharpSerializeQuant(b *strings.Builder, access string, f parser.Field, indent string) error {
|
||||||
q := f.Quant
|
q := f.Quant
|
||||||
maxUint := q.MaxUint()
|
maxUint := q.MaxUint()
|
||||||
|
fmt.Fprintf(b, "%sArpackGenerated.EnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||||
|
indent, access, q.Min, q.Max, quantContext(f))
|
||||||
if q.Bits == 8 {
|
if q.Bits == 8 {
|
||||||
fmt.Fprintf(b, "%s*ptr = (byte)((%s - (%gf)) / (%gf - (%gf)) * %gf); ptr += 1;\n",
|
fmt.Fprintf(b, "%s*ptr = (byte)((%s - (%gf)) / (%gf - (%gf)) * %gf); ptr += 1;\n",
|
||||||
indent, access, q.Min, q.Max, q.Min, maxUint)
|
indent, access, q.Min, q.Max, q.Min, maxUint)
|
||||||
|
|||||||
@@ -254,10 +254,109 @@ func TestGenerateCSharp_Output(t *testing.T) {
|
|||||||
if !strings.Contains(code, "public Opcode Code;") {
|
if !strings.Contains(code, "public Opcode Code;") {
|
||||||
t.Error("EnvelopeMessage.Code should use generated enum type")
|
t.Error("EnvelopeMessage.Code should use generated enum type")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(code, "internal static class ArpackGenerated") {
|
||||||
|
t.Error("missing shared ArpackGenerated helper class")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, "EnsureU16Length") {
|
||||||
|
t.Error("missing uint16 length guard helper")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, "EnsureQuantizedRange") {
|
||||||
|
t.Error("missing quantized range guard helper")
|
||||||
|
}
|
||||||
|
|
||||||
t.Logf("Generated C# (%d bytes):\n%s", len(src), code)
|
t.Logf("Generated C# (%d bytes):\n%s", len(src), code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateGo_RuntimeGuards(t *testing.T) {
|
||||||
|
schemaSrc := `package messages
|
||||||
|
|
||||||
|
type Quantized struct {
|
||||||
|
Value float32 ` + "`" + `pack:"min=0,max=1,bits=8"` + "`" + `
|
||||||
|
}
|
||||||
|
|
||||||
|
type LengthLimited struct {
|
||||||
|
Name string
|
||||||
|
Items []uint8
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
schema, err := parser.ParseSchemaSource(schemaSrc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseSchemaSource: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
src, err := GenerateGoSchema(schema, "messages")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateGoSchema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "messages.go"), []byte(schemaSrc), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "messages_arpack.go"), src, 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
runtimeTests := `package messages
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func expectPanicContaining(t *testing.T, want string, fn func()) {
|
||||||
|
t.Helper()
|
||||||
|
defer func() {
|
||||||
|
r := recover()
|
||||||
|
if r == nil {
|
||||||
|
t.Fatalf("expected panic containing %q, got nil", want)
|
||||||
|
}
|
||||||
|
if !strings.Contains(r.(string), want) {
|
||||||
|
t.Fatalf("expected panic containing %q, got %v", want, r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLengthGuard_String(t *testing.T) {
|
||||||
|
expectPanicContaining(t, "string length for Name exceeds uint16 limit", func() {
|
||||||
|
msg := LengthLimited{Name: strings.Repeat("a", 65536)}
|
||||||
|
_ = msg.Marshal(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLengthGuard_Slice(t *testing.T) {
|
||||||
|
expectPanicContaining(t, "slice length for Items exceeds uint16 limit", func() {
|
||||||
|
msg := LengthLimited{Items: make([]uint8, 65536)}
|
||||||
|
_ = msg.Marshal(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuantizedRangeGuard(t *testing.T) {
|
||||||
|
expectPanicContaining(t, "quantized value out of range for Value", func() {
|
||||||
|
msg := Quantized{Value: 1.5}
|
||||||
|
_ = msg.Marshal(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "guards_test.go"), []byte(runtimeTests), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
goMod := "module messages\n\ngo 1.21\n"
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "go.mod"), []byte(goMod), 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("go", "test", "./...")
|
||||||
|
cmd.Dir = dir
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("go test failed:\n%s", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBoolPacking_GoCode(t *testing.T) {
|
func TestBoolPacking_GoCode(t *testing.T) {
|
||||||
msgs, err := parser.ParseFile(samplePath)
|
msgs, err := parser.ParseFile(samplePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+25
-2
@@ -15,6 +15,8 @@ func GenerateGo(messages []parser.Message, pkgName string) ([]byte, error) {
|
|||||||
func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
|
func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
|
||||||
messages := schema.Messages
|
messages := schema.Messages
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||||
|
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||||
|
|
||||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
||||||
fmt.Fprintf(&b, "package %s\n\n", pkgName)
|
fmt.Fprintf(&b, "package %s\n\n", pkgName)
|
||||||
@@ -27,6 +29,23 @@ func GenerateGoSchema(schema parser.Schema, pkgName string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
b.WriteString(")\n\n")
|
b.WriteString(")\n\n")
|
||||||
|
|
||||||
|
if needsLengthGuards {
|
||||||
|
b.WriteString("func arpackEnsureUint16Length(length int, context string) uint16 {\n")
|
||||||
|
b.WriteString("\tif length > 65535 {\n")
|
||||||
|
b.WriteString("\t\tpanic(\"arpack: \" + context + \" exceeds uint16 limit\")\n")
|
||||||
|
b.WriteString("\t}\n")
|
||||||
|
b.WriteString("\treturn uint16(length)\n")
|
||||||
|
b.WriteString("}\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsQuantGuards {
|
||||||
|
b.WriteString("func arpackEnsureQuantizedRange(value float64, min float64, max float64, context string) {\n")
|
||||||
|
b.WriteString("\tif value != value || value < min || value > max {\n")
|
||||||
|
b.WriteString("\t\tpanic(\"arpack: quantized value out of range for \" + context)\n")
|
||||||
|
b.WriteString("\t}\n")
|
||||||
|
b.WriteString("}\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
if err := writeGoMessage(&b, msg); err != nil {
|
if err := writeGoMessage(&b, msg); err != nil {
|
||||||
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
|
return nil, fmt.Errorf("message %s: %w", msg.Name, err)
|
||||||
@@ -120,7 +139,8 @@ func writeGoMarshalField(b *strings.Builder, recv string, f parser.Field, indent
|
|||||||
}
|
}
|
||||||
fmt.Fprintf(b, "%s}\n", indent)
|
fmt.Fprintf(b, "%s}\n", indent)
|
||||||
case parser.KindSlice:
|
case parser.KindSlice:
|
||||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, access)
|
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n",
|
||||||
|
indent, access, lengthContext(f))
|
||||||
fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access)
|
fmt.Fprintf(b, "%sfor _i%s := range %s {\n", indent, f.Name, access)
|
||||||
elemField := parser.Field{
|
elemField := parser.Field{
|
||||||
Name: f.Name + "[_i" + f.Name + "]",
|
Name: f.Name + "[_i" + f.Name + "]",
|
||||||
@@ -169,7 +189,8 @@ func writeGoMarshalPrimitive(b *strings.Builder, access string, f parser.Field,
|
|||||||
case parser.KindUint64:
|
case parser.KindUint64:
|
||||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr)
|
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint64(buf, %s)\n", indent, valueExpr)
|
||||||
case parser.KindString:
|
case parser.KindString:
|
||||||
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, uint16(len(%s)))\n", indent, valueExpr)
|
fmt.Fprintf(b, "%sbuf = binary.LittleEndian.AppendUint16(buf, arpackEnsureUint16Length(len(%s), %q))\n",
|
||||||
|
indent, valueExpr, lengthContext(f))
|
||||||
fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr)
|
fmt.Fprintf(b, "%sbuf = append(buf, %s...)\n", indent, valueExpr)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -179,6 +200,8 @@ func writeGoMarshalQuant(b *strings.Builder, access string, f parser.Field, inde
|
|||||||
q := f.Quant
|
q := f.Quant
|
||||||
varName := "_q" + sanitizeVarName(access)
|
varName := "_q" + sanitizeVarName(access)
|
||||||
valueExpr := goMarshalValueExpr(access, f)
|
valueExpr := goMarshalValueExpr(access, f)
|
||||||
|
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(float64(%s), %g, %g, %q)\n",
|
||||||
|
indent, valueExpr, q.Min, q.Max, quantContext(f))
|
||||||
if q.Bits == 8 {
|
if q.Bits == 8 {
|
||||||
fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n",
|
fmt.Fprintf(b, "%s%s := uint8((%s - (%g)) / (%g - (%g)) * %g)\n",
|
||||||
indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint())
|
indent, varName, valueExpr, q.Min, q.Max, q.Min, q.MaxUint())
|
||||||
|
|||||||
+21
-2
@@ -69,6 +69,20 @@ func writeLuaHelpers(b *strings.Builder) {
|
|||||||
b.WriteString(" end\n")
|
b.WriteString(" end\n")
|
||||||
b.WriteString("end\n\n")
|
b.WriteString("end\n\n")
|
||||||
|
|
||||||
|
b.WriteString("local function ensure_u16_limit(n, context)\n")
|
||||||
|
b.WriteString(" if n < 0 or n > 65535 then\n")
|
||||||
|
b.WriteString(" error(string.format(\"arpack: %s exceeds uint16 limit: %d\", context, n))\n")
|
||||||
|
b.WriteString(" end\n")
|
||||||
|
b.WriteString(" return n\n")
|
||||||
|
b.WriteString("end\n\n")
|
||||||
|
|
||||||
|
b.WriteString("local function ensure_quant_range(value, min, max, context)\n")
|
||||||
|
b.WriteString(" if value ~= value or value < min or value > max then\n")
|
||||||
|
b.WriteString(" error(string.format(\"arpack: quantized value out of range for %s\", context))\n")
|
||||||
|
b.WriteString(" end\n")
|
||||||
|
b.WriteString(" return value\n")
|
||||||
|
b.WriteString("end\n\n")
|
||||||
|
|
||||||
b.WriteString("local function read_u8(data, offset)\n")
|
b.WriteString("local function read_u8(data, offset)\n")
|
||||||
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
|
b.WriteString(" if offset > #data then error(\"arpack: buffer too short for u8\") end\n")
|
||||||
b.WriteString(" return string.byte(data, offset), 1\n")
|
b.WriteString(" return string.byte(data, offset), 1\n")
|
||||||
@@ -304,6 +318,7 @@ func writeLuaHelpers(b *strings.Builder) {
|
|||||||
|
|
||||||
b.WriteString("local function write_string(s)\n")
|
b.WriteString("local function write_string(s)\n")
|
||||||
b.WriteString(" local len = #s\n")
|
b.WriteString(" local len = #s\n")
|
||||||
|
b.WriteString(" ensure_u16_limit(len, \"string length\")\n")
|
||||||
b.WriteString(" return write_u16_le(len) .. s\n")
|
b.WriteString(" return write_u16_le(len) .. s\n")
|
||||||
b.WriteString("end\n\n")
|
b.WriteString("end\n\n")
|
||||||
}
|
}
|
||||||
@@ -430,6 +445,7 @@ func writeLuaSerializeField(b *strings.Builder, recv string, f parser.Field, ind
|
|||||||
case parser.KindSlice:
|
case parser.KindSlice:
|
||||||
lenVar := "_len_" + strings.ToLower(f.Name)
|
lenVar := "_len_" + strings.ToLower(f.Name)
|
||||||
fmt.Fprintf(b, "%slocal %s = #(%s or {})\n", indent, lenVar, access)
|
fmt.Fprintf(b, "%slocal %s = #(%s or {})\n", indent, lenVar, access)
|
||||||
|
fmt.Fprintf(b, "%s%s = ensure_u16_limit(%s, %q)\n", indent, lenVar, lenVar, "slice length for "+luaFieldName(f.Name))
|
||||||
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, lenVar)
|
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u16_le(%s)\n", indent, lenVar)
|
||||||
iVar := "_i_" + strings.ToLower(f.Name)
|
iVar := "_i_" + strings.ToLower(f.Name)
|
||||||
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
|
fmt.Fprintf(b, "%sfor %s = 1, %s do\n", indent, iVar, lenVar)
|
||||||
@@ -497,8 +513,11 @@ func writeLuaSerializeQuant(b *strings.Builder, access string, f parser.Field, i
|
|||||||
q := f.Quant
|
q := f.Quant
|
||||||
maxUint := q.MaxUint()
|
maxUint := q.MaxUint()
|
||||||
varName := "_q_" + sanitizeLuaVarName(access)
|
varName := "_q_" + sanitizeLuaVarName(access)
|
||||||
fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g + 0.5)\n",
|
valueVar := "_quant_value_" + sanitizeLuaVarName(access)
|
||||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
fmt.Fprintf(b, "%slocal %s = ensure_quant_range(%s, %g, %g, %q)\n",
|
||||||
|
indent, valueVar, access, q.Min, q.Max, quantContext(f))
|
||||||
|
fmt.Fprintf(b, "%slocal %s = math.floor(((%s - (%g)) / (%g - (%g))) * %g)\n",
|
||||||
|
indent, varName, valueVar, q.Min, q.Max, q.Min, maxUint)
|
||||||
if q.Bits == 8 {
|
if q.Bits == 8 {
|
||||||
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
|
fmt.Fprintf(b, "%spart_idx = part_idx + 1; parts[part_idx] = write_u8(%s)\n", indent, varName)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+212
-2
@@ -281,8 +281,14 @@ func TestGenerateLua_QuantizedFloat(t *testing.T) {
|
|||||||
|
|
||||||
luaStr := string(lua)
|
luaStr := string(lua)
|
||||||
|
|
||||||
if !strings.Contains(luaStr, "math.floor") {
|
if !strings.Contains(luaStr, "math.floor(((_quant_value_msg_position - (-500)) / (500 - (-500))) * 65535)") {
|
||||||
t.Error("Missing math.floor for quantization")
|
t.Error("Missing truncating quantization code for Lua")
|
||||||
|
}
|
||||||
|
if !strings.Contains(luaStr, `ensure_quant_range(msg.position, -500, 500, "Position")`) {
|
||||||
|
t.Error("Missing quantized range guard for Lua")
|
||||||
|
}
|
||||||
|
if strings.Contains(luaStr, "math.floor(((msg.position - (-500)) / (500 - (-500))) * 65535 + 0.5)") {
|
||||||
|
t.Error("Lua quantization should not round to nearest")
|
||||||
}
|
}
|
||||||
if !strings.Contains(luaStr, "write_u16_le") {
|
if !strings.Contains(luaStr, "write_u16_le") {
|
||||||
t.Error("Missing u16 write for 16-bit quantization")
|
t.Error("Missing u16 write for 16-bit quantization")
|
||||||
@@ -334,6 +340,8 @@ func TestLuaHelpersGenerated(t *testing.T) {
|
|||||||
"local bit = require('bit')",
|
"local bit = require('bit')",
|
||||||
"buffer too short for u8",
|
"buffer too short for u8",
|
||||||
"buffer too short for bool",
|
"buffer too short for bool",
|
||||||
|
"local function ensure_u16_limit(n, context)",
|
||||||
|
"local function ensure_quant_range(value, min, max, context)",
|
||||||
"local function write_u8(n)",
|
"local function write_u8(n)",
|
||||||
"buffer too short for u16",
|
"buffer too short for u16",
|
||||||
"local function write_u16_le(n)",
|
"local function write_u16_le(n)",
|
||||||
@@ -453,6 +461,10 @@ func TestGenerateLua_BoundsChecks(t *testing.T) {
|
|||||||
t.Error("Missing check_bounds function")
|
t.Error("Missing check_bounds function")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(luaStr, "ensure_u16_limit") {
|
||||||
|
t.Error("Missing uint16 overflow helper")
|
||||||
|
}
|
||||||
|
|
||||||
// Check that read_u16_le has bounds check
|
// Check that read_u16_le has bounds check
|
||||||
if !strings.Contains(luaStr, "buffer too short for u16") {
|
if !strings.Contains(luaStr, "buffer too short for u16") {
|
||||||
t.Error("Missing bounds check in read_u16_le")
|
t.Error("Missing bounds check in read_u16_le")
|
||||||
@@ -489,6 +501,146 @@ func TestGenerateLua_BoundsChecks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateLua_LengthOverflowGuards(t *testing.T) {
|
||||||
|
schema := parser.Schema{
|
||||||
|
Messages: []parser.Message{
|
||||||
|
{
|
||||||
|
Name: "LengthLimited",
|
||||||
|
Fields: []parser.Field{
|
||||||
|
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
|
||||||
|
{
|
||||||
|
Name: "Items",
|
||||||
|
Kind: parser.KindSlice,
|
||||||
|
Elem: &parser.Field{
|
||||||
|
Kind: parser.KindPrimitive,
|
||||||
|
Primitive: parser.KindUint8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
lua, err := GenerateLuaSchema(schema, "test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateLuaSchema failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
luaStr := string(lua)
|
||||||
|
|
||||||
|
if !strings.Contains(luaStr, `ensure_u16_limit(len, "string length")`) {
|
||||||
|
t.Error("Missing string length overflow guard")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(luaStr, `ensure_u16_limit(_len_items, "slice length for items")`) {
|
||||||
|
t.Error("Missing slice length overflow guard")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateLua_RuntimeLengthLimits(t *testing.T) {
|
||||||
|
if _, err := exec.LookPath("luajit"); err != nil {
|
||||||
|
t.Skip("luajit not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := parser.Schema{
|
||||||
|
Messages: []parser.Message{
|
||||||
|
{
|
||||||
|
Name: "LengthLimited",
|
||||||
|
Fields: []parser.Field{
|
||||||
|
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
|
||||||
|
{
|
||||||
|
Name: "Items",
|
||||||
|
Kind: parser.KindSlice,
|
||||||
|
Elem: &parser.Field{
|
||||||
|
Kind: parser.KindPrimitive,
|
||||||
|
Primitive: parser.KindUint8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
lua, err := GenerateLuaSchema(schema, "messages")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateLuaSchema failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
modulePath := filepath.Join(dir, "messages_gen.lua")
|
||||||
|
if err := os.WriteFile(modulePath, lua, 0o600); err != nil {
|
||||||
|
t.Fatalf("write module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
scriptPath := filepath.Join(dir, "check.lua")
|
||||||
|
script := `local messages = require("messages_gen")
|
||||||
|
|
||||||
|
local function emit(label, ok, value)
|
||||||
|
if ok then
|
||||||
|
print(label .. ":OK")
|
||||||
|
else
|
||||||
|
print(label .. ":" .. tostring(value))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
local msg = messages.new_length_limited()
|
||||||
|
|
||||||
|
local ok, res = pcall(messages.serialize_length_limited, msg)
|
||||||
|
emit("EMPTY", ok, res)
|
||||||
|
|
||||||
|
msg.name = string.rep("a", 65535)
|
||||||
|
ok, res = pcall(messages.serialize_length_limited, msg)
|
||||||
|
emit("STR_MAX", ok, res)
|
||||||
|
|
||||||
|
msg.name = string.rep("a", 65536)
|
||||||
|
ok, res = pcall(messages.serialize_length_limited, msg)
|
||||||
|
emit("STR_OVER", ok, res)
|
||||||
|
|
||||||
|
msg.name = ""
|
||||||
|
msg.items = {}
|
||||||
|
for i = 1, 65535 do
|
||||||
|
msg.items[i] = 0
|
||||||
|
end
|
||||||
|
ok, res = pcall(messages.serialize_length_limited, msg)
|
||||||
|
emit("SLICE_MAX", ok, res)
|
||||||
|
|
||||||
|
msg.items[65536] = 0
|
||||||
|
ok, res = pcall(messages.serialize_length_limited, msg)
|
||||||
|
emit("SLICE_OVER", ok, res)
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(scriptPath, []byte(script), 0o600); err != nil {
|
||||||
|
t.Fatalf("write script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("luajit", "check.lua")
|
||||||
|
cmd.Dir = dir
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("luajit failed: %v\n%s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSpace(string(out)), "\n")
|
||||||
|
if len(lines) != 5 {
|
||||||
|
t.Fatalf("expected 5 output lines, got %d: %q", len(lines), string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
if lines[0] != "EMPTY:OK" {
|
||||||
|
t.Fatalf("expected empty serialization to succeed, got %q", lines[0])
|
||||||
|
}
|
||||||
|
if lines[1] != "STR_MAX:OK" {
|
||||||
|
t.Fatalf("expected 65535-byte string serialization to succeed, got %q", lines[1])
|
||||||
|
}
|
||||||
|
if !strings.Contains(lines[2], "string length exceeds uint16 limit") {
|
||||||
|
t.Fatalf("expected string overflow guard, got %q", lines[2])
|
||||||
|
}
|
||||||
|
if lines[3] != "SLICE_MAX:OK" {
|
||||||
|
t.Fatalf("expected 65535-element slice serialization to succeed, got %q", lines[3])
|
||||||
|
}
|
||||||
|
if !strings.Contains(lines[4], "slice length for items exceeds uint16 limit") {
|
||||||
|
t.Fatalf("expected slice overflow guard, got %q", lines[4])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) {
|
func TestGenerateLua_RuntimeFloatEdgeCases(t *testing.T) {
|
||||||
if _, err := exec.LookPath("luajit"); err != nil {
|
if _, err := exec.LookPath("luajit"); err != nil {
|
||||||
t.Skip("luajit not found")
|
t.Skip("luajit not found")
|
||||||
@@ -555,3 +707,61 @@ print(bytes_to_hex(messages.serialize_float_edges(msg)))
|
|||||||
t.Fatalf("subnormal roundtrip mismatch: %s", lines[1])
|
t.Fatalf("subnormal roundtrip mismatch: %s", lines[1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGenerateLua_RuntimeQuantizedRangeGuard(t *testing.T) {
|
||||||
|
if _, err := exec.LookPath("luajit"); err != nil {
|
||||||
|
t.Skip("luajit not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
schema := parser.Schema{
|
||||||
|
Messages: []parser.Message{
|
||||||
|
{
|
||||||
|
Name: "WithQuantized",
|
||||||
|
Fields: []parser.Field{
|
||||||
|
{
|
||||||
|
Name: "Position",
|
||||||
|
Kind: parser.KindPrimitive,
|
||||||
|
Primitive: parser.KindFloat32,
|
||||||
|
Quant: &parser.QuantInfo{Min: -500, Max: 500, Bits: 16},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
lua, err := GenerateLuaSchema(schema, "messages")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateLuaSchema failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "messages_gen.lua"), lua, 0o600); err != nil {
|
||||||
|
t.Fatalf("write module: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
script := `local messages = require("messages_gen")
|
||||||
|
local msg = messages.new_with_quantized()
|
||||||
|
msg.position = 501
|
||||||
|
local ok, res = pcall(messages.serialize_with_quantized, msg)
|
||||||
|
if ok then
|
||||||
|
print("OK")
|
||||||
|
else
|
||||||
|
print(res)
|
||||||
|
end
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(filepath.Join(dir, "check.lua"), []byte(script), 0o600); err != nil {
|
||||||
|
t.Fatalf("write script: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("luajit", "check.lua")
|
||||||
|
cmd.Dir = dir
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("luajit failed: %v\n%s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := strings.TrimSpace(string(out))
|
||||||
|
if !strings.Contains(got, "quantized value out of range for Position") {
|
||||||
|
t.Fatalf("expected quantized range guard, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import "github.com/edmand46/arpack/parser"
|
||||||
|
|
||||||
|
func lengthContext(f parser.Field) string {
|
||||||
|
switch {
|
||||||
|
case f.Kind == parser.KindSlice:
|
||||||
|
if f.Name != "" {
|
||||||
|
return "slice length for " + f.Name
|
||||||
|
}
|
||||||
|
return "slice length"
|
||||||
|
case f.Kind == parser.KindPrimitive && f.Primitive == parser.KindString:
|
||||||
|
if f.Name != "" {
|
||||||
|
return "string length for " + f.Name
|
||||||
|
}
|
||||||
|
return "string length"
|
||||||
|
default:
|
||||||
|
return "length"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func quantContext(f parser.Field) string {
|
||||||
|
if f.Name != "" {
|
||||||
|
return f.Name
|
||||||
|
}
|
||||||
|
return "value"
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaNeedsLengthGuards(messages []parser.Message) bool {
|
||||||
|
for _, msg := range messages {
|
||||||
|
for _, f := range msg.Fields {
|
||||||
|
if fieldNeedsLengthGuard(f) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldNeedsLengthGuard(f parser.Field) bool {
|
||||||
|
switch f.Kind {
|
||||||
|
case parser.KindPrimitive:
|
||||||
|
return f.Primitive == parser.KindString
|
||||||
|
case parser.KindFixedArray, parser.KindSlice:
|
||||||
|
if f.Kind == parser.KindSlice {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if f.Elem != nil {
|
||||||
|
return fieldNeedsLengthGuard(*f.Elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaNeedsQuantRangeGuards(messages []parser.Message) bool {
|
||||||
|
for _, msg := range messages {
|
||||||
|
for _, f := range msg.Fields {
|
||||||
|
if fieldNeedsQuantRangeGuard(f) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldNeedsQuantRangeGuard(f parser.Field) bool {
|
||||||
|
switch f.Kind {
|
||||||
|
case parser.KindPrimitive:
|
||||||
|
return f.Quant != nil
|
||||||
|
case parser.KindFixedArray, parser.KindSlice:
|
||||||
|
if f.Elem != nil {
|
||||||
|
return fieldNeedsQuantRangeGuard(*f.Elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
+45
-11
@@ -16,10 +16,31 @@ func GenerateTypeScript(messages []parser.Message, namespace string) ([]byte, er
|
|||||||
func GenerateTypeScriptSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
func GenerateTypeScriptSchema(schema parser.Schema, namespace string) ([]byte, error) {
|
||||||
messages := schema.Messages
|
messages := schema.Messages
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
|
needsLengthGuards := schemaNeedsLengthGuards(messages)
|
||||||
|
needsQuantGuards := schemaNeedsQuantRangeGuards(messages)
|
||||||
|
|
||||||
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
b.WriteString("// <auto-generated> arpack </auto-generated>\n")
|
||||||
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
b.WriteString("// Code generated by arpack. DO NOT EDIT.\n\n")
|
||||||
|
|
||||||
|
if needsLengthGuards {
|
||||||
|
b.WriteString("const arpackTextEncoder = new TextEncoder();\n")
|
||||||
|
b.WriteString("const arpackTextDecoder = new TextDecoder();\n\n")
|
||||||
|
b.WriteString("function arpackEnsureUint16Length(length: number, context: string): number {\n")
|
||||||
|
b.WriteString(" if (length > 65535) {\n")
|
||||||
|
b.WriteString(" throw new RangeError(\"arpack: \" + context + \" exceeds uint16 limit\");\n")
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
b.WriteString(" return length;\n")
|
||||||
|
b.WriteString("}\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsQuantGuards {
|
||||||
|
b.WriteString("function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void {\n")
|
||||||
|
b.WriteString(" if (Number.isNaN(value) || value < min || value > max) {\n")
|
||||||
|
b.WriteString(" throw new RangeError(\"arpack: quantized value out of range for \" + context);\n")
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
b.WriteString("}\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
enumNames := make(map[string]struct{}, len(schema.Enums))
|
enumNames := make(map[string]struct{}, len(schema.Enums))
|
||||||
for _, enum := range schema.Enums {
|
for _, enum := range schema.Enums {
|
||||||
enumNames[enum.Name] = struct{}{}
|
enumNames[enum.Name] = struct{}{}
|
||||||
@@ -148,7 +169,10 @@ func writeTSSerializeField(b *strings.Builder, recv string, f parser.Field, inde
|
|||||||
}
|
}
|
||||||
fmt.Fprintf(b, "%s}\n", indent)
|
fmt.Fprintf(b, "%s}\n", indent)
|
||||||
case parser.KindSlice:
|
case parser.KindSlice:
|
||||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, access)
|
lenVar := "_len" + sanitizeVarName(access)
|
||||||
|
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||||
|
indent, lenVar, access, lengthContext(f))
|
||||||
|
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, lenVar)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
iVar := "_i" + f.Name
|
iVar := "_i" + f.Name
|
||||||
fmt.Fprintf(b, "%sfor (const %s of %s) {\n", indent, iVar, access)
|
fmt.Fprintf(b, "%sfor (const %s of %s) {\n", indent, iVar, access)
|
||||||
@@ -222,8 +246,11 @@ func writeTSSerializePrimitiveElement(b *strings.Builder, access string, f parse
|
|||||||
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
||||||
case parser.KindString:
|
case parser.KindString:
|
||||||
lenVar := "_slen" + sanitizeVarName(access)
|
lenVar := "_slen" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr)
|
guardVar := "_slenChecked" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar)
|
fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr)
|
||||||
|
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||||
|
indent, guardVar, lenVar, lengthContext(f))
|
||||||
|
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
||||||
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
||||||
@@ -273,8 +300,11 @@ func writeTSSerializePrimitive(b *strings.Builder, access string, f parser.Field
|
|||||||
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
fmt.Fprintf(b, "%spos += 8;\n", indent)
|
||||||
case parser.KindString:
|
case parser.KindString:
|
||||||
lenVar := "_slen" + sanitizeVarName(access)
|
lenVar := "_slen" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sconst %s = new TextEncoder().encode(%s);\n", indent, lenVar, valueExpr)
|
guardVar := "_slenChecked" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s.length, true);\n", indent, lenVar)
|
fmt.Fprintf(b, "%sconst %s = arpackTextEncoder.encode(%s);\n", indent, lenVar, valueExpr)
|
||||||
|
fmt.Fprintf(b, "%sconst %s = arpackEnsureUint16Length(%s.length, %q);\n",
|
||||||
|
indent, guardVar, lenVar, lengthContext(f))
|
||||||
|
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, guardVar)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
fmt.Fprintf(b, "%snew Uint8Array(view.buffer, pos, %s.length).set(%s);\n", indent, lenVar, lenVar)
|
||||||
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
fmt.Fprintf(b, "%spos += %s.length;\n", indent, lenVar)
|
||||||
@@ -286,13 +316,15 @@ func writeTSSerializeQuant(b *strings.Builder, access string, f parser.Field, in
|
|||||||
q := f.Quant
|
q := f.Quant
|
||||||
maxUint := q.MaxUint()
|
maxUint := q.MaxUint()
|
||||||
varName := "_q" + sanitizeVarName(access)
|
varName := "_q" + sanitizeVarName(access)
|
||||||
|
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||||
|
indent, access, q.Min, q.Max, quantContext(f))
|
||||||
if q.Bits == 8 {
|
if q.Bits == 8 {
|
||||||
fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n",
|
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||||
fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName)
|
fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName)
|
||||||
fmt.Fprintf(b, "%spos += 1;\n", indent)
|
fmt.Fprintf(b, "%spos += 1;\n", indent)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n",
|
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName)
|
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
@@ -304,13 +336,15 @@ func writeTSSerializeQuantElement(b *strings.Builder, access string, f parser.Fi
|
|||||||
q := f.Quant
|
q := f.Quant
|
||||||
maxUint := q.MaxUint()
|
maxUint := q.MaxUint()
|
||||||
varName := "_q" + sanitizeVarName(access)
|
varName := "_q" + sanitizeVarName(access)
|
||||||
|
fmt.Fprintf(b, "%sarpackEnsureQuantizedRange(%s, %g, %g, %q);\n",
|
||||||
|
indent, access, q.Min, q.Max, quantContext(f))
|
||||||
if q.Bits == 8 {
|
if q.Bits == 8 {
|
||||||
fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n",
|
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||||
fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName)
|
fmt.Fprintf(b, "%sview.setUint8(pos, %s);\n", indent, varName)
|
||||||
fmt.Fprintf(b, "%spos += 1;\n", indent)
|
fmt.Fprintf(b, "%spos += 1;\n", indent)
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintf(b, "%sconst %s = Math.round((%s - (%g)) / (%g - (%g)) * %g);\n",
|
fmt.Fprintf(b, "%sconst %s = Math.trunc((%s - (%g)) / (%g - (%g)) * %g);\n",
|
||||||
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
indent, varName, access, q.Min, q.Max, q.Min, maxUint)
|
||||||
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName)
|
fmt.Fprintf(b, "%sview.setUint16(pos, %s, true);\n", indent, varName)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
@@ -438,7 +472,7 @@ func writeTSDeserializePrimitiveElement(b *strings.Builder, access string, f par
|
|||||||
lenVar := "_slen" + sanitizeVarName(access)
|
lenVar := "_slen" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||||
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
||||||
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
||||||
}
|
}
|
||||||
@@ -499,7 +533,7 @@ func writeTSDeserializePrimitive(b *strings.Builder, access string, f parser.Fie
|
|||||||
lenVar := "_slen" + sanitizeVarName(access)
|
lenVar := "_slen" + sanitizeVarName(access)
|
||||||
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
fmt.Fprintf(b, "%sconst %s = view.getUint16(pos, true);\n", indent, lenVar)
|
||||||
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
fmt.Fprintf(b, "%spos += 2;\n", indent)
|
||||||
expr := fmt.Sprintf("new TextDecoder().decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
expr := fmt.Sprintf("arpackTextDecoder.decode(new Uint8Array(view.buffer, pos, %s))", lenVar)
|
||||||
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
fmt.Fprintf(b, "%s%s = %s;\n", indent, access, tsDeserializeValueExpr(expr, f, enumNames))
|
||||||
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
fmt.Fprintf(b, "%spos += %s;\n", indent, lenVar)
|
||||||
}
|
}
|
||||||
|
|||||||
+74
-8
@@ -98,14 +98,20 @@ func TestGenerateTypeScript_QuantizedFloats(t *testing.T) {
|
|||||||
code := string(src)
|
code := string(src)
|
||||||
|
|
||||||
// Check 8-bit quantization (using camelCase field names)
|
// Check 8-bit quantization (using camelCase field names)
|
||||||
if !strings.Contains(code, "Math.round((this.q8 - (0)) / (100 - (0)) * 255)") {
|
if !strings.Contains(code, "Math.trunc((this.q8 - (0)) / (100 - (0)) * 255)") {
|
||||||
t.Error("Missing 8-bit quantization code")
|
t.Error("Missing 8-bit quantization code")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q8, 0, 100, "Q8");`) {
|
||||||
|
t.Error("Missing 8-bit quantized range guard")
|
||||||
|
}
|
||||||
|
|
||||||
// Check 16-bit quantization (using camelCase field names)
|
// Check 16-bit quantization (using camelCase field names)
|
||||||
if !strings.Contains(code, "Math.round((this.q16 - (-500)) / (500 - (-500)) * 65535)") {
|
if !strings.Contains(code, "Math.trunc((this.q16 - (-500)) / (500 - (-500)) * 65535)") {
|
||||||
t.Error("Missing 16-bit quantization code")
|
t.Error("Missing 16-bit quantization code")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.q16, -500, 500, "Q16");`) {
|
||||||
|
t.Error("Missing 16-bit quantized range guard")
|
||||||
|
}
|
||||||
|
|
||||||
// Check deserialization with dequantization
|
// Check deserialization with dequantization
|
||||||
if !strings.Contains(code, "/ 255 * (100 - (0)) + (0)") {
|
if !strings.Contains(code, "/ 255 * (100 - (0)) + (0)") {
|
||||||
@@ -286,8 +292,11 @@ func TestGenerateTypeScript_Slices(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check length prefix in serialize (using camelCase field name)
|
// Check length prefix in serialize (using camelCase field name)
|
||||||
if !strings.Contains(code, "view.setUint16(pos, this.items.length, true);") {
|
if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) {
|
||||||
t.Error("Missing slice length prefix in serialize")
|
t.Error("Missing slice length guard in serialize")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, "view.setUint16(pos, _lenthis_items, true);") {
|
||||||
|
t.Error("Missing guarded slice length prefix in serialize")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check length reading in deserialize
|
// Check length reading in deserialize
|
||||||
@@ -385,17 +394,74 @@ func TestGenerateTypeScript_Strings(t *testing.T) {
|
|||||||
code := string(src)
|
code := string(src)
|
||||||
|
|
||||||
// Check TextEncoder usage
|
// Check TextEncoder usage
|
||||||
if !strings.Contains(code, "new TextEncoder().encode(") {
|
if !strings.Contains(code, "const arpackTextEncoder = new TextEncoder();") {
|
||||||
t.Error("Missing TextEncoder in serialize")
|
t.Error("Missing shared TextEncoder helper")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check length prefix
|
// Check length prefix
|
||||||
if !strings.Contains(code, "view.setUint16(pos, _slen") {
|
if !strings.Contains(code, "view.setUint16(pos, _slen") {
|
||||||
t.Error("Missing string length prefix in serialize")
|
t.Error("Missing string length prefix in serialize")
|
||||||
}
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) {
|
||||||
|
t.Error("Missing string length guard in serialize")
|
||||||
|
}
|
||||||
|
|
||||||
// Check TextDecoder usage
|
// Check TextDecoder usage
|
||||||
if !strings.Contains(code, "new TextDecoder().decode(") {
|
if !strings.Contains(code, "const arpackTextDecoder = new TextDecoder();") {
|
||||||
t.Error("Missing TextDecoder in deserialize")
|
t.Error("Missing shared TextDecoder helper")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, "arpackTextDecoder.decode(") {
|
||||||
|
t.Error("Missing shared TextDecoder in deserialize")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateTypeScript_LengthAndRangeHelpers(t *testing.T) {
|
||||||
|
schema := parser.Schema{
|
||||||
|
Messages: []parser.Message{
|
||||||
|
{
|
||||||
|
PackageName: "test",
|
||||||
|
Name: "LengthAndQuant",
|
||||||
|
Fields: []parser.Field{
|
||||||
|
{Name: "Name", Kind: parser.KindPrimitive, Primitive: parser.KindString},
|
||||||
|
{
|
||||||
|
Name: "Items",
|
||||||
|
Kind: parser.KindSlice,
|
||||||
|
Elem: &parser.Field{
|
||||||
|
Kind: parser.KindPrimitive,
|
||||||
|
Primitive: parser.KindUint8,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Ratio",
|
||||||
|
Kind: parser.KindPrimitive,
|
||||||
|
Primitive: parser.KindFloat32,
|
||||||
|
Quant: &parser.QuantInfo{Min: 0, Max: 1, Bits: 8},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
src, err := GenerateTypeScriptSchema(schema, "Test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTypeScriptSchema: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := string(src)
|
||||||
|
|
||||||
|
if !strings.Contains(code, "function arpackEnsureUint16Length(length: number, context: string): number") {
|
||||||
|
t.Error("Missing uint16 length helper")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, "function arpackEnsureQuantizedRange(value: number, min: number, max: number, context: string): void") {
|
||||||
|
t.Error("Missing quantized range helper")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureUint16Length(this.items.length, "slice length for Items")`) {
|
||||||
|
t.Error("Missing slice length guard")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureUint16Length(_slen`) {
|
||||||
|
t.Error("Missing string length helper call")
|
||||||
|
}
|
||||||
|
if !strings.Contains(code, `arpackEnsureQuantizedRange(this.ratio, 0, 1, "Ratio");`) {
|
||||||
|
t.Error("Missing quantized range helper call")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+67
-7
@@ -55,6 +55,8 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) {
|
|||||||
|
|
||||||
knownStructs := map[string]bool{}
|
knownStructs := map[string]bool{}
|
||||||
namedPrimitives := map[string]PrimitiveKind{}
|
namedPrimitives := map[string]PrimitiveKind{}
|
||||||
|
unsupportedNamedPrimitives := map[string]string{}
|
||||||
|
unresolvedNamedPrimitives := map[string]string{}
|
||||||
var enumOrder []string
|
var enumOrder []string
|
||||||
|
|
||||||
for _, decl := range f.Decls {
|
for _, decl := range f.Decls {
|
||||||
@@ -73,8 +75,14 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) {
|
|||||||
case *ast.StructType:
|
case *ast.StructType:
|
||||||
knownStructs[typeSpec.Name.Name] = true
|
knownStructs[typeSpec.Name.Name] = true
|
||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
|
switch t.Name {
|
||||||
|
case "int", "uint", "uintptr":
|
||||||
|
unsupportedNamedPrimitives[typeSpec.Name.Name] = t.Name
|
||||||
|
continue
|
||||||
|
}
|
||||||
primKind, isPrimitive := goPrimitiveKind(t.Name)
|
primKind, isPrimitive := goPrimitiveKind(t.Name)
|
||||||
if !isPrimitive {
|
if !isPrimitive {
|
||||||
|
unresolvedNamedPrimitives[typeSpec.Name.Name] = t.Name
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
namedPrimitives[typeSpec.Name.Name] = primKind
|
namedPrimitives[typeSpec.Name.Name] = primKind
|
||||||
@@ -85,6 +93,19 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for changed := true; changed; {
|
||||||
|
changed = false
|
||||||
|
for name, target := range unresolvedNamedPrimitives {
|
||||||
|
if _, ok := unsupportedNamedPrimitives[name]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if baseType, ok := unsupportedNamedPrimitives[target]; ok {
|
||||||
|
unsupportedNamedPrimitives[name] = baseType
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
info, err := typeCheckFile(fset, f)
|
info, err := typeCheckFile(fset, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Schema{}, err
|
return Schema{}, err
|
||||||
@@ -119,7 +140,14 @@ func parseASTFile(fset *token.FileSet, f *ast.File) (Schema, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err := parseStruct(pkgName, typeSpec.Name.Name, structType, knownStructs, namedPrimitives)
|
msg, err := parseStruct(
|
||||||
|
pkgName,
|
||||||
|
typeSpec.Name.Name,
|
||||||
|
structType,
|
||||||
|
knownStructs,
|
||||||
|
namedPrimitives,
|
||||||
|
unsupportedNamedPrimitives,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Schema{}, fmt.Errorf("struct %s: %w", typeSpec.Name.Name, err)
|
return Schema{}, fmt.Errorf("struct %s: %w", typeSpec.Name.Name, err)
|
||||||
}
|
}
|
||||||
@@ -187,6 +215,7 @@ func parseStruct(
|
|||||||
st *ast.StructType,
|
st *ast.StructType,
|
||||||
knownStructs map[string]bool,
|
knownStructs map[string]bool,
|
||||||
namedPrimitives map[string]PrimitiveKind,
|
namedPrimitives map[string]PrimitiveKind,
|
||||||
|
unsupportedNamedPrimitives map[string]string,
|
||||||
) (Message, error) {
|
) (Message, error) {
|
||||||
msg := Message{PackageName: pkg, Name: name}
|
msg := Message{PackageName: pkg, Name: name}
|
||||||
|
|
||||||
@@ -202,7 +231,14 @@ func parseStruct(
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, fieldName := range astField.Names {
|
for _, fieldName := range astField.Names {
|
||||||
field, err := parseFieldType(fieldName.Name, astField.Type, rawTag, knownStructs, namedPrimitives)
|
field, err := parseFieldType(
|
||||||
|
fieldName.Name,
|
||||||
|
astField.Type,
|
||||||
|
rawTag,
|
||||||
|
knownStructs,
|
||||||
|
namedPrimitives,
|
||||||
|
unsupportedNamedPrimitives,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Message{}, fmt.Errorf("field %s: %w", fieldName.Name, err)
|
return Message{}, fmt.Errorf("field %s: %w", fieldName.Name, err)
|
||||||
}
|
}
|
||||||
@@ -219,14 +255,22 @@ func parseFieldType(
|
|||||||
rawTag string,
|
rawTag string,
|
||||||
knownStructs map[string]bool,
|
knownStructs map[string]bool,
|
||||||
namedPrimitives map[string]PrimitiveKind,
|
namedPrimitives map[string]PrimitiveKind,
|
||||||
|
unsupportedNamedPrimitives map[string]string,
|
||||||
) (Field, error) {
|
) (Field, error) {
|
||||||
switch t := expr.(type) {
|
switch t := expr.(type) {
|
||||||
case *ast.Ident:
|
case *ast.Ident:
|
||||||
return parsePrimitiveOrNested(name, t.Name, rawTag, knownStructs, namedPrimitives)
|
return parsePrimitiveOrNested(
|
||||||
|
name,
|
||||||
|
t.Name,
|
||||||
|
rawTag,
|
||||||
|
knownStructs,
|
||||||
|
namedPrimitives,
|
||||||
|
unsupportedNamedPrimitives,
|
||||||
|
)
|
||||||
|
|
||||||
case *ast.ArrayType:
|
case *ast.ArrayType:
|
||||||
if t.Len == nil {
|
if t.Len == nil {
|
||||||
elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives)
|
elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives, unsupportedNamedPrimitives)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Field{}, fmt.Errorf("slice element: %w", err)
|
return Field{}, fmt.Errorf("slice element: %w", err)
|
||||||
}
|
}
|
||||||
@@ -243,7 +287,7 @@ func parseFieldType(
|
|||||||
return Field{}, fmt.Errorf("array length: %w", err)
|
return Field{}, fmt.Errorf("array length: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives)
|
elem, err := parseFieldType("", t.Elt, rawTag, knownStructs, namedPrimitives, unsupportedNamedPrimitives)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Field{}, fmt.Errorf("array element: %w", err)
|
return Field{}, fmt.Errorf("array element: %w", err)
|
||||||
}
|
}
|
||||||
@@ -271,9 +315,25 @@ func parsePrimitiveOrNested(
|
|||||||
rawTag string,
|
rawTag string,
|
||||||
knownStructs map[string]bool,
|
knownStructs map[string]bool,
|
||||||
namedPrimitives map[string]PrimitiveKind,
|
namedPrimitives map[string]PrimitiveKind,
|
||||||
|
unsupportedNamedPrimitives map[string]string,
|
||||||
) (Field, error) {
|
) (Field, error) {
|
||||||
|
switch typeName {
|
||||||
|
case "int", "uint", "uintptr":
|
||||||
|
return Field{}, fmt.Errorf(
|
||||||
|
"platform-dependent type %q is not supported; use int32/int64, uint32/uint64, or fixed-size integer IDs instead",
|
||||||
|
typeName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
primKind, isPrimitive := goPrimitiveKind(typeName)
|
primKind, isPrimitive := goPrimitiveKind(typeName)
|
||||||
if !isPrimitive {
|
if !isPrimitive {
|
||||||
|
if baseType, ok := unsupportedNamedPrimitives[typeName]; ok {
|
||||||
|
return Field{}, fmt.Errorf(
|
||||||
|
"type %q aliases unsupported platform-dependent %q; use int32/int64, uint32/uint64, or fixed-size integer IDs instead",
|
||||||
|
typeName,
|
||||||
|
baseType,
|
||||||
|
)
|
||||||
|
}
|
||||||
if namedPrimitive, ok := namedPrimitives[typeName]; ok {
|
if namedPrimitive, ok := namedPrimitives[typeName]; ok {
|
||||||
return buildPrimitiveField(name, typeName, namedPrimitive, rawTag)
|
return buildPrimitiveField(name, typeName, namedPrimitive, rawTag)
|
||||||
}
|
}
|
||||||
@@ -385,7 +445,7 @@ func goPrimitiveKind(name string) (PrimitiveKind, bool) {
|
|||||||
return KindInt8, true
|
return KindInt8, true
|
||||||
case "int16":
|
case "int16":
|
||||||
return KindInt16, true
|
return KindInt16, true
|
||||||
case "int32", "int":
|
case "int32":
|
||||||
return KindInt32, true
|
return KindInt32, true
|
||||||
case "int64":
|
case "int64":
|
||||||
return KindInt64, true
|
return KindInt64, true
|
||||||
@@ -393,7 +453,7 @@ func goPrimitiveKind(name string) (PrimitiveKind, bool) {
|
|||||||
return KindUint8, true
|
return KindUint8, true
|
||||||
case "uint16":
|
case "uint16":
|
||||||
return KindUint16, true
|
return KindUint16, true
|
||||||
case "uint32", "uint":
|
case "uint32":
|
||||||
return KindUint32, true
|
return KindUint32, true
|
||||||
case "uint64":
|
case "uint64":
|
||||||
return KindUint64, true
|
return KindUint64, true
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package parser
|
package parser
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -266,3 +267,92 @@ type Msg struct {
|
|||||||
t.Fatal("expected error for unknown nested type, got nil")
|
t.Fatal("expected error for unknown nested type, got nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnsupportedPlatformDependentIntTypes(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
src string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "direct int field",
|
||||||
|
src: `package p
|
||||||
|
type Msg struct {
|
||||||
|
X int
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `platform-dependent type "int" is not supported`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "direct uint field",
|
||||||
|
src: `package p
|
||||||
|
type Msg struct {
|
||||||
|
X uint
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `platform-dependent type "uint" is not supported`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "direct uintptr field",
|
||||||
|
src: `package p
|
||||||
|
type Msg struct {
|
||||||
|
X uintptr
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `platform-dependent type "uintptr" is not supported`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alias of int",
|
||||||
|
src: `package p
|
||||||
|
type Counter int
|
||||||
|
type Msg struct {
|
||||||
|
X Counter
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `type "Counter" aliases unsupported platform-dependent "int"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alias of uint",
|
||||||
|
src: `package p
|
||||||
|
type Counter uint
|
||||||
|
type Msg struct {
|
||||||
|
X Counter
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `type "Counter" aliases unsupported platform-dependent "uint"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "alias of uintptr",
|
||||||
|
src: `package p
|
||||||
|
type Handle uintptr
|
||||||
|
type Msg struct {
|
||||||
|
X Handle
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `type "Handle" aliases unsupported platform-dependent "uintptr"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "transitive alias of int",
|
||||||
|
src: `package p
|
||||||
|
type Base int
|
||||||
|
type Counter Base
|
||||||
|
type Msg struct {
|
||||||
|
X Counter
|
||||||
|
}
|
||||||
|
`,
|
||||||
|
wantErr: `type "Counter" aliases unsupported platform-dependent "int"`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
_, err := ParseSource(tc.src)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), tc.wantErr) {
|
||||||
|
t.Fatalf("expected error containing %q, got %v", tc.wantErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Vendored
+39
@@ -0,0 +1,39 @@
|
|||||||
|
package lua
|
||||||
|
|
||||||
|
type Vector3 struct {
|
||||||
|
X float32 `pack:"min=-500,max=500,bits=16"`
|
||||||
|
Y float32 `pack:"min=-500,max=500,bits=16"`
|
||||||
|
Z float32 `pack:"min=-500,max=500,bits=16"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Opcode uint16
|
||||||
|
|
||||||
|
const (
|
||||||
|
OpcodeUnknown Opcode = iota
|
||||||
|
OpcodeAuthorize
|
||||||
|
OpcodeJoinRoom
|
||||||
|
)
|
||||||
|
|
||||||
|
type MoveMessage struct {
|
||||||
|
Position Vector3
|
||||||
|
Velocity [3]float32
|
||||||
|
Waypoints []Vector3
|
||||||
|
PlayerID uint32
|
||||||
|
Active bool
|
||||||
|
Visible bool
|
||||||
|
Ghost bool
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SpawnMessage struct {
|
||||||
|
EntityID uint32
|
||||||
|
Position Vector3
|
||||||
|
Health int16
|
||||||
|
Tags []string
|
||||||
|
Data []uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvelopeMessage struct {
|
||||||
|
Code Opcode
|
||||||
|
Counter uint8
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user