From 4ac459ff3a447786220d01402756b70eb21329a1 Mon Sep 17 00:00:00 2001 From: Mark Slee Date: Wed, 25 Oct 2006 21:39:01 +0000 Subject: [PATCH] Fix python server bugs and go to new protocol wraps transport model Reviewed By: ccheever git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664849 13f79535-47bb-0310-9956-ffa450edef68 --- lib/py/src/Thrift.py | 2 +- lib/py/src/protocol/TBinaryProtocol.py | 170 +++++++++++++------------ lib/py/src/protocol/TProtocol.py | 134 +++++++++---------- lib/py/src/server/TServer.py | 40 ++++-- 4 files changed, 187 insertions(+), 159 deletions(-) diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py index 0c4a458d..1be84e29 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/Thrift.py @@ -2,5 +2,5 @@ class TProcessor: """Base class for procsessor, which works on two streams.""" - def process(itrans, otrans): + def process(iprot, oprot): pass diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index 25f3218a..8275e347 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -5,164 +5,172 @@ class TBinaryProtocol(TProtocolBase): """Binary implementation of the Thrift protocol driver.""" - def writeMessageBegin(self, otrans, name, type, seqid): - self.writeString(otrans, name) - self.writeByte(otrans, type) - self.writeI32(otrans, seqid) + def __init__(self, itrans, otrans=None): + TProtocolBase.__init__(self, itrans, otrans) - def writeMessageEnd(self, otrans): + def writeMessageBegin(self, name, type, seqid): + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): pass - def writeStructBegin(self, otrans, name): + def writeStructBegin(self, name): pass - def writeStructEnd(self, otrans): + def writeStructEnd(self): pass - def writeFieldBegin(self, otrans, name, type, id): - self.writeByte(otrans, type) - self.writeI16(otrans, id) + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) - def writeFieldEnd(self, otrans): + def writeFieldEnd(self): pass - def writeFieldStop(self, otrans): - self.writeByte(otrans, TType.STOP); + def writeFieldStop(self): + self.writeByte(TType.STOP); - def writeMapBegin(self, otrans, ktype, vtype, size): - self.writeByte(otrans, ktype) - self.writeByte(otrans, vtype) - self.writeI32(otrans, size) + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) - def writeMapEnd(self, otrans): + def writeMapEnd(self): pass - def writeListBegin(self, otrans, etype, size): - self.writeByte(otrans, etype) - self.writeI32(otrans, size) + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) - def writeListEnd(self, otrans): + def writeListEnd(self): pass - def writeSetBegin(self, otrans, etype, size): - self.writeByte(otrans, etype) - self.writeI32(otrans, size) + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) - def writeSetEnd(self, otrans): + def writeSetEnd(self): pass - def writeBool(self, otrans, bool): + def writeBool(self, bool): if bool: - self.writeByte(otrans, 1) + self.writeByte(1) else: - self.writeByte(otrans, 0) + self.writeByte(0) - def writeByte(self, otrans, byte): + def writeByte(self, byte): buff = pack("!b", byte) - otrans.write(buff) + self.otrans.write(buff) - def writeI16(self, otrans, i16): + def writeI16(self, i16): buff = pack("!h", i16) - otrans.write(buff) + self.otrans.write(buff) - def writeI32(self, otrans, i32): + def writeI32(self, i32): buff = pack("!i", i32) - otrans.write(buff) + self.otrans.write(buff) - def writeI64(self, otrans, i64): + def writeI64(self, i64): buff = pack("!q", i64) - otrans.write(buff) + self.otrans.write(buff) - def writeDouble(self, otrans, dub): + def writeDouble(self, dub): buff = pack("!d", dub) - otrans.write(buff) + self.otrans.write(buff) - def writeString(self, otrans, str): - self.writeI32(otrans, len(str)) - otrans.write(str) + def writeString(self, str): + self.writeI32(len(str)) + self.otrans.write(str) - def readMessageBegin(self, itrans): - name = self.readString(itrans) - type = self.readByte(itrans) - seqid = self.readI32(itrans) + def readMessageBegin(self): + name = self.readString() + type = self.readByte() + seqid = self.readI32() return (name, type, seqid) - def readMessageEnd(self, itrans): + def readMessageEnd(self): pass - def readStructBegin(self, itrans): + def readStructBegin(self): pass - def readStructEnd(self, itrans): + def readStructEnd(self): pass - def readFieldBegin(self, itrans): - type = self.readByte(itrans) + def readFieldBegin(self): + type = self.readByte() if type == TType.STOP: return (None, type, 0) - id = self.readI16(itrans) + id = self.readI16() return (None, type, id) - def readFieldEnd(self, itrans): + def readFieldEnd(self): pass - def readMapBegin(self, itrans): - ktype = self.readByte(itrans) - vtype = self.readByte(itrans) - size = self.readI32(itrans) + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() return (ktype, vtype, size) - def readMapEnd(self, itrans): + def readMapEnd(self): pass - def readListBegin(self, itrans): - etype = self.readByte(itrans) - size = self.readI32(itrans) + def readListBegin(self): + etype = self.readByte() + size = self.readI32() return (etype, size) - def readListEnd(self, itrans): + def readListEnd(self): pass - def readSetBegin(self, itrans): - etype = self.readByte(itrans) - size = self.readI32(itrans) + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() return (etype, size) - def readSetEnd(self, itrans): + def readSetEnd(self): pass - def readBool(self, itrans): - byte = self.readByte(itrans) + def readBool(self): + byte = self.readByte() if byte == 0: return False return True - def readByte(self, itrans): - buff = itrans.readAll(1) + def readByte(self): + buff = self.itrans.readAll(1) val, = unpack('!b', buff) return val - def readI16(self, itrans): - buff = itrans.readAll(2) + def readI16(self): + buff = self.itrans.readAll(2) val, = unpack('!h', buff) return val - def readI32(self, itrans): - buff = itrans.readAll(4) + def readI32(self): + buff = self.itrans.readAll(4) val, = unpack('!i', buff) return val - def readI64(self, itrans): - buff = itrans.readAll(8) + def readI64(self): + buff = self.itrans.readAll(8) val, = unpack('!q', buff) return val - def readDouble(self, itrans): - buff = itrans.readAll(8) + def readDouble(self): + buff = self.itrans.readAll(8) val, = unpack('!d', buff) return val - def readString(self, itrans): - len = self.readI32(itrans) - str = itrans.readAll(len) + def readString(self): + len = self.readI32() + str = self.itrans.readAll(len) return str + +class TBinaryProtocolFactory: + def getIOProtocols(self, itrans, otrans): + prot = TBinaryProtocol(itrans, otrans) + return (prot, prot) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index 0b480d30..cc9517cf 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -25,165 +25,171 @@ class TProtocolBase: """Base class for Thrift protocol driver.""" - def writeMessageBegin(self, otrans, name, type, seqid): + def __init__(self, itrans, otrans=None): + self.itrans = self.otrans = itrans + if otrans != None: + self.otrans = otrans + + def writeMessageBegin(self, name, type, seqid): pass - def writeMessageEnd(self, otrans): + def writeMessageEnd(self): pass - def writeStructBegin(self, otrans, name): + def writeStructBegin(self, name): pass - def writeStructEnd(self, otrans): + def writeStructEnd(self): pass - def writeFieldBegin(self, otrans, name, type, id): + def writeFieldBegin(self, name, type, id): pass - def writeFieldEnd(self, otrans): + def writeFieldEnd(self): pass - def writeFieldStop(self, otrans): + def writeFieldStop(self): pass - def writeMapBegin(self, otrans, ktype, vtype, size): + def writeMapBegin(self, ktype, vtype, size): pass - def writeMapEnd(self, otrans): + def writeMapEnd(self): pass - def writeListBegin(self, otrans, etype, size): + def writeListBegin(self, etype, size): pass - def writeListEnd(self, otrans): + def writeListEnd(self): pass - def writeSetBegin(self, otrans, etype, size): + def writeSetBegin(self, etype, size): pass - def writeSetEnd(self, otrans): + def writeSetEnd(self): pass - def writeBool(self, otrans, bool): + def writeBool(self, bool): pass - def writeByte(self, otrans, byte): + def writeByte(self, byte): pass - def writeI16(self, otrans, i16): + def writeI16(self, i16): pass - def writeI32(self, otrans, i32): + def writeI32(self, i32): pass - def writeI64(self, otrans, i64): + def writeI64(self, i64): pass - def writeDouble(self, otrans, dub): + def writeDouble(self, dub): pass - def writeString(self, otrans, str): + def writeString(self, str): pass - def readMessageBegin(self, itrans): + def readMessageBegin(self): pass - def readMessageEnd(self, itrans): + def readMessageEnd(self): pass - def readStructBegin(self, itrans): + def readStructBegin(self): pass - def readStructEnd(self, itrans): + def readStructEnd(self): pass - def readFieldBegin(self, itrans): + def readFieldBegin(self): pass - def readFieldEnd(self, itrans): + def readFieldEnd(self): pass - def readMapBegin(self, itrans): + def readMapBegin(self): pass - def readMapEnd(self, itrans): + def readMapEnd(self): pass - def readListBegin(self, itrans): + def readListBegin(self): pass - def readListEnd(self, itrans): + def readListEnd(self): pass - def readSetBegin(self, itrans): + def readSetBegin(self): pass - def readSetEnd(self, itrans): + def readSetEnd(self): pass - def readBool(self, itrans): + def readBool(self): pass - def readByte(self, itrans): + def readByte(self): pass - def readI16(self, itrans): + def readI16(self): pass - def readI32(self, itrans): + def readI32(self): pass - def readI64(self, itrans): + def readI64(self): pass - def readDouble(self, itrans): + def readDouble(self): pass - def readString(self, itrans): + def readString(self): pass - def skip(self, itrans, type): + def skip(self, type): if type == TType.STOP: return elif type == TType.BOOL: - self.readBool(itrans) + self.readBool() elif type == TType.BYTE: - self.readByte(itrans) + self.readByte() elif type == TType.I16: - self.readI16(itrans) + self.readI16() elif type == TType.I32: - self.readI32(itrans) + self.readI32() elif type == TType.I64: - self.readI64(itrans) + self.readI64() elif type == TType.DOUBLE: - self.readDouble(itrans) + self.readDouble() elif type == TType.STRING: - self.readString(itrans) + self.readString() elif type == TType.STRUCT: - name = self.readStructBegin(itrans) + name = self.readStructBegin() while True: - (name, type, id) = self.readFieldBegin(itrans) + (name, type, id) = self.readFieldBegin() if type == TType.STOP: break - self.skip(itrans, type) - self.readFieldEnd(itrans) - self.readStructEnd(itrans) + self.skip(type) + self.readFieldEnd() + self.readStructEnd() elif type == TType.MAP: - (ktype, vtype, size) = self.readMapBegin(itrans) + (ktype, vtype, size) = self.readMapBegin() for i in range(size): - self.skip(itrans, ktype) - self.skip(itrans, vtype) - self.readMapEnd(itrans) + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() elif type == TType.SET: - (etype, size) = self.readSetBegin(itrans) + (etype, size) = self.readSetBegin() for i in range(size): - self.skip(itrans, etype) - self.readSetEnd(itrans) + self.skip(etype) + self.readSetEnd() elif type == TType.LIST: - (etype, size) = self.readListBegin(itrans) + (etype, size) = self.readListBegin() for i in range(size): - self.skip(itrans, etype) - self.readListEnd(itrans) - + self.skip(etype) + self.readListEnd() - +class TProtocolFactory: + def getIOProtocols(self, itrans, otrans): + pass diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py index 56ee9c08..75142642 100644 --- a/lib/py/src/server/TServer.py +++ b/lib/py/src/server/TServer.py @@ -5,18 +5,23 @@ import Queue from thrift.Thrift import TProcessor from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol class TServer: """Base interface for a server, which must have a serve method.""" - def __init__(self, processor, serverTransport, transportFactory=None): + def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): self.processor = processor self.serverTransport = serverTransport if transportFactory == None: self.transportFactory = TTransport.TTransportFactoryBase() else: self.transportFactory = transportFactory + if protocolFactory == None: + self.protocolFactory = TBinaryProtocol.TBinaryProtocolFactory() + else: + self.protocolFactory = protocolFactory def serve(self): pass @@ -25,31 +30,32 @@ class TSimpleServer(TServer): """Simple single-threaded server that just pumps around one transport.""" - def __init__(self, processor, serverTransport, transportFactory=None): - TServer.__init__(self, processor, serverTransport, transportFactory) + def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): + TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory) def serve(self): self.serverTransport.listen() while True: client = self.serverTransport.accept() - (input, output) = self.transportFactory.getIOTransports(client) + (itrans, otrans) = self.transportFactory.getIOTransports(client) + (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) try: while True: - self.processor.process(input, output) + self.processor.process(iprot, oprot) except TTransport.TTransportException, tx: pass except Exception, x: print '%s, %s, %s' % (type(x), x, traceback.format_exc()) - input.close() - output.close() + itrans.close() + otrans.close() class TThreadedServer(TServer): """Threaded server that spawns a new thread per each connection.""" - def __init__(self, processor, serverTransport, transportFactory=None): - TServer.__init__(self, processor, serverTransport, transportFactory) + def __init__(self, processor, serverTransport, transportFactory=None, protocolFactory=None): + TServer.__init__(self, processor, serverTransport, transportFactory, protocolFactory) def serve(self): self.serverTransport.listen() @@ -62,15 +68,19 @@ class TThreadedServer(TServer): print '%s, %s, %s,' % (type(x), x, traceback.format_exc()) def handle(self, client): - (input, output) = self.transportFactory.getIOTransports(client) + (itrans, otrans) = self.transportFactory.getIOTransports(client) + (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) try: while True: - self.processor.process(input, output) + self.processor.process(iprot, oprot) except TTransport.TTransportException, tx: pass except Exception, x: print '%s, %s, %s' % (type(x), x, traceback.format_exc()) + itrans.close() + otrans.close() + class TThreadPoolServer(TServer): """Server with a fixed size pool of threads which service requests.""" @@ -95,15 +105,19 @@ class TThreadPoolServer(TServer): def serveClient(self, client): """Process input/output from a client for as long as possible""" - (input, output) = self.transportFactory.getIOTransports(client) + (itrans, otrans) = self.transportFactory.getIOTransports(client) + (iprot, oprot) = self.protocolFactory.getIOProtocols(itrans, otrans) try: while True: - self.processor.process(input, output) + self.processor.process(iprot, oprot) except TTransport.TTransportException, tx: pass except Exception, x: print '%s, %s, %s' % (type(x), x, traceback.format_exc()) + itrans.close() + otrans.close() + def serve(self): """Start a fixed number of worker threads and put client into a queue""" for i in range(self.threads): -- 2.17.1