From e71115be6caa2f3afd6fa092a09fd41c2c355691 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Wed, 6 Oct 2010 17:09:56 +0000 Subject: [PATCH] THRIFT-922. cpp: Templatize binary and compact protocol Convert TBinaryProtocol and TCompactProtocol to template classes, taking the transport class as a template parameter. This allows them to make non-virtual calls when using the template, improving serialization performance. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@1005136 13f79535-47bb-0310-9956-ffa450edef68 --- lib/cpp/Makefile.am | 4 +- lib/cpp/src/protocol/TBinaryProtocol.cpp | 403 ---------------- lib/cpp/src/protocol/TBinaryProtocol.h | 166 ++++--- lib/cpp/src/protocol/TBinaryProtocol.tcc | 451 ++++++++++++++++++ lib/cpp/src/protocol/TCompactProtocol.h | 67 +-- ...mpactProtocol.cpp => TCompactProtocol.tcc} | 282 +++++++---- lib/cpp/src/protocol/TDebugProtocol.h | 3 + lib/cpp/src/protocol/TJSONProtocol.cpp | 1 + lib/cpp/src/protocol/TJSONProtocol.h | 1 + lib/cpp/src/protocol/TProtocol.h | 2 - lib/cpp/test/Benchmark.cpp | 10 +- lib/cpp/test/Makefile.am | 9 + lib/cpp/test/SpecializationTest.cpp | 108 +++++ test/cpp/src/TestClient.cpp | 5 +- 14 files changed, 889 insertions(+), 623 deletions(-) delete mode 100644 lib/cpp/src/protocol/TBinaryProtocol.cpp create mode 100644 lib/cpp/src/protocol/TBinaryProtocol.tcc rename lib/cpp/src/protocol/{TCompactProtocol.cpp => TCompactProtocol.tcc} (67%) create mode 100644 lib/cpp/test/SpecializationTest.cpp diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am index 344e3301..c15cbe0f 100644 --- a/lib/cpp/Makefile.am +++ b/lib/cpp/Makefile.am @@ -48,8 +48,6 @@ libthrift_la_SOURCES = src/Thrift.cpp \ src/concurrency/ThreadManager.cpp \ src/concurrency/TimerManager.cpp \ src/concurrency/Util.cpp \ - src/protocol/TBinaryProtocol.cpp \ - src/protocol/TCompactProtocol.cpp \ src/protocol/TDebugProtocol.cpp \ src/protocol/TDenseProtocol.cpp \ src/protocol/TJSONProtocol.cpp \ @@ -108,7 +106,9 @@ include_concurrency_HEADERS = \ include_protocoldir = $(include_thriftdir)/protocol include_protocol_HEADERS = \ src/protocol/TBinaryProtocol.h \ + src/protocol/TBinaryProtocol.tcc \ src/protocol/TCompactProtocol.h \ + src/protocol/TCompactProtocol.tcc \ src/protocol/TDenseProtocol.h \ src/protocol/TDebugProtocol.h \ src/protocol/TBase64Utils.h \ diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp deleted file mode 100644 index 39c189ca..00000000 --- a/lib/cpp/src/protocol/TBinaryProtocol.cpp +++ /dev/null @@ -1,403 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "TBinaryProtocol.h" - -#include - -using std::string; - -namespace apache { namespace thrift { namespace protocol { - -uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name, - const TMessageType messageType, - const int32_t seqid) { - if (strict_write_) { - int32_t version = (VERSION_1) | ((int32_t)messageType); - uint32_t wsize = 0; - wsize += writeI32(version); - wsize += writeString(name); - wsize += writeI32(seqid); - return wsize; - } else { - uint32_t wsize = 0; - wsize += writeString(name); - wsize += writeByte((int8_t)messageType); - wsize += writeI32(seqid); - return wsize; - } -} - -uint32_t TBinaryProtocol::writeMessageEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeStructBegin(const char* name) { - return 0; -} - -uint32_t TBinaryProtocol::writeStructEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeFieldBegin(const char* name, - const TType fieldType, - const int16_t fieldId) { - uint32_t wsize = 0; - wsize += writeByte((int8_t)fieldType); - wsize += writeI16(fieldId); - return wsize; -} - -uint32_t TBinaryProtocol::writeFieldEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeFieldStop() { - return - writeByte((int8_t)T_STOP); -} - -uint32_t TBinaryProtocol::writeMapBegin(const TType keyType, - const TType valType, - const uint32_t size) { - uint32_t wsize = 0; - wsize += writeByte((int8_t)keyType); - wsize += writeByte((int8_t)valType); - wsize += writeI32((int32_t)size); - return wsize; -} - -uint32_t TBinaryProtocol::writeMapEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeListBegin(const TType elemType, - const uint32_t size) { - uint32_t wsize = 0; - wsize += writeByte((int8_t) elemType); - wsize += writeI32((int32_t)size); - return wsize; -} - -uint32_t TBinaryProtocol::writeListEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeSetBegin(const TType elemType, - const uint32_t size) { - uint32_t wsize = 0; - wsize += writeByte((int8_t)elemType); - wsize += writeI32((int32_t)size); - return wsize; -} - -uint32_t TBinaryProtocol::writeSetEnd() { - return 0; -} - -uint32_t TBinaryProtocol::writeBool(const bool value) { - uint8_t tmp = value ? 1 : 0; - trans_->write(&tmp, 1); - return 1; -} - -uint32_t TBinaryProtocol::writeByte(const int8_t byte) { - trans_->write((uint8_t*)&byte, 1); - return 1; -} - -uint32_t TBinaryProtocol::writeI16(const int16_t i16) { - int16_t net = (int16_t)htons(i16); - trans_->write((uint8_t*)&net, 2); - return 2; -} - -uint32_t TBinaryProtocol::writeI32(const int32_t i32) { - int32_t net = (int32_t)htonl(i32); - trans_->write((uint8_t*)&net, 4); - return 4; -} - -uint32_t TBinaryProtocol::writeI64(const int64_t i64) { - int64_t net = (int64_t)htonll(i64); - trans_->write((uint8_t*)&net, 8); - return 8; -} - -uint32_t TBinaryProtocol::writeDouble(const double dub) { - BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); - BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); - - uint64_t bits = bitwise_cast(dub); - bits = htonll(bits); - trans_->write((uint8_t*)&bits, 8); - return 8; -} - - -uint32_t TBinaryProtocol::writeString(const string& str) { - uint32_t size = str.size(); - uint32_t result = writeI32((int32_t)size); - if (size > 0) { - trans_->write((uint8_t*)str.data(), size); - } - return result + size; -} - -uint32_t TBinaryProtocol::writeBinary(const string& str) { - return TBinaryProtocol::writeString(str); -} - -/** - * Reading functions - */ - -uint32_t TBinaryProtocol::readMessageBegin(std::string& name, - TMessageType& messageType, - int32_t& seqid) { - uint32_t result = 0; - int32_t sz; - result += readI32(sz); - - if (sz < 0) { - // Check for correct version number - int32_t version = sz & VERSION_MASK; - if (version != VERSION_1) { - throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); - } - messageType = (TMessageType)(sz & 0x000000ff); - result += readString(name); - result += readI32(seqid); - } else { - if (strict_read_) { - throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); - } else { - // Handle pre-versioned input - int8_t type; - result += readStringBody(name, sz); - result += readByte(type); - messageType = (TMessageType)type; - result += readI32(seqid); - } - } - return result; -} - -uint32_t TBinaryProtocol::readMessageEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readStructBegin(string& name) { - name = ""; - return 0; -} - -uint32_t TBinaryProtocol::readStructEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readFieldBegin(string& name, - TType& fieldType, - int16_t& fieldId) { - uint32_t result = 0; - int8_t type; - result += readByte(type); - fieldType = (TType)type; - if (fieldType == T_STOP) { - fieldId = 0; - return result; - } - result += readI16(fieldId); - return result; -} - -uint32_t TBinaryProtocol::readFieldEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readMapBegin(TType& keyType, - TType& valType, - uint32_t& size) { - int8_t k, v; - uint32_t result = 0; - int32_t sizei; - result += readByte(k); - keyType = (TType)k; - result += readByte(v); - valType = (TType)v; - result += readI32(sizei); - if (sizei < 0) { - throw TProtocolException(TProtocolException::NEGATIVE_SIZE); - } else if (container_limit_ && sizei > container_limit_) { - throw TProtocolException(TProtocolException::SIZE_LIMIT); - } - size = (uint32_t)sizei; - return result; -} - -uint32_t TBinaryProtocol::readMapEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readListBegin(TType& elemType, - uint32_t& size) { - int8_t e; - uint32_t result = 0; - int32_t sizei; - result += readByte(e); - elemType = (TType)e; - result += readI32(sizei); - if (sizei < 0) { - throw TProtocolException(TProtocolException::NEGATIVE_SIZE); - } else if (container_limit_ && sizei > container_limit_) { - throw TProtocolException(TProtocolException::SIZE_LIMIT); - } - size = (uint32_t)sizei; - return result; -} - -uint32_t TBinaryProtocol::readListEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readSetBegin(TType& elemType, - uint32_t& size) { - int8_t e; - uint32_t result = 0; - int32_t sizei; - result += readByte(e); - elemType = (TType)e; - result += readI32(sizei); - if (sizei < 0) { - throw TProtocolException(TProtocolException::NEGATIVE_SIZE); - } else if (container_limit_ && sizei > container_limit_) { - throw TProtocolException(TProtocolException::SIZE_LIMIT); - } - size = (uint32_t)sizei; - return result; -} - -uint32_t TBinaryProtocol::readSetEnd() { - return 0; -} - -uint32_t TBinaryProtocol::readBool(bool& value) { - uint8_t b[1]; - trans_->readAll(b, 1); - value = *(int8_t*)b != 0; - return 1; -} - -uint32_t TBinaryProtocol::readByte(int8_t& byte) { - uint8_t b[1]; - trans_->readAll(b, 1); - byte = *(int8_t*)b; - return 1; -} - -uint32_t TBinaryProtocol::readI16(int16_t& i16) { - uint8_t b[2]; - trans_->readAll(b, 2); - i16 = *(int16_t*)b; - i16 = (int16_t)ntohs(i16); - return 2; -} - -uint32_t TBinaryProtocol::readI32(int32_t& i32) { - uint8_t b[4]; - trans_->readAll(b, 4); - i32 = *(int32_t*)b; - i32 = (int32_t)ntohl(i32); - return 4; -} - -uint32_t TBinaryProtocol::readI64(int64_t& i64) { - uint8_t b[8]; - trans_->readAll(b, 8); - i64 = *(int64_t*)b; - i64 = (int64_t)ntohll(i64); - return 8; -} - -uint32_t TBinaryProtocol::readDouble(double& dub) { - BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); - BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); - - uint64_t bits; - uint8_t b[8]; - trans_->readAll(b, 8); - bits = *(uint64_t*)b; - bits = ntohll(bits); - dub = bitwise_cast(bits); - return 8; -} - -uint32_t TBinaryProtocol::readString(string& str) { - uint32_t result; - int32_t size; - result = readI32(size); - return result + readStringBody(str, size); -} - -uint32_t TBinaryProtocol::readBinary(string& str) { - return TBinaryProtocol::readString(str); -} - -uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) { - uint32_t result = 0; - - // Catch error cases - if (size < 0) { - throw TProtocolException(TProtocolException::NEGATIVE_SIZE); - } - if (string_limit_ > 0 && size > string_limit_) { - throw TProtocolException(TProtocolException::SIZE_LIMIT); - } - - // Catch empty string case - if (size == 0) { - str = ""; - return result; - } - - // Try to borrow first - const uint8_t* borrow_buf; - uint32_t got = size; - if ((borrow_buf = trans_->borrow(NULL, &got))) { - str.assign((const char*)borrow_buf, size); - trans_->consume(size); - return size; - } - - // Use the heap here to prevent stack overflow for v. large strings - if (size > string_buf_size_ || string_buf_ == NULL) { - void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); - if (new_string_buf == NULL) { - throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString"); - } - string_buf_ = (uint8_t*)new_string_buf; - string_buf_size_ = size; - } - trans_->readAll(string_buf_, size); - str = string((char*)string_buf_, size); - return (uint32_t)size; -} - -}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h index 45c58427..ca45294e 100644 --- a/lib/cpp/src/protocol/TBinaryProtocol.h +++ b/lib/cpp/src/protocol/TBinaryProtocol.h @@ -32,15 +32,18 @@ namespace apache { namespace thrift { namespace protocol { * binary format, essentially just spitting out the raw bytes. * */ -class TBinaryProtocol : public TVirtualProtocol { +template +class TBinaryProtocolT + : public TVirtualProtocol< TBinaryProtocolT > { protected: static const int32_t VERSION_MASK = 0xffff0000; static const int32_t VERSION_1 = 0x80010000; // VERSION_2 (0x80020000) is taken by TDenseProtocol. public: - TBinaryProtocol(boost::shared_ptr trans) : - TVirtualProtocol(trans), + TBinaryProtocolT(boost::shared_ptr trans) : + TVirtualProtocol< TBinaryProtocolT >(trans), + trans_(trans.get()), string_limit_(0), container_limit_(0), strict_read_(false), @@ -48,12 +51,13 @@ class TBinaryProtocol : public TVirtualProtocol { string_buf_(NULL), string_buf_size_(0) {} - TBinaryProtocol(boost::shared_ptr trans, - int32_t string_limit, - int32_t container_limit, - bool strict_read, - bool strict_write) : - TVirtualProtocol(trans), + TBinaryProtocolT(boost::shared_ptr trans, + int32_t string_limit, + int32_t container_limit, + bool strict_read, + bool strict_write) : + TVirtualProtocol< TBinaryProtocolT >(trans), + trans_(trans.get()), string_limit_(string_limit), container_limit_(container_limit), strict_read_(strict_read), @@ -61,7 +65,7 @@ class TBinaryProtocol : public TVirtualProtocol { string_buf_(NULL), string_buf_size_(0) {} - ~TBinaryProtocol() { + ~TBinaryProtocolT() { if (string_buf_ != NULL) { std::free(string_buf_); string_buf_size_ = 0; @@ -85,113 +89,111 @@ class TBinaryProtocol : public TVirtualProtocol { * Writing functions. */ - uint32_t writeMessageBegin(const std::string& name, - const TMessageType messageType, - const int32_t seqid); + /*ol*/ uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); - uint32_t writeMessageEnd(); + /*ol*/ uint32_t writeMessageEnd(); - uint32_t writeStructBegin(const char* name); + inline uint32_t writeStructBegin(const char* name); - uint32_t writeStructEnd(); + inline uint32_t writeStructEnd(); - uint32_t writeFieldBegin(const char* name, - const TType fieldType, - const int16_t fieldId); + inline uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); - uint32_t writeFieldEnd(); + inline uint32_t writeFieldEnd(); - uint32_t writeFieldStop(); + inline uint32_t writeFieldStop(); - uint32_t writeMapBegin(const TType keyType, - const TType valType, - const uint32_t size); + inline uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); - uint32_t writeMapEnd(); + inline uint32_t writeMapEnd(); - uint32_t writeListBegin(const TType elemType, - const uint32_t size); + inline uint32_t writeListBegin(const TType elemType, const uint32_t size); - uint32_t writeListEnd(); + inline uint32_t writeListEnd(); - uint32_t writeSetBegin(const TType elemType, - const uint32_t size); + inline uint32_t writeSetBegin(const TType elemType, const uint32_t size); - uint32_t writeSetEnd(); + inline uint32_t writeSetEnd(); - uint32_t writeBool(const bool value); + inline uint32_t writeBool(const bool value); - uint32_t writeByte(const int8_t byte); + inline uint32_t writeByte(const int8_t byte); - uint32_t writeI16(const int16_t i16); + inline uint32_t writeI16(const int16_t i16); - uint32_t writeI32(const int32_t i32); + inline uint32_t writeI32(const int32_t i32); - uint32_t writeI64(const int64_t i64); + inline uint32_t writeI64(const int64_t i64); - uint32_t writeDouble(const double dub); + inline uint32_t writeDouble(const double dub); - uint32_t writeString(const std::string& str); + inline uint32_t writeString(const std::string& str); - uint32_t writeBinary(const std::string& str); + inline uint32_t writeBinary(const std::string& str); /** * Reading functions */ - uint32_t readMessageBegin(std::string& name, - TMessageType& messageType, - int32_t& seqid); + /*ol*/ uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); - uint32_t readMessageEnd(); + /*ol*/ uint32_t readMessageEnd(); - uint32_t readStructBegin(std::string& name); + inline uint32_t readStructBegin(std::string& name); - uint32_t readStructEnd(); + inline uint32_t readStructEnd(); - uint32_t readFieldBegin(std::string& name, - TType& fieldType, - int16_t& fieldId); + inline uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); - uint32_t readFieldEnd(); + inline uint32_t readFieldEnd(); - uint32_t readMapBegin(TType& keyType, - TType& valType, - uint32_t& size); + inline uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); - uint32_t readMapEnd(); + inline uint32_t readMapEnd(); - uint32_t readListBegin(TType& elemType, - uint32_t& size); + inline uint32_t readListBegin(TType& elemType, uint32_t& size); - uint32_t readListEnd(); + inline uint32_t readListEnd(); - uint32_t readSetBegin(TType& elemType, - uint32_t& size); + inline uint32_t readSetBegin(TType& elemType, uint32_t& size); - uint32_t readSetEnd(); + inline uint32_t readSetEnd(); - uint32_t readBool(bool& value); + inline uint32_t readBool(bool& value); - uint32_t readByte(int8_t& byte); + inline uint32_t readByte(int8_t& byte); - uint32_t readI16(int16_t& i16); + inline uint32_t readI16(int16_t& i16); - uint32_t readI32(int32_t& i32); + inline uint32_t readI32(int32_t& i32); - uint32_t readI64(int64_t& i64); + inline uint32_t readI64(int64_t& i64); - uint32_t readDouble(double& dub); + inline uint32_t readDouble(double& dub); - uint32_t readString(std::string& str); + inline uint32_t readString(std::string& str); - uint32_t readBinary(std::string& str); + inline uint32_t readBinary(std::string& str); protected: uint32_t readStringBody(std::string& str, int32_t sz); + Transport_* trans_; + int32_t string_limit_; int32_t container_limit_; @@ -206,24 +208,28 @@ class TBinaryProtocol : public TVirtualProtocol { }; +typedef TBinaryProtocolT TBinaryProtocol; + /** * Constructs binary protocol handlers */ -class TBinaryProtocolFactory : public TProtocolFactory { +template +class TBinaryProtocolFactoryT : public TProtocolFactory { public: - TBinaryProtocolFactory() : + TBinaryProtocolFactoryT() : string_limit_(0), container_limit_(0), strict_read_(false), strict_write_(true) {} - TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) : + TBinaryProtocolFactoryT(int32_t string_limit, int32_t container_limit, + bool strict_read, bool strict_write) : string_limit_(string_limit), container_limit_(container_limit), strict_read_(strict_read), strict_write_(strict_write) {} - virtual ~TBinaryProtocolFactory() {} + virtual ~TBinaryProtocolFactoryT() {} void setStringSizeLimit(int32_t string_limit) { string_limit_ = string_limit; @@ -239,7 +245,19 @@ class TBinaryProtocolFactory : public TProtocolFactory { } boost::shared_ptr getProtocol(boost::shared_ptr trans) { - return boost::shared_ptr(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_)); + boost::shared_ptr specific_trans = + boost::dynamic_pointer_cast(trans); + TProtocol* prot; + if (specific_trans) { + prot = new TBinaryProtocolT(specific_trans, string_limit_, + container_limit_, strict_read_, + strict_write_); + } else { + prot = new TBinaryProtocol(trans, string_limit_, container_limit_, + strict_read_, strict_write_); + } + + return boost::shared_ptr(prot); } private: @@ -250,6 +268,10 @@ class TBinaryProtocolFactory : public TProtocolFactory { }; +typedef TBinaryProtocolFactoryT TBinaryProtocolFactory; + }}} // apache::thrift::protocol +#include "TBinaryProtocol.tcc" + #endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TBinaryProtocol.tcc b/lib/cpp/src/protocol/TBinaryProtocol.tcc new file mode 100644 index 00000000..1433a4f0 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.tcc @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ +#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ 1 + +#include "TBinaryProtocol.h" + +#include + + +namespace apache { namespace thrift { namespace protocol { + +template +uint32_t TBinaryProtocolT::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + if (this->strict_write_) { + int32_t version = (VERSION_1) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += writeI32(version); + wsize += writeString(name); + wsize += writeI32(seqid); + return wsize; + } else { + uint32_t wsize = 0; + wsize += writeString(name); + wsize += writeByte((int8_t)messageType); + wsize += writeI32(seqid); + return wsize; + } +} + +template +uint32_t TBinaryProtocolT::writeMessageEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeStructBegin(const char* name) { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeStructEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)fieldType); + wsize += writeI16(fieldId); + return wsize; +} + +template +uint32_t TBinaryProtocolT::writeFieldEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeFieldStop() { + return + writeByte((int8_t)T_STOP); +} + +template +uint32_t TBinaryProtocolT::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)keyType); + wsize += writeByte((int8_t)valType); + wsize += writeI32((int32_t)size); + return wsize; +} + +template +uint32_t TBinaryProtocolT::writeMapEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t) elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +template +uint32_t TBinaryProtocolT::writeListEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +template +uint32_t TBinaryProtocolT::writeSetEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::writeBool(const bool value) { + uint8_t tmp = value ? 1 : 0; + this->trans_->write(&tmp, 1); + return 1; +} + +template +uint32_t TBinaryProtocolT::writeByte(const int8_t byte) { + this->trans_->write((uint8_t*)&byte, 1); + return 1; +} + +template +uint32_t TBinaryProtocolT::writeI16(const int16_t i16) { + int16_t net = (int16_t)htons(i16); + this->trans_->write((uint8_t*)&net, 2); + return 2; +} + +template +uint32_t TBinaryProtocolT::writeI32(const int32_t i32) { + int32_t net = (int32_t)htonl(i32); + this->trans_->write((uint8_t*)&net, 4); + return 4; +} + +template +uint32_t TBinaryProtocolT::writeI64(const int64_t i64) { + int64_t net = (int64_t)htonll(i64); + this->trans_->write((uint8_t*)&net, 8); + return 8; +} + +template +uint32_t TBinaryProtocolT::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits = bitwise_cast(dub); + bits = htonll(bits); + this->trans_->write((uint8_t*)&bits, 8); + return 8; +} + + +template +uint32_t TBinaryProtocolT::writeString(const std::string& str) { + uint32_t size = str.size(); + uint32_t result = writeI32((int32_t)size); + if (size > 0) { + this->trans_->write((uint8_t*)str.data(), size); + } + return result + size; +} + +template +uint32_t TBinaryProtocolT::writeBinary(const std::string& str) { + return TBinaryProtocolT::writeString(str); +} + +/** + * Reading functions + */ + +template +uint32_t TBinaryProtocolT::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = 0; + int32_t sz; + result += readI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + result += readString(name); + result += readI32(seqid); + } else { + if (this->strict_read_) { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } else { + // Handle pre-versioned input + int8_t type; + result += readStringBody(name, sz); + result += readByte(type); + messageType = (TMessageType)type; + result += readI32(seqid); + } + } + return result; +} + +template +uint32_t TBinaryProtocolT::readMessageEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readStructBegin(std::string& name) { + name = ""; + return 0; +} + +template +uint32_t TBinaryProtocolT::readStructEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + int8_t type; + result += readByte(type); + fieldType = (TType)type; + if (fieldType == T_STOP) { + fieldId = 0; + return result; + } + result += readI16(fieldId); + return result; +} + +template +uint32_t TBinaryProtocolT::readFieldEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + int8_t k, v; + uint32_t result = 0; + int32_t sizei; + result += readByte(k); + keyType = (TType)k; + result += readByte(v); + valType = (TType)v; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (this->container_limit_ && sizei > this->container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +template +uint32_t TBinaryProtocolT::readMapEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readListBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (this->container_limit_ && sizei > this->container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +template +uint32_t TBinaryProtocolT::readListEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readSetBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (this->container_limit_ && sizei > this->container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +template +uint32_t TBinaryProtocolT::readSetEnd() { + return 0; +} + +template +uint32_t TBinaryProtocolT::readBool(bool& value) { + uint8_t b[1]; + this->trans_->readAll(b, 1); + value = *(int8_t*)b != 0; + return 1; +} + +template +uint32_t TBinaryProtocolT::readByte(int8_t& byte) { + uint8_t b[1]; + this->trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +template +uint32_t TBinaryProtocolT::readI16(int16_t& i16) { + uint8_t b[2]; + this->trans_->readAll(b, 2); + i16 = *(int16_t*)b; + i16 = (int16_t)ntohs(i16); + return 2; +} + +template +uint32_t TBinaryProtocolT::readI32(int32_t& i32) { + uint8_t b[4]; + this->trans_->readAll(b, 4); + i32 = *(int32_t*)b; + i32 = (int32_t)ntohl(i32); + return 4; +} + +template +uint32_t TBinaryProtocolT::readI64(int64_t& i64) { + uint8_t b[8]; + this->trans_->readAll(b, 8); + i64 = *(int64_t*)b; + i64 = (int64_t)ntohll(i64); + return 8; +} + +template +uint32_t TBinaryProtocolT::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits; + uint8_t b[8]; + this->trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = ntohll(bits); + dub = bitwise_cast(bits); + return 8; +} + +template +uint32_t TBinaryProtocolT::readString(std::string& str) { + uint32_t result; + int32_t size; + result = readI32(size); + return result + readStringBody(str, size); +} + +template +uint32_t TBinaryProtocolT::readBinary(std::string& str) { + return TBinaryProtocolT::readString(str); +} + +template +uint32_t TBinaryProtocolT::readStringBody(std::string& str, + int32_t size) { + uint32_t result = 0; + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (this->string_limit_ > 0 && size > this->string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Catch empty string case + if (size == 0) { + str = ""; + return result; + } + + // Try to borrow first + const uint8_t* borrow_buf; + uint32_t got = size; + if ((borrow_buf = this->trans_->borrow(NULL, &got))) { + str.assign((const char*)borrow_buf, size); + this->trans_->consume(size); + return size; + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > this->string_buf_size_ || this->string_buf_ == NULL) { + void* new_string_buf = std::realloc(this->string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, + "Out of memory in TBinaryProtocolT::readString"); + } + this->string_buf_ = (uint8_t*)new_string_buf; + this->string_buf_size_ = size; + } + this->trans_->readAll(this->string_buf_, size); + str = std::string((char*)this->string_buf_, size); + return (uint32_t)size; +} + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_TCC_ diff --git a/lib/cpp/src/protocol/TCompactProtocol.h b/lib/cpp/src/protocol/TCompactProtocol.h index 77c44544..2150cde9 100644 --- a/lib/cpp/src/protocol/TCompactProtocol.h +++ b/lib/cpp/src/protocol/TCompactProtocol.h @@ -30,7 +30,9 @@ namespace apache { namespace thrift { namespace protocol { /** * C++ Implementation of the Compact Protocol as described in THRIFT-110 */ -class TCompactProtocol : public TVirtualProtocol { +template +class TCompactProtocolT + : public TVirtualProtocol< TCompactProtocolT > { protected: static const int8_t PROTOCOL_ID = 0x82; @@ -39,6 +41,8 @@ class TCompactProtocol : public TVirtualProtocol { static const int8_t TYPE_MASK = 0xE0; // 1110 0000 static const int32_t TYPE_SHIFT_AMOUNT = 5; + Transport_* trans_; + /** * (Writing) If we encounter a boolean field begin, save the TField here * so it can have the value incorporated. @@ -66,27 +70,10 @@ class TCompactProtocol : public TVirtualProtocol { std::stack lastField_; int16_t lastFieldId_; - enum Types { - CT_STOP = 0x00, - CT_BOOLEAN_TRUE = 0x01, - CT_BOOLEAN_FALSE = 0x02, - CT_BYTE = 0x03, - CT_I16 = 0x04, - CT_I32 = 0x05, - CT_I64 = 0x06, - CT_DOUBLE = 0x07, - CT_BINARY = 0x08, - CT_LIST = 0x09, - CT_SET = 0x0A, - CT_MAP = 0x0B, - CT_STRUCT = 0x0C, - }; - - static const int8_t TTypeToCType[16]; - public: - TCompactProtocol(boost::shared_ptr trans) : - TVirtualProtocol(trans), + TCompactProtocolT(boost::shared_ptr trans) : + TVirtualProtocol< TCompactProtocolT >(trans), + trans_(trans.get()), lastFieldId_(0), string_limit_(0), string_buf_(NULL), @@ -96,10 +83,11 @@ class TCompactProtocol : public TVirtualProtocol { boolValue_.hasBoolValue = false; } - TCompactProtocol(boost::shared_ptr trans, - int32_t string_limit, - int32_t container_limit) : - TVirtualProtocol(trans), + TCompactProtocolT(boost::shared_ptr trans, + int32_t string_limit, + int32_t container_limit) : + TVirtualProtocol< TCompactProtocolT >(trans), + trans_(trans.get()), lastFieldId_(0), string_limit_(string_limit), string_buf_(NULL), @@ -109,7 +97,7 @@ class TCompactProtocol : public TVirtualProtocol { boolValue_.hasBoolValue = false; } - ~TCompactProtocol() { + ~TCompactProtocolT() { free(string_buf_); } @@ -244,20 +232,23 @@ class TCompactProtocol : public TVirtualProtocol { int32_t container_limit_; }; +typedef TCompactProtocolT TCompactProtocol; + /** * Constructs compact protocol handlers */ -class TCompactProtocolFactory : public TProtocolFactory { +template +class TCompactProtocolFactoryT : public TProtocolFactory { public: - TCompactProtocolFactory() : + TCompactProtocolFactoryT() : string_limit_(0), container_limit_(0) {} - TCompactProtocolFactory(int32_t string_limit, int32_t container_limit) : + TCompactProtocolFactoryT(int32_t string_limit, int32_t container_limit) : string_limit_(string_limit), container_limit_(container_limit) {} - virtual ~TCompactProtocolFactory() {} + virtual ~TCompactProtocolFactoryT() {} void setStringSizeLimit(int32_t string_limit) { string_limit_ = string_limit; @@ -268,7 +259,17 @@ class TCompactProtocolFactory : public TProtocolFactory { } boost::shared_ptr getProtocol(boost::shared_ptr trans) { - return boost::shared_ptr(new TCompactProtocol(trans, string_limit_, container_limit_)); + boost::shared_ptr specific_trans = + boost::dynamic_pointer_cast(trans); + TProtocol* prot; + if (specific_trans) { + prot = new TCompactProtocolT(specific_trans, string_limit_, + container_limit_); + } else { + prot = new TCompactProtocol(trans, string_limit_, container_limit_); + } + + return boost::shared_ptr(prot); } private: @@ -277,6 +278,10 @@ class TCompactProtocolFactory : public TProtocolFactory { }; +typedef TCompactProtocolFactoryT TCompactProtocolFactory; + }}} // apache::thrift::protocol +#include "TCompactProtocol.tcc" + #endif diff --git a/lib/cpp/src/protocol/TCompactProtocol.cpp b/lib/cpp/src/protocol/TCompactProtocol.tcc similarity index 67% rename from lib/cpp/src/protocol/TCompactProtocol.cpp rename to lib/cpp/src/protocol/TCompactProtocol.tcc index ce2ee54d..24481462 100644 --- a/lib/cpp/src/protocol/TCompactProtocol.cpp +++ b/lib/cpp/src/protocol/TCompactProtocol.tcc @@ -16,10 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#ifndef _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_ +#define _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_ 1 -#include "TCompactProtocol.h" - -#include #include /* @@ -33,7 +32,7 @@ # error "Unable to determine the behavior of a signed right shift" #endif #if SIGNED_RIGHT_SHIFT_IS != ARITHMETIC_RIGHT_SHIFT -# error "TCompactProtocol currenly only works if a signed right shift is arithmetic" +# error "TCompactProtocol currently only works if a signed right shift is arithmetic" #endif #ifdef __GNUC__ @@ -44,29 +43,51 @@ namespace apache { namespace thrift { namespace protocol { -const int8_t TCompactProtocol::TTypeToCType[16] = { - CT_STOP, // T_STOP - 0, // unused - CT_BOOLEAN_TRUE, // T_BOOL - CT_BYTE, // T_BYTE - CT_DOUBLE, // T_DOUBLE - 0, // unused - CT_I16, // T_I16 - 0, // unused - CT_I32, // T_I32 - 0, // unused - CT_I64, // T_I64 - CT_BINARY, // T_STRING - CT_STRUCT, // T_STRUCT - CT_MAP, // T_MAP - CT_SET, // T_SET - CT_LIST, // T_LIST - }; - - -uint32_t TCompactProtocol::writeMessageBegin(const std::string& name, - const TMessageType messageType, - const int32_t seqid) { +namespace detail { namespace compact { + +enum Types { + CT_STOP = 0x00, + CT_BOOLEAN_TRUE = 0x01, + CT_BOOLEAN_FALSE = 0x02, + CT_BYTE = 0x03, + CT_I16 = 0x04, + CT_I32 = 0x05, + CT_I64 = 0x06, + CT_DOUBLE = 0x07, + CT_BINARY = 0x08, + CT_LIST = 0x09, + CT_SET = 0x0A, + CT_MAP = 0x0B, + CT_STRUCT = 0x0C, +}; + +const int8_t TTypeToCType[16] = { + CT_STOP, // T_STOP + 0, // unused + CT_BOOLEAN_TRUE, // T_BOOL + CT_BYTE, // T_BYTE + CT_DOUBLE, // T_DOUBLE + 0, // unused + CT_I16, // T_I16 + 0, // unused + CT_I32, // T_I32 + 0, // unused + CT_I64, // T_I64 + CT_BINARY, // T_STRING + CT_STRUCT, // T_STRUCT + CT_MAP, // T_MAP + CT_SET, // T_SET + CT_LIST, // T_LIST +}; + +}} // end detail::compact namespace + + +template +uint32_t TCompactProtocolT::writeMessageBegin( + const std::string& name, + const TMessageType messageType, + const int32_t seqid) { uint32_t wsize = 0; wsize += writeByte(PROTOCOL_ID); wsize += writeByte((VERSION_N & VERSION_MASK) | (((int32_t)messageType << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); @@ -81,9 +102,10 @@ uint32_t TCompactProtocol::writeMessageBegin(const std::string& name, * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the * field id will follow the type header as a zigzag varint. */ -uint32_t TCompactProtocol::writeFieldBegin(const char* name, - const TType fieldType, - const int16_t fieldId) { +template +uint32_t TCompactProtocolT::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { if (fieldType == T_BOOL) { booleanField_.name = name; booleanField_.fieldType = fieldType; @@ -97,7 +119,8 @@ uint32_t TCompactProtocol::writeFieldBegin(const char* name, /** * Write the STOP symbol so we know there are no more fields in this struct. */ -uint32_t TCompactProtocol::writeFieldStop() { +template +uint32_t TCompactProtocolT::writeFieldStop() { return writeByte(T_STOP); } @@ -106,7 +129,8 @@ uint32_t TCompactProtocol::writeFieldStop() { * use it as an opportunity to put special placeholder markers on the field * stack so we can get the field id deltas correct. */ -uint32_t TCompactProtocol::writeStructBegin(const char* name) { +template +uint32_t TCompactProtocolT::writeStructBegin(const char* name) { lastField_.push(lastFieldId_); lastFieldId_ = 0; return 0; @@ -117,7 +141,8 @@ uint32_t TCompactProtocol::writeStructBegin(const char* name) { * this as an opportunity to pop the last field from the current struct off * of the field stack. */ -uint32_t TCompactProtocol::writeStructEnd() { +template +uint32_t TCompactProtocolT::writeStructEnd() { lastFieldId_ = lastField_.top(); lastField_.pop(); return 0; @@ -126,16 +151,18 @@ uint32_t TCompactProtocol::writeStructEnd() { /** * Write a List header. */ -uint32_t TCompactProtocol::writeListBegin(const TType elemType, - const uint32_t size) { +template +uint32_t TCompactProtocolT::writeListBegin(const TType elemType, + const uint32_t size) { return writeCollectionBegin(elemType, size); } /** * Write a set header. */ -uint32_t TCompactProtocol::writeSetBegin(const TType elemType, - const uint32_t size) { +template +uint32_t TCompactProtocolT::writeSetBegin(const TType elemType, + const uint32_t size) { return writeCollectionBegin(elemType, size); } @@ -143,9 +170,10 @@ uint32_t TCompactProtocol::writeSetBegin(const TType elemType, * Write a map header. If the map is empty, omit the key and value type * headers, as we don't need any additional information to skip it. */ -uint32_t TCompactProtocol::writeMapBegin(const TType keyType, - const TType valType, - const uint32_t size) { +template +uint32_t TCompactProtocolT::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { uint32_t wsize = 0; if (size == 0) { @@ -163,7 +191,8 @@ uint32_t TCompactProtocol::writeMapBegin(const TType keyType, * right type header is for the value and then write the field header. * Otherwise, write a single byte. */ -uint32_t TCompactProtocol::writeBool(const bool value) { +template +uint32_t TCompactProtocolT::writeBool(const bool value) { uint32_t wsize = 0; if (booleanField_.name != NULL) { @@ -171,16 +200,19 @@ uint32_t TCompactProtocol::writeBool(const bool value) { wsize += writeFieldBeginInternal(booleanField_.name, booleanField_.fieldType, booleanField_.fieldId, - value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + value ? detail::compact::CT_BOOLEAN_TRUE : + detail::compact::CT_BOOLEAN_FALSE); booleanField_.name = NULL; } else { // we're not part of a field, so just write the value - wsize += writeByte(value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + wsize += writeByte(value ? detail::compact::CT_BOOLEAN_TRUE : + detail::compact::CT_BOOLEAN_FALSE); } return wsize; } -uint32_t TCompactProtocol::writeByte(const int8_t byte) { +template +uint32_t TCompactProtocolT::writeByte(const int8_t byte) { trans_->write((uint8_t*)&byte, 1); return 1; } @@ -188,28 +220,32 @@ uint32_t TCompactProtocol::writeByte(const int8_t byte) { /** * Write an i16 as a zigzag varint. */ -uint32_t TCompactProtocol::writeI16(const int16_t i16) { +template +uint32_t TCompactProtocolT::writeI16(const int16_t i16) { return writeVarint32(i32ToZigzag(i16)); } /** * Write an i32 as a zigzag varint. */ -uint32_t TCompactProtocol::writeI32(const int32_t i32) { +template +uint32_t TCompactProtocolT::writeI32(const int32_t i32) { return writeVarint32(i32ToZigzag(i32)); } /** * Write an i64 as a zigzag varint. */ -uint32_t TCompactProtocol::writeI64(const int64_t i64) { +template +uint32_t TCompactProtocolT::writeI64(const int64_t i64) { return writeVarint64(i64ToZigzag(i64)); } /** * Write a double to the wire as 8 bytes. */ -uint32_t TCompactProtocol::writeDouble(const double dub) { +template +uint32_t TCompactProtocolT::writeDouble(const double dub) { BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); @@ -222,11 +258,13 @@ uint32_t TCompactProtocol::writeDouble(const double dub) { /** * Write a string to the wire with a varint size preceeding. */ -uint32_t TCompactProtocol::writeString(const std::string& str) { +template +uint32_t TCompactProtocolT::writeString(const std::string& str) { return writeBinary(str); } -uint32_t TCompactProtocol::writeBinary(const std::string& str) { +template +uint32_t TCompactProtocolT::writeBinary(const std::string& str) { uint32_t ssize = str.size(); uint32_t wsize = writeVarint32(ssize) + ssize; trans_->write((uint8_t*)str.data(), ssize); @@ -242,10 +280,12 @@ uint32_t TCompactProtocol::writeBinary(const std::string& str) { * 'type override' of the type header. This is used specifically in the * boolean field case. */ -int32_t TCompactProtocol::writeFieldBeginInternal(const char* name, - const TType fieldType, - const int16_t fieldId, - int8_t typeOverride) { +template +int32_t TCompactProtocolT::writeFieldBeginInternal( + const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride) { uint32_t wsize = 0; // if there's a type override, use that. @@ -269,7 +309,9 @@ int32_t TCompactProtocol::writeFieldBeginInternal(const char* name, * Abstract method for writing the start of lists and sets. List and sets on * the wire differ only by the type indicator. */ -uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) { +template +uint32_t TCompactProtocolT::writeCollectionBegin(int8_t elemType, + int32_t size) { uint32_t wsize = 0; if (size <= 14) { wsize += writeByte(size << 4 | getCompactType(elemType)); @@ -283,7 +325,8 @@ uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) { /** * Write an i32 as a varint. Results in 1-5 bytes on the wire. */ -uint32_t TCompactProtocol::writeVarint32(uint32_t n) { +template +uint32_t TCompactProtocolT::writeVarint32(uint32_t n) { uint8_t buf[5]; uint32_t wsize = 0; @@ -303,7 +346,8 @@ uint32_t TCompactProtocol::writeVarint32(uint32_t n) { /** * Write an i64 as a varint. Results in 1-10 bytes on the wire. */ -uint32_t TCompactProtocol::writeVarint64(uint64_t n) { +template +uint32_t TCompactProtocolT::writeVarint64(uint64_t n) { uint8_t buf[10]; uint32_t wsize = 0; @@ -324,7 +368,8 @@ uint32_t TCompactProtocol::writeVarint64(uint64_t n) { * Convert l into a zigzag long. This allows negative numbers to be * represented compactly as a varint. */ -uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) { +template +uint64_t TCompactProtocolT::i64ToZigzag(const int64_t l) { return (l << 1) ^ (l >> 63); } @@ -332,15 +377,17 @@ uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) { * Convert n into a zigzag int. This allows negative numbers to be * represented compactly as a varint. */ -uint32_t TCompactProtocol::i32ToZigzag(const int32_t n) { +template +uint32_t TCompactProtocolT::i32ToZigzag(const int32_t n) { return (n << 1) ^ (n >> 31); } /** - * Given a TType value, find the appropriate TCompactProtocol.Type value + * Given a TType value, find the appropriate detail::compact::Types value */ -int8_t TCompactProtocol::getCompactType(int8_t ttype) { - return TTypeToCType[ttype]; +template +int8_t TCompactProtocolT::getCompactType(int8_t ttype) { + return detail::compact::TTypeToCType[ttype]; } // @@ -350,9 +397,11 @@ int8_t TCompactProtocol::getCompactType(int8_t ttype) { /** * Read a message header. */ -uint32_t TCompactProtocol::readMessageBegin(std::string& name, - TMessageType& messageType, - int32_t& seqid) { +template +uint32_t TCompactProtocolT::readMessageBegin( + std::string& name, + TMessageType& messageType, + int32_t& seqid) { uint32_t rsize = 0; int8_t protocolId; int8_t versionAndType; @@ -380,7 +429,8 @@ uint32_t TCompactProtocol::readMessageBegin(std::string& name, * Read a struct begin. There's nothing on the wire for this, but it is our * opportunity to push a new struct begin marker on the field stack. */ -uint32_t TCompactProtocol::readStructBegin(std::string& name) { +template +uint32_t TCompactProtocolT::readStructBegin(std::string& name) { name = ""; lastField_.push(lastFieldId_); lastFieldId_ = 0; @@ -391,7 +441,8 @@ uint32_t TCompactProtocol::readStructBegin(std::string& name) { * Doesn't actually consume any wire data, just removes the last field for * this struct from the field stack. */ -uint32_t TCompactProtocol::readStructEnd() { +template +uint32_t TCompactProtocolT::readStructEnd() { lastFieldId_ = lastField_.top(); lastField_.pop(); return 0; @@ -400,9 +451,10 @@ uint32_t TCompactProtocol::readStructEnd() { /** * Read a field header off the wire. */ -uint32_t TCompactProtocol::readFieldBegin(std::string& name, - TType& fieldType, - int16_t& fieldId) { +template +uint32_t TCompactProtocolT::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { uint32_t rsize = 0; int8_t byte; int8_t type; @@ -428,10 +480,12 @@ uint32_t TCompactProtocol::readFieldBegin(std::string& name, fieldType = getTType(type); // if this happens to be a boolean field, the value is encoded in the type - if (type == CT_BOOLEAN_TRUE || type == CT_BOOLEAN_FALSE) { + if (type == detail::compact::CT_BOOLEAN_TRUE || + type == detail::compact::CT_BOOLEAN_FALSE) { // save the boolean value in a special instance variable. boolValue_.hasBoolValue = true; - boolValue_.boolValue = (type == CT_BOOLEAN_TRUE ? true : false); + boolValue_.boolValue = + (type == detail::compact::CT_BOOLEAN_TRUE ? true : false); } // push the new field onto the field stack so we can keep the deltas going. @@ -444,9 +498,10 @@ uint32_t TCompactProtocol::readFieldBegin(std::string& name, * and value type. This means that 0-length maps will yield TMaps without the * "correct" types. */ -uint32_t TCompactProtocol::readMapBegin(TType& keyType, - TType& valType, - uint32_t& size) { +template +uint32_t TCompactProtocolT::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { uint32_t rsize = 0; int8_t kvType = 0; int32_t msize = 0; @@ -474,8 +529,9 @@ uint32_t TCompactProtocol::readMapBegin(TType& keyType, * of the element type header will be 0xF, and a varint will follow with the * true size. */ -uint32_t TCompactProtocol::readListBegin(TType& elemType, - uint32_t& size) { +template +uint32_t TCompactProtocolT::readListBegin(TType& elemType, + uint32_t& size) { int8_t size_and_type; uint32_t rsize = 0; int32_t lsize; @@ -505,8 +561,9 @@ uint32_t TCompactProtocol::readListBegin(TType& elemType, * of the element type header will be 0xF, and a varint will follow with the * true size. */ -uint32_t TCompactProtocol::readSetBegin(TType& elemType, - uint32_t& size) { +template +uint32_t TCompactProtocolT::readSetBegin(TType& elemType, + uint32_t& size) { return readListBegin(elemType, size); } @@ -515,7 +572,8 @@ uint32_t TCompactProtocol::readSetBegin(TType& elemType, * already have been read during readFieldBegin, so we'll just consume the * pre-stored value. Otherwise, read a byte. */ -uint32_t TCompactProtocol::readBool(bool& value) { +template +uint32_t TCompactProtocolT::readBool(bool& value) { if (boolValue_.hasBoolValue == true) { value = boolValue_.boolValue; boolValue_.hasBoolValue = false; @@ -523,7 +581,7 @@ uint32_t TCompactProtocol::readBool(bool& value) { } else { int8_t val; readByte(val); - value = (val == CT_BOOLEAN_TRUE); + value = (val == detail::compact::CT_BOOLEAN_TRUE); return 1; } } @@ -531,7 +589,8 @@ uint32_t TCompactProtocol::readBool(bool& value) { /** * Read a single byte off the wire. Nothing interesting here. */ -uint32_t TCompactProtocol::readByte(int8_t& byte) { +template +uint32_t TCompactProtocolT::readByte(int8_t& byte) { uint8_t b[1]; trans_->readAll(b, 1); byte = *(int8_t*)b; @@ -541,7 +600,8 @@ uint32_t TCompactProtocol::readByte(int8_t& byte) { /** * Read an i16 from the wire as a zigzag varint. */ -uint32_t TCompactProtocol::readI16(int16_t& i16) { +template +uint32_t TCompactProtocolT::readI16(int16_t& i16) { int32_t value; uint32_t rsize = readVarint32(value); i16 = (int16_t)zigzagToI32(value); @@ -551,7 +611,8 @@ uint32_t TCompactProtocol::readI16(int16_t& i16) { /** * Read an i32 from the wire as a zigzag varint. */ -uint32_t TCompactProtocol::readI32(int32_t& i32) { +template +uint32_t TCompactProtocolT::readI32(int32_t& i32) { int32_t value; uint32_t rsize = readVarint32(value); i32 = zigzagToI32(value); @@ -561,7 +622,8 @@ uint32_t TCompactProtocol::readI32(int32_t& i32) { /** * Read an i64 from the wire as a zigzag varint. */ -uint32_t TCompactProtocol::readI64(int64_t& i64) { +template +uint32_t TCompactProtocolT::readI64(int64_t& i64) { int64_t value; uint32_t rsize = readVarint64(value); i64 = zigzagToI64(value); @@ -571,7 +633,8 @@ uint32_t TCompactProtocol::readI64(int64_t& i64) { /** * No magic here - just read a double off the wire. */ -uint32_t TCompactProtocol::readDouble(double& dub) { +template +uint32_t TCompactProtocolT::readDouble(double& dub) { BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); @@ -584,14 +647,16 @@ uint32_t TCompactProtocol::readDouble(double& dub) { return 8; } -uint32_t TCompactProtocol::readString(std::string& str) { +template +uint32_t TCompactProtocolT::readString(std::string& str) { return readBinary(str); } /** * Read a byte[] from the wire. */ -uint32_t TCompactProtocol::readBinary(std::string& str) { +template +uint32_t TCompactProtocolT::readBinary(std::string& str) { int32_t rsize = 0; int32_t size; @@ -629,7 +694,8 @@ uint32_t TCompactProtocol::readBinary(std::string& str) { * Read an i32 from the wire as a varint. The MSB of each byte is set * if there is another byte to follow. This can read up to 5 bytes. */ -uint32_t TCompactProtocol::readVarint32(int32_t& i32) { +template +uint32_t TCompactProtocolT::readVarint32(int32_t& i32) { int64_t val; uint32_t rsize = readVarint64(val); i32 = (int32_t)val; @@ -640,7 +706,8 @@ uint32_t TCompactProtocol::readVarint32(int32_t& i32) { * Read an i64 from the wire as a proper varint. The MSB of each byte is set * if there is another byte to follow. This can read up to 10 bytes. */ -uint32_t TCompactProtocol::readVarint64(int64_t& i64) { +template +uint32_t TCompactProtocolT::readVarint64(int64_t& i64) { uint32_t rsize = 0; uint64_t val = 0; int shift = 0; @@ -689,43 +756,46 @@ uint32_t TCompactProtocol::readVarint64(int64_t& i64) { /** * Convert from zigzag int to int. */ -int32_t TCompactProtocol::zigzagToI32(uint32_t n) { +template +int32_t TCompactProtocolT::zigzagToI32(uint32_t n) { return (n >> 1) ^ -(n & 1); } /** * Convert from zigzag long to long. */ -int64_t TCompactProtocol::zigzagToI64(uint64_t n) { +template +int64_t TCompactProtocolT::zigzagToI64(uint64_t n) { return (n >> 1) ^ -(n & 1); } -TType TCompactProtocol::getTType(int8_t type) { +template +TType TCompactProtocolT::getTType(int8_t type) { switch (type) { case T_STOP: return T_STOP; - case CT_BOOLEAN_FALSE: - case CT_BOOLEAN_TRUE: + case detail::compact::CT_BOOLEAN_FALSE: + case detail::compact::CT_BOOLEAN_TRUE: return T_BOOL; - case CT_BYTE: + case detail::compact::CT_BYTE: return T_BYTE; - case CT_I16: + case detail::compact::CT_I16: return T_I16; - case CT_I32: + case detail::compact::CT_I32: return T_I32; - case CT_I64: + case detail::compact::CT_I64: return T_I64; - case CT_DOUBLE: + case detail::compact::CT_DOUBLE: return T_DOUBLE; - case CT_BINARY: + case detail::compact::CT_BINARY: return T_STRING; - case CT_LIST: + case detail::compact::CT_LIST: return T_LIST; - case CT_SET: + case detail::compact::CT_SET: return T_SET; - case CT_MAP: + case detail::compact::CT_MAP: return T_MAP; - case CT_STRUCT: + case detail::compact::CT_STRUCT: return T_STRUCT; default: throw TException("don't know what type: " + type); @@ -734,3 +804,5 @@ TType TCompactProtocol::getTType(int8_t type) { } }}} // apache::thrift::protocol + +#endif // _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_TCC_ diff --git a/lib/cpp/src/protocol/TDebugProtocol.h b/lib/cpp/src/protocol/TDebugProtocol.h index 1efcbd0a..3f7877c0 100644 --- a/lib/cpp/src/protocol/TDebugProtocol.h +++ b/lib/cpp/src/protocol/TDebugProtocol.h @@ -59,6 +59,7 @@ class TDebugProtocol : public TVirtualProtocol { public: TDebugProtocol(boost::shared_ptr trans) : TVirtualProtocol(trans) + , trans_(trans.get()) , string_limit_(DEFAULT_STRING_LIMIT) , string_prefix_size_(DEFAULT_STRING_PREFIX_SIZE) { @@ -140,6 +141,8 @@ class TDebugProtocol : public TVirtualProtocol { static std::string fieldTypeName(TType type); + TTransport* trans_; + int32_t string_limit_; int32_t string_prefix_size_; diff --git a/lib/cpp/src/protocol/TJSONProtocol.cpp b/lib/cpp/src/protocol/TJSONProtocol.cpp index ed2f518b..9859c0f8 100644 --- a/lib/cpp/src/protocol/TJSONProtocol.cpp +++ b/lib/cpp/src/protocol/TJSONProtocol.cpp @@ -358,6 +358,7 @@ public: TJSONProtocol::TJSONProtocol(boost::shared_ptr ptrans) : TVirtualProtocol(ptrans), + trans_(ptrans.get()), context_(new TJSONContext()), reader_(*ptrans) { } diff --git a/lib/cpp/src/protocol/TJSONProtocol.h b/lib/cpp/src/protocol/TJSONProtocol.h index cd42f5ea..b3a66679 100644 --- a/lib/cpp/src/protocol/TJSONProtocol.h +++ b/lib/cpp/src/protocol/TJSONProtocol.h @@ -291,6 +291,7 @@ class TJSONProtocol : public TVirtualProtocol { }; private: + TTransport* trans_; std::stack > contexts_; boost::shared_ptr context_; diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h index 6bf7e3bb..4b05de20 100644 --- a/lib/cpp/src/protocol/TProtocol.h +++ b/lib/cpp/src/protocol/TProtocol.h @@ -645,11 +645,9 @@ class TProtocol { protected: TProtocol(boost::shared_ptr ptrans): ptrans_(ptrans) { - trans_ = ptrans.get(); } boost::shared_ptr ptrans_; - TTransport* trans_; private: TProtocol() {} diff --git a/lib/cpp/test/Benchmark.cpp b/lib/cpp/test/Benchmark.cpp index 4a0eae96..a9859d87 100644 --- a/lib/cpp/test/Benchmark.cpp +++ b/lib/cpp/test/Benchmark.cpp @@ -19,12 +19,10 @@ #include #include -#include -#include -#include +#include "transport/TBufferTransports.h" +#include "protocol/TBinaryProtocol.h" #include "gen-cpp/DebugProtoTest_types.h" #include -#include #include class Timer { @@ -76,7 +74,7 @@ int main() { for (int i = 0; i < num; i ++) { buf->resetBuffer(); - TBinaryProtocol prot(buf); + TBinaryProtocolT prot(buf); ooe.write(&prot); } cout << "Write: " << num / (1000 * timer.frame()) << " kHz" << endl; @@ -95,7 +93,7 @@ int main() { OneOfEach ooe2; shared_ptr buf2(new TMemoryBuffer(data, datasize)); //buf2->resetBuffer(data, datasize); - TBinaryProtocol prot(buf2); + TBinaryProtocolT prot(buf2); ooe2.read(&prot); //cout << apache::thrift::ThriftDebugString(ooe2) << endl << endl; diff --git a/lib/cpp/test/Makefile.am b/lib/cpp/test/Makefile.am index c478cced..536796f0 100644 --- a/lib/cpp/test/Makefile.am +++ b/lib/cpp/test/Makefile.am @@ -47,6 +47,7 @@ check_PROGRAMS = \ DebugProtoTest \ JSONProtoTest \ OptionalRequiredTest \ + SpecializationTest \ AllProtocolsTest \ UnitTests @@ -114,6 +115,14 @@ OptionalRequiredTest_SOURCES = \ OptionalRequiredTest_LDADD = libtestgencpp.la +# +# SpecializationTest +# +SpecializationTest_SOURCES = \ + SpecializationTest.cpp + +SpecializationTest_LDADD = libtestgencpp.la + # # Common thrift code generation rules diff --git a/lib/cpp/test/SpecializationTest.cpp b/lib/cpp/test/SpecializationTest.cpp new file mode 100644 index 00000000..954585a5 --- /dev/null +++ b/lib/cpp/test/SpecializationTest.cpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include + +using std::cout; +using std::endl; +using namespace thrift::test::debug; +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; + +typedef TBinaryProtocolT MyProtocol; +//typedef TBinaryProtocolT MyProtocol; + +int main() { + + OneOfEach ooe; + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (uint64_t)6000 * 1000 * 1000; + ooe.double_precision = M_PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = "\xd7\n\a\t"; + ooe.base64 = "\1\2\3\255"; + + Nesting n; + n.my_ooe = ooe; + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (std::sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20" + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e" + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba\xc7\x83\xe2\x80\xbc"; + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm; + + hm.big.push_back(ooe); + hm.big.push_back(n.my_ooe); + hm.big[0].a_bite = 0x22; + hm.big[1].a_bite = 0x33; + + std::vector stage1; + stage1.push_back("and a one"); + stage1.push_back("and a two"); + hm.contain.insert(stage1); + stage1.clear(); + stage1.push_back("then a one, two"); + stage1.push_back("three!"); + stage1.push_back("FOUR!!"); + hm.contain.insert(stage1); + stage1.clear(); + hm.contain.insert(stage1); + + std::vector stage2; + hm.bonks["nothing"] = stage2; + stage2.resize(stage2.size()+1); + stage2.back().type = 1; + stage2.back().message = "Wait."; + stage2.resize(stage2.size()+1); + stage2.back().type = 2; + stage2.back().message = "What?"; + hm.bonks["something"] = stage2; + stage2.clear(); + stage2.resize(stage2.size()+1); + stage2.back().type = 3; + stage2.back().message = "quoth"; + stage2.resize(stage2.size()+1); + stage2.back().type = 4; + stage2.back().message = "the raven"; + stage2.resize(stage2.size()+1); + stage2.back().type = 5; + stage2.back().message = "nevermore"; + hm.bonks["poe"] = stage2; + + boost::shared_ptr buffer(new TMemoryBuffer()); + boost::shared_ptr proto(new MyProtocol(buffer)); + + cout << "Testing ooe" << endl; + + ooe.write(proto.get()); + OneOfEach ooe2; + ooe2.read(proto.get()); + + assert(ooe == ooe2); + + + cout << "Testing hm" << endl; + + hm.write(proto.get()); + HolyMoley hm2; + hm2.read(proto.get()); + + assert(hm == hm2); + + hm2.big[0].a_bite = 0xFF; + + assert(hm != hm2); + + return 0; +} diff --git a/test/cpp/src/TestClient.cpp b/test/cpp/src/TestClient.cpp index 4764a4c3..8ae819a9 100644 --- a/test/cpp/src/TestClient.cpp +++ b/test/cpp/src/TestClient.cpp @@ -75,7 +75,7 @@ int main(int argc, char** argv) { } - shared_ptr transport; + shared_ptr transport; shared_ptr socket(new TSocket(host, port)); @@ -87,7 +87,8 @@ int main(int argc, char** argv) { transport = bufferedSocket; } - shared_ptr protocol(new TBinaryProtocol(transport)); + shared_ptr< TBinaryProtocolT > protocol( + new TBinaryProtocolT(transport)); ThriftTestClient testClient(protocol); uint64_t time_min = 0; -- 2.17.1