THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings
Patch: Will Pierce

git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift
index b6cd939..17b0295 100644
--- a/test/ThriftTest.thrift
+++ b/test/ThriftTest.thrift
@@ -208,3 +208,21 @@
   3000: VersioningTestV2 vertwo3000,
   4000: list<i32> big_numbers
 }
+
+struct NestedListsI32x2 {
+  1: list<list<i32>> integerlist
+}
+struct NestedListsI32x3 {
+  1: list<list<list<i32>>> integerlist
+}
+struct NestedMixedx2 {
+  1: list<set<i32>> int_set_list
+  2: map<i32,set<string>> map_int_strset
+  3: list<map<i32,set<string>>> map_int_strset_list
+}
+struct ListBonks {
+  1: list<Bonk> bonk
+}
+struct NestedListsBonk {
+  1: list<list<list<Bonk>>> bonk
+}
diff --git a/test/py/Makefile.am b/test/py/Makefile.am
index 63b7a89..2317ef6 100644
--- a/test/py/Makefile.am
+++ b/test/py/Makefile.am
@@ -19,22 +19,30 @@
 
 THRIFT = $(top_srcdir)/compiler/cpp/thrift
 
-py_unit_tests =                                 \
-        SerializationTest.py                    \
-        TestEof.py                              \
-        TestSyntax.py                           \
-        RunClientServer.py
+py_unit_tests = RunClientServer.py
 
 thrift_gen =                                    \
         gen-py/ThriftTest/__init__.py           \
-        gen-py/DebugProtoTest/__init__.py
+        gen-py/DebugProtoTest/__init__.py \
+        gen-py-default/ThriftTest/__init__.py           \
+        gen-py-default/DebugProtoTest/__init__.py \
+        gen-py-slots/ThriftTest/__init__.py           \
+        gen-py-slots/DebugProtoTest/__init__.py \
+        gen-py-newstyle/ThriftTest/__init__.py           \
+        gen-py-newstyle/DebugProtoTest/__init__.py \
+        gen-py-newstyleslots/ThriftTest/__init__.py           \
+        gen-py-newstyleslots/DebugProtoTest/__init__.py \
+        gen-py-dynamic/ThriftTest/__init__.py           \
+        gen-py-dynamic/DebugProtoTest/__init__.py \
+        gen-py-dynamicslots/ThriftTest/__init__.py           \
+        gen-py-dynamicslots/DebugProtoTest/__init__.py
 
 helper_scripts=                                 \
         TestClient.py                           \
         TestServer.py
 
 check_SCRIPTS=                                  \
-        $(thrift_gen)                           \
+        $(thrift_gen) \
         $(py_unit_tests)                        \
         $(helper_scripts)
 
@@ -42,7 +50,29 @@
 
 
 gen-py/%/__init__.py: ../%.thrift
-	$(THRIFT) --gen py $<
+	$(THRIFT) --gen py  $<
+	test -d gen-py-default || mkdir gen-py-default
+	$(THRIFT) --gen py -out gen-py-default $<
+
+gen-py-slots/%/__init__.py: ../%.thrift
+	test -d gen-py-slots || mkdir gen-py-slots
+	$(THRIFT) --gen py:slots -out gen-py-slots $<
+
+gen-py-newstyle/%/__init__.py: ../%.thrift
+	test -d gen-py-newstyle || mkdir gen-py-newstyle
+	$(THRIFT) --gen py:new_style -out gen-py-newstyle $<
+
+gen-py-newstyleslots/%/__init__.py: ../%.thrift
+	test -d gen-py-newstyleslots || mkdir gen-py-newstyleslots
+	$(THRIFT) --gen py:new_style,slots -out gen-py-newstyleslots $<
+
+gen-py-dynamic/%/__init__.py: ../%.thrift
+	test -d gen-py-dynamic || mkdir gen-py-dynamic
+	$(THRIFT) --gen py:dynamic -out gen-py-dynamic $<
+
+gen-py-dynamicslots/%/__init__.py: ../%.thrift
+	test -d gen-py-dynamicslots || mkdir gen-py-dynamicslots
+	$(THRIFT) --gen py:dynamic,slots -out gen-py-dynamicslots $<
 
 clean-local:
-	$(RM) -r gen-py
+	$(RM) -r gen-py gen-py-slots gen-py-default gen-py-newstyle gen-py-newstyleslots gen-py-dynamic gen-py-dynamicslots
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index 633856f..8a7fda6 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -28,6 +28,9 @@
 from optparse import OptionParser
 
 parser = OptionParser()
+parser.add_option('--genpydirs', type='string', dest='genpydirs',
+    default='default,slots,newstyle,newstyleslots,dynamic,dynamicslots',
+    help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests')
 parser.add_option("--port", type="int", dest="port", default=9090,
     help="port number for server to listen on")
 parser.add_option('-v', '--verbose', action="store_const", 
@@ -39,11 +42,15 @@
 parser.set_defaults(verbose=1)
 options, args = parser.parse_args()
 
+generated_dirs = []
+for gp_dir in options.genpydirs.split(','):
+  generated_dirs.append('gen-py-%s' % (gp_dir))
+
+SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py']
 FRAMED = ["TNonblockingServer"]
 SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
 SKIP_SSL = ['TNonblockingServer', 'THttpServer']
-EXTRA_DELAY = ['TProcessPoolServer']
-EXTRA_SLEEP = 3.5
+EXTRA_DELAY = dict(TProcessPoolServer=3.5)
 
 PROTOS= [
     'accel',
@@ -85,11 +92,21 @@
 def relfile(fname):
     return os.path.join(os.path.dirname(__file__), fname)
 
-def runTest(server_class, proto, port, use_zlib, use_ssl):
+def runScriptTest(genpydir, script):
+  script_args = [sys.executable, relfile(script) ]
+  script_args.append('--genpydir=%s' % genpydir)
+  serverproc = subprocess.Popen(script_args)
+  print '\nTesting script: %s\n----' % (' '.join(script_args))
+  ret = subprocess.call(script_args)
+  if ret != 0:
+    raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args)))
+  
+def runServiceTest(genpydir, server_class, proto, port, use_zlib, use_ssl):
   # Build command line arguments
   server_args = [sys.executable, relfile('TestServer.py') ]
   cli_args = [sys.executable, relfile('TestClient.py') ]
   for which in (server_args, cli_args):
+    which.append('--genpydir=%s' % genpydir)
     which.append('--proto=%s' % proto) # accel, binary or compact
     which.append('--port=%d' % port) # default to 9090
     if use_zlib:
@@ -110,7 +127,7 @@
   if options.verbose > 0:
     print 'Testing server %s: %s' % (server_class, ' '.join(server_args))
   serverproc = subprocess.Popen(server_args)
-  time.sleep(0.2)
+  time.sleep(0.15)
   try:
     if options.verbose > 0:
       print 'Testing client: %s' % (' '.join(cli_args))
@@ -124,29 +141,47 @@
       print 'FAIL: Server process (%s) failed with retcode %d' % (' '.join(server_args), serverproc.returncode)
       raise Exception('Server subprocess %s died, args: %s' % (server_class, ' '.join(server_args)))
     else:
-      if server_class in EXTRA_DELAY:
-        if options.verbose > 0:
-          print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
-                proto, use_zlib, use_ssl, EXTRA_SLEEP)
-        time.sleep(EXTRA_SLEEP)
+      extra_sleep = EXTRA_DELAY.get(server_class, 0)
+      if extra_sleep > 0 and options.verbose > 0:
+        print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
+              proto, use_zlib, use_ssl, extra_sleep)
+        time.sleep(extra_sleep)
       os.kill(serverproc.pid, signal.SIGKILL)
   # wait for shutdown
-  time.sleep(0.1)
+  time.sleep(0.05)
 
 test_count = 0
+# run tests without a client/server first
+print '----------------'
+print ' Executing individual test scripts with various generated code directories'
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Scripts to be tested: ' + ', '.join(SCRIPTS)
+print '----------------'
+for genpydir in generated_dirs:
+  for script in SCRIPTS:
+    runScriptTest(genpydir, script)
+  
+print '----------------'
+print ' Executing Client/Server tests with various generated code directories'
+print ' Servers to be tested: ' + ', '.join(SERVERS)
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Protocols to be tested: ' + ', '.join(PROTOS)
+print ' Options to be tested: ZLIB(yes/no), SSL(yes/no)'
+print '----------------'
 for try_server in SERVERS:
-  for try_proto in PROTOS:
-    for with_zlib in (False, True):
-      # skip any servers that don't work with the Zlib transport
-      if with_zlib and try_server in SKIP_ZLIB:
-        continue
-      for with_ssl in (False, True):
-        # skip any servers that don't work with SSL
-        if with_ssl and try_server in SKIP_SSL:
+  for genpydir in generated_dirs:
+    for try_proto in PROTOS:
+      for with_zlib in (False, True):
+        # skip any servers that don't work with the Zlib transport
+        if with_zlib and try_server in SKIP_ZLIB:
           continue
-        test_count += 1
-        if options.verbose > 0:
-          print '\nTest run #%d:  Server=%s,  Proto=%s,  zlib=%s,  SSL=%s' % (test_count, try_server, try_proto, with_zlib, with_ssl)
-        runTest(try_server, try_proto, options.port, with_zlib, with_ssl)
-        if options.verbose > 0:
-          print 'OK: Finished  %s / %s proto / zlib=%s / SSL=%s.   %d combinations tested.' % (try_server, try_proto, with_zlib, with_ssl, test_count)
+        for with_ssl in (False, True):
+          # skip any servers that don't work with SSL
+          if with_ssl and try_server in SKIP_SSL:
+            continue
+          test_count += 1
+          if options.verbose > 0:
+            print '\nTest run #%d:  (includes %s) Server=%s,  Proto=%s,  zlib=%s,  SSL=%s' % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)
+          runServiceTest(genpydir, try_server, try_proto, options.port, with_zlib, with_ssl)
+          if options.verbose > 0:
+            print 'OK: Finished (includes %s)  %s / %s proto / zlib=%s / SSL=%s.   %d combinations tested.' % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index 3ba76fb..0664146 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest.ttypes import *
@@ -119,28 +124,86 @@
           byte_list_map={0 : [], 1 : [1], 2 : [1, 2]},
           )
 
+      self.nested_lists_i32x2 = NestedListsI32x2(
+                                              [
+                                                [ 1, 1, 2 ],
+                                                [ 2, 7, 9 ],
+                                                [ 3, 5, 8 ]
+                                              ]
+                                            )
+
+      self.nested_lists_i32x3 = NestedListsI32x3(
+                                              [
+                                                [
+                                                  [ 2, 7, 9 ],
+                                                  [ 3, 5, 8 ]
+                                                ],
+                                                [
+                                                  [ 1, 1, 2 ],
+                                                  [ 1, 4, 9 ]
+                                                ]
+                                              ]
+                                            )
+
+      self.nested_mixedx2 = NestedMixedx2( int_set_list=[
+                                            set([1,2,3]),
+                                            set([1,4,9]),
+                                            set([1,2,3,5,8,13,21]),
+                                            set([-1, 0, 1])
+                                            ],
+                                            # note, the sets below are sets of chars, since the strings are iterated
+                                            map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') },
+                                            map_int_strset_list=[
+                                                                 { 10:set('abc'), 20:set('def'), 30:set('GHI') },
+                                                                 { 100:set('lmn'), 200:set('opq'), 300:set('RST') },
+                                                                 { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') }
+                                                                 ]
+                                          )
+
+      self.nested_lists_bonk = NestedListsBonk(
+                                              [
+                                                [
+                                                  [
+                                                    Bonk(message='inner A first', type=1),
+                                                    Bonk(message='inner A second', type=1)
+                                                  ],
+                                                  [
+                                                  Bonk(message='inner B first', type=2),
+                                                  Bonk(message='inner B second', type=2)
+                                                  ]
+                                                ]
+                                              ]
+                                            )
+
+      self.list_bonks = ListBonks(
+                                    [
+                                      Bonk(message='inner A', type=1),
+                                      Bonk(message='inner B', type=2),
+                                      Bonk(message='inner C', type=0)
+                                    ]
+                                  )
 
   def _serialize(self, obj):
-      trans = TTransport.TMemoryBuffer()
-      prot = self.protocol_factory.getProtocol(trans)
-      obj.write(prot)
-      return trans.getvalue()
+    trans = TTransport.TMemoryBuffer()
+    prot = self.protocol_factory.getProtocol(trans)
+    obj.write(prot)
+    return trans.getvalue()
 
   def _deserialize(self, objtype, data):
-      prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
-      ret = objtype()
-      ret.read(prot)
-      return ret
+    prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+    ret = objtype()
+    ret.read(prot)
+    return ret
 
   def testForwards(self):
-      obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
-      self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
-      self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
+    obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
+    self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
+    self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
 
   def testBackwards(self):
-      obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
-      self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
-      self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
+    obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
+    self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
+    self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
 
   def testSerializeV1(self):
     obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
@@ -152,20 +215,57 @@
 
   def testBools(self):
     self.assertNotEquals(self.bools, self.bools_flipped)
+    self.assertNotEquals(self.bools, self.v1obj)
     obj = self._deserialize(Bools, self._serialize(self.bools))
     self.assertEquals(obj, self.bools)
     obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
     self.assertEquals(obj, self.bools_flipped)
+    rep = repr(self.bools)
+    self.assertTrue(len(rep) > 0)
 
   def testLargeDeltas(self):
     # test large field deltas (meaningful in CompactProto only)
     obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
     self.assertEquals(obj, self.large_deltas)
+    rep = repr(self.large_deltas)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsI32x2(self):
+    obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
+    self.assertEquals(obj, self.nested_lists_i32x2)
+    rep = repr(self.nested_lists_i32x2)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsI32x3(self):
+    obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
+    self.assertEquals(obj, self.nested_lists_i32x3)
+    rep = repr(self.nested_lists_i32x3)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedMixedx2(self):
+    obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
+    self.assertEquals(obj, self.nested_mixedx2)
+    rep = repr(self.nested_mixedx2)
+    self.assertTrue(len(rep) > 0)
+
+  def testNestedListsBonk(self):
+    obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
+    self.assertEquals(obj, self.nested_lists_bonk)
+    rep = repr(self.nested_lists_bonk)
+    self.assertTrue(len(rep) > 0)
+
+  def testListBonks(self):
+    obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
+    self.assertEquals(obj, self.list_bonks)
+    rep = repr(self.list_bonks)
+    self.assertTrue(len(rep) > 0)
 
   def testCompactStruct(self):
     # test large field deltas (meaningful in CompactProto only)
     obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
     self.assertEquals(obj, self.compact_struct)
+    rep = repr(self.compact_struct)
+    self.assertTrue(len(rep) > 0)
 
 class NormalBinaryTest(AbstractTest):
   protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 6429ec3..e5d4326 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -20,23 +20,16 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import THttpClient
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
 import unittest
 import time
 from optparse import OptionParser
 
-
 parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+                  default='gen-py',
+                  help='include this local directory in sys.path for locating generated code')
 parser.add_option("--port", type="int", dest="port",
     help="connect to server at port")
 parser.add_option("--host", type="string", dest="host",
@@ -60,6 +53,17 @@
 parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
 options, args = parser.parse_args()
 
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import THttpClient
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+
 class AbstractTest(unittest.TestCase):
   def setUp(self):
     if options.http_path:
@@ -176,6 +180,9 @@
     except Xception, x:
       self.assertEqual(x.errorCode, 1001)
       self.assertEqual(x.message, 'Xception')
+      # ensure exception's repr method works
+      x_repr = repr(x)
+      self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')')
 
     try:
       self.client.testException("throw_undeclared")
@@ -225,4 +232,4 @@
         self.createTests()
 
 if __name__ == "__main__":
-  OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+  OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1))
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
index 7ff0b42..a9d81f1 100755
--- a/test/py/TestEof.py
+++ b/test/py/TestEof.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest import ThriftTest
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index fa62765..6f4af44 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -20,24 +20,13 @@
 #
 from __future__ import division
 import sys, glob, time
-sys.path.insert(0, './gen-py')
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 from optparse import OptionParser
 
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
-from thrift.server import TServer, TNonblockingServer, THttpServer
-
-PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
-    'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
-    'compact': TCompactProtocol.TCompactProtocolFactory}
-
 parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+                  default='gen-py',
+                  help='include this local directory in sys.path for locating generated code')
 parser.add_option("--port", type="int", dest="port",
     help="port number for server to listen on")
 parser.add_option("--zlib", action="store_true", dest="zlib",
@@ -55,6 +44,21 @@
 parser.set_defaults(port=9090, verbose=1, proto='binary')
 options, args = parser.parse_args()
 
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+from thrift.server import TServer, TNonblockingServer, THttpServer
+
+PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
+    'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
+    'compact': TCompactProtocol.TCompactProtocolFactory}
+
 class TestHandler:
 
   def testVoid(self):
@@ -105,7 +109,7 @@
       x.message = str
       raise x
     elif str == "throw_undeclared":
-      raise ValueError("foo")
+      raise ValueError("Exception test PASSES.")
 
   def testOneway(self, seconds):
     if options.verbose > 1:
@@ -206,7 +210,10 @@
         worker.terminate()
       if options.verbose > 0:
         print 'Requesting server to stop()'
-      server.stop()
+      try:
+        server.stop()
+      except:
+        pass
     signal.signal(signal.SIGALRM, clean_shutdown)
     signal.alarm(2)
   set_alarm()
diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py
index 2f7353f..b9bdf27 100755
--- a/test/py/TestSocket.py
+++ b/test/py/TestSocket.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 from ThriftTest import ThriftTest
diff --git a/test/py/TestSyntax.py b/test/py/TestSyntax.py
index df67d48..9f71cf5 100755
--- a/test/py/TestSyntax.py
+++ b/test/py/TestSyntax.py
@@ -20,7 +20,12 @@
 #
 
 import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
 sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
 
 # Just import these generated files to make sure they are syntactically valid