From 74421273ad97359402556428f23afacfe31ce978 Mon Sep 17 00:00:00 2001 From: David Reiss Date: Fri, 7 Nov 2008 23:09:31 +0000 Subject: [PATCH] THRIFT-67. python: Add TNonblockingServer This TNonblockingServer is very similar to the C++ implementation. It assumes the framed transport, but it uses select instead of libevent. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@712306 13f79535-47bb-0310-9956-ffa450edef68 --- lib/py/src/server/TNonblockingServer.py | 291 ++++++++++++++++++++++++ lib/py/src/server/__init__.py | 2 +- test/py/RunClientServer.py | 10 +- test/py/TestClient.py | 39 ++-- test/py/TestServer.py | 28 ++- 5 files changed, 349 insertions(+), 21 deletions(-) create mode 100644 lib/py/src/server/TNonblockingServer.py diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py new file mode 100644 index 00000000..a588fe34 --- /dev/null +++ b/lib/py/src/server/TNonblockingServer.py @@ -0,0 +1,291 @@ +"""Implementation of non-blocking server. + +The main idea of the server is reciving and sending requests +only from main thread. + +It also makes thread pool server in tasks terms, not connections. +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +class Worker(threading.Thread): + """Worker is a small helper to process incoming connection.""" + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + + def run(self): + """Process queries from task queue, stop if processor is None.""" + while True: + try: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break + processor.process(iprot, oprot) + callback(True, otrans.getvalue()) + except Exception: + logging.exception("Exception while processing request") + callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + +def locked(func): + "Decorator which locks self.lock." + def nested(self, *args, **kwargs): + self.lock.acquire() + try: + return func(self, *args, **kwargs) + finally: + self.lock.release() + return nested + +def socket_exception(func): + "Decorator close object on socket.error." + def read(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except socket.error: + self.close() + return read + +class Connection: + """Basic class is represented connection. + + It can be in state: + WAIT_LEN --- connection is reading request len. + WAIT_MESSAGE --- connection is reading request. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. + SEND_ANSWER --- connection is sending answer string (including length + of answer). + CLOSED --- socket was closed and connection should be deleted. + """ + def __init__(self, new_socket, wake_up): + self.socket = new_socket + self.socket.setblocking(False) + self.status = WAIT_LEN + self.len = 0 + self.message = '' + self.lock = threading.Lock() + self.wake_up = wake_up + + def _read_len(self): + """Reads length of request. + + It's really paranoic routine and it may be replaced by + self.socket.recv(4).""" + read = self.socket.recv(4 - len(self.message)) + if len(read) == 0: + # if we read 0 bytes and self.message is empty, it means client close + # connection + if len(self.message) != 0: + logging.error("can't read frame size from socket") + self.close() + return + self.message += read + if len(self.message) == 4: + self.len, = struct.unpack('!i', self.message) + if self.len < 0: + logging.error("negative frame size, it seems client"\ + " doesn't use FramedTransport") + self.close() + elif self.len == 0: + logging.error("empty frame, it's really strange") + self.close() + else: + self.message = '' + self.status = WAIT_MESSAGE + + @socket_exception + def read(self): + """Reads data from stream and switch state.""" + assert self.status in (WAIT_LEN, WAIT_MESSAGE) + if self.status == WAIT_LEN: + self._read_len() + # go back to the main loop here for simplicity instead of + # falling through, even though there is a good chance that + # the message is already available + elif self.status == WAIT_MESSAGE: + read = self.socket.recv(self.len - len(self.message)) + if len(read) == 0: + logging.error("can't read frame from socket (get %d of %d bytes)" % + (len(self.message), self.len)) + self.close() + return + self.message += read + if len(self.message) == self.len: + self.status = WAIT_PROCESS + + @socket_exception + def write(self): + """Writes data from socket and switch state.""" + assert self.status == SEND_ANSWER + sent = self.socket.send(self.message) + if sent == len(self.message): + self.status = WAIT_LEN + self.message = '' + self.len = 0 + else: + self.message = self.message[sent:] + + @locked + def ready(self, all_ok, message): + """Callback function for switching state and waking up main thread. + + This function is the only function witch can be called asynchronous. + + The ready can switch Connection to three states: + WAIT_LEN if request was async. + SEND_ANSWER if request was processed in normal way. + CLOSED if request throws unexpected exception. + + The one wakes up main thread. + """ + assert self.status == WAIT_PROCESS + if not all_ok: + self.close() + self.wake_up() + return + self.len = '' + self.message = struct.pack('!i', len(message)) + message + if len(message) == 0: + # it was async request, do not write answer + self.status = WAIT_LEN + else: + self.status = SEND_ANSWER + self.wake_up() + + @locked + def is_writeable(self): + "Returns True if connection should be added to write list of select." + return self.status == SEND_ANSWER + + # it's not necessary, but... + @locked + def is_readable(self): + "Returns True if connection should be added to read list of select." + return self.status in (WAIT_LEN, WAIT_MESSAGE) + + @locked + def is_closed(self): + "Returns True if connection is closed." + return self.status == CLOSED + + def fileno(self): + "Returns the file descriptor of the associated socket." + return self.socket.fileno() + + def close(self): + "Closes connection" + self.status = CLOSED + self.socket.close() + +class TNonblockingServer: + """Non-blocking server.""" + def __init__(self, processor, lsocket, inputProtocolFactory=None, + outputProtocolFactory=None, threads=10): + self.processor = processor + self.socket = lsocket + self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() + self.out_protocol = outputProtocolFactory or self.in_protocol + self.threads = int(threads) + self.clients = {} + self.tasks = Queue.Queue() + self._read, self._write = socket.socketpair() + self.prepared = False + + def setNumThreads(self, num): + """Set the number of worker threads that should be created.""" + # implement ThreadPool interface + assert not self.prepared, "You can't change number of threads for working server" + self.threads = num + + def prepare(self): + """Prepares server for serve requests.""" + self.socket.listen() + for _ in xrange(self.threads): + thread = Worker(self.tasks) + thread.setDaemon(True) + thread.start() + self.prepared = True + + def wake_up(self): + """Wake up main thread. + + The server usualy waits in select call in we should terminate one. + The simplest way is using socketpair. + + Select always wait to read from the first socket of socketpair. + + In this case, we can just write anything to the second socket from + socketpair.""" + self._write.send('1') + + def _select(self): + """Does select on open connections.""" + readable = [self.socket.handle.fileno(), self._read.fileno()] + writable = [] + for i, connection in self.clients.items(): + if connection.is_readable(): + readable.append(connection.fileno()) + if connection.is_writeable(): + writable.append(connection.fileno()) + if connection.is_closed(): + del self.clients[i] + return select.select(readable, writable, readable) + + def handle(self): + """Handle requests. + + WARNING! You must call prepare BEFORE calling handle. + """ + assert self.prepared, "You have to call prepare before handle" + rset, wset, xset = self._select() + for readable in rset: + if readable == self._read.fileno(): + # don't care i just need to clean readable flag + self._read.recv(1024) + elif readable == self.socket.handle.fileno(): + client = self.socket.accept().handle + self.clients[client.fileno()] = Connection(client, self.wake_up) + else: + connection = self.clients[readable] + connection.read() + if connection.status == WAIT_PROCESS: + itransport = TTransport.TMemoryBuffer(connection.message) + otransport = TTransport.TMemoryBuffer() + iprot = self.in_protocol.getProtocol(itransport) + oprot = self.out_protocol.getProtocol(otransport) + self.tasks.put([self.processor, iprot, oprot, + otransport, connection.ready]) + for writeable in wset: + self.clients[writeable].write() + for oob in xset: + self.clients[oob].close() + del self.clients[oob] + + def close(self): + """Closes the server.""" + for _ in xrange(self.threads): + self.tasks.put([None, None, None, None, None]) + self.socket.close() + self.prepared = False + + def serve(self): + """Serve forever.""" + self.prepare() + while True: + self.handle() diff --git a/lib/py/src/server/__init__.py b/lib/py/src/server/__init__.py index b4b46a1f..f017abd4 100644 --- a/lib/py/src/server/__init__.py +++ b/lib/py/src/server/__init__.py @@ -4,4 +4,4 @@ # See accompanying file LICENSE or visit the Thrift site at: # http://developers.facebook.com/thrift/ -__all__ = ['TServer'] +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py index cbff3729..48eadb66 100755 --- a/test/py/RunClientServer.py +++ b/test/py/RunClientServer.py @@ -9,12 +9,16 @@ import signal def relfile(fname): return os.path.join(os.path.dirname(__file__), fname) +FRAMED = ["TNonblockingServer"] + def runTest(server_class): print "Testing ", server_class serverproc = subprocess.Popen([sys.executable, relfile("TestServer.py"), server_class]) try: - - ret = subprocess.call([sys.executable, relfile("TestClient.py")]) + argv = [sys.executable, relfile("TestClient.py")] + if server_class in FRAMED: + argv.append('--framed') + ret = subprocess.call(argv) if ret != 0: raise Exception("subprocess failed") finally: @@ -25,4 +29,4 @@ def runTest(server_class): time.sleep(5) map(runTest, ["TForkingServer", "TThreadPoolServer", - "TThreadedServer", "TSimpleServer"]) + "TThreadedServer", "TSimpleServer", "TNonblockingServer"]) diff --git a/test/py/TestClient.py b/test/py/TestClient.py index fb0133a5..78dc80a1 100755 --- a/test/py/TestClient.py +++ b/test/py/TestClient.py @@ -15,24 +15,29 @@ from optparse import OptionParser parser = OptionParser() - -parser.add_option("--port", type="int", dest="port", default=9090) -parser.add_option("--host", type="string", dest="host", default='localhost') -parser.add_option("--framed-input", action="store_true", dest="framed_input") -parser.add_option("--framed-output", action="store_false", dest="framed_output") - -(options, args) = parser.parse_args() +parser.set_defaults(framed=False, verbose=1, host='localhost', port=9090) +parser.add_option("--port", type="int", dest="port", + help="connect to server at port") +parser.add_option("--host", type="string", dest="host", + help="connect to server") +parser.add_option("--framed", action="store_true", dest="framed", + help="use framed transport") +parser.add_option('-v', '--verbose', action="store_const", + dest="verbose", const=2, + help="verbose output") +parser.add_option('-q', '--quiet', action="store_const", + dest="verbose", const=0, + help="minimal output") + +options, args = parser.parse_args() class AbstractTest(unittest.TestCase): - def setUp(self): - global options - socket = TSocket.TSocket(options.host, options.port) # Frame or buffer depending upon args - if options.framed_input or options.framed_output: - self.transport = TTransport.TFramedTransport(socket, options.framed_input, options.framed_output) + if options.framed: + self.transport = TTransport.TFramedTransport(socket) else: self.transport = TTransport.TBufferedTransport(socket) @@ -113,5 +118,13 @@ def suite(): suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) return suite +class OwnArgsTestProgram(unittest.TestProgram): + def parseArgs(self, argv): + if args: + self.testNames = args + else: + self.testNames = (self.defaultTest,) + self.createTests() + if __name__ == "__main__": - unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) + OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestServer.py b/test/py/TestServer.py index 0247bc27..a7bf6d06 100755 --- a/test/py/TestServer.py +++ b/test/py/TestServer.py @@ -9,7 +9,7 @@ from ThriftTest.ttypes import * from thrift.transport import TTransport from thrift.transport import TSocket from thrift.protocol import TBinaryProtocol -from thrift.server import TServer +from thrift.server import TServer, TNonblockingServer class TestHandler: @@ -59,13 +59,33 @@ class TestHandler: time.sleep(seconds) print 'done sleeping' + def testNest(self, thing): + return thing + + def testMap(self, thing): + return thing + + def testSet(self, thing): + return thing + + def testList(self, thing): + return thing + + def testEnum(self, thing): + return thing + + def testTypedef(self, thing): + return thing + handler = TestHandler() processor = ThriftTest.Processor(handler) transport = TSocket.TServerSocket(9090) tfactory = TTransport.TBufferedTransportFactory() pfactory = TBinaryProtocol.TBinaryProtocolFactory() -ServerClass = getattr(TServer, sys.argv[1]) - -server = ServerClass(processor, transport, tfactory, pfactory) +if sys.argv[1] == "TNonblockingServer": + server = TNonblockingServer.TNonblockingServer(processor, transport) +else: + ServerClass = getattr(TServer, sys.argv[1]) + server = ServerClass(processor, transport, tfactory, pfactory) server.serve() -- 2.17.1