From 0fd37f08716758b283010abfa5162eb2c1aee2ad Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Tue, 8 Feb 2011 17:26:37 +0000 Subject: [PATCH] THRIFT-447. java: Make an abstract base Client class so we can generate less code This patch introduces a handful of abstract, non-generated classes that allow us to generate much less code for service implementations. git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1068487 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_java_generator.cc | 250 ++++-------------- .../org/apache/thrift/ProcessFunction.java | 46 ++++ .../src/org/apache/thrift/TBaseProcessor.java | 37 +++ .../src/org/apache/thrift/TServiceClient.java | 48 +++- 4 files changed, 184 insertions(+), 197 deletions(-) create mode 100644 lib/java/src/org/apache/thrift/ProcessFunction.java create mode 100644 lib/java/src/org/apache/thrift/TBaseProcessor.java diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index b56d7207..e84fd1d5 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -2267,13 +2267,15 @@ void t_java_generator::generate_service_helpers(t_service* tservice) { void t_java_generator::generate_service_client(t_service* tservice) { string extends = ""; string extends_client = ""; - if (tservice->get_extends() != NULL) { + if (tservice->get_extends() == NULL) { + extends_client = "org.apache.thrift.TServiceClient"; + } else { extends = type_name(tservice->get_extends()); - extends_client = " extends " + extends + ".Client"; + extends_client = extends + ".Client"; } indent(f_service_) << - "public static class Client" << extends_client << " implements org.apache.thrift.TServiceClient, Iface {" << endl; + "public static class Client extends " << extends_client << " implements Iface {" << endl; indent_up(); indent(f_service_) << "public static class Factory implements org.apache.thrift.TServiceClientFactory {" << endl; @@ -2296,49 +2298,14 @@ void t_java_generator::generate_service_client(t_service* tservice) { "public Client(org.apache.thrift.protocol.TProtocol prot)" << endl; scope_up(f_service_); indent(f_service_) << - "this(prot, prot);" << endl; + "super(prot, prot);" << endl; scope_down(f_service_); f_service_ << endl; indent(f_service_) << - "public Client(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot)" << endl; - scope_up(f_service_); - if (extends.empty()) { - f_service_ << - indent() << "iprot_ = iprot;" << endl << - indent() << "oprot_ = oprot;" << endl; - } else { - f_service_ << - indent() << "super(iprot, oprot);" << endl; - } - scope_down(f_service_); - f_service_ << endl; - - if (extends.empty()) { - f_service_ << - indent() << "protected org.apache.thrift.protocol.TProtocol iprot_;" << endl << - indent() << "protected org.apache.thrift.protocol.TProtocol oprot_;" << endl << - endl << - indent() << "protected int seqid_;" << endl << - endl; - - indent(f_service_) << - "public org.apache.thrift.protocol.TProtocol getInputProtocol()" << endl; - scope_up(f_service_); - indent(f_service_) << - "return this.iprot_;" << endl; - scope_down(f_service_); - f_service_ << endl; - - indent(f_service_) << - "public org.apache.thrift.protocol.TProtocol getOutputProtocol()" << endl; - scope_up(f_service_); - indent(f_service_) << - "return this.oprot_;" << endl; - scope_down(f_service_); - f_service_ << endl; - - } + "public Client(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) {" << endl; + indent(f_service_) << " super(iprot, oprot);" << endl; + indent(f_service_) << "}" << endl << endl; // Generate client method implementations vector functions = tservice->get_functions(); @@ -2393,19 +2360,14 @@ void t_java_generator::generate_service_client(t_service* tservice) { scope_up(f_service_); // Serialize the request - f_service_ << - indent() << "oprot_.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << funname << "\", org.apache.thrift.protocol.TMessageType.CALL, ++seqid_));" << endl << - indent() << argsname << " args = new " << argsname << "();" << endl; + indent(f_service_) << argsname << " args = new " << argsname << "();" << endl; for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { f_service_ << indent() << "args.set" << get_cap_name((*fld_iter)->get_name()) << "(" << (*fld_iter)->get_name() << ");" << endl; } - f_service_ << - indent() << "args.write(oprot_);" << endl << - indent() << "oprot_.writeMessageEnd();" << endl << - indent() << "oprot_.getTransport().flush();" << endl; + indent(f_service_) << "sendBase(\"" << funname << "\", args);" << endl; scope_down(f_service_); f_service_ << endl; @@ -2424,18 +2386,8 @@ void t_java_generator::generate_service_client(t_service* tservice) { scope_up(f_service_); f_service_ << - indent() << "org.apache.thrift.protocol.TMessage msg = iprot_.readMessageBegin();" << endl << - indent() << "if (msg.type == org.apache.thrift.protocol.TMessageType.EXCEPTION) {" << endl << - indent() << " org.apache.thrift.TApplicationException x = org.apache.thrift.TApplicationException.read(iprot_);" << endl << - indent() << " iprot_.readMessageEnd();" << endl << - indent() << " throw x;" << endl << - indent() << "}" << endl << - indent() << "if (msg.seqid != seqid_) {" << endl << - indent() << " throw new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.BAD_SEQUENCE_ID, \"" << (*f_iter)->get_name() << " failed: out of sequence response\");" << endl << - indent() << "}" << endl << indent() << resultname << " result = new " << resultname << "();" << endl << - indent() << "result.read(iprot_);" << endl << - indent() << "iprot_.readMessageEnd();" << endl; + indent() << "receiveBase(result, \"" << funname << "\");" << endl; // Careful, only return _result if not a void function if (!(*f_iter)->get_returntype()->is_void()) { @@ -2620,83 +2572,36 @@ void t_java_generator::generate_service_server(t_service* tservice) { // Extends stuff string extends = ""; string extends_processor = ""; - if (tservice->get_extends() != NULL) { + if (tservice->get_extends() == NULL) { + extends_processor = "org.apache.thrift.TBaseProcessor"; + } else { extends = type_name(tservice->get_extends()); - extends_processor = " extends " + extends + ".Processor"; + extends_processor = extends + ".Processor"; } // Generate the header portion indent(f_service_) << - "public static class Processor" << extends_processor << " implements org.apache.thrift.TProcessor {" << endl; + "public static class Processor extends " << extends_processor << " implements org.apache.thrift.TProcessor {" << endl; indent_up(); indent(f_service_) << "private static final Logger LOGGER = LoggerFactory.getLogger(Processor.class.getName());" << endl; - indent(f_service_) << - "public Processor(Iface iface)" << endl; - scope_up(f_service_); - if (!extends.empty()) { - f_service_ << - indent() << "super(iface);" << endl; - } - f_service_ << - indent() << "iface_ = iface;" << endl; - - for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { - f_service_ << - indent() << "processMap_.put(\"" << (*f_iter)->get_name() << "\", new " << (*f_iter)->get_name() << "());" << endl; - } - - scope_down(f_service_); - f_service_ << endl; - - if (extends.empty()) { - f_service_ << - indent() << "protected static interface ProcessFunction {" << endl << - indent() << " public void process(int seqid, org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException;" << endl << - indent() << "}" << endl << - endl; - } + indent(f_service_) << "public Processor(I iface) {" << endl; + indent(f_service_) << " super(iface, getProcessMap(new HashMap>()));" << endl; + indent(f_service_) << "}" << endl << endl; - f_service_ << - indent() << "private Iface iface_;" << endl; + indent(f_service_) << "protected Processor(I iface, Map> processMap) {" << endl; + indent(f_service_) << " super(iface, getProcessMap(processMap));" << endl; + indent(f_service_) << "}" << endl << endl; - if (extends.empty()) { - f_service_ << - indent() << "protected final HashMap processMap_ = new HashMap();" << endl; + indent(f_service_) << "private static Map> getProcessMap(Map> processMap) {" << endl; + indent_up(); + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + indent(f_service_) << "processMap.put(\"" << (*f_iter)->get_name() << "\", new " << (*f_iter)->get_name() << "());" << endl; } - - f_service_ << endl; - - // Generate the server implementation - indent(f_service_) << - "public boolean process(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException" << endl; - scope_up(f_service_); - - f_service_ << - indent() << "org.apache.thrift.protocol.TMessage msg = iprot.readMessageBegin();" << endl; - - // TODO(mcslee): validate message, was the seqid etc. legit? - - f_service_ << - indent() << "ProcessFunction fn = processMap_.get(msg.name);" << endl << - indent() << "if (fn == null) {" << endl << - indent() << " org.apache.thrift.protocol.TProtocolUtil.skip(iprot, org.apache.thrift.protocol.TType.STRUCT);" << endl << - indent() << " iprot.readMessageEnd();" << endl << - indent() << " org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.UNKNOWN_METHOD, \"Invalid method name: '\"+msg.name+\"'\");" << endl << - indent() << " oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(msg.name, org.apache.thrift.protocol.TMessageType.EXCEPTION, msg.seqid));" << endl << - indent() << " x.write(oprot);" << endl << - indent() << " oprot.writeMessageEnd();" << endl << - indent() << " oprot.getTransport().flush();" << endl << - indent() << " return true;" << endl << - indent() << "}" << endl << - indent() << "fn.process(msg.seqid, iprot, oprot);" << endl; - - f_service_ << - indent() << "return true;" << endl; - - scope_down(f_service_); - f_service_ << endl; + indent(f_service_) << "return processMap;" << endl; + indent_down(); + indent(f_service_) << "}" << endl << endl; // Generate the process subfunctions for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { @@ -2704,9 +2609,7 @@ void t_java_generator::generate_service_server(t_service* tservice) { } indent_down(); - indent(f_service_) << - "}" << endl << - endl; + indent(f_service_) << "}" << endl << endl; } /** @@ -2742,53 +2645,36 @@ void t_java_generator::generate_function_helpers(t_function* tfunction) { */ void t_java_generator::generate_process_function(t_service* tservice, t_function* tfunction) { + string argsname = tfunction->get_name() + "_args"; + string resultname = tfunction->get_name() + "_result"; + if (tfunction->is_oneway()) { + resultname = "org.apache.thrift.TBase"; + } + (void) tservice; // Open class indent(f_service_) << - "private class " << tfunction->get_name() << " implements ProcessFunction {" << endl; + "private static class " << tfunction->get_name() << " extends org.apache.thrift.ProcessFunction {" << endl; indent_up(); - // Open function - indent(f_service_) << - "public void process(int seqid, org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException" << endl; - scope_up(f_service_); + indent(f_service_) << "public " << tfunction->get_name() << "() {" << endl; + indent(f_service_) << " super(\"" << tfunction->get_name() << "\");" << endl; + indent(f_service_) << "}" << endl << endl; - string argsname = tfunction->get_name() + "_args"; - string resultname = tfunction->get_name() + "_result"; + indent(f_service_) << "protected " << argsname << " getEmptyArgsInstance() {" << endl; + indent(f_service_) << " return new " << argsname << "();" << endl; + indent(f_service_) << "}" << endl << endl; - f_service_ << - indent() << argsname << " args = new " << argsname << "();" << endl << - indent() << "try {" << endl; - indent_up(); - f_service_ << - indent() << "args.read(iprot);" << endl; - indent_down(); - f_service_ << - indent() << "} catch (org.apache.thrift.protocol.TProtocolException e) {" << endl; + indent(f_service_) << "protected " << resultname << " getResult(I iface, " << argsname << " args) throws org.apache.thrift.TException {" << endl; indent_up(); - f_service_ << - indent() << "iprot.readMessageEnd();" << endl << - indent() << "org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.PROTOCOL_ERROR, e.getMessage());" << endl << - indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.EXCEPTION, seqid));" << endl << - indent() << "x.write(oprot);" << endl << - indent() << "oprot.writeMessageEnd();" << endl << - indent() << "oprot.getTransport().flush();" << endl << - indent() << "return;" << endl; - indent_down(); - f_service_ << indent() << "}" << endl; - f_service_ << - indent() << "iprot.readMessageEnd();" << endl; + if (!tfunction->is_oneway()) { + indent(f_service_) << resultname << " result = new " << resultname << "();" << endl; + } t_struct* xs = tfunction->get_xceptions(); const std::vector& xceptions = xs->get_members(); vector::const_iterator x_iter; - // Declare result for non oneway function - if (!tfunction->is_oneway()) { - f_service_ << - indent() << resultname << " result = new " << resultname << "();" << endl; - } - // Try block for a function with exceptions if (xceptions.size() > 0) { f_service_ << @@ -2800,13 +2686,13 @@ void t_java_generator::generate_process_function(t_service* tservice, t_struct* arg_struct = tfunction->get_arglist(); const std::vector& fields = arg_struct->get_members(); vector::const_iterator f_iter; - f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { f_service_ << "result.success = "; } f_service_ << - "iface_." << tfunction->get_name() << "("; + "iface." << tfunction->get_name() << "("; bool first = true; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { if (first) { @@ -2839,42 +2725,18 @@ void t_java_generator::generate_process_function(t_service* tservice, f_service_ << "}"; } } - f_service_ << " catch (Throwable th) {" << endl; - indent_up(); - f_service_ << - indent() << "LOGGER.error(\"Internal error processing " << tfunction->get_name() << "\", th);" << endl << - indent() << "org.apache.thrift.TApplicationException x = new org.apache.thrift.TApplicationException(org.apache.thrift.TApplicationException.INTERNAL_ERROR, \"Internal error processing " << tfunction->get_name() << "\");" << endl << - indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.EXCEPTION, seqid));" << endl << - indent() << "x.write(oprot);" << endl << - indent() << "oprot.writeMessageEnd();" << endl << - indent() << "oprot.getTransport().flush();" << endl << - indent() << "return;" << endl; - indent_down(); - f_service_ << indent() << "}" << endl; + f_service_ << endl; } - // Shortcut out here for oneway functions if (tfunction->is_oneway()) { - f_service_ << - indent() << "return;" << endl; - scope_down(f_service_); - - // Close class - indent_down(); - f_service_ << - indent() << "}" << endl << - endl; - return; + indent(f_service_) << "return null;" << endl; + } else { + indent(f_service_) << "return result;" << endl; } - - f_service_ << - indent() << "oprot.writeMessageBegin(new org.apache.thrift.protocol.TMessage(\"" << tfunction->get_name() << "\", org.apache.thrift.protocol.TMessageType.REPLY, seqid));" << endl << - indent() << "result.write(oprot);" << endl << - indent() << "oprot.writeMessageEnd();" << endl << - indent() << "oprot.getTransport().flush();" << endl; + indent_down(); + indent(f_service_) << "}"; // Close function - scope_down(f_service_); f_service_ << endl; // Close class diff --git a/lib/java/src/org/apache/thrift/ProcessFunction.java b/lib/java/src/org/apache/thrift/ProcessFunction.java new file mode 100644 index 00000000..e0cdc7b8 --- /dev/null +++ b/lib/java/src/org/apache/thrift/ProcessFunction.java @@ -0,0 +1,46 @@ +/** + * + */ +package org.apache.thrift; + +import org.apache.thrift.protocol.TMessage; +import org.apache.thrift.protocol.TMessageType; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolException; + +public abstract class ProcessFunction { + private final String methodName; + + public ProcessFunction(String methodName) { + this.methodName = methodName; + } + + public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface) throws TException { + T args = getEmptyArgsInstance(); + try { + args.read(iprot); + } catch (TProtocolException e) { + iprot.readMessageEnd(); + TApplicationException x = new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage()); + oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid)); + x.write(oprot); + oprot.writeMessageEnd(); + oprot.getTransport().flush(); + return; + } + iprot.readMessageEnd(); + TBase result = getResult(iface, args); + oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.REPLY, seqid)); + result.write(oprot); + oprot.writeMessageEnd(); + oprot.getTransport().flush(); + } + + protected abstract TBase getResult(I iface, T args) throws TException; + + protected abstract T getEmptyArgsInstance(); + + public String getMethodName() { + return methodName; + } +} \ No newline at end of file diff --git a/lib/java/src/org/apache/thrift/TBaseProcessor.java b/lib/java/src/org/apache/thrift/TBaseProcessor.java new file mode 100644 index 00000000..f93b1336 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TBaseProcessor.java @@ -0,0 +1,37 @@ +package org.apache.thrift; + +import java.util.Map; + +import org.apache.thrift.protocol.TMessage; +import org.apache.thrift.protocol.TMessageType; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolUtil; +import org.apache.thrift.protocol.TType; + +public abstract class TBaseProcessor implements TProcessor { + private final I iface; + private final Map> processMap; + + protected TBaseProcessor(I iface, Map> processFunctionMap) { + this.iface = iface; + this.processMap = processFunctionMap; + } + + @Override + public boolean process(TProtocol in, TProtocol out) throws TException { + TMessage msg = in.readMessageBegin(); + ProcessFunction fn = processMap.get(msg.name); + if (fn == null) { + TProtocolUtil.skip(in, TType.STRUCT); + in.readMessageEnd(); + TApplicationException x = new TApplicationException(TApplicationException.UNKNOWN_METHOD, "Invalid method name: '"+msg.name+"'"); + out.writeMessageBegin(new TMessage(msg.name, TMessageType.EXCEPTION, msg.seqid)); + x.write(out); + out.writeMessageEnd(); + out.getTransport().flush(); + return true; + } + fn.process(msg.seqid, in, out, iface); + return true; + } +} diff --git a/lib/java/src/org/apache/thrift/TServiceClient.java b/lib/java/src/org/apache/thrift/TServiceClient.java index ee07b782..c70e66f8 100644 --- a/lib/java/src/org/apache/thrift/TServiceClient.java +++ b/lib/java/src/org/apache/thrift/TServiceClient.java @@ -19,21 +19,63 @@ package org.apache.thrift; +import org.apache.thrift.protocol.TMessage; +import org.apache.thrift.protocol.TMessageType; import org.apache.thrift.protocol.TProtocol; /** * A TServiceClient is used to communicate with a TService implementation * across protocols and transports. */ -public interface TServiceClient { +public abstract class TServiceClient { + public TServiceClient(TProtocol prot) { + this(prot, prot); + } + + public TServiceClient(TProtocol iprot, TProtocol oprot) { + iprot_ = iprot; + oprot_ = oprot; + } + + protected TProtocol iprot_; + protected TProtocol oprot_; + + protected int seqid_; + /** * Get the TProtocol being used as the input (read) protocol. * @return */ - public TProtocol getInputProtocol(); + public TProtocol getInputProtocol() { + return this.iprot_; + } + /** * Get the TProtocol being used as the output (write) protocol. * @return */ - public TProtocol getOutputProtocol(); + public TProtocol getOutputProtocol() { + return this.oprot_; + } + + protected void sendBase(String methodName, TBase args) throws TException { + oprot_.writeMessageBegin(new TMessage(methodName, TMessageType.CALL, ++seqid_)); + args.write(oprot_); + oprot_.writeMessageEnd(); + oprot_.getTransport().flush(); + } + + protected void receiveBase(TBase result, String methodName) throws TException { + TMessage msg = iprot_.readMessageBegin(); + if (msg.type == TMessageType.EXCEPTION) { + TApplicationException x = TApplicationException.read(iprot_); + iprot_.readMessageEnd(); + throw x; + } + if (msg.seqid != seqid_) { + throw new TApplicationException(TApplicationException.BAD_SEQUENCE_ID, methodName + " failed: out of sequence response"); + } + result.read(iprot_); + iprot_.readMessageEnd(); + } } -- 2.17.1