THRIFT-2174 Deserializing JSON fails in specific cases
authorJens Geyer <jensg@apache.org>
Tue, 10 Sep 2013 19:30:41 +0000 (21:30 +0200)
committerJens Geyer <jensg@apache.org>
Tue, 10 Sep 2013 19:34:42 +0000 (21:34 +0200)
Patch: Jens Geyer

lib/go/thrift/json_protocol.go
lib/go/thrift/simple_json_protocol.go

index 5e8453a..957d8ed 100644 (file)
@@ -100,7 +100,11 @@ func (p *TJSONProtocol) WriteFieldBegin(name string, typeId TType, id int16) err
        if e := p.OutputObjectBegin(); e != nil {
                return e
        }
-       if e := p.WriteString(p.TypeIdToString(typeId)); e != nil {
+       s, e1 := p.TypeIdToString(typeId)
+       if e1 != nil {
+               return e1
+       }
+       if e := p.WriteString(s); e != nil {
                return e
        }
        return nil
@@ -116,10 +120,18 @@ func (p *TJSONProtocol) WriteMapBegin(keyType TType, valueType TType, size int)
        if e := p.OutputListBegin(); e != nil {
                return e
        }
-       if e := p.WriteString(p.TypeIdToString(keyType)); e != nil {
+       s, e1 := p.TypeIdToString(keyType)
+       if e1 != nil {
+               return e1
+       }
+       if e := p.WriteString(s); e != nil {
                return e
        }
-       if e := p.WriteString(p.TypeIdToString(valueType)); e != nil {
+       s, e1 = p.TypeIdToString(valueType)
+       if e1 != nil {
+               return e1
+       }
+       if e := p.WriteString(s); e != nil {
                return e
        }
        return p.WriteI64(int64(size))
@@ -250,7 +262,10 @@ func (p *TJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
                return "", STOP, fieldId, err
        }
        sType, err := p.ReadString()
-       fType := p.StringToTypeId(sType)
+       if err != nil {
+               return "", STOP, fieldId, err
+       }
+       fType, err := p.StringToTypeId(sType)
        return "", fType, fieldId, err
 }
 
@@ -265,14 +280,20 @@ func (p *TJSONProtocol) ReadMapBegin() (keyType TType, valueType TType, size int
 
        // read keyType
        sKeyType, e := p.ReadString()
-       keyType = p.StringToTypeId(sKeyType)
+       if e != nil {
+               return keyType, valueType, size, e
+       }
+       keyType, e = p.StringToTypeId(sKeyType)
        if e != nil {
                return keyType, valueType, size, e
        }
 
        // read valueType
        sValueType, e := p.ReadString()
-       valueType = p.StringToTypeId(sValueType)
+       if e != nil {
+               return keyType, valueType, size, e
+       }
+       valueType, e = p.StringToTypeId(sValueType)
        if e != nil {
                return keyType, valueType, size, e
        }
@@ -436,7 +457,11 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
        if e := p.OutputListBegin(); e != nil {
                return e
        }
-       if e := p.WriteString(p.TypeIdToString(elemType)); e != nil {
+       s, e1 := p.TypeIdToString(elemType)
+       if e1 != nil {
+               return e1
+       }
+       if e := p.WriteString(s); e != nil {
                return e
        }
        if e := p.WriteI64(int64(size)); e != nil {
@@ -445,13 +470,15 @@ func (p *TJSONProtocol) OutputElemListBegin(elemType TType, size int) error {
        return nil
 }
 
-
 func (p *TJSONProtocol) ParseElemListBegin() (elemType TType, size int, e error) {
        if isNull, e := p.ParseListBegin(); isNull || e != nil {
                return VOID, 0, e
        }
        sElemType, err := p.ReadString()
-       elemType = p.StringToTypeId(sElemType)
+       if err != nil {
+               return VOID, size, err
+       }
+       elemType, err = p.StringToTypeId(sElemType)
        if err != nil {
                return elemType, size, err
        }
@@ -465,7 +492,10 @@ func (p *TJSONProtocol) readElemListBegin() (elemType TType, size int, e error)
                return VOID, 0, e
        }
        sElemType, err := p.ReadString()
-       elemType = p.StringToTypeId(sElemType)
+       if err != nil {
+               return VOID, size, err
+       }
+       elemType, err = p.StringToTypeId(sElemType)
        if err != nil {
                return elemType, size, err
        }
@@ -478,7 +508,11 @@ func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error {
        if e := p.OutputListBegin(); e != nil {
                return e
        }
-       if e := p.OutputString(p.TypeIdToString(elemType)); e != nil {
+       s, e1 := p.TypeIdToString(elemType)
+       if e1 != nil {
+               return e1
+       }
+       if e := p.OutputString(s); e != nil {
                return e
        }
        if e := p.OutputI64(int64(size)); e != nil {
@@ -487,70 +521,62 @@ func (p *TJSONProtocol) writeElemListBegin(elemType TType, size int) error {
        return nil
 }
 
-func (p *TJSONProtocol) TypeIdToString(fieldType TType) string {
+func (p *TJSONProtocol) TypeIdToString(fieldType TType) (string, error) {
        switch byte(fieldType) {
-       case STOP:
-               return "stp"
-       case VOID:
-               return "v"
        case BOOL:
-               return "tf"
+               return "tf", nil
        case BYTE:
-               return "i8"
-       case DOUBLE:
-               return "dbl"
+               return "i8", nil
        case I16:
-               return "i16"
+               return "i16", nil
        case I32:
-               return "i32"
+               return "i32", nil
        case I64:
-               return "i64"
+               return "i64", nil
+       case DOUBLE:
+               return "dbl", nil
        case STRING:
-               return "str"
+               return "str", nil
        case STRUCT:
-               return "rec"
+               return "rec", nil
        case MAP:
-               return "map"
+               return "map", nil
        case SET:
-               return "set"
+               return "set", nil
        case LIST:
-               return "lst"
-       case UTF16:
-               return "str"
+               return "lst", nil
        }
-       return ""
+
+       e := fmt.Errorf("Unknown fieldType: %d", int(fieldType))
+       return "", NewTProtocolExceptionWithType(INVALID_DATA, e)
 }
 
-func (p *TJSONProtocol) StringToTypeId(fieldType string) TType {
+func (p *TJSONProtocol) StringToTypeId(fieldType string) (TType, error) {
        switch fieldType {
-       case "stp":
-               return TType(STOP)
-       case "v":
-               return TType(VOID)
        case "tf":
-               return TType(BOOL)
+               return TType(BOOL), nil
        case "i8":
-               return TType(BYTE)
-       case "dbl":
-               return TType(DOUBLE)
-       case "16":
-               return TType(I16)
+               return TType(BYTE), nil
+       case "i16":
+               return TType(I16), nil
        case "i32":
-               return TType(I32)
+               return TType(I32), nil
        case "i64":
-               return TType(I64)
+               return TType(I64), nil
+       case "dbl":
+               return TType(DOUBLE), nil
        case "str":
-               return TType(STRING)
+               return TType(STRING), nil
        case "rec":
-               return TType(STRUCT)
+               return TType(STRUCT), nil
        case "map":
-               return TType(MAP)
+               return TType(MAP), nil
        case "set":
-               return TType(SET)
+               return TType(SET), nil
        case "lst":
-               return TType(LIST)
-       case "u16":
-               return TType(UTF16)
+               return TType(LIST), nil
        }
-       return TType(STOP)
+
+       e := fmt.Errorf("Unknown type identifier: %s", fieldType)
+       return TType(STOP), NewTProtocolExceptionWithType(INVALID_DATA, e)
 }
index 9d0f68f..3755a2d 100644 (file)
@@ -322,9 +322,6 @@ func (p *TSimpleJSONProtocol) ReadFieldBegin() (string, TType, int16, error) {
        if err := p.ParsePreValue(); err != nil {
                return "", STOP, 0, err
        }
-       if p.reader.Buffered() < 1 {
-               return "", STOP, 0, nil
-       }
        b, _ := p.reader.Peek(1)
        if len(b) > 0 {
                switch b[0] {
@@ -482,11 +479,7 @@ func (p *TSimpleJSONProtocol) ReadString() (string, error) {
                return v, err
        }
        var b []byte
-       if p.reader.Buffered() >= len(JSON_NULL) {
-               b, _ = p.reader.Peek(len(JSON_NULL))
-       } else {
-               b, _ = p.reader.Peek(1)
-       }
+       b, _ = p.reader.Peek(len(JSON_NULL))
        if len(b) > 0 && b[0] == JSON_QUOTE {
                p.reader.ReadByte()
                value, err := p.ParseStringBody()
@@ -732,9 +725,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
                return NewTProtocolException(e)
        }
        cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
-       if p.reader.Buffered() < 1 {
-               return nil
-       }
        b, _ := p.reader.Peek(1)
        switch cxt {
        case _CONTEXT_IN_LIST:
@@ -813,7 +803,7 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error {
 }
 
 func (p *TSimpleJSONProtocol) readNonSignificantWhitespace() error {
-       for p.reader.Buffered() > 0 {
+       for {
                b, _ := p.reader.Peek(1)
                if len(b) < 1 {
                        return nil
@@ -950,11 +940,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) {
                return false, err
        }
        var b []byte
-       if p.reader.Buffered() >= len(JSON_NULL) {
-               b, _ = p.reader.Peek(len(JSON_NULL))
-       } else if p.reader.Buffered() >= 1 {
-               b, _ = p.reader.Peek(1)
-       }
+       b, _ = p.reader.Peek(len(JSON_NULL))
        if len(b) > 0 && b[0] == JSON_LBRACE[0] {
                p.reader.ReadByte()
                p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
@@ -997,11 +983,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) {
                return false, e
        }
        var b []byte
-       if p.reader.Buffered() >= len(JSON_NULL) {
-               b, err = p.reader.Peek(len(JSON_NULL))
-       } else {
-               b, err = p.reader.Peek(1)
-       }
+       b, err = p.reader.Peek(len(JSON_NULL))
        if err != nil {
                return false, err
        }
@@ -1134,7 +1116,7 @@ func (p *TSimpleJSONProtocol) readSingleValue() (interface{}, TType, error) {
 
 func (p *TSimpleJSONProtocol) readIfNull() (bool, error) {
        cont := true
-       for p.reader.Buffered() > 0 && cont {
+       for cont {
                b, _ := p.reader.Peek(1)
                if len(b) < 1 {
                        return false, nil
@@ -1150,9 +1132,6 @@ func (p *TSimpleJSONProtocol) readIfNull() (bool, error) {
                        break
                }
        }
-       if p.reader.Buffered() == 0 {
-               return false, nil
-       }
        b, _ := p.reader.Peek(len(JSON_NULL))
        if string(b) == string(JSON_NULL) {
                p.reader.Read(b[0:len(JSON_NULL)])
@@ -1162,9 +1141,6 @@ func (p *TSimpleJSONProtocol) readIfNull() (bool, error) {
 }
 
 func (p *TSimpleJSONProtocol) readQuoteIfNext() {
-       if p.reader.Buffered() < 1 {
-               return
-       }
        b, _ := p.reader.Peek(1)
        if len(b) > 0 && b[0] == JSON_QUOTE {
                p.reader.ReadByte()