From f4eec7a57b4c6ab08a545033fd3840586510ae8b Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Sun, 11 Sep 2011 18:16:21 +0000 Subject: [PATCH] THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings Patch: Will Pierce git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68 --- .gitignore | 1 + compiler/cpp/src/generate/t_py_generator.cc | 203 +++++++++++++++----- lib/py/src/Thrift.py | 19 ++ lib/py/src/protocol/TBase.py | 72 +++++++ lib/py/src/protocol/TCompactProtocol.py | 19 ++ lib/py/src/protocol/TProtocol.py | 199 +++++++++++++++++++ lib/py/src/protocol/__init__.py | 2 +- test/ThriftTest.thrift | 18 ++ test/py/Makefile.am | 48 ++++- test/py/RunClientServer.py | 83 +++++--- test/py/SerializationTest.py | 130 +++++++++++-- test/py/TestClient.py | 29 +-- test/py/TestEof.py | 7 +- test/py/TestServer.py | 39 ++-- test/py/TestSocket.py | 7 +- test/py/TestSyntax.py | 7 +- 16 files changed, 758 insertions(+), 125 deletions(-) create mode 100644 lib/py/src/protocol/TBase.py diff --git a/.gitignore b/.gitignore index bfb61d2f..082f7b73 100644 --- a/.gitignore +++ b/.gitignore @@ -179,6 +179,7 @@ /test/py/Makefile /test/py/Makefile.in /test/py/gen-py +/test/py/gen-py-* /test/py.twisted/Makefile /test/py.twisted/Makefile.in /test/py.twisted/_trial_temp/ diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc index 6a82bd7e..34acba4e 100644 --- a/compiler/cpp/src/generate/t_py_generator.cc +++ b/compiler/cpp/src/generate/t_py_generator.cc @@ -52,12 +52,44 @@ class t_py_generator : public t_generator { iter = parsed_options.find("new_style"); gen_newstyle_ = (iter != parsed_options.end()); + iter = parsed_options.find("slots"); + gen_slots_ = (iter != parsed_options.end()); + + iter = parsed_options.find("dynamic"); + gen_dynamic_ = (iter != parsed_options.end()); + + if (gen_dynamic_) { + gen_newstyle_ = 0; // dynamic is newstyle + gen_dynbaseclass_ = "TBase"; + gen_dynbaseclass_exc_ = "TExceptionBase"; + import_dynbase_ = "from thrift.protocol.TBase import TBase, TExceptionBase\n"; + } + + iter = parsed_options.find("dynbase"); + if (iter != parsed_options.end()) { + gen_dynbase_ = true; + gen_dynbaseclass_ = (iter->second); + } + + iter = parsed_options.find("dynexc"); + if (iter != parsed_options.end()) { + gen_dynbaseclass_exc_ = (iter->second); + } + + iter = parsed_options.find("dynimport"); + if (iter != parsed_options.end()) { + gen_dynbase_ = true; + import_dynbase_ = (iter->second); + } + iter = parsed_options.find("twisted"); gen_twisted_ = (iter != parsed_options.end()); iter = parsed_options.find("utf8strings"); gen_utf8strings_ = (iter != parsed_options.end()); + copy_options_ = option_string; + if (gen_twisted_){ out_dir_base_ = "gen-py.twisted"; } else { @@ -214,17 +246,32 @@ class t_py_generator : public t_generator { private: /** - * True iff we should generate new-style classes. + * True if we should generate new-style classes. */ bool gen_newstyle_; + /** + * True if we should generate dynamic style classes. + */ + bool gen_dynamic_; + + bool gen_dynbase_; + std::string gen_dynbaseclass_; + std::string gen_dynbaseclass_exc_; + + std::string import_dynbase_; + + bool gen_slots_; + + std::string copy_options_; + /** - * True iff we should generate Twisted-friendly RPC services. + * True if we should generate Twisted-friendly RPC services. */ bool gen_twisted_; /** - * True iff strings should be encoded using utf-8. + * True if strings should be encoded using utf-8. */ bool gen_utf8strings_; @@ -325,13 +372,19 @@ string t_py_generator::render_includes() { * Renders all the imports necessary to use the accelerated TBinaryProtocol */ string t_py_generator::render_fastbinary_includes() { - return - "from thrift.transport import TTransport\n" - "from thrift.protocol import TBinaryProtocol, TProtocol\n" - "try:\n" - " from thrift.protocol import fastbinary\n" - "except:\n" - " fastbinary = None\n"; + string hdr = ""; + if (gen_dynamic_) { + hdr += std::string(import_dynbase_); + } else { + hdr += + "from thrift.transport import TTransport\n" + "from thrift.protocol import TBinaryProtocol, TProtocol\n" + "try:\n" + " from thrift.protocol import fastbinary\n" + "except:\n" + " fastbinary = None\n"; + } + return hdr; } /** @@ -343,6 +396,8 @@ string t_py_generator::py_autogen_comment() { "# Autogenerated by Thrift Compiler (" + THRIFT_VERSION + ")\n" + "#\n" + "# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "#\n" + + "# options string: " + copy_options_ + "\n" + "#\n"; } @@ -351,7 +406,7 @@ string t_py_generator::py_autogen_comment() { */ string t_py_generator::py_imports() { return - string("from thrift.Thrift import *"); + string("from thrift.Thrift import TType, TMessageType"); } /** @@ -384,6 +439,7 @@ void t_py_generator::generate_enum(t_enum* tenum) { f_types_ << "class " << tenum->get_name() << (gen_newstyle_ ? "(object)" : "") << + (gen_dynamic_ ? "(" + gen_dynbaseclass_ + ")" : "") << ":" << endl; indent_up(); generate_python_docstring(f_types_, tenum); @@ -575,12 +631,19 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, out << std::endl << "class " << tstruct->get_name(); if (is_exception) { - out << "(Exception)"; - } else if (gen_newstyle_) { - out << "(object)"; + if (gen_dynamic_) { + out << "(" << gen_dynbaseclass_exc_ << ")"; + } else { + out << "(Exception)"; + } + } else { + if (gen_newstyle_) { + out << "(object)"; + } else if (gen_dynamic_) { + out << "(" << gen_dynbaseclass_ << ")"; + } } - out << - ":" << endl; + out << ":" << endl; indent_up(); generate_python_docstring(out, tstruct); @@ -606,6 +669,17 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, TODO(dreiss): Consider making this work for structs with negative tags. */ + if (gen_slots_) { + indent(out) << "__slots__ = [ " << endl; + indent_up(); + for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) { + indent(out) << "'" << (*m_iter)->get_name() << "'," << endl; + } + indent_down(); + indent(out) << " ]" << endl << endl; + + } + // TODO(dreiss): Look into generating an empty tuple instead of None // for structures with no members. // TODO(dreiss): Test encoding of structs where some inner structs @@ -672,8 +746,10 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, out << endl; } - generate_py_struct_reader(out, tstruct); - generate_py_struct_writer(out, tstruct); + if (!gen_dynamic_) { + generate_py_struct_reader(out, tstruct); + generate_py_struct_writer(out, tstruct); + } // For exceptions only, generate a __str__ method. This is // because when raised exceptions are printed to the console, __repr__ @@ -685,31 +761,61 @@ void t_py_generator::generate_py_struct_definition(ofstream& out, endl; } - // Printing utilities so that on the command line thrift - // structs look pretty like dictionaries - out << - indent() << "def __repr__(self):" << endl << - indent() << " L = ['%s=%r' % (key, value)" << endl << - indent() << " for key, value in self.__dict__.iteritems()]" << endl << - indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl << - endl; + if (!gen_slots_) { + // Printing utilities so that on the command line thrift + // structs look pretty like dictionaries + out << + indent() << "def __repr__(self):" << endl << + indent() << " L = ['%s=%r' % (key, value)" << endl << + indent() << " for key, value in self.__dict__.iteritems()]" << endl << + indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl << + endl; - // Equality and inequality methods that compare by value - out << - indent() << "def __eq__(self, other):" << endl; - indent_up(); - out << - indent() << "return isinstance(other, self.__class__) and " - "self.__dict__ == other.__dict__" << endl; - indent_down(); - out << endl; + // Equality and inequality methods that compare by value + out << + indent() << "def __eq__(self, other):" << endl; + indent_up(); + out << + indent() << "return isinstance(other, self.__class__) and " + "self.__dict__ == other.__dict__" << endl; + indent_down(); + out << endl; - out << - indent() << "def __ne__(self, other):" << endl; - indent_up(); - out << - indent() << "return not (self == other)" << endl; - indent_down(); + out << + indent() << "def __ne__(self, other):" << endl; + indent_up(); + + out << + indent() << "return not (self == other)" << endl; + indent_down(); + } else if (!gen_dynamic_) { + // no base class available to implement __eq__ and __repr__ and __ne__ for us + // so we must provide one that uses __slots__ + out << + indent() << "def __repr__(self):" << endl << + indent() << " L = ['%s=%r' % (key, getattr(self, key))" << endl << + indent() << " for key in self.__slots__]" << endl << + indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl << + endl; + + // Equality method that compares each attribute by value and type, walking __slots__ + out << + indent() << "def __eq__(self, other):" << endl << + indent() << " if not isinstance(other, self.__class__):" << endl << + indent() << " return False" << endl << + indent() << " for attr in self.__slots__:" << endl << + indent() << " my_val = getattr(self, attr)" << endl << + indent() << " other_val = getattr(other, attr)" << endl << + indent() << " if my_val != other_val:" << endl << + indent() << " return False" << endl << + indent() << " return True" << endl << + endl; + + out << + indent() << "def __ne__(self, other):" << endl << + indent() << " return not (self == other)" << endl << + endl; + } indent_down(); } @@ -984,7 +1090,7 @@ void t_py_generator::generate_service_interface(t_service* tservice) { } else { if (gen_twisted_) { extends_if = "(Interface)"; - } else if (gen_newstyle_) { + } else if (gen_newstyle_ || gen_dynamic_) { extends_if = "(object)"; } } @@ -1031,8 +1137,8 @@ void t_py_generator::generate_service_client(t_service* tservice) { extends_client = extends + ".Client, "; } } else { - if (gen_twisted_ && gen_newstyle_) { - extends_client = "(object)"; + if (gen_twisted_ && (gen_newstyle_ || gen_dynamic_)) { + extends_client = "(object)"; } } @@ -2388,6 +2494,11 @@ string t_py_generator::type_to_spec_args(t_type* ttype) { THRIFT_REGISTER_GENERATOR(py, "Python", " new_style: Generate new-style classes.\n" \ " twisted: Generate Twisted-friendly RPC services.\n" \ -" utf8strings: Encode/decode strings using utf8 in the generated code.\n" -) +" utf8strings: Encode/decode strings using utf8 in the generated code.\n" \ +" slots: Generate code using slots for instance members.\n" \ +" dynamic: Generate dynamic code, less code generated but slower.\n" \ +" dynbase=CLS Derive generated classes from class CLS instead of TBase.\n" \ +" dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n" \ +" dynimport='from foo.bar import CLS'\n" \ +" Add an import line to generated code to find the dynbase class.\n") diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py index af6f58da..1d271fcf 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/Thrift.py @@ -38,6 +38,25 @@ class TType: UTF8 = 16 UTF16 = 17 + _VALUES_TO_NAMES = ( 'STOP', + 'VOID', + 'BOOL', + 'BYTE', + 'DOUBLE', + None, + 'I16', + None, + 'I32', + None, + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16' ) + class TMessageType: CALL = 1 REPLY = 2 diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py new file mode 100644 index 00000000..e675c7dc --- /dev/null +++ b/lib/py/src/protocol/TBase.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * +from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport + +try: + from thrift.protocol import fastbinary +except: + fastbinary = None + +class TBase(object): + __slots__ = [] + + def __repr__(self): + L = ['%s=%r' % (key, getattr(self, key)) + for key in self.__slots__ ] + return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self.__slots__: + my_val = getattr(self, attr) + other_val = getattr(other, attr) + if my_val != other_val: + return False + return True + + def __ne__(self, other): + return not (self == other) + + def read(self, iprot): + if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: + fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + return + iprot.readStruct(self, self.thrift_spec) + + def write(self, oprot): + if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: + oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + return + oprot.writeStruct(self, self.thrift_spec) + +class TExceptionBase(Exception): + # old style class so python2.4 can raise exceptions derived from this + # This can't inherit from TBase because of that limitation. + __slots__ = [] + + __repr__ = TBase.__repr__.im_func + __eq__ = TBase.__eq__.im_func + __ne__ = TBase.__ne__.im_func + read = TBase.read.im_func + write = TBase.write.im_func + diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py index 6d57aeba..016a3317 100644 --- a/lib/py/src/protocol/TCompactProtocol.py +++ b/lib/py/src/protocol/TCompactProtocol.py @@ -1,3 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + from TProtocol import * from struct import pack, unpack diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index be3cb140..7338ff68 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -200,6 +200,205 @@ class TProtocolBase: self.skip(etype) self.readListEnd() + # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) + _TTYPE_HANDLERS = ( + (None, None, False), # 0 == TType,STOP + (None, None, False), # 1 == TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 == TType.BOOL + ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE + (None, None, False), # 5, undefined + ('readI16', 'writeI16', False), # 6 == TType.I16 + (None, None, False), # 7, undefined + ('readI32', 'writeI32', False), # 8 == TType.I32 + (None, None, False), # 9, undefined + ('readI64', 'writeI64', False), # 10 == TType.I64 + ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET + ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST + (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? + (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? + ) + + def readFieldByTType(self, ttype, spec): + try: + (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype] + except IndexError: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + if r_handler is None: + raise TProtocolException(type=TProtocolException.INVALID_DATA, + message='Invalid field type %d' % (ttype)) + reader = getattr(self, r_handler) + if not is_container: + return reader() + return reader(spec) + + def readContainerList(self, spec): + results = [] + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (list_type, list_len) = self.readListBegin() + if tspec is None: + # list values are simple types + for idx in xrange(list_len): + results.append(reader()) + else: + # this is like an inlined readFieldByTType + container_reader = self._TTYPE_HANDLERS[list_type][0] + val_reader = getattr(self, container_reader) + for idx in xrange(list_len): + val = val_reader(tspec) + results.append(val) + self.readListEnd() + return results + + def readContainerSet(self, spec): + results = set() + ttype, tspec = spec[0], spec[1] + r_handler = self._TTYPE_HANDLERS[ttype][0] + reader = getattr(self, r_handler) + (set_type, set_len) = self.readSetBegin() + if tspec is None: + # set members are simple types + for idx in xrange(set_len): + results.add(reader()) + else: + container_reader = self._TTYPE_HANDLERS[set_type][0] + val_reader = getattr(self, container_reader) + for idx in xrange(set_len): + results.add(val_reader(tspec)) + self.readSetEnd() + return results + + def readContainerStruct(self, spec): + (obj_class, obj_spec) = spec + obj = obj_class() + obj.read(self) + return obj + + def readContainerMap(self, spec): + results = dict() + key_ttype, key_spec = spec[0], spec[1] + val_ttype, val_spec = spec[2], spec[3] + (map_ktype, map_vtype, map_len) = self.readMapBegin() + # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree + key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) + val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) + # list values are simple types + for idx in xrange(map_len): + if key_spec is None: + k_val = key_reader() + else: + k_val = self.readFieldByTType(key_ttype, key_spec) + if val_spec is None: + v_val = val_reader() + else: + v_val = self.readFieldByTType(val_ttype, val_spec) + # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails + results[k_val] = v_val + self.readMapEnd() + return results + + def readStruct(self, obj, thrift_spec): + self.readStructBegin() + while True: + (fname, ftype, fid) = self.readFieldBegin() + if ftype == TType.STOP: + break + try: + field = thrift_spec[fid] + except IndexError: + self.skip(ftype) + else: + if field is not None and ftype == field[1]: + fname = field[2] + fspec = field[3] + val = self.readFieldByTType(ftype, fspec) + setattr(obj, fname, val) + else: + self.skip(ftype) + self.readFieldEnd() + self.readStructEnd() + + def writeContainerStruct(self, val, spec): + val.write(self) + + def writeContainerList(self, val, spec): + self.writeListBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec[1]) + self.writeListEnd() + + def writeContainerSet(self, val, spec): + self.writeSetBegin(spec[0], len(val)) + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + e_writer = getattr(self, w_handler) + if not is_container: + for elem in val: + e_writer(elem) + else: + for elem in val: + e_writer(elem, spec[1]) + self.writeSetEnd() + + def writeContainerMap(self, val, spec): + k_type = spec[0] + v_type = spec[2] + ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type] + ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type] + k_writer = getattr(self, ktype_name) + v_writer = getattr(self, vtype_name) + self.writeMapBegin(k_type, v_type, len(val)) + for m_key, m_val in val.iteritems(): + if not k_is_container: + k_writer(m_key) + else: + k_writer(m_key, spec[1]) + if not v_is_container: + v_writer(m_val) + else: + v_writer(m_val, spec[3]) + self.writeMapEnd() + + def writeStruct(self, obj, thrift_spec): + self.writeStructBegin(obj.__class__.__name__) + for field in thrift_spec: + if field is None: + continue + fname = field[2] + val = getattr(obj, fname) + if val is None: + # skip writing out unset fields + continue + fid = field[0] + ftype = field[1] + fspec = field[3] + # get the writer method for this value + self.writeFieldBegin(fname, ftype, fid) + self.writeFieldByTType(ftype, val, fspec) + self.writeFieldEnd() + self.writeFieldStop() + self.writeStructEnd() + + def writeFieldByTType(self, ttype, val, spec): + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype] + writer = getattr(self, w_handler) + if is_container: + writer(val, spec) + else: + writer(val) + class TProtocolFactory: def getProtocol(self, trans): pass + diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py index 01bfe18e..d53359b2 100644 --- a/lib/py/src/protocol/__init__.py +++ b/lib/py/src/protocol/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase'] diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift index b6cd939a..17b0295c 100644 --- a/test/ThriftTest.thrift +++ b/test/ThriftTest.thrift @@ -208,3 +208,21 @@ struct LargeDeltas { 3000: VersioningTestV2 vertwo3000, 4000: list big_numbers } + +struct NestedListsI32x2 { + 1: list> integerlist +} +struct NestedListsI32x3 { + 1: list>> integerlist +} +struct NestedMixedx2 { + 1: list> int_set_list + 2: map> map_int_strset + 3: list>> map_int_strset_list +} +struct ListBonks { + 1: list bonk +} +struct NestedListsBonk { + 1: list>> bonk +} diff --git a/test/py/Makefile.am b/test/py/Makefile.am index 63b7a890..2317ef61 100644 --- a/test/py/Makefile.am +++ b/test/py/Makefile.am @@ -19,22 +19,30 @@ THRIFT = $(top_srcdir)/compiler/cpp/thrift -py_unit_tests = \ - SerializationTest.py \ - TestEof.py \ - TestSyntax.py \ - RunClientServer.py +py_unit_tests = RunClientServer.py thrift_gen = \ gen-py/ThriftTest/__init__.py \ - gen-py/DebugProtoTest/__init__.py + gen-py/DebugProtoTest/__init__.py \ + gen-py-default/ThriftTest/__init__.py \ + gen-py-default/DebugProtoTest/__init__.py \ + gen-py-slots/ThriftTest/__init__.py \ + gen-py-slots/DebugProtoTest/__init__.py \ + gen-py-newstyle/ThriftTest/__init__.py \ + gen-py-newstyle/DebugProtoTest/__init__.py \ + gen-py-newstyleslots/ThriftTest/__init__.py \ + gen-py-newstyleslots/DebugProtoTest/__init__.py \ + gen-py-dynamic/ThriftTest/__init__.py \ + gen-py-dynamic/DebugProtoTest/__init__.py \ + gen-py-dynamicslots/ThriftTest/__init__.py \ + gen-py-dynamicslots/DebugProtoTest/__init__.py helper_scripts= \ TestClient.py \ TestServer.py check_SCRIPTS= \ - $(thrift_gen) \ + $(thrift_gen) \ $(py_unit_tests) \ $(helper_scripts) @@ -42,7 +50,29 @@ TESTS= $(py_unit_tests) gen-py/%/__init__.py: ../%.thrift - $(THRIFT) --gen py $< + $(THRIFT) --gen py $< + test -d gen-py-default || mkdir gen-py-default + $(THRIFT) --gen py -out gen-py-default $< + +gen-py-slots/%/__init__.py: ../%.thrift + test -d gen-py-slots || mkdir gen-py-slots + $(THRIFT) --gen py:slots -out gen-py-slots $< + +gen-py-newstyle/%/__init__.py: ../%.thrift + test -d gen-py-newstyle || mkdir gen-py-newstyle + $(THRIFT) --gen py:new_style -out gen-py-newstyle $< + +gen-py-newstyleslots/%/__init__.py: ../%.thrift + test -d gen-py-newstyleslots || mkdir gen-py-newstyleslots + $(THRIFT) --gen py:new_style,slots -out gen-py-newstyleslots $< + +gen-py-dynamic/%/__init__.py: ../%.thrift + test -d gen-py-dynamic || mkdir gen-py-dynamic + $(THRIFT) --gen py:dynamic -out gen-py-dynamic $< + +gen-py-dynamicslots/%/__init__.py: ../%.thrift + test -d gen-py-dynamicslots || mkdir gen-py-dynamicslots + $(THRIFT) --gen py:dynamic,slots -out gen-py-dynamicslots $< clean-local: - $(RM) -r gen-py + $(RM) -r gen-py gen-py-slots gen-py-default gen-py-newstyle gen-py-newstyleslots gen-py-dynamic gen-py-dynamicslots diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py index 633856f5..8a7fda64 100755 --- a/test/py/RunClientServer.py +++ b/test/py/RunClientServer.py @@ -28,6 +28,9 @@ import signal from optparse import OptionParser parser = OptionParser() +parser.add_option('--genpydirs', type='string', dest='genpydirs', + default='default,slots,newstyle,newstyleslots,dynamic,dynamicslots', + help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests') parser.add_option("--port", type="int", dest="port", default=9090, help="port number for server to listen on") parser.add_option('-v', '--verbose', action="store_const", @@ -39,11 +42,15 @@ parser.add_option('-q', '--quiet', action="store_const", parser.set_defaults(verbose=1) options, args = parser.parse_args() +generated_dirs = [] +for gp_dir in options.genpydirs.split(','): + generated_dirs.append('gen-py-%s' % (gp_dir)) + +SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py'] FRAMED = ["TNonblockingServer"] SKIP_ZLIB = ['TNonblockingServer', 'THttpServer'] SKIP_SSL = ['TNonblockingServer', 'THttpServer'] -EXTRA_DELAY = ['TProcessPoolServer'] -EXTRA_SLEEP = 3.5 +EXTRA_DELAY = dict(TProcessPoolServer=3.5) PROTOS= [ 'accel', @@ -85,11 +92,21 @@ if len(args) == 1: def relfile(fname): return os.path.join(os.path.dirname(__file__), fname) -def runTest(server_class, proto, port, use_zlib, use_ssl): +def runScriptTest(genpydir, script): + script_args = [sys.executable, relfile(script) ] + script_args.append('--genpydir=%s' % genpydir) + serverproc = subprocess.Popen(script_args) + print '\nTesting script: %s\n----' % (' '.join(script_args)) + ret = subprocess.call(script_args) + if ret != 0: + raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args))) + +def runServiceTest(genpydir, server_class, proto, port, use_zlib, use_ssl): # Build command line arguments server_args = [sys.executable, relfile('TestServer.py') ] cli_args = [sys.executable, relfile('TestClient.py') ] for which in (server_args, cli_args): + which.append('--genpydir=%s' % genpydir) which.append('--proto=%s' % proto) # accel, binary or compact which.append('--port=%d' % port) # default to 9090 if use_zlib: @@ -110,7 +127,7 @@ def runTest(server_class, proto, port, use_zlib, use_ssl): if options.verbose > 0: print 'Testing server %s: %s' % (server_class, ' '.join(server_args)) serverproc = subprocess.Popen(server_args) - time.sleep(0.2) + time.sleep(0.15) try: if options.verbose > 0: print 'Testing client: %s' % (' '.join(cli_args)) @@ -124,29 +141,47 @@ def runTest(server_class, proto, port, use_zlib, use_ssl): print 'FAIL: Server process (%s) failed with retcode %d' % (' '.join(server_args), serverproc.returncode) raise Exception('Server subprocess %s died, args: %s' % (server_class, ' '.join(server_args))) else: - if server_class in EXTRA_DELAY: - if options.verbose > 0: - print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class, - proto, use_zlib, use_ssl, EXTRA_SLEEP) - time.sleep(EXTRA_SLEEP) + extra_sleep = EXTRA_DELAY.get(server_class, 0) + if extra_sleep > 0 and options.verbose > 0: + print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class, + proto, use_zlib, use_ssl, extra_sleep) + time.sleep(extra_sleep) os.kill(serverproc.pid, signal.SIGKILL) # wait for shutdown - time.sleep(0.1) + time.sleep(0.05) test_count = 0 +# run tests without a client/server first +print '----------------' +print ' Executing individual test scripts with various generated code directories' +print ' Directories to be tested: ' + ', '.join(generated_dirs) +print ' Scripts to be tested: ' + ', '.join(SCRIPTS) +print '----------------' +for genpydir in generated_dirs: + for script in SCRIPTS: + runScriptTest(genpydir, script) + +print '----------------' +print ' Executing Client/Server tests with various generated code directories' +print ' Servers to be tested: ' + ', '.join(SERVERS) +print ' Directories to be tested: ' + ', '.join(generated_dirs) +print ' Protocols to be tested: ' + ', '.join(PROTOS) +print ' Options to be tested: ZLIB(yes/no), SSL(yes/no)' +print '----------------' for try_server in SERVERS: - for try_proto in PROTOS: - for with_zlib in (False, True): - # skip any servers that don't work with the Zlib transport - if with_zlib and try_server in SKIP_ZLIB: - continue - for with_ssl in (False, True): - # skip any servers that don't work with SSL - if with_ssl and try_server in SKIP_SSL: + for genpydir in generated_dirs: + for try_proto in PROTOS: + for with_zlib in (False, True): + # skip any servers that don't work with the Zlib transport + if with_zlib and try_server in SKIP_ZLIB: continue - test_count += 1 - if options.verbose > 0: - print '\nTest run #%d: Server=%s, Proto=%s, zlib=%s, SSL=%s' % (test_count, try_server, try_proto, with_zlib, with_ssl) - runTest(try_server, try_proto, options.port, with_zlib, with_ssl) - if options.verbose > 0: - print 'OK: Finished %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' % (try_server, try_proto, with_zlib, with_ssl, test_count) + for with_ssl in (False, True): + # skip any servers that don't work with SSL + if with_ssl and try_server in SKIP_SSL: + continue + test_count += 1 + if options.verbose > 0: + print '\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl) + runServiceTest(genpydir, try_server, try_proto, options.port, with_zlib, with_ssl) + if options.verbose > 0: + print 'OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count) diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py index 3ba76fba..06641461 100755 --- a/test/py/SerializationTest.py +++ b/test/py/SerializationTest.py @@ -20,7 +20,12 @@ # import sys, glob -sys.path.insert(0, './gen-py') +from optparse import OptionParser +parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py') +options, args = parser.parse_args() +del sys.argv[1:] # clean up hack so unittest doesn't complain +sys.path.insert(0, options.genpydir) sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) from ThriftTest.ttypes import * @@ -119,28 +124,86 @@ class AbstractTest(unittest.TestCase): byte_list_map={0 : [], 1 : [1], 2 : [1, 2]}, ) + self.nested_lists_i32x2 = NestedListsI32x2( + [ + [ 1, 1, 2 ], + [ 2, 7, 9 ], + [ 3, 5, 8 ] + ] + ) + + self.nested_lists_i32x3 = NestedListsI32x3( + [ + [ + [ 2, 7, 9 ], + [ 3, 5, 8 ] + ], + [ + [ 1, 1, 2 ], + [ 1, 4, 9 ] + ] + ] + ) + + self.nested_mixedx2 = NestedMixedx2( int_set_list=[ + set([1,2,3]), + set([1,4,9]), + set([1,2,3,5,8,13,21]), + set([-1, 0, 1]) + ], + # note, the sets below are sets of chars, since the strings are iterated + map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') }, + map_int_strset_list=[ + { 10:set('abc'), 20:set('def'), 30:set('GHI') }, + { 100:set('lmn'), 200:set('opq'), 300:set('RST') }, + { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') } + ] + ) + + self.nested_lists_bonk = NestedListsBonk( + [ + [ + [ + Bonk(message='inner A first', type=1), + Bonk(message='inner A second', type=1) + ], + [ + Bonk(message='inner B first', type=2), + Bonk(message='inner B second', type=2) + ] + ] + ] + ) + + self.list_bonks = ListBonks( + [ + Bonk(message='inner A', type=1), + Bonk(message='inner B', type=2), + Bonk(message='inner C', type=0) + ] + ) def _serialize(self, obj): - trans = TTransport.TMemoryBuffer() - prot = self.protocol_factory.getProtocol(trans) - obj.write(prot) - return trans.getvalue() + trans = TTransport.TMemoryBuffer() + prot = self.protocol_factory.getProtocol(trans) + obj.write(prot) + return trans.getvalue() def _deserialize(self, objtype, data): - prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) - ret = objtype() - ret.read(prot) - return ret + prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) + ret = objtype() + ret.read(prot) + return ret def testForwards(self): - obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj)) - self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both) - self.assertEquals(obj.end_in_both, self.v1obj.end_in_both) + obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj)) + self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v1obj.end_in_both) def testBackwards(self): - obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj)) - self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both) - self.assertEquals(obj.end_in_both, self.v2obj.end_in_both) + obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj)) + self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v2obj.end_in_both) def testSerializeV1(self): obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj)) @@ -152,20 +215,57 @@ class AbstractTest(unittest.TestCase): def testBools(self): self.assertNotEquals(self.bools, self.bools_flipped) + self.assertNotEquals(self.bools, self.v1obj) obj = self._deserialize(Bools, self._serialize(self.bools)) self.assertEquals(obj, self.bools) obj = self._deserialize(Bools, self._serialize(self.bools_flipped)) self.assertEquals(obj, self.bools_flipped) + rep = repr(self.bools) + self.assertTrue(len(rep) > 0) def testLargeDeltas(self): # test large field deltas (meaningful in CompactProto only) obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas)) self.assertEquals(obj, self.large_deltas) + rep = repr(self.large_deltas) + self.assertTrue(len(rep) > 0) + + def testNestedListsI32x2(self): + obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2)) + self.assertEquals(obj, self.nested_lists_i32x2) + rep = repr(self.nested_lists_i32x2) + self.assertTrue(len(rep) > 0) + + def testNestedListsI32x3(self): + obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3)) + self.assertEquals(obj, self.nested_lists_i32x3) + rep = repr(self.nested_lists_i32x3) + self.assertTrue(len(rep) > 0) + + def testNestedMixedx2(self): + obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2)) + self.assertEquals(obj, self.nested_mixedx2) + rep = repr(self.nested_mixedx2) + self.assertTrue(len(rep) > 0) + + def testNestedListsBonk(self): + obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk)) + self.assertEquals(obj, self.nested_lists_bonk) + rep = repr(self.nested_lists_bonk) + self.assertTrue(len(rep) > 0) + + def testListBonks(self): + obj = self._deserialize(ListBonks, self._serialize(self.list_bonks)) + self.assertEquals(obj, self.list_bonks) + rep = repr(self.list_bonks) + self.assertTrue(len(rep) > 0) def testCompactStruct(self): # test large field deltas (meaningful in CompactProto only) obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct)) self.assertEquals(obj, self.compact_struct) + rep = repr(self.compact_struct) + self.assertTrue(len(rep) > 0) class NormalBinaryTest(AbstractTest): protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() diff --git a/test/py/TestClient.py b/test/py/TestClient.py index 6429ec37..e5d43269 100755 --- a/test/py/TestClient.py +++ b/test/py/TestClient.py @@ -20,23 +20,16 @@ # import sys, glob -sys.path.insert(0, './gen-py') sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) -from ThriftTest import ThriftTest -from ThriftTest.ttypes import * -from thrift.transport import TTransport -from thrift.transport import TSocket -from thrift.transport import THttpClient -from thrift.transport import TZlibTransport -from thrift.protocol import TBinaryProtocol -from thrift.protocol import TCompactProtocol import unittest import time from optparse import OptionParser - parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', + default='gen-py', + help='include this local directory in sys.path for locating generated code') parser.add_option("--port", type="int", dest="port", help="connect to server at port") parser.add_option("--host", type="string", dest="host", @@ -60,6 +53,17 @@ parser.add_option('--proto', dest="proto", type="string", parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary') options, args = parser.parse_args() +sys.path.insert(0, options.genpydir) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.transport import THttpClient +from thrift.transport import TZlibTransport +from thrift.protocol import TBinaryProtocol +from thrift.protocol import TCompactProtocol + class AbstractTest(unittest.TestCase): def setUp(self): if options.http_path: @@ -176,6 +180,9 @@ class AbstractTest(unittest.TestCase): except Xception, x: self.assertEqual(x.errorCode, 1001) self.assertEqual(x.message, 'Xception') + # ensure exception's repr method works + x_repr = repr(x) + self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')') try: self.client.testException("throw_undeclared") @@ -225,4 +232,4 @@ class OwnArgsTestProgram(unittest.TestProgram): self.createTests() if __name__ == "__main__": - OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) + OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1)) diff --git a/test/py/TestEof.py b/test/py/TestEof.py index 7ff0b427..a9d81f1a 100755 --- a/test/py/TestEof.py +++ b/test/py/TestEof.py @@ -20,7 +20,12 @@ # import sys, glob -sys.path.insert(0, './gen-py') +from optparse import OptionParser +parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py') +options, args = parser.parse_args() +del sys.argv[1:] # clean up hack so unittest doesn't complain +sys.path.insert(0, options.genpydir) sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) from ThriftTest import ThriftTest diff --git a/test/py/TestServer.py b/test/py/TestServer.py index fa627650..6f4af440 100755 --- a/test/py/TestServer.py +++ b/test/py/TestServer.py @@ -20,24 +20,13 @@ # from __future__ import division import sys, glob, time -sys.path.insert(0, './gen-py') sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) from optparse import OptionParser -from ThriftTest import ThriftTest -from ThriftTest.ttypes import * -from thrift.transport import TTransport -from thrift.transport import TSocket -from thrift.transport import TZlibTransport -from thrift.protocol import TBinaryProtocol -from thrift.protocol import TCompactProtocol -from thrift.server import TServer, TNonblockingServer, THttpServer - -PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory, - 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory, - 'compact': TCompactProtocol.TCompactProtocolFactory} - parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', + default='gen-py', + help='include this local directory in sys.path for locating generated code') parser.add_option("--port", type="int", dest="port", help="port number for server to listen on") parser.add_option("--zlib", action="store_true", dest="zlib", @@ -55,6 +44,21 @@ parser.add_option('--proto', dest="proto", type="string", parser.set_defaults(port=9090, verbose=1, proto='binary') options, args = parser.parse_args() +sys.path.insert(0, options.genpydir) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.transport import TZlibTransport +from thrift.protocol import TBinaryProtocol +from thrift.protocol import TCompactProtocol +from thrift.server import TServer, TNonblockingServer, THttpServer + +PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory, + 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory, + 'compact': TCompactProtocol.TCompactProtocolFactory} + class TestHandler: def testVoid(self): @@ -105,7 +109,7 @@ class TestHandler: x.message = str raise x elif str == "throw_undeclared": - raise ValueError("foo") + raise ValueError("Exception test PASSES.") def testOneway(self, seconds): if options.verbose > 1: @@ -206,7 +210,10 @@ elif server_type == "TProcessPoolServer": worker.terminate() if options.verbose > 0: print 'Requesting server to stop()' - server.stop() + try: + server.stop() + except: + pass signal.signal(signal.SIGALRM, clean_shutdown) signal.alarm(2) set_alarm() diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py index 2f7353fb..b9bdf27e 100755 --- a/test/py/TestSocket.py +++ b/test/py/TestSocket.py @@ -20,7 +20,12 @@ # import sys, glob -sys.path.insert(0, './gen-py') +from optparse import OptionParser +parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py') +options, args = parser.parse_args() +del sys.argv[1:] # clean up hack so unittest doesn't complain +sys.path.insert(0, options.genpydir) sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) from ThriftTest import ThriftTest diff --git a/test/py/TestSyntax.py b/test/py/TestSyntax.py index df67d485..9f71cf58 100755 --- a/test/py/TestSyntax.py +++ b/test/py/TestSyntax.py @@ -20,7 +20,12 @@ # import sys, glob -sys.path.insert(0, './gen-py') +from optparse import OptionParser +parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py') +options, args = parser.parse_args() +del sys.argv[1:] # clean up hack so unittest doesn't complain +sys.path.insert(0, options.genpydir) sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) # Just import these generated files to make sure they are syntactically valid -- 2.17.1