From 8a0f8d1a2021394c552497324c9a4d3d0ed7f61c Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Tue, 10 Sep 2013 21:30:41 +0200 Subject: [PATCH] THRIFT-2174 Deserializing JSON fails in specific cases Patch: Jens Geyer --- lib/go/thrift/json_protocol.go | 130 +++++++++++++++----------- lib/go/thrift/simple_json_protocol.go | 34 +------ 2 files changed, 83 insertions(+), 81 deletions(-) diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go index 5e8453a1..957d8ed8 100644 --- a/lib/go/thrift/json_protocol.go +++ b/lib/go/thrift/json_protocol.go @@ -100,7 +100,11 @@ func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) err if e := p.OutputObjectBegin(); e != nil { return e } - if e := p.WriteString(p.TypeIdToString(typeId)); e != nil { + s, e1 := p.TypeIdToString(typeId) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { return e } return nil @@ -116,10 +120,18 @@ func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int) if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteString(p.TypeIdToString(keyType)); e != nil { + s, e1 := p.TypeIdToString(keyType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { return e } - if e := p.WriteString(p.TypeIdToString(valueType)); e != nil { + s, e1 = p.TypeIdToString(valueType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { return e } return p.WriteI64(int64(size)) @@ -250,7 +262,10 @@ func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { return "", STOP, fieldId, err } sType, err := p.ReadString() - fType := p.StringToTypeId(sType) + if err != nil { + return "", STOP, fieldId, err + } + fType, err := p.StringToTypeId(sType) return "", fType, fieldId, err } @@ -265,14 +280,20 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int // read keyType sKeyType, e := p.ReadString() - keyType = p.StringToTypeId(sKeyType) + if e != nil { + return keyType, valueType, size, e + } + keyType, e = p.StringToTypeId(sKeyType) if e != nil { return keyType, valueType, size, e } // read valueType sValueType, e := p.ReadString() - valueType = p.StringToTypeId(sValueType) + if e != nil { + return keyType, valueType, size, e + } + valueType, e = p.StringToTypeId(sValueType) if e != nil { return keyType, valueType, size, e } @@ -436,7 +457,11 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } - if e := p.WriteString(p.TypeIdToString(elemType)); e != nil { + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.WriteString(s); e != nil { return e } if e := p.WriteI64(int64(size)); e != nil { @@ -445,13 +470,15 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error { return nil } - func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) { if isNull, e := p.ParseListBegin(); isNull || e != nil { return VOID, 0, e } sElemType, err := p.ReadString() - elemType = p.StringToTypeId(sElemType) + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) if err != nil { return elemType, size, err } @@ -465,7 +492,10 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error) return VOID, 0, e } sElemType, err := p.ReadString() - elemType = p.StringToTypeId(sElemType) + if err != nil { + return VOID, size, err + } + elemType, err = p.StringToTypeId(sElemType) if err != nil { return elemType, size, err } @@ -478,7 +508,11 @@ func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error { if e := p.OutputListBegin(); e != nil { return e } - if e := p.OutputString(p.TypeIdToString(elemType)); e != nil { + s, e1 := p.TypeIdToString(elemType) + if e1 != nil { + return e1 + } + if e := p.OutputString(s); e != nil { return e } if e := p.OutputI64(int64(size)); e != nil { @@ -487,70 +521,62 @@ func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error { return nil } -func (p *TJSONProtocol) TypeIdToString(fieldType TType) string { +func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) { switch byte(fieldType) { - case STOP: - return "stp" - case VOID: - return "v" case BOOL: - return "tf" + return "tf", nil case BYTE: - return "i8" - case DOUBLE: - return "dbl" + return "i8", nil case I16: - return "i16" + return "i16", nil case I32: - return "i32" + return "i32", nil case I64: - return "i64" + return "i64", nil + case DOUBLE: + return "dbl", nil case STRING: - return "str" + return "str", nil case STRUCT: - return "rec" + return "rec", nil case MAP: - return "map" + return "map", nil case SET: - return "set" + return "set", nil case LIST: - return "lst" - case UTF16: - return "str" + return "lst", nil } - return "" + + e := fmt.Errorf("Unknown fieldType: %d", int(fieldType)) + return "", NewTProtocolExceptionWithType(INVALID_DATA, e) } -func (p *TJSONProtocol) StringToTypeId(fieldType string) TType { +func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) { switch fieldType { - case "stp": - return TType(STOP) - case "v": - return TType(VOID) case "tf": - return TType(BOOL) + return TType(BOOL), nil case "i8": - return TType(BYTE) - case "dbl": - return TType(DOUBLE) - case "16": - return TType(I16) + return TType(BYTE), nil + case "i16": + return TType(I16), nil case "i32": - return TType(I32) + return TType(I32), nil case "i64": - return TType(I64) + return TType(I64), nil + case "dbl": + return TType(DOUBLE), nil case "str": - return TType(STRING) + return TType(STRING), nil case "rec": - return TType(STRUCT) + return TType(STRUCT), nil case "map": - return TType(MAP) + return TType(MAP), nil case "set": - return TType(SET) + return TType(SET), nil case "lst": - return TType(LIST) - case "u16": - return TType(UTF16) + return TType(LIST), nil } - return TType(STOP) + + e := fmt.Errorf("Unknown type identifier: %s", fieldType) + return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e) } diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go index 9d0f68fb..3755a2d9 100644 --- a/lib/go/thrift/simple_json_protocol.go +++ b/lib/go/thrift/simple_json_protocol.go @@ -322,9 +322,6 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) { if err := p.ParsePreValue(); err != nil { return "", STOP, 0, err } - if p.reader.Buffered() < 1 { - return "", STOP, 0, nil - } b, _ := p.reader.Peek(1) if len(b) > 0 { switch b[0] { @@ -482,11 +479,7 @@ func (p *TSimpleJSONProtocol) ReadString() (string, error) { return v, err } var b []byte - if p.reader.Buffered() >= len(JSON_NULL) { - b, _ = p.reader.Peek(len(JSON_NULL)) - } else { - b, _ = p.reader.Peek(1) - } + b, _ = p.reader.Peek(len(JSON_NULL)) if len(b) > 0 && b[0] == JSON_QUOTE { p.reader.ReadByte() value, err := p.ParseStringBody() @@ -732,9 +725,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error { return NewTProtocolException(e) } cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) - if p.reader.Buffered() < 1 { - return nil - } b, _ := p.reader.Peek(1) switch cxt { case _CONTEXT_IN_LIST: @@ -813,7 +803,7 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error { } func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error { - for p.reader.Buffered() > 0 { + for { b, _ := p.reader.Peek(1) if len(b) < 1 { return nil @@ -950,11 +940,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) { return false, err } var b []byte - if p.reader.Buffered() >= len(JSON_NULL) { - b, _ = p.reader.Peek(len(JSON_NULL)) - } else if p.reader.Buffered() >= 1 { - b, _ = p.reader.Peek(1) - } + b, _ = p.reader.Peek(len(JSON_NULL)) if len(b) > 0 && b[0] == JSON_LBRACE[0] { p.reader.ReadByte() p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST)) @@ -997,11 +983,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) { return false, e } var b []byte - if p.reader.Buffered() >= len(JSON_NULL) { - b, err = p.reader.Peek(len(JSON_NULL)) - } else { - b, err = p.reader.Peek(1) - } + b, err = p.reader.Peek(len(JSON_NULL)) if err != nil { return false, err } @@ -1134,7 +1116,7 @@ func (p *TSimpleJSONProtocol) readSingleValue() (interface{}, TType, error) { func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { cont := true - for p.reader.Buffered() > 0 && cont { + for cont { b, _ := p.reader.Peek(1) if len(b) < 1 { return false, nil @@ -1150,9 +1132,6 @@ func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { break } } - if p.reader.Buffered() == 0 { - return false, nil - } b, _ := p.reader.Peek(len(JSON_NULL)) if string(b) == string(JSON_NULL) { p.reader.Read(b[0:len(JSON_NULL)]) @@ -1162,9 +1141,6 @@ func (p *TSimpleJSONProtocol) readIfNull() (bool, error) { } func (p *TSimpleJSONProtocol) readQuoteIfNext() { - if p.reader.Buffered() < 1 { - return - } b, _ := p.reader.Peek(1) if len(b) > 0 && b[0] == JSON_QUOTE { p.reader.ReadByte() -- 2.17.1