From: Aditya Agarwal Date: Tue, 6 Feb 2007 01:14:33 +0000 (+0000) Subject: -- Protocol and transport factories now wrap around a single protocol/transport X-Git-Tag: 0.2.0~1497 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=5c468196dc7a1d6035df83fb656ab3f214657421;p=common%2Fthrift.git -- Protocol and transport factories now wrap around a single protocol/transport Summary: - This is an analagous to the C++ change made in r31441 Reviewed By: cheever, mcslee git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664975 13f79535-47bb-0310-9956-ffa450edef68 --- diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index 8275e347..7fdfdda4 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -5,8 +5,8 @@ class TBinaryProtocol(TProtocolBase): """Binary implementation of the Thrift protocol driver.""" - def __init__(self, itrans, otrans=None): - TProtocolBase.__init__(self, itrans, otrans) + def __init__(self, trans): + TProtocolBase.__init__(self, trans) def writeMessageBegin(self, name, type, seqid): self.writeString(name) @@ -62,27 +62,27 @@ class TBinaryProtocol(TProtocolBase): def writeByte(self, byte): buff = pack("!b", byte) - self.otrans.write(buff) + self.trans.write(buff) def writeI16(self, i16): buff = pack("!h", i16) - self.otrans.write(buff) + self.trans.write(buff) def writeI32(self, i32): buff = pack("!i", i32) - self.otrans.write(buff) + self.trans.write(buff) def writeI64(self, i64): buff = pack("!q", i64) - self.otrans.write(buff) + self.trans.write(buff) def writeDouble(self, dub): buff = pack("!d", dub) - self.otrans.write(buff) + self.trans.write(buff) def writeString(self, str): self.writeI32(len(str)) - self.otrans.write(str) + self.trans.write(str) def readMessageBegin(self): name = self.readString() @@ -141,36 +141,36 @@ class TBinaryProtocol(TProtocolBase): return True def readByte(self): - buff = self.itrans.readAll(1) + buff = self.trans.readAll(1) val, = unpack('!b', buff) return val def readI16(self): - buff = self.itrans.readAll(2) + buff = self.trans.readAll(2) val, = unpack('!h', buff) return val def readI32(self): - buff = self.itrans.readAll(4) + buff = self.trans.readAll(4) val, = unpack('!i', buff) return val def readI64(self): - buff = self.itrans.readAll(8) + buff = self.trans.readAll(8) val, = unpack('!q', buff) return val def readDouble(self): - buff = self.itrans.readAll(8) + buff = self.trans.readAll(8) val, = unpack('!d', buff) return val def readString(self): len = self.readI32() - str = self.itrans.readAll(len) + str = self.trans.readAll(len) return str class TBinaryProtocolFactory: - def getIOProtocols(self, itrans, otrans): - prot = TBinaryProtocol(itrans, otrans) - return (prot, prot) + def getProtocol(self, trans): + prot = TBinaryProtocol(trans) + return prot diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index cc9517cf..15206b0d 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -25,10 +25,8 @@ class TProtocolBase: """Base class for Thrift protocol driver.""" - def __init__(self, itrans, otrans=None): - self.itrans = self.otrans = itrans - if otrans != None: - self.otrans = otrans + def __init__(self, trans): + self.trans = trans def writeMessageBegin(self, name, type, seqid): pass @@ -191,5 +189,5 @@ class TProtocolBase: self.readListEnd() class TProtocolFactory: - def getIOProtocols(self, itrans, otrans): + def getProtocol(self, trans): pass diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py index 5b9e1fd6..48a4fcbe 100644 --- a/lib/py/src/server/TServer.py +++ b/lib/py/src/server/TServer.py @@ -11,17 +11,34 @@ class TServer: """Base interface for a server, which must have a serve method.""" - def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): + """ 3 constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory)""" + def __init__(self, *args): + print args + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): self.processor = processor self.serverTransport = serverTransport - if transportFactory == None: - self.transportFactory = TTransport.TTransportFactoryBase() - else: - self.transportFactory = transportFactory - if protocolFactory == None: - self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory() - else: - self.protocolFactory = protocolFactory + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory def serve(self): pass @@ -30,15 +47,17 @@ class TSimpleServer(TServer): """Simple single-threaded server that just pumps around one transport.""" - def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): - TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory) + def __init__(self, *args): + TServer.__init__(self, *args) def serve(self): self.serverTransport.listen() while True: client = self.serverTransport.accept() - (itrans, otrans) = self.transportFactory.getIOTransports(client) - (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.oututProtocolFactory.getProtocol(otrans) try: while True: self.processor.process(iprot, oprot) @@ -54,8 +73,8 @@ class TThreadedServer(TServer): """Threaded server that spawns a new thread per each connection.""" - def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): - TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory) + def __init__(self, *args): + TServer.__init__(self, *args) def serve(self): self.serverTransport.listen() @@ -68,8 +87,10 @@ class TThreadedServer(TServer): print '%s, %s, %s,' % (type(x), x, traceback.format_exc()) def handle(self, client): - (itrans, otrans) = self.transportFactory.getIOTransports(client) - (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.oututProtocolFactory.getProtocol(otrans) try: while True: self.processor.process(iprot, oprot) @@ -85,8 +106,8 @@ class TThreadPoolServer(TServer): """Server with a fixed size pool of threads which service requests.""" - def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): - TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory) + def __init__(self, *args): + TServer.__init__(self, *args) self.clients = Queue.Queue() self.threads = 10 @@ -105,8 +126,10 @@ class TThreadPoolServer(TServer): def serveClient(self, client): """Process input/output from a client for as long as possible""" - (itrans, otrans) = self.transportFactory.getIOTransports(client) - (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.oututProtocolFactory.getProtocol(otrans) try: while True: self.processor.process(iprot, oprot) diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index e6e202b1..502b3270 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -55,16 +55,16 @@ class TTransportFactoryBase: """Base class for a Transport Factory""" - def getIOTransports(self, trans): - return (trans, trans) + def getTransport(self, trans): + return trans class TBufferedTransportFactory: """Factory transport that builds buffered transports""" - def getIOTransports(self, trans): + def getTransport(self, trans): buffered = TBufferedTransport(trans) - return (buffered, buffered) + return buffered class TBufferedTransport(TTransportBase): @@ -99,9 +99,9 @@ class TFramedTransportFactory: """Factory transport that builds framed transports""" - def getIOTransports(self, trans): + def getTransport(self, trans): framed = TFramedTransport(trans) - return (framed, framed) + return framed class TFramedTransport(TTransportBase):