blob: 6152911142be60c2e8bb1a03194544b129631b72 [file] [log] [blame]
Gavin McDonald0b75e1a2010-10-28 02:12:01 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import logging
21import sys
22import os
23import traceback
24import threading
25import Queue
26
27from thrift.Thrift import TProcessor
28from thrift.transport import TTransport
29from thrift.protocol import TBinaryProtocol
30
31class TServer:
32
33 """Base interface for a server, which must have a serve method."""
34
35 """ 3 constructors for all servers:
36 1) (processor, serverTransport)
37 2) (processor, serverTransport, transportFactory, protocolFactory)
38 3) (processor, serverTransport,
39 inputTransportFactory, outputTransportFactory,
40 inputProtocolFactory, outputProtocolFactory)"""
41 def __init__(self, *args):
42 if (len(args) == 2):
43 self.__initArgs__(args[0], args[1],
44 TTransport.TTransportFactoryBase(),
45 TTransport.TTransportFactoryBase(),
46 TBinaryProtocol.TBinaryProtocolFactory(),
47 TBinaryProtocol.TBinaryProtocolFactory())
48 elif (len(args) == 4):
49 self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
50 elif (len(args) == 6):
51 self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
52
53 def __initArgs__(self, processor, serverTransport,
54 inputTransportFactory, outputTransportFactory,
55 inputProtocolFactory, outputProtocolFactory):
56 self.processor = processor
57 self.serverTransport = serverTransport
58 self.inputTransportFactory = inputTransportFactory
59 self.outputTransportFactory = outputTransportFactory
60 self.inputProtocolFactory = inputProtocolFactory
61 self.outputProtocolFactory = outputProtocolFactory
62
63 def serve(self):
64 pass
65
66class TSimpleServer(TServer):
67
68 """Simple single-threaded server that just pumps around one transport."""
69
70 def __init__(self, *args):
71 TServer.__init__(self, *args)
72
73 def serve(self):
74 self.serverTransport.listen()
75 while True:
76 client = self.serverTransport.accept()
77 itrans = self.inputTransportFactory.getTransport(client)
78 otrans = self.outputTransportFactory.getTransport(client)
79 iprot = self.inputProtocolFactory.getProtocol(itrans)
80 oprot = self.outputProtocolFactory.getProtocol(otrans)
81 try:
82 while True:
83 self.processor.process(iprot, oprot)
84 except TTransport.TTransportException, tx:
85 pass
86 except Exception, x:
87 logging.exception(x)
88
89 itrans.close()
90 otrans.close()
91
92class TThreadedServer(TServer):
93
94 """Threaded server that spawns a new thread per each connection."""
95
96 def __init__(self, *args):
97 TServer.__init__(self, *args)
98
99 def serve(self):
100 self.serverTransport.listen()
101 while True:
102 try:
103 client = self.serverTransport.accept()
104 t = threading.Thread(target = self.handle, args=(client,))
105 t.start()
106 except KeyboardInterrupt:
107 raise
108 except Exception, x:
109 logging.exception(x)
110
111 def handle(self, client):
112 itrans = self.inputTransportFactory.getTransport(client)
113 otrans = self.outputTransportFactory.getTransport(client)
114 iprot = self.inputProtocolFactory.getProtocol(itrans)
115 oprot = self.outputProtocolFactory.getProtocol(otrans)
116 try:
117 while True:
118 self.processor.process(iprot, oprot)
119 except TTransport.TTransportException, tx:
120 pass
121 except Exception, x:
122 logging.exception(x)
123
124 itrans.close()
125 otrans.close()
126
127class TThreadPoolServer(TServer):
128
129 """Server with a fixed size pool of threads which service requests."""
130
131 def __init__(self, *args):
132 TServer.__init__(self, *args)
133 self.clients = Queue.Queue()
134 self.threads = 10
135
136 def setNumThreads(self, num):
137 """Set the number of worker threads that should be created"""
138 self.threads = num
139
140 def serveThread(self):
141 """Loop around getting clients from the shared queue and process them."""
142 while True:
143 try:
144 client = self.clients.get()
145 self.serveClient(client)
146 except Exception, x:
147 logging.exception(x)
148
149 def serveClient(self, client):
150 """Process input/output from a client for as long as possible"""
151 itrans = self.inputTransportFactory.getTransport(client)
152 otrans = self.outputTransportFactory.getTransport(client)
153 iprot = self.inputProtocolFactory.getProtocol(itrans)
154 oprot = self.outputProtocolFactory.getProtocol(otrans)
155 try:
156 while True:
157 self.processor.process(iprot, oprot)
158 except TTransport.TTransportException, tx:
159 pass
160 except Exception, x:
161 logging.exception(x)
162
163 itrans.close()
164 otrans.close()
165
166 def serve(self):
167 """Start a fixed number of worker threads and put client into a queue"""
168 for i in range(self.threads):
169 try:
170 t = threading.Thread(target = self.serveThread)
171 t.start()
172 except Exception, x:
173 logging.exception(x)
174
175 # Pump the socket for clients
176 self.serverTransport.listen()
177 while True:
178 try:
179 client = self.serverTransport.accept()
180 self.clients.put(client)
181 except Exception, x:
182 logging.exception(x)
183
184
185class TForkingServer(TServer):
186
187 """A Thrift server that forks a new process for each request"""
188 """
189 This is more scalable than the threaded server as it does not cause
190 GIL contention.
191
192 Note that this has different semantics from the threading server.
193 Specifically, updates to shared variables will no longer be shared.
194 It will also not work on windows.
195
196 This code is heavily inspired by SocketServer.ForkingMixIn in the
197 Python stdlib.
198 """
199
200 def __init__(self, *args):
201 TServer.__init__(self, *args)
202 self.children = []
203
204 def serve(self):
205 def try_close(file):
206 try:
207 file.close()
208 except IOError, e:
209 logging.warning(e, exc_info=True)
210
211
212 self.serverTransport.listen()
213 while True:
214 client = self.serverTransport.accept()
215 try:
216 pid = os.fork()
217
218 if pid: # parent
219 # add before collect, otherwise you race w/ waitpid
220 self.children.append(pid)
221 self.collect_children()
222
223 # Parent must close socket or the connection may not get
224 # closed promptly
225 itrans = self.inputTransportFactory.getTransport(client)
226 otrans = self.outputTransportFactory.getTransport(client)
227 try_close(itrans)
228 try_close(otrans)
229 else:
230 itrans = self.inputTransportFactory.getTransport(client)
231 otrans = self.outputTransportFactory.getTransport(client)
232
233 iprot = self.inputProtocolFactory.getProtocol(itrans)
234 oprot = self.outputProtocolFactory.getProtocol(otrans)
235
236 ecode = 0
237 try:
238 try:
239 while True:
240 self.processor.process(iprot, oprot)
241 except TTransport.TTransportException, tx:
242 pass
243 except Exception, e:
244 logging.exception(e)
245 ecode = 1
246 finally:
247 try_close(itrans)
248 try_close(otrans)
249
250 os._exit(ecode)
251
252 except TTransport.TTransportException, tx:
253 pass
254 except Exception, x:
255 logging.exception(x)
256
257
258 def collect_children(self):
259 while self.children:
260 try:
261 pid, status = os.waitpid(0, os.WNOHANG)
262 except os.error:
263 pid = None
264
265 if pid:
266 self.children.remove(pid)
267 else:
268 break
269
270