From 382fc3043cba33fea1a919e4e6bfeac0cb9c22aa Mon Sep 17 00:00:00 2001 From: David Reiss Date: Sat, 25 Aug 2007 18:01:30 +0000 Subject: [PATCH] Thrift: Native-code Binary Protocol encoder. Summary: Merging a patch from Ben Maurer. This adds a python extension (i.e., a C module) that encodes Python thrift structs into the standard binary protocol much faster than our generated Python code. Also added by-value equality comparison to thrift structs (to help with testing). Cleaned up some trailing whitespace too. Reviewed By: mcslee, dreiss Test Plan: Recompiled Thrift. Thrifted a bunch of IDLs and compared the generated Python output. Looked at the extension module a lot. test/FastBinaryTest.py Revert Plan: ok git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665224 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_py_generator.cc | 250 +++- compiler/cpp/src/generate/t_py_generator.h | 11 +- lib/py/setup.py | 7 +- lib/py/src/protocol/TBinaryProtocol.py | 34 +- lib/py/src/protocol/__init__.py | 2 +- lib/py/src/protocol/fastbinary.c | 1146 +++++++++++++++++++ lib/py/src/transport/TTransport.py | 49 +- test/DebugProtoTest.thrift | 27 + test/FastbinaryTest.py | 190 +++ 9 files changed, 1651 insertions(+), 65 deletions(-) create mode 100644 lib/py/src/protocol/fastbinary.c create mode 100755 test/FastbinaryTest.py diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc index aeabeefa..5ed25674 100644 --- a/compiler/cpp/src/generate/t_py_generator.cc +++ b/compiler/cpp/src/generate/t_py_generator.cc @@ -8,6 +8,7 @@ #include #include #include +#include #include "t_py_generator.h" using namespace std; @@ -48,7 +49,10 @@ void t_py_generator::init_generator() { f_types_ << py_autogen_comment() << endl << py_imports() << endl << - render_includes() << endl; + render_includes() << endl << + "from thrift.transport import TTransport" << endl << + "from thrift.protocol import fastbinary" << endl << + "from thrift.protocol import TBinaryProtocol" << endl; f_consts_ << py_autogen_comment() << endl << @@ -118,7 +122,7 @@ void t_py_generator::generate_enum(t_enum* tenum) { f_types_ << "class " << tenum->get_name() << ":" << endl; indent_up(); - + vector constants = tenum->get_constants(); vector::iterator c_iter; int value = -1; @@ -144,7 +148,7 @@ void t_py_generator::generate_const(t_const* tconst) { t_type* type = tconst->get_type(); string name = tconst->get_name(); t_const_value* value = tconst->get_value(); - + indent(f_consts_) << name << " = " << render_const_value(type, value); f_consts_ << endl << endl; } @@ -268,7 +272,7 @@ void t_py_generator::generate_struct(t_struct* tstruct) { * @param txception The struct definition */ void t_py_generator::generate_xception(t_struct* txception) { - generate_py_struct(txception, true); + generate_py_struct(txception, true); } /** @@ -279,6 +283,19 @@ void t_py_generator::generate_py_struct(t_struct* tstruct, generate_py_struct_definition(f_types_, tstruct, is_exception); } +/** + * Comparator to sort fields in ascending order by key. + * Make this a functor instead of a function to help GCC inline it. + * The arguments are (const) references to const pointers to const t_fields. + * Unfortunately, we cannot declare it within the function. Boo! + * http://www.open-std.org/jtc1/sc22/open/n2356/ (paragraph 9). + */ +struct FieldKeyCompare { + bool operator()(t_field const * const & a, t_field const * const & b) { + return a->get_key() < b->get_key(); + } +}; + /** * Generates a struct definition for a thrift data type. This is nothing in PHP * where the objects are all just associative arrays (unless of course we @@ -290,8 +307,11 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, t_struct* tstruct, bool is_exception, bool is_result) { + const vector& members = tstruct->get_members(); - vector::const_iterator m_iter; + vector::const_iterator m_iter; + vector sorted_members(members); + std::sort(sorted_members.begin(), sorted_members.end(), FieldKeyCompare()); out << "class " << tstruct->get_name(); @@ -304,6 +324,53 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, out << endl; + /* + Here we generate the structure specification for the fastbinary codec. + These specifications have the following structure: + thrift_spec -> tuple of item_spec + item_spec -> None | (tag, type_enum, name, spec_args, default) + tag -> integer + type_enum -> TType.I32 | TType.STRING | TType.STRUCT | ... + name -> string_literal + default -> None # Handled by __init__ + spec_args -> None # For simple types + | (type_enum, spec_args) # Value type for list/set + | (type_enum, spec_args, type_enum, spec_args) + # Key and value for map + | (class_name, spec_args_ptr) # For struct/exception + class_name -> identifier # Basically a pointer to the class + spec_args_ptr -> expression # just class_name.spec_args + + TODO(dreiss): Consider making this work for structs with negative tags. + */ + + if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) { + indent(out) << "thrift_spec = (" << endl; + indent_up(); + + int sorted_keys_pos = 0; + for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) { + + for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) { + indent(out) << "None, # " << sorted_keys_pos << endl; + } + + indent(out) << "(" << (*m_iter)->get_key() << ", " + << type_to_enum((*m_iter)->get_type()) << ", " + << "'" << (*m_iter)->get_name() << "'" << ", " + << type_to_spec_args((*m_iter)->get_type()) << ", " + << "None" << ", " + << ")," + << " # " << sorted_keys_pos + << endl; + + sorted_keys_pos ++; + } + + indent_down(); + indent(out) << ")" << endl << endl; + } + out << indent() << "def __init__(self, d=None):" << endl; indent_up(); @@ -330,9 +397,10 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, } indent_down(); - + out << endl; + generate_py_struct_reader(out, tstruct); generate_py_struct_writer(out, tstruct); @@ -346,6 +414,24 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, indent() << " return repr(self.__dict__)" << endl << endl; + // Equality and inequality methods that compare by value + out << + indent() << "def __eq__(self, other):" << endl; + indent_up(); + out << + indent() << "return isinstance(other, self.__class__) and " + "self.__dict__ == other.__dict__" << endl; + indent_down(); + out << endl; + + out << + indent() << "def __ne__(self, other):" << endl; + indent_up(); + out << + indent() << "return not (self == other)" << endl; + indent_down(); + out << endl; + indent_down(); } @@ -360,15 +446,26 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, indent(out) << "def read(self, iprot):" << endl; indent_up(); - + + indent(out) << + "if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated " + "and isinstance(iprot.trans, TTransport.CReadableTransport):" << endl; + indent_up(); + + indent(out) << + "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" << endl; indent(out) << - "iprot.readStructBegin()" << endl; + "return" << endl; + indent_down(); + + indent(out) << + "iprot.readStructBegin()" << endl; // Loop over reading in fields indent(out) << "while True:" << endl; indent_up(); - + // Read beginning field marker indent(out) << "(fname, ftype, fid) = iprot.readFieldBegin()" << endl; @@ -380,10 +477,10 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, indent(out) << "break" << endl; indent_down(); - + // Switch statement on the field we are reading bool first = true; - + // Generate deserialization code for known cases for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { if (first) { @@ -405,18 +502,18 @@ void t_py_generator::generate_py_struct_reader(ofstream& out, indent() << " iprot.skip(ftype)" << endl; indent_down(); } - + // In the default case we skip the field out << indent() << "else:" << endl << indent() << " iprot.skip(ftype)" << endl; - + // Read field end marker indent(out) << "iprot.readFieldEnd()" << endl; - + indent_down(); - + indent(out) << "iprot.readStructEnd()" << endl; @@ -433,7 +530,17 @@ void t_py_generator::generate_py_struct_writer(ofstream& out, indent(out) << "def write(self, oprot):" << endl; indent_up(); - + + indent(out) << + "if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated:" << endl; + indent_up(); + + indent(out) << + "oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))" << endl; + indent(out) << + "return" << endl; + indent_down(); + indent(out) << "oprot.writeStructBegin('" << name << "')" << endl; @@ -487,8 +594,11 @@ void t_py_generator::generate_service(t_service* tservice) { } f_service_ << - "from ttypes import *" << endl << + "from ttypes import *" << endl << "from thrift.Thrift import TProcessor" << endl << + "from thrift.transport import TTransport" << endl << + "from thrift.protocol import fastbinary" << endl << + "from thrift.protocol import TBinaryProtocol" << endl << endl; // Generate the three main parts of the service (well, two for now in PHP) @@ -560,7 +670,7 @@ void t_py_generator::generate_service_interface(t_service* tservice) { "class Iface" << extends_if << ":" << endl; indent_up(); vector functions = tservice->get_functions(); - vector::iterator f_iter; + vector::iterator f_iter; for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { f_service_ << indent() << "def " << function_signature(*f_iter) << ":" << endl << @@ -606,7 +716,7 @@ void t_py_generator::generate_service_client(t_service* tservice) { // Generate client method implementations vector functions = tservice->get_functions(); - vector::const_iterator f_iter; + vector::const_iterator f_iter; for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { t_struct* arg_struct = (*f_iter)->get_arglist(); const vector& fields = arg_struct->get_members(); @@ -651,27 +761,27 @@ void t_py_generator::generate_service_client(t_service* tservice) { // Serialize the request header f_service_ << indent() << "self._oprot.writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType.CALL, self._seqid)" << endl; - + f_service_ << indent() << "args = " << argsname << "()" << endl; - + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { f_service_ << indent() << "args." << (*fld_iter)->get_name() << " = " << (*fld_iter)->get_name() << endl; } - + // Write to the stream f_service_ << indent() << "args.write(self._oprot)" << endl << indent() << "self._oprot.writeMessageEnd()" << endl << - indent() << "self._oprot.trans.flush()" << endl; + indent() << "self._oprot.trans.flush()" << endl; indent_down(); if (!(*f_iter)->is_async()) { std::string resultname = (*f_iter)->get_name() + "_result"; t_struct noargs(program_); - + t_function recv_function((*f_iter)->get_returntype(), string("recv_") + (*f_iter)->get_name(), &noargs); @@ -719,12 +829,12 @@ void t_py_generator::generate_service_client(t_service* tservice) { } else { f_service_ << indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; - } + } // Close function indent_down(); - f_service_ << endl; - } + f_service_ << endl; + } } indent_down(); @@ -739,7 +849,7 @@ void t_py_generator::generate_service_client(t_service* tservice) { */ void t_py_generator::generate_service_remote(t_service* tservice) { vector functions = tservice->get_functions(); - vector::iterator f_iter; + vector::iterator f_iter; string f_remote_name = package_dir_+"/"+service_name_+"-remote"; ofstream f_remote; @@ -759,7 +869,7 @@ void t_py_generator::generate_service_remote(t_service* tservice) { f_remote << "import " << service_name_ << endl << - "from ttypes import *" << endl << + "from ttypes import *" << endl << endl; f_remote << @@ -782,11 +892,11 @@ void t_py_generator::generate_service_remote(t_service* tservice) { } else { f_remote << ", "; } - f_remote << + f_remote << args[i]->get_type()->get_name() << " " << args[i]->get_name(); } f_remote << ")'" << endl; - } + } f_remote << " print ''" << endl << " sys.exit(0)" << endl << @@ -838,7 +948,7 @@ void t_py_generator::generate_service_remote(t_service* tservice) { "client = " << service_name_ << ".Client(protocol)" << endl << "transport.open()" << endl << endl; - + // Generate the dispatch methods bool first = true; @@ -868,15 +978,15 @@ void t_py_generator::generate_service_remote(t_service* tservice) { } } f_remote << "))" << endl; - + f_remote << endl; } f_remote << "transport.close()" << endl; - + // Close service file f_remote.close(); - + // Make file executable, love that bitwise OR action chmod(f_remote_name.c_str(), S_IRUSR | @@ -896,7 +1006,7 @@ void t_py_generator::generate_service_remote(t_service* tservice) { void t_py_generator::generate_service_server(t_service* tservice) { // Generate the dispatch methods vector functions = tservice->get_functions(); - vector::iterator f_iter; + vector::iterator f_iter; string extends = ""; string extends_processor = ""; @@ -924,10 +1034,10 @@ void t_py_generator::generate_service_server(t_service* tservice) { for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { f_service_ << indent() << "self._processMap[\"" << (*f_iter)->get_name() << "\"] = Processor.process_" << (*f_iter)->get_name() << endl; - } + } indent_down(); f_service_ << endl; - + // Generate the server implementation indent(f_service_) << "def process(self, iprot, oprot):" << endl; @@ -1005,7 +1115,7 @@ void t_py_generator::generate_process_function(t_service* tservice, indent() << "try:" << endl; indent_up(); } - + // Generate the function call t_struct* arg_struct = tfunction->get_arglist(); const std::vector& fields = arg_struct->get_members(); @@ -1090,7 +1200,7 @@ void t_py_generator::generate_deserialize_field(ofstream &out, } else if (type->is_base_type() || type->is_enum()) { indent(out) << name << " = iprot."; - + if (type->is_base_type()) { t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); switch (tbase) { @@ -1098,7 +1208,7 @@ void t_py_generator::generate_deserialize_field(ofstream &out, throw "compiler error: cannot serialize void field in a struct: " + name; break; - case t_base_type::TYPE_STRING: + case t_base_type::TYPE_STRING: out << "readString();"; break; case t_base_type::TYPE_BOOL: @@ -1130,7 +1240,7 @@ void t_py_generator::generate_deserialize_field(ofstream &out, } else { printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", tfield->get_name().c_str(), type->get_name().c_str()); - } + } } /** @@ -1155,7 +1265,7 @@ void t_py_generator::generate_deserialize_container(ofstream &out, string ktype = tmp("_ktype"); string vtype = tmp("_vtype"); string etype = tmp("_etype"); - + t_field fsize(g_type_i32, size); t_field fktype(g_type_byte, ktype); t_field fvtype(g_type_byte, vtype); @@ -1180,9 +1290,9 @@ void t_py_generator::generate_deserialize_container(ofstream &out, string i = tmp("_i"); indent(out) << "for " << i << " in xrange(" << size << "):" << endl; - + indent_up(); - + if (ttype->is_map()) { generate_deserialize_map_element(out, (t_map*)ttype, prefix); } else if (ttype->is_set()) { @@ -1190,7 +1300,7 @@ void t_py_generator::generate_deserialize_container(ofstream &out, } else if (ttype->is_list()) { generate_deserialize_list_element(out, (t_list*)ttype, prefix); } - + indent_down(); // Read container end @@ -1269,7 +1379,7 @@ void t_py_generator::generate_serialize_field(ofstream &out, throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + prefix + tfield->get_name(); } - + if (type->is_struct() || type->is_xception()) { generate_serialize_struct(out, (t_struct*)type, @@ -1284,7 +1394,7 @@ void t_py_generator::generate_serialize_field(ofstream &out, indent(out) << "oprot."; - + if (type->is_base_type()) { t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); switch (tbase) { @@ -1365,27 +1475,27 @@ void t_py_generator::generate_serialize_container(ofstream &out, if (ttype->is_map()) { string kiter = tmp("kiter"); string viter = tmp("viter"); - indent(out) << + indent(out) << "for " << kiter << "," << viter << " in " << prefix << ".items():" << endl; indent_up(); generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); indent_down(); } else if (ttype->is_set()) { string iter = tmp("iter"); - indent(out) << + indent(out) << "for " << iter << " in " << prefix << ":" << endl; indent_up(); generate_serialize_set_element(out, (t_set*)ttype, iter); indent_down(); } else if (ttype->is_list()) { string iter = tmp("iter"); - indent(out) << + indent(out) << "for " << iter << " in " << prefix << ":" << endl; indent_up(); generate_serialize_list_element(out, (t_list*)ttype, iter); indent_down(); } - + if (ttype->is_map()) { indent(out) << "oprot.writeMapEnd()" << endl; @@ -1500,7 +1610,7 @@ string t_py_generator::type_name(t_type* ttype) { */ string t_py_generator::type_to_enum(t_type* type) { type = get_true_type(type); - + if (type->is_base_type()) { t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); switch (tbase) { @@ -1535,3 +1645,37 @@ string t_py_generator::type_to_enum(t_type* type) { throw "INVALID TYPE IN type_to_enum: " + type->get_name(); } + +/** See the comment inside generate_py_struct_definition for what this is. */ +string t_py_generator::type_to_spec_args(t_type* ttype) { + while (ttype->is_typedef()) { + ttype = ((t_typedef*)ttype)->get_type(); + } + + if (ttype->is_base_type() || ttype->is_enum()) { + return "None"; + } else if (ttype->is_struct() || ttype->is_xception()) { + return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)"; + } else if (ttype->is_map()) { + return "(" + + type_to_enum(((t_map*)ttype)->get_key_type()) + "," + + type_to_spec_args(((t_map*)ttype)->get_key_type()) + "," + + type_to_enum(((t_map*)ttype)->get_val_type()) + "," + + type_to_spec_args(((t_map*)ttype)->get_val_type()) + + ")"; + + } else if (ttype->is_set()) { + return "(" + + type_to_enum(((t_set*)ttype)->get_elem_type()) + "," + + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + + ")"; + + } else if (ttype->is_list()) { + return "(" + + type_to_enum(((t_list*)ttype)->get_elem_type()) + "," + + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + + ")"; + } + + throw "INVALID TYPE IN type_to_spec_args: " + ttype->get_name(); +} diff --git a/compiler/cpp/src/generate/t_py_generator.h b/compiler/cpp/src/generate/t_py_generator.h index 87b3f473..b301d361 100644 --- a/compiler/cpp/src/generate/t_py_generator.h +++ b/compiler/cpp/src/generate/t_py_generator.h @@ -72,18 +72,18 @@ class t_py_generator : public t_oop_generator { */ void generate_deserialize_field (std::ofstream &out, - t_field* tfield, + t_field* tfield, std::string prefix="", bool inclass=false); - + void generate_deserialize_struct (std::ofstream &out, t_struct* tstruct, std::string prefix=""); - + void generate_deserialize_container (std::ofstream &out, t_type* ttype, std::string prefix=""); - + void generate_deserialize_set_element (std::ofstream &out, t_set* tset, std::string prefix=""); @@ -133,6 +133,7 @@ class t_py_generator : public t_oop_generator { std::string function_signature(t_function* tfunction, std::string prefix=""); std::string argument_list(t_struct* tstruct); std::string type_to_enum(t_type* ttype); + std::string type_to_spec_args(t_type* ttype); private: @@ -141,7 +142,7 @@ class t_py_generator : public t_oop_generator { */ std::ofstream f_types_; - std::ofstream f_consts_; + std::ofstream f_consts_; std::ofstream f_service_; std::string package_dir_; diff --git a/lib/py/setup.py b/lib/py/setup.py index 8ff1645e..582a985a 100644 --- a/lib/py/setup.py +++ b/lib/py/setup.py @@ -6,7 +6,11 @@ # See accompanying file LICENSE or visit the Thrift site at: # http://developers.facebook.com/thrift/ -from distutils.core import setup +from distutils.core import setup, Extension + +fastbinarymod = Extension('thrift.protocol.fastbinary', + sources = ['src/protocol/fastbinary.c'], + ) setup(name = 'Thrift', version = '1.0', @@ -16,5 +20,6 @@ setup(name = 'Thrift', url = 'http://code.facebook.com/thrift', packages = ['thrift', 'thrift.protocol', 'thrift.transport', 'thrift.server'], package_dir = {'thrift' : 'src'}, + ext_modules = [fastbinarymod], ) diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index 6ae0c867..3fd6b02a 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -77,7 +77,7 @@ class TBinaryProtocol(TProtocolBase): self.writeByte(1) else: self.writeByte(0) - + def writeByte(self, byte): buff = pack("!b", byte) self.trans.write(buff) @@ -89,7 +89,7 @@ class TBinaryProtocol(TProtocolBase): def writeI32(self, i32): buff = pack("!i", i32) self.trans.write(buff) - + def writeI64(self, i64): buff = pack("!q", i64) self.trans.write(buff) @@ -199,6 +199,7 @@ class TBinaryProtocol(TProtocolBase): str = self.trans.readAll(len) return str + class TBinaryProtocolFactory: def __init__(self, strictRead=False, strictWrite=True): self.strictRead = strictRead @@ -207,3 +208,32 @@ class TBinaryProtocolFactory: def getProtocol(self, trans): prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely.) + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + + pass + + +class TBinaryProtocolAcceleratedFactory: + def getProtocol(self, trans): + return TBinaryProtocolAccelerated(trans) diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py index bcc981db..11ae3a79 100644 --- a/lib/py/src/protocol/__init__.py +++ b/lib/py/src/protocol/__init__.py @@ -6,4 +6,4 @@ # See accompanying file LICENSE or visit the Thrift site at: # http://developers.facebook.com/thrift/ -__all__ = ['TProtocol', 'TBinaryProtocol'] +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c new file mode 100644 index 00000000..cfc504e9 --- /dev/null +++ b/lib/py/src/protocol/fastbinary.c @@ -0,0 +1,1146 @@ +// Copyright (c) 2006- Facebook +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ +// +// NOTE: This code was contributed by an external developer. +// The internal Thrift team has reviewed and tested it, +// but we cannot guarantee that it is production-ready. +// Please feel free to report bugs and/or success stories +// to the public mailing list. + +#include +#include "cStringIO.h" +#include +#include +#include + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +} TType; + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +# include +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType element_type; + PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + PyObject* klass; + PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { + PyObject* stringiobuf; + PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, INT32_MAX)) { + PyErr_SetString(PyExc_OverflowError, "string size out of range"); + return false; + } + return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = (int32_t) val; + return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + return false; + } + + dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 4) { + PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + return false; + } + + dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + + return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); + return false; + } + + dest->klass = PyTuple_GET_ITEM(typeargs, 0); + dest->spec = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + return false; + } + + dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { + int8_t net = val; + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { + int16_t net = (int16_t)htons(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { + int32_t net = (int32_t)htonl(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { + int64_t net = (int64_t)htonll(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * output_val assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + + writeByte(output, (int8_t) v); + break; + } + case T_I08: { + int32_t val; + + if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { + return false; + } + + writeByte(output, (int8_t) val); + break; + } + case T_I16: { + int32_t val; + + if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { + return false; + } + + writeI16(output, (int16_t) val); + break; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { + return false; + } + + writeI32(output, val); + break; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + writeI64(output, nval); + break; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + writeDouble(output, nval); + break; + } + + case T_STRING: { + Py_ssize_t len = PyString_Size(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeI32(output, (int32_t) len); + PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); + break; + } + + case T_LIST: + case T_SET: { + Py_ssize_t len; + SetListTypeArgs parsedargs; + PyObject *item; + PyObject *iterator; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + len = PyObject_Length(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeByte(output, parsedargs.element_type); + writeI32(output, (int32_t) len); + + iterator = PyObject_GetIter(value); + if (iterator == NULL) { + return false; + } + + while ((item = PyIter_Next(iterator))) { + if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { + Py_DECREF(item); + Py_DECREF(iterator); + return false; + } + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + return false; + } + + break; + } + + case T_MAP: { + PyObject *k, *v; + int pos = 0; + Py_ssize_t len; + + MapTypeArgs parsedargs; + + len = PyDict_Size(value); + if (!check_ssize_t_32(len)) { + return false; + } + + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + writeByte(output, parsedargs.ktag); + writeByte(output, parsedargs.vtag); + writeI32(output, len); + + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + // TODO(dreiss): Think hard about whether these INCREFs actually + // turn any unsafe scenarios into safe scenarios. + Py_INCREF(k); + Py_INCREF(v); + + if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) + || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { + Py_DECREF(k); + Py_DECREF(v); + return false; + } + } + break; + } + + // TODO(dreiss): Consider breaking this out as a function + // the way we did for decode_struct. + case T_STRUCT: { + StructTypeArgs parsedargs; + Py_ssize_t nspec; + Py_ssize_t i; + + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + nspec = PyTuple_Size(parsedargs.spec); + + if (nspec == -1) { + return false; + } + + for (i = 0; i < nspec; i++) { + StructItemSpec parsedspec; + PyObject* spec_tuple; + PyObject* instval = NULL; + + spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { + return false; + } + + instval = PyObject_GetAttr(value, parsedspec.attrname); + + if (!instval) { + return false; + } + + if (instval == Py_None) { + Py_DECREF(instval); + continue; + } + + writeByte(output, (int8_t) parsedspec.type); + writeI16(output, parsedspec.tag); + + if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { + Py_DECREF(instval); + return false; + } + + Py_DECREF(instval); + } + + writeByte(output, (int8_t)T_STOP); + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { + PyObject* enc_obj; + PyObject* type_args; + PyObject* buf; + PyObject* ret = NULL; + + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + + buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); + if (output_val(buf, enc_obj, T_STRUCT, type_args)) { + ret = PycStringIO->cgetvalue(buf); + } + + Py_DECREF(buf); + return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { + Py_XDECREF(d->stringiobuf); + Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { + dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); + if (!dest->stringiobuf) { + return false; + } + + if (!PycStringIO_InputCheck(dest->stringiobuf)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting stringio input"); + return false; + } + + dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + + if(!dest->refill_callable) { + free_decodebuf(dest); + return false; + } + + if (!PyCallable_Check(dest->refill_callable)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { + int read; + + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + PyObject* newiobuf; + + // using building functions as this is a rare codepath + newiobuf = PyObject_CallFunction( + input->refill_callable, "s#i", *output, len, read, NULL); + if (newiobuf == NULL) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + Py_CLEAR(input->stringiobuf); + input->stringiobuf = newiobuf; + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, + "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +static int8_t readByte(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int8_t))) { + return -1; + } + + return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int16_t))) { + return -1; + } + + return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int32_t))) { + return -1; + } + return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int64_t))) { + return -1; + } + + return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { + union { + int64_t f; + double t; + } transfer; + + transfer.f = readI64(input); + if (transfer.f == -1) { + return -1; + } + return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { + TType got = readByte(input); + + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(input, &dummy_buf, (n))) { \ + return false; \ + } \ + } while(0) + + char* dummy_buf; + + switch (type) { + + case T_BOOL: + case T_I08: SKIPBYTES(1); break; + case T_I16: SKIPBYTES(2); break; + case T_I32: SKIPBYTES(4); break; + case T_I64: + case T_DOUBLE: SKIPBYTES(8); break; + + case T_STRING: { + // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. + int len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + SKIPBYTES(len); + break; + } + + case T_LIST: + case T_SET: { + TType etype; + int len, i; + + etype = readByte(input); + if (etype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!skip(input, etype)) { + return false; + } + } + break; + } + + case T_MAP: { + TType ktype, vtype; + int len, i; + + ktype = readByte(input); + if (ktype == -1) { + return false; + } + + vtype = readByte(input); + if (vtype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!(skip(input, ktype) && skip(input, vtype))) { + return false; + } + } + break; + } + + case T_STRUCT: { + while (true) { + TType type; + + type = readByte(input); + if (type == -1) { + return false; + } + + if (type == T_STOP) + break; + + SKIPBYTES(2); // tag + if (!skip(input, type)) { + return false; + } + } + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return false; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + if (spec_seq_len == -1) { + return false; + } + + while (true) { + TType type; + int16_t tag; + PyObject* item_spec; + PyObject* fieldval = NULL; + StructItemSpec parsedspec; + + type = readByte(input); + if (type == T_STOP) { + break; + } + tag = readI16(input); + + if (tag >= 0 && tag < spec_seq_len) { + item_spec = PyTuple_GET_ITEM(spec_seq, tag); + } else { + item_spec = Py_None; + } + + if (item_spec == Py_None) { + if (!skip(input, type)) { + return false; + } + } + + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return false; + } + if (parsedspec.type != type) { + PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading"); + return false; + } + + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + if (fieldval == NULL) { + return false; + } + + if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + Py_DECREF(fieldval); + return false; + } + Py_DECREF(fieldval); + } + return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + switch (v) { + case 0: Py_RETURN_FALSE; + case 1: Py_RETURN_TRUE; + // Don't laugh. This is a potentially serious issue. + default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; + } + break; + } + case T_I08: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = readI16(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = readI32(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = readI64(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long) v); + } + + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = readDouble(input); + if (v == -1.0 && PyErr_Occurred()) { + return false; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + Py_ssize_t len = readI32(input); + char* buf; + if (!readBytes(input, &buf, len)) { + return NULL; + } + + return PyString_FromStringAndSize(buf, len); + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + int32_t len; + PyObject* ret = NULL; + int i; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.element_type)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return NULL; + } + + ret = PyList_New(len); + if (!ret) { + return NULL; + } + + for (i = 0; i < len; i++) { + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + if (!item) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, item); + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) + // hack needed for older versions + setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else + // official version + setret = PySet_New(ret); +#endif + Py_DECREF(ret); + return setret; + } + return ret; + } + + case T_MAP: { + int32_t len; + int i; + MapTypeArgs parsedargs; + PyObject* ret = NULL; + + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.ktag)) { + return NULL; + } + if (!checkTypeByte(input, parsedargs.vtag)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + ret = PyDict_New(); + if (!ret) { + goto error; + } + + for (i = 0; i < len; i++) { + PyObject* k = NULL; + PyObject* v = NULL; + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + if (k == NULL) { + goto loop_error; + } + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + if (v == NULL) { + goto loop_error; + } + if (PyDict_SetItem(ret, k, v) == -1) { + goto loop_error; + } + + Py_DECREF(k); + Py_DECREF(v); + continue; + + // Yuck! Destructors, anyone? + loop_error: + Py_XDECREF(k); + Py_XDECREF(v); + goto error; + } + + return ret; + + error: + Py_XDECREF(ret); + return NULL; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + PyObject* ret = PyObject_CallObject(parsedargs.klass, NULL); + if (!ret) { + return NULL; + } + + if (!decode_struct(input, ret, parsedargs.spec)) { + Py_DECREF(ret); + return NULL; + } + + return ret; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return NULL; + } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { + PyObject* output_obj = NULL; + PyObject* transport = NULL; + PyObject* typeargs = NULL; + StructTypeArgs parsedargs; + DecodeBuffer input = {}; + + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + return NULL; + } + + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!decode_buffer_from_obj(&input, transport)) { + return NULL; + } + + if (!decode_struct(&input, output_obj, parsedargs.spec)) { + free_decodebuf(&input); + return NULL; + } + + free_decodebuf(&input); + + Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if(!INTERN_STRING(value)) return; \ + } while(0) + + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + + PycString_IMPORT; + if (PycStringIO == NULL) return; + + (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index 3c18221e..0f5bfdce 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -55,6 +55,34 @@ class TTransportBase: def flush(self): pass +# This class should be thought of as an interface. +class CReadableTransport: + """base class for transports that are readable from C""" + + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. + + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass + + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. + + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. + + If reqlen bytes can't be read, throw EOFError. + """ + pass + class TServerTransportBase: """Base class for Thrift server transports.""" @@ -112,8 +140,14 @@ class TBufferedTransport(TTransportBase): self.__trans.flush() self.__buf = StringIO() -class TMemoryBuffer(TTransportBase): - """Wraps a string object as a TTransport""" +class TMemoryBuffer(TTransportBase, CReadableTransport): + """Wraps a cStringIO object as a TTransport. + + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ def __init__(self, value=None): """value -- a value to read from for stringio @@ -146,6 +180,15 @@ class TMemoryBuffer(TTransportBase): def getvalue(self): return self._buffer.getvalue() + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer + + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFException() + class TFramedTransportFactory: """Factory transport that builds framed transports""" @@ -193,7 +236,7 @@ class TFramedTransport(TTransportBase): buff = self.__trans.readAll(4) sz, = unpack('!i', buff) self.__rbuf = self.__trans.readAll(sz) - + def write(self, buf): if self.__wbuf == None: return self.__trans.write(buf) diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift index ac3b9b41..bbd86df0 100644 --- a/test/DebugProtoTest.thrift +++ b/test/DebugProtoTest.thrift @@ -36,3 +36,30 @@ struct HolyMoley { 2: set> contain, 3: map> bonks, } + +struct Backwards { + 2: i32 first_tag2, + 1: i32 second_tag1, +} + +struct Empty { +} + +struct Wrapper { + 1: Empty foo +} + +struct RandomStuff { + 1: i32 a, + 2: i32 b, + 3: i32 c, + 4: i32 d, + 5: list myintlist, + 6: map maps, + 7: i64 bigint, + 8: double triple, +} + +service Srv { + i32 Janky(i32 arg) +} diff --git a/test/FastbinaryTest.py b/test/FastbinaryTest.py new file mode 100755 index 00000000..0918002f --- /dev/null +++ b/test/FastbinaryTest.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python +r""" +thrift -py DebugProtoTest.thrift +./FastbinaryTest.py +""" + +# TODO(dreiss): Test error cases. Check for memory leaks. + +import sys +sys.path.append('./gen-py') + +import math +from DebugProtoTest import Srv +from DebugProtoTest.ttypes import * +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +import timeit +from cStringIO import StringIO +from copy import deepcopy +from pprint import pprint + +class TDevNullTransport(TTransport.TTransportBase): + def __init__(self): + pass + def isOpen(self): + return True + +ooe1 = OneOfEach() +ooe1.im_true = True; +ooe1.im_false = False; +ooe1.a_bite = 0xd6; +ooe1.integer16 = 27000; +ooe1.integer32 = 1<<24; +ooe1.integer64 = 6000 * 1000 * 1000; +ooe1.double_precision = math.pi; +ooe1.some_characters = "Debug THIS!"; +ooe1.zomg_unicode = "\xd7\n\a\t"; + +ooe2 = OneOfEach(); +ooe2.integer16 = 16; +ooe2.integer32 = 32; +ooe2.integer64 = 64; +ooe2.double_precision = (math.sqrt(5)+1)/2; +ooe2.some_characters = ":R (me going \"rrrr\")"; +ooe2.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20"\ + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe"\ + "\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e"\ + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba"\ + "\xc7\x83\xe2\x80\xbc"; + +hm = HolyMoley({"big":[], "contain":set(), "bonks":{}}) +hm.big.append(ooe1) +hm.big.append(ooe2) +hm.big[0].a_bite = 0x22; +hm.big[1].a_bite = 0x22; + +hm.contain.add(("and a one", "and a two")) +hm.contain.add(("then a one, two", "three!", "FOUR!")) +hm.contain.add(()) + +hm.bonks["nothing"] = []; +hm.bonks["something"] = [ + Bonk({"type":1, "message":"Wait."}), + Bonk({"type":2, "message":"What?"}), +] +hm.bonks["poe"] = [ + Bonk({"type":3, "message":"quoth"}), + Bonk({"type":4, "message":"the raven"}), + Bonk({"type":5, "message":"nevermore"}), +] + +rs = RandomStuff() +rs.a = 1 +rs.b = 2 +rs.c = 3 +rs.myintlist = range(20) +rs.maps = {1:Wrapper({"foo":Empty()}),2:Wrapper({"foo":Empty()})} +rs.bigint = 124523452435L +rs.triple = 3.14 + +my_zero = Srv.Janky_result({"arg":5}) +my_nega = Srv.Janky_args({"success":6}) + +def checkWrite(o): + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + if ORIG != MINE: + print "mine: %s\norig: %s" % (repr(MINE), repr(ORIG)) + +def checkRead(o): + prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer()) + o.write(prot) + prot = TBinaryProtocol.TBinaryProtocolAccelerated( + TTransport.TMemoryBuffer( + prot.trans.getvalue())) + c = o.__class__() + c.read(prot) + if c != o: + print "copy: " + pprint(eval(repr(c))) + print "orig: " + pprint(eval(repr(o))) + + +def doTest(): + checkWrite(hm) + no_set = deepcopy(hm) + no_set.contain = set() + checkRead(no_set) + checkWrite(rs) + checkRead(rs) + checkWrite(my_zero) + checkRead(my_zero) + checkRead(Backwards({"first_tag2":4, "second_tag1":2})) + try: + checkWrite(my_nega) + print "Hey, did this get fixed?" + except AttributeError: + # Sorry, doesn't work with negative tags. + pass + + # One case where the serialized form changes, but only superficially. + o = Backwards({"first_tag2":4, "second_tag1":2}) + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + if ORIG == MINE: + print "That shouldn't happen." + + + prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer()) + o.write(prot) + prot = TBinaryProtocol.TBinaryProtocol( + TTransport.TMemoryBuffer( + prot.trans.getvalue())) + c = o.__class__() + c.read(prot) + if c != o: + print "copy: " + pprint(eval(repr(c))) + print "orig: " + pprint(eval(repr(o))) + + + +def doBenchmark(): + + iters = 25000 + + setup = """ +from __main__ import hm, rs, TDevNullTransport +from thrift.protocol import TBinaryProtocol +trans = TDevNullTransport() +prot = TBinaryProtocol.TBinaryProtocol%s(trans) +""" + + setup_fast = setup % "Accelerated" + setup_slow = setup % "" + + print "Starting Benchmarks" + + print "HolyMoley Standard = %f" % \ + timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters) + print "HolyMoley Acceler. = %f" % \ + timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters) + + print "FastStruct Standard = %f" % \ + timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters) + print "FastStruct Acceler. = %f" % \ + timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters) + + + +doTest() +doBenchmark() + -- 2.17.1