THRIFT-212. python: Make TFramedTransport implement CReadableTransport
authorDavid Reiss <dreiss@apache.org>
Sat, 31 Jan 2009 21:39:25 +0000 (21:39 +0000)
committerDavid Reiss <dreiss@apache.org>
Sat, 31 Jan 2009 21:39:25 +0000 (21:39 +0000)
This involved adding a few methods to provide lower-level access to the
internal read buffer.  This will allow us to use TBinaryProtocolAccelerated
with TFramedTransport.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@739632 13f79535-47bb-0310-9956-ffa450edef68

lib/py/src/transport/TTransport.py
test/py/SerializationTest.py

index ddb368f..67f97ba 100644 (file)
@@ -228,7 +228,7 @@ class TFramedTransportFactory:
     return framed
 
 
-class TFramedTransport(TTransportBase):
+class TFramedTransport(TTransportBase, CReadableTransport):
 
   """Class that wraps another transport and frames its I/O when writing."""
 
@@ -274,3 +274,18 @@ class TFramedTransport(TTransportBase):
     buf = pack("!i", wsz) + wout
     self.__trans.write(buf)
     self.__trans.flush()
+
+  # Implement the CReadableTransport interface.
+  @property
+  def cstringio_buf(self):
+    return self.__rbuf
+
+  def cstringio_refill(self, prefix, reqlen):
+    # self.__rbuf will already be empty here because fastbinary doesn't
+    # ask for a refill until the previous buffer is empty.  Therefore,
+    # we can start reading new frames immediately.
+    while len(prefix) < reqlen:
+      readFrame()
+      prefix += self.__rbuf.getvalue()
+    self.__rbuf = StringIO(prefix)
+    return self.__rbuf
index 4be8b8c..a99bce6 100755 (executable)
@@ -64,12 +64,49 @@ class AcceleratedBinaryTest(AbstractTest):
   protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
 
 
+class AcceleratedFramedTest(unittest.TestCase):
+  def testSplit(self):
+    """Test FramedTransport and BinaryProtocolAccelerated
+
+    Tests that TBinaryProtocolAccelerated and TFramedTransport
+    play nicely together when a read spans a frame"""
+
+    protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
+    bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1))
+
+    databuf = TTransport.TMemoryBuffer()
+    prot = protocol_factory.getProtocol(databuf)
+    prot.writeI32(42)
+    prot.writeString(bigstring)
+    prot.writeI16(24)
+    data = databuf.getvalue()
+    cutpoint = len(data)/2
+    parts = [ data[:cutpoint], data[cutpoint:] ]
+
+    framed_buffer = TTransport.TMemoryBuffer()
+    framed_writer = TTransport.TFramedTransport(framed_buffer)
+    for part in parts:
+      framed_writer.write(part)
+      framed_writer.flush()
+    self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
+
+    # Recreate framed_buffer so we can read from it.
+    framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
+    framed_reader = TTransport.TFramedTransport(framed_buffer)
+    prot = protocol_factory.getProtocol(framed_reader)
+    self.assertEqual(prot.readI32(), 42)
+    self.assertEqual(prot.readString(), bigstring)
+    self.assertEqual(prot.readI16(), 24)
+
+
+
 def suite():
   suite = unittest.TestSuite()
   loader = unittest.TestLoader()
 
   suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
   suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
+  suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
   return suite
 
 if __name__ == "__main__":