diff --git a/go/vt/vtgate/evalengine/arena.go b/go/vt/vtgate/evalengine/arena.go index 590dc3b02c7..9236b29dcd1 100644 --- a/go/vt/vtgate/evalengine/arena.go +++ b/go/vt/vtgate/evalengine/arena.go @@ -61,6 +61,35 @@ func (a *Arena) newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { return a.newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } +<<<<<<< HEAD +======= +func (a *Arena) newEvalEnum(raw []byte, values *EnumSetValues) *evalEnum { + if cap(a.aEnum) > len(a.aEnum) { + a.aEnum = a.aEnum[:len(a.aEnum)+1] + } else { + a.aEnum = append(a.aEnum, evalEnum{}) + } + val := &a.aEnum[len(a.aEnum)-1] + s := string(raw) + val.string = s + val.value = valueIdx(values, s) + return val +} + +func (a *Arena) newEvalSet(raw []byte, values *EnumSetValues) *evalSet { + if cap(a.aSet) > len(a.aSet) { + a.aSet = a.aSet[:len(a.aSet)+1] + } else { + a.aSet = append(a.aSet, evalSet{}) + } + val := &a.aSet[len(a.aSet)-1] + s := string(raw) + val.string = s + val.set = evalSetBits(values, s) + return val +} + +>>>>>>> 06def14056 (Fix crash in the evalengine (#17487)) func (a *Arena) newEvalBool(b bool) *evalInt64 { if b { return a.newEvalInt64(1) diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 201b2816f45..7b89dafee00 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -716,6 +716,20 @@ func TestCompilerSingle(t *testing.T) { expression: `WEEK(timestamp '2024-01-01 10:34:58', 1)`, result: `INT64(1)`, }, + { + expression: `column0 + 1`, + values: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Enum, []byte("foo"))}, + // Returns 0, as unknown enums evaluate here to -1. We have this test to + // exercise the path to push enums onto the stack. + result: `FLOAT64(0)`, + }, + { + expression: `column0 + 1`, + values: []sqltypes.Value{sqltypes.MakeTrusted(sqltypes.Set, []byte("foo"))}, + // Returns 1, as unknown sets evaluate here to 0. We have this test to + // exercise the path to push sets onto the stack. + result: `FLOAT64(1)`, + }, } tz, _ := time.LoadLocation("Europe/Madrid")