From: Jens Geyer Date: Thu, 1 May 2014 23:30:13 +0000 (+0200) Subject: THRIFT-2502 Optimize go implementations of binary and compact protocols for speed X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=0997250744112ec0519d5f67cda92b2b87fb0063;p=common%2Fthrift.git THRIFT-2502 Optimize go implementations of binary and compact protocols for speed Client: Go Patch: Aleksey Pesternikov This closes #110 commit 7ece8e6f16f7ff46cda4b896215d595ac986d332 Author: Aleksey Pesternikov Date: 2014-04-26T17:45:12Z simplify buffered transport by reusing bufio commit 814b661d7e5c3c27ad4035a42925eae619447ee3 Author: Aleksey Pesternikov Date: 2014-04-26T18:05:12Z zero-initialize buffers in framed transport commit 0f576138e24fae8e7f8d210cfb480889a41d1d9a Author: Aleksey Pesternikov Date: 2014-04-26T19:19:39Z do not buffer the whole frame while reading in framed transport reuse frame header buffer commit 4db9b65458eb34e1b1676dba76d1e664c6339a57 Author: Aleksey Pesternikov Date: 2014-04-26T19:43:07Z enforce max frame size in framed transport commit 58ecc23ec1a2176f7dc5db7a658a51817dc626e6 Author: Aleksey Pesternikov Date: 2014-04-27T00:31:16Z microbenchmarks for serialization/deserialization (binary,compact)x(memoryBuffer,Stream,framedMemoryBuffer)x(bool,byte,i16,i32,i64,double,string,binary) commit 156116f484db513251e0e6c65942466ed5a8142c Author: Aleksey Pesternikov Date: 2014-04-27T00:32:09Z Merge branch 'go_microbench' into go_simplify_and_optimize commit 1c27c0913cf5a8c0352afff1dae9e9fc9f758409 Author: Aleksey Pesternikov Date: 2014-04-27T22:45:52Z do not allocate buffer in TBinaryProtocol.WriteByte commit 86addfb0585e04c648cde1b9cb1566d7976f8cda Author: Aleksey Pesternikov Date: 2014-04-27T23:46:12Z no extra alloc in double marshaling test commit 98ac62b0a80d4f27dce736b561005953cb915a90 Author: Aleksey Pesternikov Date: 2014-04-27T23:46:50Z Merge branch 'go_microbench' into go_simplify_and_optimize commit 76c26624578a5455cacd08bb0167444748aaa41d Author: Aleksey Pesternikov Date: 2014-04-28T12:48:41Z optimized ReadByte, WriteByte commit 5a79d81d326582dbbdcf523ebc0180390ac24497 Author: Aleksey Pesternikov Date: 2014-04-28T13:03:26Z optimized WriteString commit f6d4a9aa65434831cbd2993148fa12c12b2a342c Author: Aleksey Pesternikov Date: 2014-04-28T15:35:06Z compact protocol optimization commit 03bdb6b9f097a47ef54826483867c23d49374ac0 Author: Aleksey Pesternikov Date: 2014-04-28T16:08:52Z cache reader and writer to aviod interface conversions --- diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go index b57b528d..abbe0bc6 100644 --- a/lib/go/thrift/binary_protocol.go +++ b/lib/go/thrift/binary_protocol.go @@ -27,10 +27,13 @@ import ( ) type TBinaryProtocol struct { - trans TTransport - strictRead bool - strictWrite bool - buffer [8]byte + trans TRichTransport + origTransport TTransport + reader io.Reader + writer io.Writer + strictRead bool + strictWrite bool + buffer [64]byte } type TBinaryProtocolFactory struct { @@ -43,7 +46,15 @@ func NewTBinaryProtocolTransport(t TTransport) *TBinaryProtocol { } func NewTBinaryProtocol(t TTransport, strictRead, strictWrite bool) *TBinaryProtocol { - return &TBinaryProtocol{trans: t, strictRead: strictRead, strictWrite: strictWrite} + p := &TBinaryProtocol{origTransport: t, strictRead: strictRead, strictWrite: strictWrite} + if et, ok := t.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(t) + } + p.reader = p.trans + p.writer = p.trans + return p } func NewTBinaryProtocolFactoryDefault() *TBinaryProtocolFactory { @@ -171,29 +182,28 @@ func (p *TBinaryProtocol) WriteBool(value bool) error { } func (p *TBinaryProtocol) WriteByte(value byte) error { - v := []byte{value} - _, e := p.trans.Write(v) + e := p.trans.WriteByte(value) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI16(value int16) error { v := p.buffer[0:2] binary.BigEndian.PutUint16(v, uint16(value)) - _, e := p.trans.Write(v) + _, e := p.writer.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI32(value int32) error { v := p.buffer[0:4] binary.BigEndian.PutUint32(v, uint32(value)) - _, e := p.trans.Write(v) + _, e := p.writer.Write(v) return NewTProtocolException(e) } func (p *TBinaryProtocol) WriteI64(value int64) error { - v := p.buffer[:] + v := p.buffer[0:8] binary.BigEndian.PutUint64(v, uint64(value)) - _, err := p.trans.Write(v) + _, err := p.writer.Write(v) return NewTProtocolException(err) } @@ -202,7 +212,12 @@ func (p *TBinaryProtocol) WriteDouble(value float64) error { } func (p *TBinaryProtocol) WriteString(value string) error { - return p.WriteBinary([]byte(value)) + e := p.WriteI32(int32(len(value))) + if e != nil { + return e + } + _, err := p.trans.WriteString(value) + return NewTProtocolException(err) } func (p *TBinaryProtocol) WriteBinary(value []byte) error { @@ -210,7 +225,7 @@ func (p *TBinaryProtocol) WriteBinary(value []byte) error { if e != nil { return e } - _, err := p.trans.Write(value) + _, err := p.writer.Write(value) return NewTProtocolException(err) } @@ -362,9 +377,7 @@ func (p *TBinaryProtocol) ReadBool() (bool, error) { } func (p *TBinaryProtocol) ReadByte() (value byte, err error) { - buf := p.buffer[0:1] - err = p.readAll(buf) - return buf[0], err + return p.trans.ReadByte() } func (p *TBinaryProtocol) ReadI16() (value int16, err error) { @@ -423,11 +436,11 @@ func (p *TBinaryProtocol) Skip(fieldType TType) (err error) { } func (p *TBinaryProtocol) Transport() TTransport { - return p.trans + return p.origTransport } func (p *TBinaryProtocol) readAll(buf []byte) error { - _, err := io.ReadFull(p.trans, buf) + _, err := io.ReadFull(p.reader, buf) return NewTProtocolException(err) } @@ -435,8 +448,12 @@ func (p *TBinaryProtocol) readStringBody(size int) (value string, err error) { if size < 0 { return "", nil } - isize := int(size) - buf := make([]byte, isize) + var buf []byte + if size <= len(p.buffer) { + buf = p.buffer[0:size] + } else { + buf = make([]byte, size) + } _, e := io.ReadFull(p.trans, buf) return string(buf), NewTProtocolException(e) } diff --git a/lib/go/thrift/buffered_transport.go b/lib/go/thrift/buffered_transport.go index b92261c5..d258b700 100644 --- a/lib/go/thrift/buffered_transport.go +++ b/lib/go/thrift/buffered_transport.go @@ -19,19 +19,17 @@ package thrift +import ( + "bufio" +) + type TBufferedTransportFactory struct { size int } -type TBuffer struct { - buffer []byte - pos, limit int -} - type TBufferedTransport struct { - tp TTransport - rbuf *TBuffer - wbuf *TBuffer + bufio.ReadWriter + tp TTransport } func (p *TBufferedTransportFactory) GetTransport(trans TTransport) TTransport { @@ -43,9 +41,13 @@ func NewTBufferedTransportFactory(bufferSize int) *TBufferedTransportFactory { } func NewTBufferedTransport(trans TTransport, bufferSize int) *TBufferedTransport { - rb := &TBuffer{buffer: make([]byte, bufferSize)} - wb := &TBuffer{buffer: make([]byte, bufferSize), limit: bufferSize} - return &TBufferedTransport{tp: trans, rbuf: rb, wbuf: wb} + return &TBufferedTransport{ + ReadWriter: bufio.ReadWriter{ + Reader: bufio.NewReaderSize(trans, bufferSize), + Writer: bufio.NewWriterSize(trans, bufferSize), + }, + tp: trans, + } } func (p *TBufferedTransport) IsOpen() bool { @@ -60,56 +62,9 @@ func (p *TBufferedTransport) Close() (err error) { return p.tp.Close() } -func (p *TBufferedTransport) Read(buf []byte) (n int, err error) { - rbuf := p.rbuf - if rbuf.pos == rbuf.limit { // no more data to read from buffer - rbuf.pos = 0 - // read data, fill buffer - rbuf.limit, err = p.tp.Read(rbuf.buffer) - if err != nil { - return 0, err - } - } - n = copy(buf, rbuf.buffer[rbuf.pos:rbuf.limit]) - rbuf.pos += n - return n, nil -} - -func (p *TBufferedTransport) Write(buf []byte) (n int, err error) { - wbuf := p.wbuf - remaining := len(buf) - - for remaining > 0 { - if wbuf.pos+remaining > wbuf.limit { // buffer is full, flush buffer - if err := p.Flush(); err != nil { - return n, err - } - } - copied := copy(wbuf.buffer[wbuf.pos:], buf[n:]) - - wbuf.pos += copied - n += copied - remaining -= copied - } - - return n, nil -} - func (p *TBufferedTransport) Flush() error { - start := 0 - wbuf := p.wbuf - for start < wbuf.pos { - n, err := p.tp.Write(wbuf.buffer[start:wbuf.pos]) - if err != nil { - return err - } - start += n + if err := p.ReadWriter.Flush(); err != nil { + return err } - - wbuf.pos = 0 return p.tp.Flush() } - -func (p *TBufferedTransport) Peek() bool { - return p.rbuf.pos < p.rbuf.limit || p.tp.Peek() -} diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go index f89fc2f3..14bf62d2 100644 --- a/lib/go/thrift/compact_protocol.go +++ b/lib/go/thrift/compact_protocol.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "math" - "strings" ) const ( @@ -84,7 +83,8 @@ func (p *TCompactProtocolFactory) GetProtocol(trans TTransport) TProtocol { } type TCompactProtocol struct { - trans TTransport + trans TRichTransport + origTransport TTransport // Used to keep track of the last field for the current and previous structs, // so we can do the delta stuff. @@ -93,17 +93,28 @@ type TCompactProtocol struct { // If we encounter a boolean field begin, save the TField here so it can // have the value incorporated. - booleanField *field + booleanFieldName string + booleanFieldId int16 + booleanFieldPending bool // If we read a field header, and it's a boolean field, save the boolean // value here so that readBool can use it. boolValue bool boolValueIsNotNull bool + buffer [64]byte } // Create a TCompactProtocol given a TTransport func NewTCompactProtocol(trans TTransport) *TCompactProtocol { - return &TCompactProtocol{trans: trans, lastField: []int{}} + p := &TCompactProtocol{origTransport: trans, lastField: []int{}} + if et, ok := trans.(TRichTransport); ok { + p.trans = et + } else { + p.trans = NewTRichTransport(trans) + } + + return p + } // @@ -113,11 +124,11 @@ func NewTCompactProtocol(trans TTransport) *TCompactProtocol { // Write a message header to the wire. Compact Protocol messages contain the // protocol version so we can migrate forwards in the future if need be. func (p *TCompactProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error { - _, err := p.writeByteDirect(COMPACT_PROTOCOL_ID) + err := p.writeByteDirect(COMPACT_PROTOCOL_ID) if err != nil { return NewTProtocolException(err) } - _, err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK)) + err = p.writeByteDirect((COMPACT_VERSION & COMPACT_VERSION_MASK) | ((byte(typeId) << COMPACT_TYPE_SHIFT_AMOUNT) & COMPACT_TYPE_MASK)) if err != nil { return NewTProtocolException(err) } @@ -153,7 +164,7 @@ func (p *TCompactProtocol) WriteStructEnd() error { func (p *TCompactProtocol) WriteFieldBegin(name string, typeId TType, id int16) error { if typeId == BOOL { // we want to possibly include the value, so we'll wait. - p.booleanField = newField(name, typeId, int(id)) + p.booleanFieldName, p.booleanFieldId, p.booleanFieldPending = name, id, true return nil } _, err := p.writeFieldBeginInternal(name, typeId, id, 0xFF) @@ -178,20 +189,20 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id written := 0 if fieldId > p.lastFieldId && fieldId-p.lastFieldId <= 15 { // write them together - written, err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite) + err := p.writeByteDirect(byte((fieldId-p.lastFieldId)<<4) | typeToWrite) if err != nil { - return written, err + return 0, err } } else { // write them separate - n, err := p.writeByteDirect(typeToWrite) + err := p.writeByteDirect(typeToWrite) if err != nil { - return n, err + return 0, err } err = p.WriteI16(id) - written = n + 2 + written = 1 + 2 if err != nil { - return written, err + return 0, err } } @@ -203,20 +214,20 @@ func (p *TCompactProtocol) writeFieldBeginInternal(name string, typeId TType, id func (p *TCompactProtocol) WriteFieldEnd() error { return nil } func (p *TCompactProtocol) WriteFieldStop() error { - _, err := p.writeByteDirect(STOP) + err := p.writeByteDirect(STOP) return NewTProtocolException(err) } func (p *TCompactProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error { if size == 0 { - _, err := p.writeByteDirect(0) + err := p.writeByteDirect(0) return NewTProtocolException(err) } _, err := p.writeVarint32(int32(size)) if err != nil { return NewTProtocolException(err) } - _, err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType))) + err = p.writeByteDirect(byte(p.getCompactType(keyType))<<4 | byte(p.getCompactType(valueType))) return NewTProtocolException(err) } @@ -243,20 +254,20 @@ func (p *TCompactProtocol) WriteBool(value bool) error { if value { v = byte(COMPACT_BOOLEAN_TRUE) } - if p.booleanField != nil { + if p.booleanFieldPending { // we haven't written the field header yet - _, err := p.writeFieldBeginInternal(p.booleanField.Name(), p.booleanField.TypeId(), int16(p.booleanField.Id()), v) - p.booleanField = nil + _, err := p.writeFieldBeginInternal(p.booleanFieldName, BOOL, p.booleanFieldId, v) + p.booleanFieldPending = false return NewTProtocolException(err) } // we're not part of a field, so just write the value. - _, err := p.writeByteDirect(v) + err := p.writeByteDirect(v) return NewTProtocolException(err) } // Write a byte. Nothing to see here! func (p *TCompactProtocol) WriteByte(value byte) error { - _, err := p.writeByteDirect(value) + err := p.writeByteDirect(value) return NewTProtocolException(err) } @@ -280,7 +291,7 @@ func (p *TCompactProtocol) WriteI64(value int64) error { // Write a double to the wire as 8 bytes. func (p *TCompactProtocol) WriteDouble(value float64) error { - buf := make([]byte, 8) + buf := p.buffer[0:8] binary.LittleEndian.PutUint64(buf, math.Float64bits(value)) _, err := p.trans.Write(buf) return NewTProtocolException(err) @@ -288,9 +299,14 @@ func (p *TCompactProtocol) WriteDouble(value float64) error { // Write a string to the wire with a varint size preceeding. func (p *TCompactProtocol) WriteString(value string) error { - buf := make([]byte, len(value)) - strings.NewReader(value).Read(buf) - return p.WriteBinary(buf) + _, e := p.writeVarint32(int32(len(value))) + if e != nil { + return NewTProtocolException(e) + } + if len(value) > 0 { + } + _, e = p.trans.WriteString(value) + return e } // Write a byte array, using a varint for the size. @@ -365,7 +381,7 @@ func (p *TCompactProtocol) ReadFieldBegin() (name string, typeId TType, id int16 // if it's a stop, then we can return immediately, as the struct is over. if (t & 0x0f) == STOP { - return "", STOP, 0,nil + return "", STOP, 0, nil } // mask off the 4 MSB of the type header. it could contain a field id delta. @@ -476,12 +492,11 @@ func (p *TCompactProtocol) ReadBool() (value bool, err error) { // Read a single byte off the wire. Nothing interesting here. func (p *TCompactProtocol) ReadByte() (value byte, err error) { - buf := []byte{0} - _, e := io.ReadFull(p.trans, buf) - if e != nil { - return 0, NewTProtocolException(e) + value, err = p.trans.ReadByte() + if err != nil { + return 0, NewTProtocolException(err) } - return buf[0], nil + return } // Read an i16 from the wire as a zigzag varint. @@ -512,7 +527,7 @@ func (p *TCompactProtocol) ReadI64() (value int64, err error) { // No magic here - just read a double off the wire. func (p *TCompactProtocol) ReadDouble() (value float64, err error) { - longBits := make([]byte, 8) + longBits := p.buffer[0:8] _, e := io.ReadFull(p.trans, longBits) if e != nil { return 0.0, NewTProtocolException(e) @@ -522,18 +537,31 @@ func (p *TCompactProtocol) ReadDouble() (value float64, err error) { // Reads a []byte (via readBinary), and then UTF-8 decodes it. func (p *TCompactProtocol) ReadString() (value string, err error) { - v, e := p.ReadBinary() - return string(v), NewTProtocolException(e) + length, e := p.readVarint32() + if e != nil { + return "", NewTProtocolException(e) + } + if length == 0 { + return "", nil + } + var buf []byte + if length <= int32(len(p.buffer)) { + buf = p.buffer[0:length] + } else { + buf = make([]byte, length) + } + _, e = io.ReadFull(p.trans, buf) + return string(buf), NewTProtocolException(e) } // Read a []byte from the wire. func (p *TCompactProtocol) ReadBinary() (value []byte, err error) { length, e := p.readVarint32() if e != nil { - return []byte{}, NewTProtocolException(e) + return nil, NewTProtocolException(e) } if length == 0 { - return []byte{}, nil + return nil, nil //nil == empty slice } buf := make([]byte, length) @@ -550,7 +578,7 @@ func (p *TCompactProtocol) Skip(fieldType TType) (err error) { } func (p *TCompactProtocol) Transport() TTransport { - return p.trans + return p.origTransport } // @@ -561,20 +589,20 @@ func (p *TCompactProtocol) Transport() TTransport { // the wire differ only by the type indicator. func (p *TCompactProtocol) writeCollectionBegin(elemType TType, size int) (int, error) { if size <= 14 { - return p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType)))) + return 1, p.writeByteDirect(byte(int32(size<<4) | int32(p.getCompactType(elemType)))) } - n, err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType))) + err := p.writeByteDirect(0xf0 | byte(p.getCompactType(elemType))) if err != nil { - return n, err + return 0, err } m, err := p.writeVarint32(int32(size)) - return n + m, err + return 1 + m, err } // Write an i32 as a varint. Results in 1-5 bytes on the wire. // TODO(pomack): make a permanent buffer like writeVarint64? func (p *TCompactProtocol) writeVarint32(n int32) (int, error) { - i32buf := make([]byte, 5) + i32buf := p.buffer[0:5] idx := 0 for { if (n & ^0x7F) == 0 { @@ -596,7 +624,7 @@ func (p *TCompactProtocol) writeVarint32(n int32) (int, error) { // Write an i64 as a varint. Results in 1-10 bytes on the wire. func (p *TCompactProtocol) writeVarint64(n int64) (int, error) { - varint64out := make([]byte, 10) + varint64out := p.buffer[0:10] idx := 0 for { if (n & ^0x7F) == 0 { @@ -635,13 +663,13 @@ func (p *TCompactProtocol) fixedInt64ToBytes(n int64, buf []byte) { // Writes a byte without any possiblity of all that field header nonsense. // Used internally by other writing methods that know they need to write a byte. -func (p *TCompactProtocol) writeByteDirect(b byte) (int, error) { - return p.trans.Write([]byte{b}) +func (p *TCompactProtocol) writeByteDirect(b byte) error { + return p.trans.WriteByte(b) } // Writes a byte without any possiblity of all that field header nonsense. func (p *TCompactProtocol) writeIntAsByteDirect(n int) (int, error) { - return p.writeByteDirect(byte(n)) + return 1, p.writeByteDirect(byte(n)) } // diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go index d1af0287..bfecbe83 100644 --- a/lib/go/thrift/framed_transport.go +++ b/lib/go/thrift/framed_transport.go @@ -20,33 +20,43 @@ package thrift import ( + "bufio" "bytes" "encoding/binary" + "fmt" "io" ) +const DEFAULT_MAX_LENGTH = 16384000 + type TFramedTransport struct { - transport TTransport - writeBuffer *bytes.Buffer - readBuffer *bytes.Buffer + transport TTransport + buf bytes.Buffer + reader *bufio.Reader + frameSize int //Current remaining size of the frame. if ==0 read next frame header + buffer [4]byte + maxLength int } type tFramedTransportFactory struct { - factory TTransportFactory + factory TTransportFactory + maxLength int } func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { - return &tFramedTransportFactory{factory: factory} + return &tFramedTransportFactory{factory: factory, maxLength: DEFAULT_MAX_LENGTH} } func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport { - return NewTFramedTransport(p.factory.GetTransport(base)) + return NewTFramedTransportMaxLength(p.factory.GetTransport(base), p.maxLength) } func NewTFramedTransport(transport TTransport) *TFramedTransport { - writeBuf := make([]byte, 0, 1024) - readBuf := make([]byte, 0, 1024) - return &TFramedTransport{transport: transport, writeBuffer: bytes.NewBuffer(writeBuf), readBuffer: bytes.NewBuffer(readBuf)} + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: DEFAULT_MAX_LENGTH} +} + +func NewTFramedTransportMaxLength(transport TTransport, maxLength int) *TFramedTransport { + return &TFramedTransport{transport: transport, reader: bufio.NewReader(transport), maxLength: maxLength} } func (p *TFramedTransport) Open() error { @@ -57,44 +67,69 @@ func (p *TFramedTransport) IsOpen() bool { return p.transport.IsOpen() } -func (p *TFramedTransport) Peek() bool { - return p.transport.Peek() -} - func (p *TFramedTransport) Close() error { return p.transport.Close() } -func (p *TFramedTransport) Read(buf []byte) (int, error) { - if p.readBuffer.Len() > 0 { - got, err := p.readBuffer.Read(buf) - if got > 0 { - return got, NewTTransportExceptionFromError(err) +func (p *TFramedTransport) Read(buf []byte) (l int, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return } } - - // Read another frame of data - p.readFrame() - - got, err := p.readBuffer.Read(buf) + if p.frameSize < len(buf) { + return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enought frame size %d to read %d bytes", p.frameSize, len(buf))) + } + got, err := p.reader.Read(buf) + p.frameSize = p.frameSize - got + //sanity check + if p.frameSize < 0 { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Negative frame size") + } return got, NewTTransportExceptionFromError(err) } +func (p *TFramedTransport) ReadByte() (c byte, err error) { + if p.frameSize == 0 { + p.frameSize, err = p.readFrameHeader() + if err != nil { + return + } + } + if p.frameSize < 1 { + return 0, NewTTransportExceptionFromError(fmt.Errorf("Not enought frame size %d to read %d bytes", p.frameSize, 1)) + } + c, err = p.reader.ReadByte() + if err == nil { + p.frameSize-- + } + return +} + func (p *TFramedTransport) Write(buf []byte) (int, error) { - n, err := p.writeBuffer.Write(buf) + n, err := p.buf.Write(buf) return n, NewTTransportExceptionFromError(err) } +func (p *TFramedTransport) WriteByte(c byte) error { + return p.buf.WriteByte(c) +} + +func (p *TFramedTransport) WriteString(s string) (n int, err error) { + return p.buf.WriteString(s) +} + func (p *TFramedTransport) Flush() error { - size := p.writeBuffer.Len() - buf := []byte{0, 0, 0, 0} + size := p.buf.Len() + buf := p.buffer[:4] binary.BigEndian.PutUint32(buf, uint32(size)) _, err := p.transport.Write(buf) if err != nil { return NewTTransportExceptionFromError(err) } if size > 0 { - if n, err := p.writeBuffer.WriteTo(p.transport); err != nil { + if n, err := p.buf.WriteTo(p.transport); err != nil { print("Error while flushing write buffer of size ", size, " to transport, only wrote ", n, " bytes: ", err.Error(), "\n") return NewTTransportExceptionFromError(err) } @@ -103,22 +138,14 @@ func (p *TFramedTransport) Flush() error { return NewTTransportExceptionFromError(err) } -func (p *TFramedTransport) readFrame() (int, error) { - buf := []byte{0, 0, 0, 0} - if _, err := io.ReadFull(p.transport, buf); err != nil { +func (p *TFramedTransport) readFrameHeader() (int, error) { + buf := p.buffer[:4] + if _, err := io.ReadFull(p.reader, buf); err != nil { return 0, err } size := int(binary.BigEndian.Uint32(buf)) - if size < 0 { - return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Read a negative frame size ("+string(size)+")") - } - if size == 0 { - return 0, nil - } - buf2 := make([]byte, size) - if n, err := io.ReadFull(p.transport, buf2); err != nil { - return n, err + if size < 0 || size > p.maxLength { + return 0, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size)) } - p.readBuffer = bytes.NewBuffer(buf2) return size, nil } diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index 9f609928..cff5ea54 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -153,11 +153,23 @@ func (p *THttpClient) Read(buf []byte) (int, error) { return n, NewTTransportExceptionFromError(err) } +func (p *THttpClient) ReadByte() (c byte, err error) { + return readByte(p.response.Body) +} + func (p *THttpClient) Write(buf []byte) (int, error) { n, err := p.requestBuffer.Write(buf) return n, err } +func (p *THttpClient) WriteByte(c byte) error { + return p.requestBuffer.WriteByte(c) +} + +func (p *THttpClient) WriteString(s string) (n int, err error) { + return p.requestBuffer.WriteString(s) +} + func (p *THttpClient) Flush() error { client := &http.Client{} req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) diff --git a/lib/go/thrift/iostream_transport.go b/lib/go/thrift/iostream_transport.go index 64b2958f..17fc969f 100644 --- a/lib/go/thrift/iostream_transport.go +++ b/lib/go/thrift/iostream_transport.go @@ -26,8 +26,8 @@ import ( // StreamTransport is a Transport made of an io.Reader and/or an io.Writer type StreamTransport struct { - Reader io.Reader - Writer io.Writer + io.Reader + io.Writer isReadWriter bool } @@ -103,9 +103,9 @@ func (p *StreamTransport) Open() error { return nil } -func (p *StreamTransport) Peek() bool { - return p.IsOpen() -} +// func (p *StreamTransport) Peek() bool { +// return p.IsOpen() +// } // Closes both the input and output streams. func (p *StreamTransport) Close() error { @@ -134,24 +134,6 @@ func (p *StreamTransport) Close() error { return nil } -// Reads from the underlying input stream if not null. -func (p *StreamTransport) Read(buf []byte) (int, error) { - if p.Reader == nil { - return 0, NewTTransportException(NOT_OPEN, "Cannot read from null inputStream") - } - n, err := p.Reader.Read(buf) - return n, NewTTransportExceptionFromError(err) -} - -// Writes to the underlying output stream if not null. -func (p *StreamTransport) Write(buf []byte) (int, error) { - if p.Writer == nil { - return 0, NewTTransportException(NOT_OPEN, "Cannot write to null outputStream") - } - n, err := p.Writer.Write(buf) - return n, NewTTransportExceptionFromError(err) -} - // Flushes the underlying output stream if not null. func (p *StreamTransport) Flush() error { if p.Writer == nil { @@ -166,3 +148,27 @@ func (p *StreamTransport) Flush() error { } return nil } + +func (p *StreamTransport) ReadByte() (c byte, err error) { + f, ok := p.Reader.(io.ByteReader) + if ok { + return f.ReadByte() + } + return readByte(p.Reader) +} + +func (p *StreamTransport) WriteByte(c byte) error { + f, ok := p.Writer.(io.ByteWriter) + if ok { + return f.WriteByte(c) + } + return writeByte(p.Writer, c) +} + +func (p *StreamTransport) WriteString(s string) (n int, err error) { + f, ok := p.Writer.(stringWriter) + if ok { + return f.WriteString(s) + } + return p.Writer.Write([]byte(s)) +} diff --git a/lib/go/thrift/lowlevel_benchmarks_test.go b/lib/go/thrift/lowlevel_benchmarks_test.go new file mode 100644 index 00000000..a5094ae9 --- /dev/null +++ b/lib/go/thrift/lowlevel_benchmarks_test.go @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "testing" +) + +var binaryProtoF = NewTBinaryProtocolFactoryDefault() +var compactProtoF = NewTCompactProtocolFactory() + +var buf = bytes.NewBuffer(make([]byte, 0, 1024)) + +var tfv = []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), +} + +func BenchmarkBinaryBool_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkBinaryBool_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkBinaryByte_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkBinaryI16_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkBinaryI32_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkBinaryI64_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkBinaryDouble_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkBinaryString_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkBinaryBinary_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := binaryProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary0(b *testing.B) { + trans := tfv[0].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary1(b *testing.B) { + trans := tfv[1].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} + +func BenchmarkCompactBool_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBool(b, p, trans) + } +} + +func BenchmarkCompactByte_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteByte(b, p, trans) + } +} + +func BenchmarkCompactI16_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI16(b, p, trans) + } +} + +func BenchmarkCompactI32_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI32(b, p, trans) + } +} +func BenchmarkCompactI64_2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteI64(b, p, trans) + } +} +func BenchmarkCompactDouble2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteDouble(b, p, trans) + } +} +func BenchmarkCompactString2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteString(b, p, trans) + } +} +func BenchmarkCompactBinary2(b *testing.B) { + trans := tfv[2].GetTransport(nil) + p := compactProtoF.GetProtocol(trans) + for i := 0; i < b.N; i++ { + ReadWriteBinary(b, p, trans) + } +} diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go index d88afedc..67048fe8 100644 --- a/lib/go/thrift/protocol_test.go +++ b/lib/go/thrift/protocol_test.go @@ -183,7 +183,7 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { } -func ReadWriteBool(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BOOL) thelen := len(BOOL_VALUES) err := p.WriteListBegin(thetype, thelen) @@ -229,7 +229,7 @@ func ReadWriteBool(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteByte(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteByte(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(BYTE) thelen := len(BYTE_VALUES) err := p.WriteListBegin(thetype, thelen) @@ -278,7 +278,7 @@ func ReadWriteByte(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteI16(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteI16(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I16) thelen := len(INT16_VALUES) p.WriteListBegin(thetype, thelen) @@ -315,7 +315,7 @@ func ReadWriteI16(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteI32(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteI32(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I32) thelen := len(INT32_VALUES) p.WriteListBegin(thetype, thelen) @@ -351,7 +351,7 @@ func ReadWriteI32(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteI64(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteI64(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(I64) thelen := len(INT64_VALUES) p.WriteListBegin(thetype, thelen) @@ -387,7 +387,7 @@ func ReadWriteI64(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteDouble(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteDouble(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(DOUBLE) thelen := len(DOUBLE_VALUES) p.WriteListBegin(thetype, thelen) @@ -396,13 +396,9 @@ func ReadWriteDouble(t *testing.T, p TProtocol, trans TTransport) { } p.WriteListEnd() p.Flush() - wrotebuffer := "" - if memtrans, ok := trans.(*TMemoryBuffer); ok { - wrotebuffer = memtrans.String() - } thetype2, thelen2, err := p.ReadListBegin() if err != nil { - t.Errorf("%s: %T %T %q Error reading list: %q, wrote: %v", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES, wrotebuffer) + t.Errorf("%s: %T %T %q Error reading list: %q", "ReadWriteDouble", p, trans, err, DOUBLE_VALUES) } if thetype != thetype2 { t.Errorf("%s: %T %T type %s != type %s", "ReadWriteDouble", p, trans, thetype, thetype2) @@ -429,7 +425,7 @@ func ReadWriteDouble(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteString(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteString(t testing.TB, p TProtocol, trans TTransport) { thetype := TType(STRING) thelen := len(STRING_VALUES) p.WriteListBegin(thetype, thelen) @@ -465,7 +461,7 @@ func ReadWriteString(t *testing.T, p TProtocol, trans TTransport) { } } -func ReadWriteBinary(t *testing.T, p TProtocol, trans TTransport) { +func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) { v := protocol_bdata p.WriteBinary(v) p.Flush() diff --git a/lib/go/thrift/rich_transport.go b/lib/go/thrift/rich_transport.go new file mode 100644 index 00000000..c409ae05 --- /dev/null +++ b/lib/go/thrift/rich_transport.go @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "io" +) + +type RichTransport struct { + TTransport +} + +// Wraps Transport to provide TRichTransport interface +func NewTRichTransport(trans TTransport) *RichTransport { + return &RichTransport{trans} +} + +func (r *RichTransport) ReadByte() (c byte, err error) { + return readByte(r.TTransport) +} + +func (r *RichTransport) WriteByte(c byte) error { + return writeByte(r.TTransport, c) +} + +func (r *RichTransport) WriteString(s string) (n int, err error) { + return r.Write([]byte(s)) +} + +func readByte(r io.Reader) (c byte, err error) { + v := [1]byte{0} + if _, err := r.Read(v[0:1]); err != nil { + return 0, err + } + return v[0], nil +} + +func writeByte(w io.Writer, c byte) error { + v := [1]byte{c} + _, err := w.Write(v[0:1]) + return err +} diff --git a/lib/go/thrift/rich_transport_test.go b/lib/go/thrift/rich_transport_test.go new file mode 100644 index 00000000..32411672 --- /dev/null +++ b/lib/go/thrift/rich_transport_test.go @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package thrift + +import ( + "bytes" + "reflect" + "testing" +) + +func TestEnsureTransportsAreRich(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1024)) + + transports := []TTransportFactory{ + NewTMemoryBufferTransportFactory(1024), + NewStreamTransportFactory(buf, buf, true), + NewTFramedTransportFactory(NewTMemoryBufferTransportFactory(1024)), + NewTHttpPostClientTransportFactory("http://127.0.0.1"), + } + for _, tf := range transports { + trans := tf.GetTransport(nil) + _, ok := trans.(TRichTransport) + if !ok { + t.Errorf("Transport %s does not implement TRichTransport interface", reflect.ValueOf(trans)) + } + } +} diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 521394cf..ffbfb766 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -167,7 +167,7 @@ func (p *TSimpleServer) processRequest(client TTransport) error { } else if err != nil { return err } - if !ok || !inputProtocol.Transport().Peek() { + if !ok { break } } diff --git a/lib/go/thrift/transport.go b/lib/go/thrift/transport.go index 44823dd5..8c0622db 100644 --- a/lib/go/thrift/transport.go +++ b/lib/go/thrift/transport.go @@ -40,7 +40,20 @@ type TTransport interface { // Returns true if the transport is open IsOpen() bool +} + +type stringWriter interface { + WriteString(s string) (n int, err error) +} - // Returns true if there is more data to be read or the remote side is still open - Peek() bool +// This is "enchanced" transport with extra capabilities. You need to use one of these +// to construct protocol. +// Notably, TSocket does not implement this interface, and it is always a mistake to use +// TSocket directly in protocol. +type TRichTransport interface { + io.ReadWriter + io.ByteReader + io.ByteWriter + stringWriter + Flusher }