From 6972041392314d526584e733781ca382a960b295 Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Tue, 3 Jan 2012 17:32:30 +0000 Subject: [PATCH] THRIFT-1480. py: remove tabs, adjust whitespace and address PEP8 warnings This patch addresses a host of PEP8 lint problems. Patch: Will Pierce git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1226890 13f79535-47bb-0310-9956-ffa450edef68 --- lib/py/src/TSCons.py | 8 +- lib/py/src/TSerialization.py | 10 ++- lib/py/src/Thrift.py | 43 +++++----- lib/py/src/protocol/TBase.py | 27 ++++-- lib/py/src/protocol/TBinaryProtocol.py | 13 +-- lib/py/src/protocol/TCompactProtocol.py | 34 +++++--- lib/py/src/protocol/TProtocol.py | 56 ++++++------ lib/py/src/server/THttpServer.py | 21 +++-- lib/py/src/server/TNonblockingServer.py | 94 +++++++++++--------- lib/py/src/server/TProcessPoolServer.py | 29 +++---- lib/py/src/server/TServer.py | 37 ++++---- lib/py/src/transport/THttpClient.py | 28 +++--- lib/py/src/transport/TSSLSocket.py | 90 ++++++++++++------- lib/py/src/transport/TSocket.py | 33 ++++--- lib/py/src/transport/TTransport.py | 25 +++--- lib/py/src/transport/TTwisted.py | 6 +- lib/py/src/transport/TZlibTransport.py | 109 +++++++++++------------- lib/py/src/transport/__init__.py | 2 +- 18 files changed, 366 insertions(+), 299 deletions(-) diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py index 24046256..da8d2833 100644 --- a/lib/py/src/TSCons.py +++ b/lib/py/src/TSCons.py @@ -20,14 +20,16 @@ from os import path from SCons.Builder import Builder + def scons_env(env, add=''): opath = path.dirname(path.abspath('$TARGET')) lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' - cppbuild = Builder(action = lstr) - env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + cppbuild = Builder(action=lstr) + env.Append(BUILDERS={'ThriftCpp': cppbuild}) + def gen_cpp(env, dir, file): scons_env(env) suffixes = ['_types.h', '_types.cpp'] targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) - return env.ThriftCpp(targets, dir+file+'.thrift') + return env.ThriftCpp(targets, dir + file + '.thrift') diff --git a/lib/py/src/TSerialization.py b/lib/py/src/TSerialization.py index b19f98aa..8a58d89d 100644 --- a/lib/py/src/TSerialization.py +++ b/lib/py/src/TSerialization.py @@ -20,15 +20,19 @@ from protocol import TBinaryProtocol from transport import TTransport -def serialize(thrift_object, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def serialize(thrift_object, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): transport = TTransport.TMemoryBuffer() protocol = protocol_factory.getProtocol(transport) thrift_object.write(protocol) return transport.getvalue() -def deserialize(base, buf, protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()): + +def deserialize(base, + buf, + protocol_factory=TBinaryProtocol.TBinaryProtocolFactory()): transport = TTransport.TMemoryBuffer(buf) protocol = protocol_factory.getProtocol(transport) base.read(protocol) return base - diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py index 1d271fcf..707a8ccc 100644 --- a/lib/py/src/Thrift.py +++ b/lib/py/src/Thrift.py @@ -19,6 +19,7 @@ import sys + class TType: STOP = 0 VOID = 1 @@ -38,7 +39,7 @@ class TType: UTF8 = 16 UTF16 = 17 - _VALUES_TO_NAMES = ( 'STOP', + _VALUES_TO_NAMES = ('STOP', 'VOID', 'BOOL', 'BYTE', @@ -48,46 +49,48 @@ class TType: None, 'I32', None, - 'I64', - 'STRING', - 'STRUCT', - 'MAP', - 'SET', - 'LIST', - 'UTF8', - 'UTF16' ) + 'I64', + 'STRING', + 'STRUCT', + 'MAP', + 'SET', + 'LIST', + 'UTF8', + 'UTF16') + class TMessageType: - CALL = 1 + CALL = 1 REPLY = 2 EXCEPTION = 3 ONEWAY = 4 -class TProcessor: +class TProcessor: """Base class for procsessor, which works on two streams.""" def process(iprot, oprot): pass -class TException(Exception): +class TException(Exception): """Base class for all thrift exceptions.""" # BaseException.message is deprecated in Python v[2.6,3.0) - if (2,6,0) <= sys.version_info < (3,0): + if (2, 6, 0) <= sys.version_info < (3, 0): def _get_message(self): - return self._message + return self._message + def _set_message(self, message): - self._message = message + self._message = message message = property(_get_message, _set_message) def __init__(self, message=None): Exception.__init__(self, message) self.message = message -class TApplicationException(TException): +class TApplicationException(TException): """Application level thrift exceptions.""" UNKNOWN = 0 @@ -127,12 +130,12 @@ class TApplicationException(TException): break if fid == 1: if ftype == TType.STRING: - self.message = iprot.readString(); + self.message = iprot.readString() else: iprot.skip(ftype) elif fid == 2: if ftype == TType.I32: - self.type = iprot.readI32(); + self.type = iprot.readI32() else: iprot.skip(ftype) else: @@ -142,11 +145,11 @@ class TApplicationException(TException): def write(self, oprot): oprot.writeStructBegin('TApplicationException') - if self.message != None: + if self.message is not None: oprot.writeFieldBegin('message', TType.STRING, 1) oprot.writeString(self.message) oprot.writeFieldEnd() - if self.type != None: + if self.type is not None: oprot.writeFieldBegin('type', TType.I32, 2) oprot.writeI32(self.type) oprot.writeFieldEnd() diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py index e675c7dc..6cbd5f39 100644 --- a/lib/py/src/protocol/TBase.py +++ b/lib/py/src/protocol/TBase.py @@ -26,12 +26,13 @@ try: except: fastbinary = None + class TBase(object): __slots__ = [] def __repr__(self): L = ['%s=%r' % (key, getattr(self, key)) - for key in self.__slots__ ] + for key in self.__slots__] return '%s(%s)' % (self.__class__.__name__, ', '.join(L)) def __eq__(self, other): @@ -43,30 +44,38 @@ class TBase(object): if my_val != other_val: return False return True - + def __ne__(self, other): return not (self == other) - + def read(self, iprot): - if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None: - fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec)) + if (iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + isinstance(iprot.trans, TTransport.CReadableTransport) and + self.thrift_spec is not None and + fastbinary is not None): + fastbinary.decode_binary(self, + iprot.trans, + (self.__class__, self.thrift_spec)) return iprot.readStruct(self, self.thrift_spec) def write(self, oprot): - if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None: - oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) + if (oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and + self.thrift_spec is not None and + fastbinary is not None): + oprot.trans.write( + fastbinary.encode_binary(self, (self.__class__, self.thrift_spec))) return oprot.writeStruct(self, self.thrift_spec) + class TExceptionBase(Exception): # old style class so python2.4 can raise exceptions derived from this # This can't inherit from TBase because of that limitation. __slots__ = [] - + __repr__ = TBase.__repr__.im_func __eq__ = TBase.__eq__.im_func __ne__ = TBase.__ne__.im_func read = TBase.read.im_func write = TBase.write.im_func - diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py index 50c6aa89..6fdd08c2 100644 --- a/lib/py/src/protocol/TBinaryProtocol.py +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -20,8 +20,8 @@ from TProtocol import * from struct import pack, unpack -class TBinaryProtocol(TProtocolBase): +class TBinaryProtocol(TProtocolBase): """Binary implementation of the Thrift protocol driver.""" # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be @@ -68,7 +68,7 @@ class TBinaryProtocol(TProtocolBase): pass def writeFieldStop(self): - self.writeByte(TType.STOP); + self.writeByte(TType.STOP) def writeMapBegin(self, ktype, vtype, size): self.writeByte(ktype) @@ -127,13 +127,16 @@ class TBinaryProtocol(TProtocolBase): if sz < 0: version = sz & TBinaryProtocol.VERSION_MASK if version != TBinaryProtocol.VERSION_1: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='Bad version in readMessageBegin: %d' % (sz)) + raise TProtocolException( + type=TProtocolException.BAD_VERSION, + message='Bad version in readMessageBegin: %d' % (sz)) type = sz & TBinaryProtocol.TYPE_MASK name = self.readString() seqid = self.readI32() else: if self.strictRead: - raise TProtocolException(type=TProtocolException.BAD_VERSION, message='No protocol version header') + raise TProtocolException(type=TProtocolException.BAD_VERSION, + message='No protocol version header') name = self.trans.readAll(sz) type = self.readByte() seqid = self.readI32() @@ -231,7 +234,6 @@ class TBinaryProtocolFactory: class TBinaryProtocolAccelerated(TBinaryProtocol): - """C-Accelerated version of TBinaryProtocol. This class does not override any of TBinaryProtocol's methods, @@ -250,7 +252,6 @@ class TBinaryProtocolAccelerated(TBinaryProtocol): Please feel free to report bugs and/or success stories to the public mailing list. """ - pass diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py index 016a3317..cdec6077 100644 --- a/lib/py/src/protocol/TCompactProtocol.py +++ b/lib/py/src/protocol/TCompactProtocol.py @@ -32,6 +32,7 @@ CONTAINER_READ = 6 VALUE_READ = 7 BOOL_READ = 8 + def make_helper(v_from, container): def helper(func): def nested(self, *args, **kwargs): @@ -42,12 +43,15 @@ def make_helper(v_from, container): writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) reader = make_helper(VALUE_READ, CONTAINER_READ) + def makeZigZag(n, bits): return (n << 1) ^ (n >> (bits - 1)) + def fromZigZag(n): return (n >> 1) ^ -(n & 1) + def writeVarint(trans, n): out = [] while True: @@ -59,6 +63,7 @@ def writeVarint(trans, n): n = n >> 7 trans.write(''.join(map(chr, out))) + def readVarint(trans): result = 0 shift = 0 @@ -70,6 +75,7 @@ def readVarint(trans): return result shift += 7 + class CompactType: STOP = 0x00 TRUE = 0x01 @@ -86,7 +92,7 @@ class CompactType: STRUCT = 0x0C CTYPES = {TType.STOP: CompactType.STOP, - TType.BOOL: CompactType.TRUE, # used for collection + TType.BOOL: CompactType.TRUE, # used for collection TType.BYTE: CompactType.BYTE, TType.I16: CompactType.I16, TType.I32: CompactType.I32, @@ -106,8 +112,9 @@ TTYPES[CompactType.FALSE] = TType.BOOL del k del v + class TCompactProtocol(TProtocolBase): - "Compact implementation of the Thrift protocol driver." + """Compact implementation of the Thrift protocol driver.""" PROTOCOL_ID = 0x82 VERSION = 1 @@ -217,18 +224,18 @@ class TCompactProtocol(TProtocolBase): def writeBool(self, bool): if self.state == BOOL_WRITE: - if bool: - ctype = CompactType.TRUE - else: - ctype = CompactType.FALSE - self.__writeFieldHeader(ctype, self.__bool_fid) + if bool: + ctype = CompactType.TRUE + else: + ctype = CompactType.FALSE + self.__writeFieldHeader(ctype, self.__bool_fid) elif self.state == CONTAINER_WRITE: - if bool: - self.__writeByte(CompactType.TRUE) - else: - self.__writeByte(CompactType.FALSE) + if bool: + self.__writeByte(CompactType.TRUE) + else: + self.__writeByte(CompactType.FALSE) else: - raise AssertionError, "Invalid state in compact protocol" + raise AssertionError("Invalid state in compact protocol") writeByte = writer(__writeByte) writeI16 = writer(__writeI16) @@ -364,7 +371,8 @@ class TCompactProtocol(TProtocolBase): elif self.state == CONTAINER_READ: return self.__readByte() == CompactType.TRUE else: - raise AssertionError, "Invalid state in compact protocol: %d" % self.state + raise AssertionError("Invalid state in compact protocol: %d" % + self.state) readByte = reader(__readByte) __readI16 = __readZigZag diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index 7338ff68..51772291 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -19,8 +19,8 @@ from thrift.Thrift import * -class TProtocolException(TException): +class TProtocolException(TException): """Custom Protocol Exception class""" UNKNOWN = 0 @@ -33,8 +33,8 @@ class TProtocolException(TException): TException.__init__(self, message) self.type = type -class TProtocolBase: +class TProtocolBase: """Base class for Thrift protocol driver.""" def __init__(self, trans): @@ -200,26 +200,26 @@ class TProtocolBase: self.skip(etype) self.readListEnd() - # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name ) + # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) _TTYPE_HANDLERS = ( - (None, None, False), # 0 == TType,STOP - (None, None, False), # 1 == TType.VOID # TODO: handle void? - ('readBool', 'writeBool', False), # 2 == TType.BOOL - ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08 - ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE - (None, None, False), # 5, undefined - ('readI16', 'writeI16', False), # 6 == TType.I16 - (None, None, False), # 7, undefined - ('readI32', 'writeI32', False), # 8 == TType.I32 - (None, None, False), # 9, undefined - ('readI64', 'writeI64', False), # 10 == TType.I64 - ('readString', 'writeString', False), # 11 == TType.STRING and UTF7 - ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT - ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP - ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET - ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST - (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types? - (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types? + (None, None, False), # 0 TType.STOP + (None, None, False), # 1 TType.VOID # TODO: handle void? + ('readBool', 'writeBool', False), # 2 TType.BOOL + ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 + ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE + (None, None, False), # 5 undefined + ('readI16', 'writeI16', False), # 6 TType.I16 + (None, None, False), # 7 undefined + ('readI32', 'writeI32', False), # 8 TType.I32 + (None, None, False), # 9 undefined + ('readI64', 'writeI64', False), # 10 TType.I64 + ('readString', 'writeString', False), # 11 TType.STRING and UTF7 + ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT + ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP + ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET + ('readContainerList', 'writeContainerList', True), # 15 TType.LIST + (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? + (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? ) def readFieldByTType(self, ttype, spec): @@ -270,7 +270,7 @@ class TProtocolBase: container_reader = self._TTYPE_HANDLERS[set_type][0] val_reader = getattr(self, container_reader) for idx in xrange(set_len): - results.add(val_reader(tspec)) + results.add(val_reader(tspec)) self.readSetEnd() return results @@ -279,13 +279,14 @@ class TProtocolBase: obj = obj_class() obj.read(self) return obj - + def readContainerMap(self, spec): results = dict() key_ttype, key_spec = spec[0], spec[1] val_ttype, val_spec = spec[2], spec[3] (map_ktype, map_vtype, map_len) = self.readMapBegin() - # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree + # TODO: compare types we just decoded with thrift_spec and + # abort/skip if types disagree key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0]) val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0]) # list values are simple types @@ -298,7 +299,8 @@ class TProtocolBase: v_val = val_reader() else: v_val = self.readFieldByTType(val_ttype, val_spec) - # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails + # this raises a TypeError with unhashable keys types + # i.e. this fails: d=dict(); d[[0,1]] = 2 results[k_val] = v_val self.readMapEnd() return results @@ -329,7 +331,7 @@ class TProtocolBase: def writeContainerList(self, val, spec): self.writeListBegin(spec[0], len(val)) - r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] + r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]] e_writer = getattr(self, w_handler) if not is_container: for elem in val: @@ -398,7 +400,7 @@ class TProtocolBase: else: writer(val) + class TProtocolFactory: def getProtocol(self, trans): pass - diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py index 3047d9c0..be54bab9 100644 --- a/lib/py/src/server/THttpServer.py +++ b/lib/py/src/server/THttpServer.py @@ -22,6 +22,7 @@ import BaseHTTPServer from thrift.server import TServer from thrift.transport import TTransport + class ResponseException(Exception): """Allows handlers to override the HTTP response @@ -39,16 +40,19 @@ class THttpServer(TServer.TServer): """A simple HTTP-based Thrift server This class is not very performant, but it is useful (for example) for - acting as a mock version of an Apache-based PHP Thrift endpoint.""" - - def __init__(self, processor, server_address, - inputProtocolFactory, outputProtocolFactory = None, - server_class = BaseHTTPServer.HTTPServer): + acting as a mock version of an Apache-based PHP Thrift endpoint. + """ + def __init__(self, + processor, + server_address, + inputProtocolFactory, + outputProtocolFactory=None, + server_class=BaseHTTPServer.HTTPServer): """Set up protocol factories and HTTP server. See BaseHTTPServer for server_address. - See TServer for protocol factories.""" - + See TServer for protocol factories. + """ if outputProtocolFactory is None: outputProtocolFactory = inputProtocolFactory @@ -62,7 +66,8 @@ class THttpServer(TServer.TServer): # Don't care about the request path. itrans = TTransport.TFileObjectTransport(self.rfile) otrans = TTransport.TFileObjectTransport(self.wfile) - itrans = TTransport.TBufferedTransport(itrans, int(self.headers['Content-Length'])) + itrans = TTransport.TBufferedTransport( + itrans, int(self.headers['Content-Length'])) otrans = TTransport.TMemoryBuffer() iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py index ea348a0b..cd90b4fc 100644 --- a/lib/py/src/server/TNonblockingServer.py +++ b/lib/py/src/server/TNonblockingServer.py @@ -18,10 +18,11 @@ # """Implementation of non-blocking server. -The main idea of the server is reciving and sending requests -only from main thread. +The main idea of the server is to receive and send requests +only from the main thread. -It also makes thread pool server in tasks terms, not connections. +The thread poool should be sized for concurrent tasks, not +maximum connections """ import threading import socket @@ -35,8 +36,10 @@ from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory __all__ = ['TNonblockingServer'] + class Worker(threading.Thread): """Worker is a small helper to process incoming connection.""" + def __init__(self, queue): threading.Thread.__init__(self) self.queue = queue @@ -60,8 +63,9 @@ WAIT_PROCESS = 2 SEND_ANSWER = 3 CLOSED = 4 + def locked(func): - "Decorator which locks self.lock." + """Decorator which locks self.lock.""" def nested(self, *args, **kwargs): self.lock.acquire() try: @@ -70,8 +74,9 @@ def locked(func): self.lock.release() return nested + def socket_exception(func): - "Decorator close object on socket.error." + """Decorator close object on socket.error.""" def read(self, *args, **kwargs): try: return func(self, *args, **kwargs) @@ -79,16 +84,17 @@ def socket_exception(func): self.close() return read + class Connection: """Basic class is represented connection. - + It can be in state: WAIT_LEN --- connection is reading request len. WAIT_MESSAGE --- connection is reading request. - WAIT_PROCESS --- connection has just read whole request and - waits for call ready routine. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. SEND_ANSWER --- connection is sending answer string (including length - of answer). + of answer). CLOSED --- socket was closed and connection should be deleted. """ def __init__(self, new_socket, wake_up): @@ -102,13 +108,13 @@ class Connection: def _read_len(self): """Reads length of request. - - It's really paranoic routine and it may be replaced by - self.socket.recv(4).""" + + It's a safer alternative to self.socket.recv(4) + """ read = self.socket.recv(4 - len(self.message)) if len(read) == 0: - # if we read 0 bytes and self.message is empty, it means client close - # connection + # if we read 0 bytes and self.message is empty, then + # the client closed the connection if len(self.message) != 0: logging.error("can't read frame size from socket") self.close() @@ -117,8 +123,8 @@ class Connection: if len(self.message) == 4: self.len, = struct.unpack('!i', self.message) if self.len < 0: - logging.error("negative frame size, it seems client"\ - " doesn't use FramedTransport") + logging.error("negative frame size, it seems client " + "doesn't use FramedTransport") self.close() elif self.len == 0: logging.error("empty frame, it's really strange") @@ -139,8 +145,8 @@ class Connection: elif self.status == WAIT_MESSAGE: read = self.socket.recv(self.len - len(self.message)) if len(read) == 0: - logging.error("can't read frame from socket (get %d of %d bytes)" % - (len(self.message), self.len)) + logging.error("can't read frame from socket (get %d of " + "%d bytes)" % (len(self.message), self.len)) self.close() return self.message += read @@ -162,14 +168,14 @@ class Connection: @locked def ready(self, all_ok, message): """Callback function for switching state and waking up main thread. - + This function is the only function witch can be called asynchronous. - + The ready can switch Connection to three states: WAIT_LEN if request was oneway. SEND_ANSWER if request was processed in normal way. CLOSED if request throws unexpected exception. - + The one wakes up main thread. """ assert self.status == WAIT_PROCESS @@ -189,33 +195,39 @@ class Connection: @locked def is_writeable(self): - "Returns True if connection should be added to write list of select." + """Return True if connection should be added to write list of select""" return self.status == SEND_ANSWER # it's not necessary, but... @locked def is_readable(self): - "Returns True if connection should be added to read list of select." + """Return True if connection should be added to read list of select""" return self.status in (WAIT_LEN, WAIT_MESSAGE) @locked def is_closed(self): - "Returns True if connection is closed." + """Returns True if connection is closed.""" return self.status == CLOSED def fileno(self): - "Returns the file descriptor of the associated socket." + """Returns the file descriptor of the associated socket.""" return self.socket.fileno() def close(self): - "Closes connection" + """Closes connection""" self.status = CLOSED self.socket.close() + class TNonblockingServer: """Non-blocking server.""" - def __init__(self, processor, lsocket, inputProtocolFactory=None, - outputProtocolFactory=None, threads=10): + + def __init__(self, + processor, + lsocket, + inputProtocolFactory=None, + outputProtocolFactory=None, + threads=10): self.processor = processor self.socket = lsocket self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() @@ -229,7 +241,7 @@ class TNonblockingServer: def setNumThreads(self, num): """Set the number of worker threads that should be created.""" # implement ThreadPool interface - assert not self.prepared, "You can't change number of threads for working server" + assert not self.prepared, "Can't change number of threads after start" self.threads = num def prepare(self): @@ -243,14 +255,15 @@ class TNonblockingServer: def wake_up(self): """Wake up main thread. - + The server usualy waits in select call in we should terminate one. The simplest way is using socketpair. - + Select always wait to read from the first socket of socketpair. - + In this case, we can just write anything to the second socket from - socketpair.""" + socketpair. + """ self._write.send('1') def _select(self): @@ -265,21 +278,22 @@ class TNonblockingServer: if connection.is_closed(): del self.clients[i] return select.select(readable, writable, readable) - + def handle(self): """Handle requests. - - WARNING! You must call prepare BEFORE calling handle. + + WARNING! You must call prepare() BEFORE calling handle() """ assert self.prepared, "You have to call prepare before handle" rset, wset, xset = self._select() for readable in rset: if readable == self._read.fileno(): # don't care i just need to clean readable flag - self._read.recv(1024) + self._read.recv(1024) elif readable == self.socket.handle.fileno(): client = self.socket.accept().handle - self.clients[client.fileno()] = Connection(client, self.wake_up) + self.clients[client.fileno()] = Connection(client, + self.wake_up) else: connection = self.clients[readable] connection.read() @@ -288,7 +302,7 @@ class TNonblockingServer: otransport = TTransport.TMemoryBuffer() iprot = self.in_protocol.getProtocol(itransport) oprot = self.out_protocol.getProtocol(otransport) - self.tasks.put([self.processor, iprot, oprot, + self.tasks.put([self.processor, iprot, oprot, otransport, connection.ready]) for writeable in wset: self.clients[writeable].write() @@ -302,7 +316,7 @@ class TNonblockingServer: self.tasks.put([None, None, None, None, None]) self.socket.close() self.prepared = False - + def serve(self): """Serve forever.""" self.prepare() diff --git a/lib/py/src/server/TProcessPoolServer.py b/lib/py/src/server/TProcessPoolServer.py index 7ed814a8..7a695a88 100644 --- a/lib/py/src/server/TProcessPoolServer.py +++ b/lib/py/src/server/TProcessPoolServer.py @@ -24,15 +24,14 @@ from multiprocessing import Process, Value, Condition, reduction from TServer import TServer from thrift.transport.TTransport import TTransportException + class TProcessPoolServer(TServer): + """Server with a fixed size pool of worker subprocesses to service requests - """ - Server with a fixed size pool of worker subprocesses which service requests. Note that if you need shared state between the handlers - it's up to you! Written by Dvir Volk, doat.com """ - - def __init__(self, * args): + def __init__(self, *args): TServer.__init__(self, *args) self.numWorkers = 10 self.workers = [] @@ -50,12 +49,11 @@ class TProcessPoolServer(TServer): self.numWorkers = num def workerProcess(self): - """Loop around getting clients from the shared queue and process them.""" - + """Loop getting clients from the shared queue and process them""" if self.postForkCallback: self.postForkCallback() - while self.isRunning.value == True: + while self.isRunning.value: try: client = self.serverTransport.accept() self.serveClient(client) @@ -82,17 +80,15 @@ class TProcessPoolServer(TServer): itrans.close() otrans.close() - def serve(self): - """Start a fixed number of worker threads and put client into a queue""" - - #this is a shared state that can tell the workers to exit when set as false + """Start workers and put into queue""" + # this is a shared state that can tell the workers to exit when False self.isRunning.value = True - #first bind and listen to the port + # first bind and listen to the port self.serverTransport.listen() - #fork the children + # fork the children for i in range(self.numWorkers): try: w = Process(target=self.workerProcess) @@ -102,16 +98,14 @@ class TProcessPoolServer(TServer): except Exception, x: logging.exception(x) - #wait until the condition is set by stop() - + # wait until the condition is set by stop() while True: - self.stopCondition.acquire() try: self.stopCondition.wait() break except (SystemExit, KeyboardInterrupt): - break + break except Exception, x: logging.exception(x) @@ -122,4 +116,3 @@ class TProcessPoolServer(TServer): self.stopCondition.acquire() self.stopCondition.notify() self.stopCondition.release() - diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py index 8456e2d4..2f24842c 100644 --- a/lib/py/src/server/TServer.py +++ b/lib/py/src/server/TServer.py @@ -17,27 +17,28 @@ # under the License. # +import Queue import logging -import sys import os -import traceback +import sys import threading -import Queue +import traceback from thrift.Thrift import TProcessor -from thrift.transport import TTransport from thrift.protocol import TBinaryProtocol +from thrift.transport import TTransport -class TServer: - """Base interface for a server, which must have a serve method.""" +class TServer: + """Base interface for a server, which must have a serve() method. - """ 3 constructors for all servers: + Three constructors for all servers: 1) (processor, serverTransport) 2) (processor, serverTransport, transportFactory, protocolFactory) 3) (processor, serverTransport, inputTransportFactory, outputTransportFactory, - inputProtocolFactory, outputProtocolFactory)""" + inputProtocolFactory, outputProtocolFactory) + """ def __init__(self, *args): if (len(args) == 2): self.__initArgs__(args[0], args[1], @@ -63,8 +64,8 @@ class TServer: def serve(self): pass -class TSimpleServer(TServer): +class TSimpleServer(TServer): """Simple single-threaded server that just pumps around one transport.""" def __init__(self, *args): @@ -89,8 +90,8 @@ class TSimpleServer(TServer): itrans.close() otrans.close() -class TThreadedServer(TServer): +class TThreadedServer(TServer): """Threaded server that spawns a new thread per each connection.""" def __init__(self, *args, **kwargs): @@ -102,7 +103,7 @@ class TThreadedServer(TServer): while True: try: client = self.serverTransport.accept() - t = threading.Thread(target = self.handle, args=(client,)) + t = threading.Thread(target=self.handle, args=(client,)) t.setDaemon(self.daemon) t.start() except KeyboardInterrupt: @@ -126,8 +127,8 @@ class TThreadedServer(TServer): itrans.close() otrans.close() -class TThreadPoolServer(TServer): +class TThreadPoolServer(TServer): """Server with a fixed size pool of threads which service requests.""" def __init__(self, *args, **kwargs): @@ -170,7 +171,7 @@ class TThreadPoolServer(TServer): """Start a fixed number of worker threads and put client into a queue""" for i in range(self.threads): try: - t = threading.Thread(target = self.serveThread) + t = threading.Thread(target=self.serveThread) t.setDaemon(self.daemon) t.start() except Exception, x: @@ -187,9 +188,8 @@ class TThreadPoolServer(TServer): class TForkingServer(TServer): + """A Thrift server that forks a new process for each request - """A Thrift server that forks a new process for each request""" - """ This is more scalable than the threaded server as it does not cause GIL contention. @@ -200,7 +200,6 @@ class TForkingServer(TServer): This code is heavily inspired by SocketServer.ForkingMixIn in the Python stdlib. """ - def __init__(self, *args): TServer.__init__(self, *args) self.children = [] @@ -212,14 +211,13 @@ class TForkingServer(TServer): except IOError, e: logging.warning(e, exc_info=True) - self.serverTransport.listen() while True: client = self.serverTransport.accept() try: pid = os.fork() - if pid: # parent + if pid: # parent # add before collect, otherwise you race w/ waitpid self.children.append(pid) self.collect_children() @@ -258,7 +256,6 @@ class TForkingServer(TServer): except Exception, x: logging.exception(x) - def collect_children(self): while self.children: try: @@ -270,5 +267,3 @@ class TForkingServer(TServer): self.children.remove(pid) else: break - - diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py index 50269785..ad94d112 100644 --- a/lib/py/src/transport/THttpClient.py +++ b/lib/py/src/transport/THttpClient.py @@ -17,16 +17,17 @@ # under the License. # -from TTransport import * -from cStringIO import StringIO - -import urlparse import httplib -import warnings import socket +import urlparse +import warnings -class THttpClient(TTransportBase): +from cStringIO import StringIO + +from TTransport import * + +class THttpClient(TTransportBase): """Http implementation of TTransport base.""" def __init__(self, uri_or_host, port=None, path=None): @@ -35,10 +36,13 @@ class THttpClient(TTransportBase): THttpClient(host, port, path) - deprecated THttpClient(uri) - Only the second supports https.""" - + Only the second supports https. + """ if port is not None: - warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) + warnings.warn( + "Please use the THttpClient('http://host:port/path') syntax", + DeprecationWarning, + stacklevel=2) self.host = uri_or_host self.port = port assert path @@ -71,7 +75,7 @@ class THttpClient(TTransportBase): self.__http = None def isOpen(self): - return self.__http != None + return self.__http is not None def setTimeout(self, ms): if not hasattr(socket, 'getdefaulttimeout'): @@ -80,7 +84,7 @@ class THttpClient(TTransportBase): if ms is None: self.__timeout = None else: - self.__timeout = ms/1000.0 + self.__timeout = ms / 1000.0 def read(self, sz): return self.__http.file.read(sz) @@ -100,7 +104,7 @@ class THttpClient(TTransportBase): def flush(self): if self.isOpen(): self.close() - self.open(); + self.open() # Pull data out of buffer data = self.__wbuf.getvalue() diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py index be358448..6d79ac6a 100644 --- a/lib/py/src/transport/TSSLSocket.py +++ b/lib/py/src/transport/TSSLSocket.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. # + import os import socket import ssl @@ -23,28 +24,35 @@ import ssl from thrift.transport import TSocket from thrift.transport.TTransport import TTransportException + class TSSLSocket(TSocket.TSocket): """ SSL implementation of client-side TSocket This class creates outbound sockets wrapped using the python standard ssl module for encrypted connections. - + The protocol used is set using the class variable SSL_VERSION, which must be one of ssl.PROTOCOL_* and defaults to ssl.PROTOCOL_TLSv1 for greatest security. """ SSL_VERSION = ssl.PROTOCOL_TLSv1 - def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, unix_socket=None): - """ - @param validate: Set to False to disable SSL certificate validation entirely. + def __init__(self, + host='localhost', + port=9090, + validate=True, + ca_certs=None, + unix_socket=None): + """Create SSL TSocket + + @param validate: Set to False to disable SSL certificate validation @type validate: bool @param ca_certs: Filename to the Certificate Authority pem file, possibly a file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to the ssl_wrap function as the 'ca_certs' parameter. @type ca_certs: str - + Raises an IOError exception if validate is True and the ca_certs file is None, not present or unreadable. """ @@ -58,18 +66,23 @@ class TSSLSocket(TSocket.TSocket): self.ca_certs = ca_certs if validate: if ca_certs is None or not os.access(ca_certs, os.R_OK): - raise IOError('Certificate Authority ca_certs file "%s" is not readable, cannot validate SSL certificates.' % (ca_certs)) + raise IOError('Certificate Authority ca_certs file "%s" ' + 'is not readable, cannot validate SSL ' + 'certificates.' % (ca_certs)) TSocket.TSocket.__init__(self, host, port, unix_socket) def open(self): try: res0 = self._resolveAddr() for res in res0: - sock_family, sock_type= res[0:2] + sock_family, sock_type = res[0:2] ip_port = res[4] plain_sock = socket.socket(sock_family, sock_type) - self.handle = ssl.wrap_socket(plain_sock, ssl_version=self.SSL_VERSION, - do_handshake_on_connect=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs) + self.handle = ssl.wrap_socket(plain_sock, + ssl_version=self.SSL_VERSION, + do_handshake_on_connect=True, + ca_certs=self.ca_certs, + cert_reqs=self.cert_reqs) self.handle.settimeout(self._timeout) try: self.handle.connect(ip_port) @@ -84,7 +97,8 @@ class TSSLSocket(TSocket.TSocket): message = 'Could not connect to secure socket %s' % self._unix_socket else: message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) if self.validate: self._validate_cert() @@ -93,13 +107,15 @@ class TSSLSocket(TSocket.TSocket): commonName of the certificate to ensure it matches the hostname we used to make this connection. Does not support subjectAltName records in certificates. - - raises TTransportException if the certificate fails validation.""" + + raises TTransportException if the certificate fails validation. + """ cert = self.handle.getpeercert() self.peercert = cert if 'subject' not in cert: - raise TTransportException(type=TTransportException.NOT_OPEN, - message='No SSL certificate found from %s:%s' % (self.host, self.port)) + raise TTransportException( + type=TTransportException.NOT_OPEN, + message='No SSL certificate found from %s:%s' % (self.host, self.port)) fields = cert['subject'] for field in fields: # ensure structure we get back is what we expect @@ -115,29 +131,38 @@ class TSSLSocket(TSocket.TSocket): if certhost == self.host: # success, cert commonName matches desired hostname self.is_valid = True - return + return else: - raise TTransportException(type=TTransportException.UNKNOWN, - message='Host name we connected to "%s" doesn\'t match certificate provided commonName "%s"' % (self.host, certhost)) - raise TTransportException(type=TTransportException.UNKNOWN, - message='Could not validate SSL certificate from host "%s". Cert=%s' % (self.host, cert)) + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Hostname we connected to "%s" doesn\'t match certificate ' + 'provided commonName "%s"' % (self.host, certhost)) + raise TTransportException( + type=TTransportException.UNKNOWN, + message='Could not validate SSL certificate from ' + 'host "%s". Cert=%s' % (self.host, cert)) + class TSSLServerSocket(TSocket.TServerSocket): - """ - SSL implementation of TServerSocket + """SSL implementation of TServerSocket This uses the ssl module's wrap_socket() method to provide SSL negotiated encryption. """ SSL_VERSION = ssl.PROTOCOL_TLSv1 - def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None): + def __init__(self, + host=None, + port=9090, + certfile='cert.pem', + unix_socket=None): """Initialize a TSSLServerSocket - - @param certfile: The filename of the server certificate file, defaults to cert.pem + + @param certfile: filename of the server certificate, defaults to cert.pem @type certfile: str - @param host: The hostname or IP to bind the listen socket to, i.e. 'localhost' for only allowing - local network connections. Pass None to bind to all interfaces. + @param host: The hostname or IP to bind the listen socket to, + i.e. 'localhost' for only allowing local network connections. + Pass None to bind to all interfaces. @type host: str @param port: The port to listen on for inbound connections. @type port: int @@ -147,10 +172,11 @@ class TSSLServerSocket(TSocket.TServerSocket): def setCertfile(self, certfile): """Set or change the server certificate file used to wrap new connections. - - @param certfile: The filename of the server certificate, i.e. '/etc/certs/server.pem' + + @param certfile: The filename of the server certificate, + i.e. '/etc/certs/server.pem' @type certfile: str - + Raises an IOError exception if the certfile is not present or unreadable. """ if not os.access(certfile, os.R_OK): @@ -166,11 +192,11 @@ class TSSLServerSocket(TSocket.TServerSocket): # failed handshake/ssl wrap, close socket to client plain_client.close() # raise ssl_exc - # We can't raise the exception, because it kills most TServer derived serve() - # methods. + # We can't raise the exception, because it kills most TServer derived + # serve() methods. # Instead, return None, and let the TServer instance deal with it in # other exception handling. (but TSimpleServer dies anyway) - return None + return None result = TSocket.TSocket() result.setHandle(client) return result diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py index 4e0e1874..9e2b3849 100644 --- a/lib/py/src/transport/TSocket.py +++ b/lib/py/src/transport/TSocket.py @@ -17,24 +17,33 @@ # under the License. # -from TTransport import * -import os import errno +import os import socket import sys +from TTransport import * + + class TSocketBase(TTransportBase): def _resolveAddr(self): if self._unix_socket is not None: - return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, + self._unix_socket)] else: - return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + return socket.getaddrinfo(self.host, + self.port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG) def close(self): if self.handle: self.handle.close() self.handle = None + class TSocket(TSocketBase): """Socket implementation of TTransport base.""" @@ -46,7 +55,6 @@ class TSocket(TSocketBase): @param unix_socket(str) The filename of a unix socket to connect to. (host and port will be ignored.) """ - self.host = host self.port = port self.handle = None @@ -63,7 +71,7 @@ class TSocket(TSocketBase): if ms is None: self._timeout = None else: - self._timeout = ms/1000.0 + self._timeout = ms / 1000.0 if self.handle is not None: self.handle.settimeout(self._timeout) @@ -87,7 +95,8 @@ class TSocket(TSocketBase): message = 'Could not connect to socket %s' % self._unix_socket else: message = 'Could not connect to %s:%d' % (self.host, self.port) - raise TTransportException(type=TTransportException.NOT_OPEN, message=message) + raise TTransportException(type=TTransportException.NOT_OPEN, + message=message) def read(self, sz): try: @@ -105,24 +114,28 @@ class TSocket(TSocketBase): else: raise if len(buff) == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket read 0 bytes') + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket read 0 bytes') return buff def write(self, buff): if not self.handle: - raise TTransportException(type=TTransportException.NOT_OPEN, message='Transport not open') + raise TTransportException(type=TTransportException.NOT_OPEN, + message='Transport not open') sent = 0 have = len(buff) while sent < have: plus = self.handle.send(buff) if plus == 0: - raise TTransportException(type=TTransportException.END_OF_FILE, message='TSocket sent 0 bytes') + raise TTransportException(type=TTransportException.END_OF_FILE, + message='TSocket sent 0 bytes') sent += plus buff = buff[plus:] def flush(self): pass + class TServerSocket(TSocketBase, TServerTransportBase): """Socket implementation of TServerTransport base.""" diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py index 12e51a9b..4481371a 100644 --- a/lib/py/src/transport/TTransport.py +++ b/lib/py/src/transport/TTransport.py @@ -18,11 +18,11 @@ # from cStringIO import StringIO -from struct import pack,unpack +from struct import pack, unpack from thrift.Thrift import TException -class TTransportException(TException): +class TTransportException(TException): """Custom Transport Exception class""" UNKNOWN = 0 @@ -35,8 +35,8 @@ class TTransportException(TException): TException.__init__(self, message) self.type = type -class TTransportBase: +class TTransportBase: """Base class for Thrift transport layer.""" def isOpen(self): @@ -55,7 +55,7 @@ class TTransportBase: buff = '' have = 0 while (have < sz): - chunk = self.read(sz-have) + chunk = self.read(sz - have) have += len(chunk) buff += chunk @@ -70,6 +70,7 @@ class TTransportBase: def flush(self): pass + # This class should be thought of as an interface. class CReadableTransport: """base class for transports that are readable from C""" @@ -98,8 +99,8 @@ class CReadableTransport: """ pass -class TServerTransportBase: +class TServerTransportBase: """Base class for Thrift server transports.""" def listen(self): @@ -111,15 +112,15 @@ class TServerTransportBase: def close(self): pass -class TTransportFactoryBase: +class TTransportFactoryBase: """Base class for a Transport Factory""" def getTransport(self, trans): return trans -class TBufferedTransportFactory: +class TBufferedTransportFactory: """Factory transport that builds buffered transports""" def getTransport(self, trans): @@ -127,17 +128,15 @@ class TBufferedTransportFactory: return buffered -class TBufferedTransport(TTransportBase,CReadableTransport): - +class TBufferedTransport(TTransportBase, CReadableTransport): """Class that wraps another transport and buffers its I/O. The implementation uses a (configurable) fixed-size read buffer but buffers all writes until a flush is performed. """ - DEFAULT_BUFFER = 4096 - def __init__(self, trans, rbuf_size = DEFAULT_BUFFER): + def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): self.__trans = trans self.__wbuf = StringIO() self.__rbuf = StringIO("") @@ -188,6 +187,7 @@ class TBufferedTransport(TTransportBase,CReadableTransport): self.__rbuf = StringIO(retstring) return self.__rbuf + class TMemoryBuffer(TTransportBase, CReadableTransport): """Wraps a cStringIO object as a TTransport. @@ -237,8 +237,8 @@ class TMemoryBuffer(TTransportBase, CReadableTransport): # only one shot at reading... raise EOFError() -class TFramedTransportFactory: +class TFramedTransportFactory: """Factory transport that builds framed transports""" def getTransport(self, trans): @@ -247,7 +247,6 @@ class TFramedTransportFactory: class TFramedTransport(TTransportBase, CReadableTransport): - """Class that wraps another transport and frames its I/O when writing.""" def __init__(self, trans,): diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py index b6dcb4e0..3ce3eb22 100644 --- a/lib/py/src/transport/TTwisted.py +++ b/lib/py/src/transport/TTwisted.py @@ -16,6 +16,9 @@ # specific language governing permissions and limitations # under the License. # + +from cStringIO import StringIO + from zope.interface import implements, Interface, Attribute from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ connectionDone @@ -25,7 +28,6 @@ from twisted.python import log from twisted.web import server, resource, http from thrift.transport import TTransport -from cStringIO import StringIO class TMessageSenderTransport(TTransport.TTransportBase): @@ -79,7 +81,7 @@ class ThriftClientProtocol(basic.Int32StringReceiver): self.started.callback(self.client) def connectionLost(self, reason=connectionDone): - for k,v in self.client._reqs.iteritems(): + for k, v in self.client._reqs.iteritems(): tex = TTransport.TTransportException( type=TTransport.TTransportException.END_OF_FILE, message='Connection closed') diff --git a/lib/py/src/transport/TZlibTransport.py b/lib/py/src/transport/TZlibTransport.py index 784d4e1e..a2f42a5d 100644 --- a/lib/py/src/transport/TZlibTransport.py +++ b/lib/py/src/transport/TZlibTransport.py @@ -16,50 +16,49 @@ # specific language governing permissions and limitations # under the License. # -''' -TZlibTransport provides a compressed transport and transport factory + +"""TZlibTransport provides a compressed transport and transport factory class, using the python standard library zlib module to implement data compression. -''' +""" from __future__ import division import zlib from cStringIO import StringIO from TTransport import TTransportBase, CReadableTransport + class TZlibTransportFactory(object): - ''' - Factory transport that builds zlib compressed transports. - + """Factory transport that builds zlib compressed transports. + This factory caches the last single client/transport that it was passed and returns the same TZlibTransport object that was created. - + This caching means the TServer class will get the _same_ transport object for both input and output transports from this factory. (For non-threaded scenarios only, since the cache only holds one object) - + The purpose of this caching is to allocate only one TZlibTransport where only one is really needed (since it must have separate read/write buffers), and makes the statistics from getCompSavings() and getCompRatio() easier to understand. - ''' - + """ # class scoped cache of last transport given and zlibtransport returned _last_trans = None _last_z = None def getTransport(self, trans, compresslevel=9): - '''Wrap a transport , trans, with the TZlibTransport + """Wrap a transport, trans, with the TZlibTransport compressed transport class, returning a new transport to the caller. - + @param compresslevel: The zlib compression level, ranging from 0 (no compression) to 9 (best compression). Defaults to 9. @type compresslevel: int - + This method returns a TZlibTransport which wraps the passed C{trans} TTransport derived instance. - ''' + """ if trans == self._last_trans: return self._last_z ztrans = TZlibTransport(trans, compresslevel) @@ -69,27 +68,24 @@ class TZlibTransportFactory(object): class TZlibTransport(TTransportBase, CReadableTransport): - ''' - Class that wraps a transport with zlib, compressing writes + """Class that wraps a transport with zlib, compressing writes and decompresses reads, using the python standard library zlib module. - ''' - + """ # Read buffer size for the python fastbinary C extension, # the TBinaryProtocolAccelerated class. DEFAULT_BUFFSIZE = 4096 def __init__(self, trans, compresslevel=9): - ''' - Create a new TZlibTransport, wrapping C{trans}, another + """Create a new TZlibTransport, wrapping C{trans}, another TTransport derived object. - + @param trans: A thrift transport object, i.e. a TSocket() object. @type trans: TTransport @param compresslevel: The zlib compression level, ranging from 0 (no compression) to 9 (best compression). Default is 9. @type compresslevel: int - ''' + """ self.__trans = trans self.compresslevel = compresslevel self.__rbuf = StringIO() @@ -98,49 +94,45 @@ class TZlibTransport(TTransportBase, CReadableTransport): self._init_stats() def _reinit_buffers(self): - ''' - Internal method to initialize/reset the internal StringIO objects + """Internal method to initialize/reset the internal StringIO objects for read and write buffers. - ''' + """ self.__rbuf = StringIO() self.__wbuf = StringIO() def _init_stats(self): - ''' - Internal method to reset the internal statistics counters + """Internal method to reset the internal statistics counters for compression ratios and bandwidth savings. - ''' + """ self.bytes_in = 0 self.bytes_out = 0 self.bytes_in_comp = 0 self.bytes_out_comp = 0 def _init_zlib(self): - ''' - Internal method for setting up the zlib compression and + """Internal method for setting up the zlib compression and decompression objects. - ''' + """ self._zcomp_read = zlib.decompressobj() self._zcomp_write = zlib.compressobj(self.compresslevel) def getCompRatio(self): - ''' - Get the current measured compression ratios (in,out) from + """Get the current measured compression ratios (in,out) from this transport. - - Returns a tuple of: + + Returns a tuple of: (inbound_compression_ratio, outbound_compression_ratio) - + The compression ratios are computed as: compressed / uncompressed E.g., data that compresses by 10x will have a ratio of: 0.10 and data that compresses to half of ts original size will have a ratio of 0.5 - + None is returned if no bytes have yet been processed in a particular direction. - ''' + """ r_percent, w_percent = (None, None) if self.bytes_in > 0: r_percent = self.bytes_in_comp / self.bytes_in @@ -149,23 +141,22 @@ class TZlibTransport(TTransportBase, CReadableTransport): return (r_percent, w_percent) def getCompSavings(self): - ''' - Get the current count of saved bytes due to data + """Get the current count of saved bytes due to data compression. - + Returns a tuple of: (inbound_saved_bytes, outbound_saved_bytes) - + Note: if compression is actually expanding your data (only likely with very tiny thrift objects), then the values returned will be negative. - ''' + """ r_saved = self.bytes_in - self.bytes_in_comp w_saved = self.bytes_out - self.bytes_out_comp return (r_saved, w_saved) def isOpen(self): - '''Return the underlying transport's open status''' + """Return the underlying transport's open status""" return self.__trans.isOpen() def open(self): @@ -174,25 +165,24 @@ class TZlibTransport(TTransportBase, CReadableTransport): return self.__trans.open() def listen(self): - '''Invoke the underlying transport's listen() method''' + """Invoke the underlying transport's listen() method""" self.__trans.listen() def accept(self): - '''Accept connections on the underlying transport''' + """Accept connections on the underlying transport""" return self.__trans.accept() def close(self): - '''Close the underlying transport,''' + """Close the underlying transport,""" self._reinit_buffers() self._init_zlib() return self.__trans.close() def read(self, sz): - ''' - Read up to sz bytes from the decompressed bytes buffer, and + """Read up to sz bytes from the decompressed bytes buffer, and read from the underlying transport if the decompression buffer is empty. - ''' + """ ret = self.__rbuf.read(sz) if len(ret) > 0: return ret @@ -204,10 +194,9 @@ class TZlibTransport(TTransportBase, CReadableTransport): return ret def readComp(self, sz): - ''' - Read compressed data from the underlying transport, then + """Read compressed data from the underlying transport, then decompress it and append it to the internal StringIO read buffer - ''' + """ zbuf = self.__trans.read(sz) zbuf = self._zcomp_read.unconsumed_tail + zbuf buf = self._zcomp_read.decompress(zbuf) @@ -220,17 +209,15 @@ class TZlibTransport(TTransportBase, CReadableTransport): return True def write(self, buf): - ''' - Write some bytes, putting them into the internal write + """Write some bytes, putting them into the internal write buffer for eventual compression. - ''' + """ self.__wbuf.write(buf) def flush(self): - ''' - Flush any queued up data in the write buffer and ensure the + """Flush any queued up data in the write buffer and ensure the compression buffer is flushed out to the underlying transport - ''' + """ wout = self.__wbuf.getvalue() if len(wout) > 0: zbuf = self._zcomp_write.compress(wout) @@ -247,11 +234,11 @@ class TZlibTransport(TTransportBase, CReadableTransport): @property def cstringio_buf(self): - '''Implement the CReadableTransport interface''' + """Implement the CReadableTransport interface""" return self.__rbuf def cstringio_refill(self, partialread, reqlen): - '''Implement the CReadableTransport interface for refill''' + """Implement the CReadableTransport interface for refill""" retstring = partialread if reqlen < self.DEFAULT_BUFFSIZE: retstring += self.read(self.DEFAULT_BUFFSIZE) diff --git a/lib/py/src/transport/__init__.py b/lib/py/src/transport/__init__.py index 46e54fe6..c9596d9a 100644 --- a/lib/py/src/transport/__init__.py +++ b/lib/py/src/transport/__init__.py @@ -17,4 +17,4 @@ # under the License. # -__all__ = ['TTransport', 'TSocket', 'THttpClient','TZlibTransport'] +__all__ = ['TTransport', 'TSocket', 'THttpClient', 'TZlibTransport'] -- 2.17.1