Thrift TTransportFactory model for servers

Summary: Servers need to create bufferedtransports etc. around the transports they get in a user-definable way. So use a factory pattern to allow the user to supply an object to the server that defines this behavior.


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664792 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am
index b06e06f..554cb90 100644
--- a/lib/cpp/Makefile.am
+++ b/lib/cpp/Makefile.am
@@ -1,7 +1,7 @@
 lib_LTLIBRARIES = libthrift.la
 
-common_cxxflags = -Isrc $(BOOST_CPPFLAGS)
-common_ldflags = $(BOOST_LDFLAGS)
+common_cxxflags = -Wall -Isrc $(BOOST_CPPFLAGS)
+common_ldflags = -Wall $(BOOST_LDFLAGS)
 
 # Define the source file for the module
 
@@ -54,7 +54,9 @@
                          src/transport/TServerTransport.h \
                          src/transport/TSocket.h \
                          src/transport/TTransport.h \
-                         src/transport/TTransportException.h
+                         src/transport/TTransportException.h \
+                         src/transport/TTransportFactory.h \
+                         src/transport/TBufferedTransportFactory.h
 
 include_serverdir = $(include_thriftdir)/server
 include_server_HEADERS = \
diff --git a/lib/cpp/src/TProcessor.h b/lib/cpp/src/TProcessor.h
index 4cbcd65..f905b1d 100644
--- a/lib/cpp/src/TProcessor.h
+++ b/lib/cpp/src/TProcessor.h
@@ -23,7 +23,8 @@
  public:
   virtual ~TProcessor() {}
   virtual bool process(shared_ptr<TTransport> in, shared_ptr<TTransport> out) = 0;
-  virtual bool process(shared_ptr<TTransport> io) { return process(io, io); }
+  bool process(shared_ptr<TTransport> io) { return process(io, io); }
+
  protected:
   TProcessor() {}
 };
diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h
index d53223f..c19302f 100644
--- a/lib/cpp/src/server/TServer.h
+++ b/lib/cpp/src/server/TServer.h
@@ -1,7 +1,9 @@
-#ifndef T_SERVER_H
-#define T_SERVER_H
+#ifndef _THRIFT_SERVER_TSERVER_H_
+#define _THRIFT_SERVER_TSERVER_H_ 1
 
 #include <TProcessor.h>
+#include <transport/TServerTransport.h>
+#include <transport/TTransportFactory.h>
 #include <concurrency/Thread.h>
 
 #include <boost/shared_ptr.hpp>
@@ -9,6 +11,7 @@
 namespace facebook { namespace thrift { namespace server { 
 
 using namespace facebook::thrift;
+using namespace facebook::thrift::transport;
 using namespace boost;
 
 class TServerOptions;
@@ -24,10 +27,22 @@
   virtual void run() = 0;
   
 protected:
-  TServer(shared_ptr<TProcessor> processor, shared_ptr<TServerOptions> options) :
+  TServer(shared_ptr<TProcessor> processor,
+          shared_ptr<TServerTransport> serverTransport,
+          shared_ptr<TTransportFactory> transportFactory,
+          shared_ptr<TServerOptions> options) :
+    processor_(processor),
+    serverTransport_(serverTransport),
+    transportFactory_(transportFactory),
+    options_(options) {}
+
+  TServer(shared_ptr<TProcessor> processor,
+          shared_ptr<TServerOptions> options) :
     processor_(processor), options_(options) {}
-  
+ 
   shared_ptr<TProcessor> processor_;
+  shared_ptr<TServerTransport> serverTransport_;
+  shared_ptr<TTransportFactory> transportFactory_;
   shared_ptr<TServerOptions> options_;
 };
   
@@ -35,12 +50,12 @@
  * Class to encapsulate all generic server options.
  */
 class TServerOptions {
-public:
+ public:
   // TODO(mcslee): Fill in getters/setters here
-protected:
+ protected:
   // TODO(mcslee): Fill data members in here
 };
 
 }}} // facebook::thrift::server
 
-#endif
+#endif // #ifndef _THRIFT_SERVER_TSERVER_H_
diff --git a/lib/cpp/src/server/TSimpleServer.cc b/lib/cpp/src/server/TSimpleServer.cc
index 2ad5145..041a52f 100644
--- a/lib/cpp/src/server/TSimpleServer.cc
+++ b/lib/cpp/src/server/TSimpleServer.cc
@@ -1,5 +1,4 @@
 #include "server/TSimpleServer.h"
-#include "transport/TBufferedTransport.h"
 #include "transport/TTransportException.h"
 #include <string>
 #include <iostream>
@@ -15,6 +14,7 @@
 void TSimpleServer::run() {
 
   shared_ptr<TTransport> client;
+  pair<shared_ptr<TTransport>,shared_ptr<TTransport> > io;
 
   try {
     // Start the server listening
@@ -25,26 +25,21 @@
   }
 
   // Fetch client from server
-  while (true) {
-    try {
+  try {
+    while (true) {
       client = serverTransport_->accept();
-      if (client != NULL) {
-        // Process for as long as we can keep the processor happy!
-        shared_ptr<TBufferedTransport> bufferedClient(new TBufferedTransport(client));
-        while (processor_->process(bufferedClient)) {}
-      }
-    } catch (TTransportException& ttx) {
-      if (client != NULL) {
+      io = transportFactory_->getIOTransports(client);
+      try {
+        while (processor_->process(io.first, io.second)) {}
+      } catch (TTransportException& ttx) {
         cerr << "TSimpleServer client died: " << ttx.getMessage() << endl;
       }
-    }
-  
-    // Clean up the client
-    if (client != NULL) {
-
-      // Ensure no resource leaks
+      io.first->close();
+      io.second->close();
       client->close();
-     }
+    }
+  } catch (TTransportException& ttx) {
+    cerr << "TServerTransport died on accept: " << ttx.getMessage() << endl;
   }
 
   // TODO(mcslee): Could this be a timeout case? Or always the real thing?
diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h
index a8242d4..973ba30 100644
--- a/lib/cpp/src/server/TSimpleServer.h
+++ b/lib/cpp/src/server/TSimpleServer.h
@@ -1,5 +1,5 @@
-#ifndef T_SIMPLE_SERVER_H
-#define T_SIMPLE_SERVER_H
+#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
+#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1
 
 #include "server/TServer.h"
 #include "transport/TServerTransport.h"
@@ -17,18 +17,17 @@
 class TSimpleServer : public TServer {
  public:
   TSimpleServer(shared_ptr<TProcessor> processor,
-                shared_ptr<TServerOptions> options,
-                shared_ptr<TServerTransport> serverTransport) :
-    TServer(processor, options), serverTransport_(serverTransport) {}
+                shared_ptr<TServerTransport> serverTransport,
+                shared_ptr<TTransportFactory> transportFactory,
+                shared_ptr<TServerOptions> options) :
+    TServer(processor, serverTransport, transportFactory, options) {}
     
   ~TSimpleServer() {}
 
   void run();
 
- protected:
-  shared_ptr<TServerTransport> serverTransport_;
 };
 
 }}} // facebook::thrift::server
 
-#endif
+#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_
diff --git a/lib/cpp/src/server/TThreadPoolServer.cc b/lib/cpp/src/server/TThreadPoolServer.cc
index d53d174..1eab53d 100644
--- a/lib/cpp/src/server/TThreadPoolServer.cc
+++ b/lib/cpp/src/server/TThreadPoolServer.cc
@@ -1,5 +1,4 @@
 #include "server/TThreadPoolServer.h"
-#include "transport/TBufferedTransport.h"
 #include "transport/TTransportException.h"
 #include "concurrency/Thread.h"
 #include "concurrency/ThreadManager.h"
@@ -15,54 +14,52 @@
 class TThreadPoolServer::Task: public Runnable {
     
   shared_ptr<TProcessor> _processor;
-  shared_ptr<TTransport> _transport;
-  shared_ptr<TBufferedTransport> _bufferedTransport;
+  shared_ptr<TTransport> _input;
+  shared_ptr<TTransport> _output;
     
 public:
     
   Task(shared_ptr<TProcessor> processor,
-       shared_ptr<TTransport> transport) :
+       shared_ptr<TTransport> input,
+       shared_ptr<TTransport> output) :
     _processor(processor),
-    _transport(transport),
-    _bufferedTransport(new TBufferedTransport(transport)) {
+    _input(input),
+    _output(output) {
   }
 
   ~Task() {}
     
-  void run() {
-      
+  void run() {     
     while(true) {
-	
       try {
-	_processor->process(_bufferedTransport);
-	
+	_processor->process(_input, _output);
       } catch (TTransportException& ttx) {
-	
-	break;
-	
+        break;
       } catch(...) {
-	
-	break;
+        break;
       }
     }
-    
-    _bufferedTransport->close();
+    _input->close();
+    _output->close();
   }
 };
   
 TThreadPoolServer::TThreadPoolServer(shared_ptr<TProcessor> processor,
-				     shared_ptr<TServerOptions> options,
-				     shared_ptr<TServerTransport> serverTransport,
-				     shared_ptr<ThreadManager> threadManager) :
-  TServer(processor, options), 
-  serverTransport_(serverTransport), 
+                                     shared_ptr<TServerTransport> serverTransport,
+                                     shared_ptr<TTransportFactory> transportFactory,
+                                     shared_ptr<ThreadManager> threadManager,
+                                     shared_ptr<TServerOptions> options) :
+  TServer(processor, serverTransport, transportFactory, options), 
   threadManager_(threadManager) {
 }
-    
+
 TThreadPoolServer::~TThreadPoolServer() {}
 
 void TThreadPoolServer::run() {
 
+  shared_ptr<TTransport> client;
+  pair<shared_ptr<TTransport>,shared_ptr<TTransport> > io;
+
   try {
     // Start the server listening
     serverTransport_->listen();
@@ -71,15 +68,14 @@
     return;
   }
   
-  // Fetch client from server
-  
-  while (true) {
-    
+  while (true) {   
     try {
-      
-      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(processor_, 
-											  shared_ptr<TTransport>(serverTransport_->accept()))));
-      
+      // Fetch client from server
+      client = serverTransport_->accept();
+      // Make IO transports
+      io = transportFactory_->getIOTransports(client);
+      // Add to threadmanager pool
+      threadManager_->add(shared_ptr<TThreadPoolServer::Task>(new TThreadPoolServer::Task(processor_, io.first, io.second)));
     } catch (TTransportException& ttx) {
       break;
     }
diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h
index 827491d..34b216c 100644
--- a/lib/cpp/src/server/TThreadPoolServer.h
+++ b/lib/cpp/src/server/TThreadPoolServer.h
@@ -1,5 +1,5 @@
-#ifndef T_THREADPOOL_SERVER_H
-#define T_THREADPOOL_SERVER_H
+#ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_
+#define _THRIFT_SERVER_TTHREADPOOLSERVER_H_ 1
 
 #include <concurrency/ThreadManager.h>
 #include <server/TServer.h>
@@ -19,9 +19,10 @@
   class Task;
   
   TThreadPoolServer(shared_ptr<TProcessor> processor,
-		    shared_ptr<TServerOptions> options,
 		    shared_ptr<TServerTransport> serverTransport,
-		    shared_ptr<ThreadManager> threadManager);
+		    shared_ptr<TTransportFactory> transportFactory,
+		    shared_ptr<ThreadManager> threadManager,
+		    shared_ptr<TServerOptions> options);
 
   virtual ~TThreadPoolServer();
 
@@ -29,11 +30,10 @@
 
 protected:
 
-  shared_ptr<TServerTransport> serverTransport_;
   shared_ptr<ThreadManager> threadManager_;
   
 };
 
 }}} // facebook::thrift::server
 
-#endif
+#endif // #ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_
diff --git a/lib/cpp/src/transport/TBufferedTransportFactory.h b/lib/cpp/src/transport/TBufferedTransportFactory.h
new file mode 100644
index 0000000..c6e87b1
--- /dev/null
+++ b/lib/cpp/src/transport/TBufferedTransportFactory.h
@@ -0,0 +1,33 @@
+#ifndef _THRIFT_TRANSPORT_TBUFFEREDTRANSPORTFACTORY_H_
+#define _THRIFT_TRANSPORT_TBUFFEREDTRANSPORTFACTORY_H_ 1
+
+#include <transport/TTransportFactory.h>
+#include <transport/TBufferedTransport.h>
+#include <boost/shared_ptr.hpp>
+
+namespace facebook { namespace thrift { namespace transport { 
+
+/**
+ * Wraps a transport into a buffered one.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TBufferedTransportFactory : public TTransportFactory {
+ public:
+  TBufferedTransportFactory() {}
+
+  virtual ~TBufferedTransportFactory() {}
+
+  /**
+   * Wraps the transport into a buffered one.
+   */
+  virtual std::pair<boost::shared_ptr<TTransport>, boost::shared_ptr<TTransport> > getIOTransports(boost::shared_ptr<TTransport> trans) {
+    boost::shared_ptr<TTransport> buffered(new TBufferedTransport(trans));
+    return std::make_pair(buffered, buffered);
+  }
+
+};
+
+}}}
+
+#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTFACTORY_H_
diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h
index 19a2cb6..d65d25b 100644
--- a/lib/cpp/src/transport/TTransport.h
+++ b/lib/cpp/src/transport/TTransport.h
@@ -1,7 +1,7 @@
 #ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_
 #define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1
 
-#include "TTransportException.h"
+#include <transport/TTransportException.h>
 #include <string>
 
 namespace facebook { namespace thrift { namespace transport { 
diff --git a/lib/cpp/src/transport/TTransportFactory.h b/lib/cpp/src/transport/TTransportFactory.h
new file mode 100644
index 0000000..abd1048
--- /dev/null
+++ b/lib/cpp/src/transport/TTransportFactory.h
@@ -0,0 +1,33 @@
+#ifndef _THRIFT_TRANSPORT_TTRANSPORTFACTORY_H_
+#define _THRIFT_TRANSPORT_TTRANSPORTFACTORY_H_ 1
+
+#include <transport/TTransport.h>
+#include <boost/shared_ptr.hpp>
+
+namespace facebook { namespace thrift { namespace transport { 
+
+/**
+ * Generic factory class to make an input and output transport out of a
+ * source transport. Commonly used inside servers to make input and output
+ * streams out of raw clients.
+ *
+ * @author Mark Slee <mcslee@facebook.com>
+ */
+class TTransportFactory {
+ public:
+  TTransportFactory() {}
+
+  virtual ~TTransportFactory() {}
+
+  /**
+   * Default implementation does nothing, just returns the transport given.
+   */
+  virtual std::pair<boost::shared_ptr<TTransport>, boost::shared_ptr<TTransport> > getIOTransports(boost::shared_ptr<TTransport> trans) {
+    return std::make_pair(trans, trans);
+  }
+
+};
+
+}}}
+
+#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTFACTORY_H_