From d58ccec66090afbbef68471cb635ad731ef03319 Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Wed, 26 May 2010 16:34:48 +0000 Subject: [PATCH] THRIFT-768. java: Async client for Java This patch adds an implementation of a fully-asynchronous client that makes use of NIO. Stubs for the async method calls are generated along with the existing synchronous ones. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@948492 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_java_generator.cc | 243 +++++++++++++++++- .../apache/thrift/TByteArrayOutputStream.java | 1 - .../thrift/async/AsyncMethodCallback.java | 38 +++ .../org/apache/thrift/async/TAsyncClient.java | 84 ++++++ .../thrift/async/TAsyncClientFactory.java | 25 ++ .../thrift/async/TAsyncClientManager.java | 109 ++++++++ .../apache/thrift/async/TAsyncMethodCall.java | 201 +++++++++++++++ .../thrift/transport/TFramedTransport.java | 38 ++- .../thrift/transport/TMemoryBuffer.java | 12 +- .../thrift/transport/TNonblockingSocket.java | 20 +- .../thrift/async/TestTAsyncClientManager.java | 184 +++++++++++++ .../thrift/protocol/ProtocolTestBase.java | 4 + test/DebugProtoTest.thrift | 8 +- 13 files changed, 927 insertions(+), 40 deletions(-) create mode 100644 lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java create mode 100644 lib/java/src/org/apache/thrift/async/TAsyncClient.java create mode 100644 lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java create mode 100644 lib/java/src/org/apache/thrift/async/TAsyncClientManager.java create mode 100644 lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java create mode 100644 lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index f1eb5661..2db3ca3e 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -115,8 +115,10 @@ class t_java_generator : public t_oop_generator { std::string isset_field_id(t_field* field); void generate_service_interface (t_service* tservice); + void generate_service_async_interface(t_service* tservice); void generate_service_helpers (t_service* tservice); void generate_service_client (t_service* tservice); + void generate_service_async_client(t_service* tservice); void generate_service_server (t_service* tservice); void generate_process_function (t_service* tservice, t_function* tfunction); @@ -215,13 +217,16 @@ class t_java_generator : public t_oop_generator { std::string base_type_name(t_base_type* tbase, bool in_container=false); std::string declare_field(t_field* tfield, bool init=false); std::string function_signature(t_function* tfunction, std::string prefix=""); - std::string argument_list(t_struct* tstruct); + std::string function_signature_async(t_function* tfunction, bool use_base_method = false, std::string prefix=""); + std::string argument_list(t_struct* tstruct, bool include_types = true); + std::string async_function_call_arglist(t_function* tfunc, bool use_base_method = true, bool include_types = true); + std::string async_argument_list(t_function* tfunct, t_struct* tstruct, t_type* ttype, bool include_types=false); std::string type_to_enum(t_type* ttype); std::string get_enum_class_name(t_type* type); void generate_struct_desc(ofstream& out, t_struct* tstruct); void generate_field_descs(ofstream& out, t_struct* tstruct); void generate_field_name_constants(ofstream& out, t_struct* tstruct); - + bool type_can_be_null(t_type* ttype) { ttype = get_true_type(ttype); @@ -330,7 +335,9 @@ string t_java_generator::java_thrift_imports() { return string() + "import org.apache.thrift.*;\n" + + "import org.apache.thrift.async.*;\n" + "import org.apache.thrift.meta_data.*;\n" + + "import org.apache.thrift.transport.*;\n" + "import org.apache.thrift.protocol.*;\n\n"; } @@ -2133,7 +2140,9 @@ void t_java_generator::generate_service(t_service* tservice) { // Generate the three main parts of the service generate_service_interface(tservice); + generate_service_async_interface(tservice); generate_service_client(tservice); + generate_service_async_client(tservice); generate_service_server(tservice); generate_service_helpers(tservice); @@ -2164,13 +2173,29 @@ void t_java_generator::generate_service_interface(t_service* tservice) { vector::iterator f_iter; for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { generate_java_doc(f_service_, *f_iter); - indent(f_service_) << "public " << function_signature(*f_iter) << ";" << - endl << endl; + indent(f_service_) << "public " << function_signature(*f_iter) << ";" << endl << endl; } indent_down(); - f_service_ << - indent() << "}" << endl << - endl; + f_service_ << indent() << "}" << endl << endl; +} + +void t_java_generator::generate_service_async_interface(t_service* tservice) { + string extends = ""; + string extends_iface = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_iface = " extends " + extends + " .AsyncIface"; + } + + f_service_ << indent() << "public interface AsyncIface" << extends_iface << " {" << endl << endl; + indent_up(); + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + indent(f_service_) << "public " << function_signature_async(*f_iter, true) << " throws TException;" << endl << endl; + } + indent_down(); + f_service_ << indent() << "}" << endl << endl; } /** @@ -2404,6 +2429,138 @@ void t_java_generator::generate_service_client(t_service* tservice) { "}" << endl; } +void t_java_generator::generate_service_async_client(t_service* tservice) { + string extends = "TAsyncClient"; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()) + ".AsyncClient"; + // extends_client = " extends " + extends + ".AsyncClient"; + } + + indent(f_service_) << + "public static class AsyncClient extends " << extends << " implements AsyncIface {" << endl; + indent_up(); + + // Factory method + indent(f_service_) << "public static class Factory implements TAsyncClientFactory {" << endl; + indent(f_service_) << " private TAsyncClientManager clientManager;" << endl; + indent(f_service_) << " private TProtocolFactory protocolFactory;" << endl; + indent(f_service_) << " public Factory(TAsyncClientManager clientManager, TProtocolFactory protocolFactory) {" << endl; + indent(f_service_) << " this.clientManager = clientManager;" << endl; + indent(f_service_) << " this.protocolFactory = protocolFactory;" << endl; + indent(f_service_) << " }" << endl; + indent(f_service_) << " public AsyncClient getAsyncClient(TNonblockingTransport transport) {" << endl; + indent(f_service_) << " return new AsyncClient(protocolFactory, clientManager, transport);" << endl; + indent(f_service_) << " }" << endl; + indent(f_service_) << "}" << endl << endl; + + indent(f_service_) << "public AsyncClient(TProtocolFactory protocolFactory, TAsyncClientManager clientManager, TNonblockingTransport transport) {" << endl; + indent(f_service_) << " super(protocolFactory, clientManager, transport);" << endl; + indent(f_service_) << "}" << endl << endl; + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + t_type* ret_type = (*f_iter)->get_returntype(); + t_struct* arg_struct = (*f_iter)->get_arglist(); + string funclassname = funname + "_call"; + const vector& fields = arg_struct->get_members(); + const std::vector& xceptions = (*f_iter)->get_xceptions()->get_members(); + vector::const_iterator fld_iter; + string args_name = (*f_iter)->get_name() + "_args"; + string result_name = (*f_iter)->get_name() + "_result"; + + // Main method body + indent(f_service_) << "public " << function_signature_async(*f_iter, false) << " throws TException {" << endl; + indent(f_service_) << " checkReady();" << endl; + indent(f_service_) << " " << funclassname << " method_call = new " + funclassname + "(" << async_argument_list(*f_iter, arg_struct, ret_type) << ", this, protocolFactory, transport);" << endl; + indent(f_service_) << " manager.call(method_call);" << endl; + indent(f_service_) << "}" << endl; + + f_service_ << endl; + + // TAsyncMethod object for this function call + indent(f_service_) << "public static class " + funclassname + " extends TAsyncMethodCall {" << endl; + indent_up(); + + // Member variables + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + indent(f_service_) << "private " + type_name((*fld_iter)->get_type()) + " " + (*fld_iter)->get_name() + ";" << endl; + } + + // NOTE since we use a new Client instance to deserialize, let's keep seqid to 0 for now + // indent(f_service_) << "private int seqid;" << endl << endl; + + // Constructor + indent(f_service_) << "public " + funclassname + "(" + async_argument_list(*f_iter, arg_struct, ret_type, true) << ", TAsyncClient client, TProtocolFactory protocolFactory, TNonblockingTransport transport) throws TException {" << endl; + indent(f_service_) << " super(client, protocolFactory, transport, resultHandler, " << ((*f_iter)->is_oneway() ? "true" : "false") << ");" << endl; + + // Assign member variables + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + indent(f_service_) << " this." + (*fld_iter)->get_name() + " = " + (*fld_iter)->get_name() + ";" << endl; + } + + indent(f_service_) << "}" << endl << endl; + + indent(f_service_) << "public void write_args(TProtocol prot) throws TException {" << endl; + indent_up(); + + // Serialize request + // NOTE we are leaving seqid as 0, for now (see above) + f_service_ << + indent() << "prot.writeMessageBegin(new TMessage(\"" << funname << "\", TMessageType.CALL, 0));" << endl << + indent() << args_name << " args = new " << args_name << "();" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << indent() << "args.set" << get_cap_name((*fld_iter)->get_name()) << "(" << (*fld_iter)->get_name() << ");" << endl; + } + + f_service_ << + indent() << "args.write(prot);" << endl << + indent() << "prot.writeMessageEnd();" << endl; + + indent_down(); + indent(f_service_) << "}" << endl << endl; + + // Return method + indent(f_service_) << "public " + type_name(ret_type) + " getResult() throws "; + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << type_name((*x_iter)->get_type(), false, false) + ", "; + } + f_service_ << "TException {" << endl; + + indent_up(); + f_service_ << + indent() << "if (getState() != State.RESPONSE_READ) {" << endl << + indent() << " throw new IllegalStateException(\"Method call not finished!\");" << endl << + indent() << "}" << endl << + indent() << "TMemoryInputTransport memoryTransport = new TMemoryInputTransport(getFrameBuffer().array());" << endl << + indent() << "TProtocol prot = client.getProtocolFactory().getProtocol(memoryTransport);" << endl; + if (!(*f_iter)->is_oneway()) { + indent(f_service_); + if (!ret_type->is_void()) { + f_service_ << "return "; + } + f_service_ << "(new Client(prot)).recv_" + funname + "();" << endl; + } + + // Close function + indent_down(); + indent(f_service_) << "}" << endl; + + // Close class + indent_down(); + indent(f_service_) << "}" << endl << endl; + } + + // Close AsyncClient + scope_down(f_service_); + f_service_ << endl; +} + /** * Generates a service server definition. * @@ -3246,10 +3403,49 @@ string t_java_generator::function_signature(t_function* tfunction, return result; } +/** + * Renders a function signature of the form 'void name(args, resultHandler)' + * + * @params tfunction Function definition + * @return String of rendered function definition + */ +string t_java_generator::function_signature_async(t_function* tfunction, bool use_base_method, string prefix) { + std::string arglist = async_function_call_arglist(tfunction, use_base_method, true); + + std::string ret_type = ""; + if (use_base_method) { + ret_type += "AsyncClient."; + } + ret_type += tfunction->get_name() + "_call"; + + std::string result = prefix + "void " + tfunction->get_name() + "(" + arglist + ")"; + return result; +} + +string t_java_generator::async_function_call_arglist(t_function* tfunc, bool use_base_method, bool include_types) { + std::string arglist = ""; + if (tfunc->get_arglist()->get_members().size() > 0) { + arglist = argument_list(tfunc->get_arglist(), include_types) + ", "; + } + + std::string ret_type = ""; + if (use_base_method) { + ret_type += "AsyncClient."; + } + ret_type += tfunc->get_name() + "_call"; + + if (include_types) { + arglist += "AsyncMethodCallback<" + ret_type + "> "; + } + arglist += "resultHandler"; + + return arglist; +} + /** * Renders a comma separated field list, with type names */ -string t_java_generator::argument_list(t_struct* tstruct) { +string t_java_generator::argument_list(t_struct* tstruct, bool include_types) { string result = ""; const vector& fields = tstruct->get_members(); @@ -3261,8 +3457,37 @@ string t_java_generator::argument_list(t_struct* tstruct) { } else { result += ", "; } - result += type_name((*f_iter)->get_type()) + " " + (*f_iter)->get_name(); + if (include_types) { + result += type_name((*f_iter)->get_type()) + " "; + } + result += (*f_iter)->get_name(); + } + return result; +} + +string t_java_generator::async_argument_list(t_function* tfunct, t_struct* tstruct, t_type* ttype, bool include_types) { + string result = ""; + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + if (include_types) { + result += type_name((*f_iter)->get_type()) + " "; + } + result += (*f_iter)->get_name(); + } + if (!first) { + result += ", "; + } + if (include_types) { + result += "AsyncMethodCallback<" + tfunct->get_name() + "_call" + "> "; } + result += "resultHandler"; return result; } diff --git a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java index e35fbcb7..9ed83c0a 100644 --- a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java +++ b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java @@ -35,7 +35,6 @@ public class TByteArrayOutputStream extends ByteArrayOutputStream { super(); } - public byte[] get() { return buf; } diff --git a/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java b/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java new file mode 100644 index 00000000..b8cd9ed6 --- /dev/null +++ b/lib/java/src/org/apache/thrift/async/AsyncMethodCallback.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +public interface AsyncMethodCallback { + /** + * This method will be called when the remote side has completed invoking + * your method call and the result is fully read. For oneway method calls, + * this method will be called as soon as we have completed writing out the + * request. + * @param response + */ + public void onComplete(T response); + + /** + * This method will be called when there is an unexpected clientside + * exception. This does not include application-defined exceptions that + * appear in the IDL, but rather things like IOExceptions. + * @param throwable + */ + public void onError(Throwable throwable); +} diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClient.java b/lib/java/src/org/apache/thrift/async/TAsyncClient.java new file mode 100644 index 00000000..2e8dea3a --- /dev/null +++ b/lib/java/src/org/apache/thrift/async/TAsyncClient.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TNonblockingTransport; + +public abstract class TAsyncClient { + protected final TProtocolFactory protocolFactory; + protected final TNonblockingTransport transport; + protected final TAsyncClientManager manager; + private TAsyncMethodCall currentMethod; + private Throwable error; + + public TAsyncClient(TProtocolFactory protocolFactory, TAsyncClientManager manager, TNonblockingTransport transport) { + this.protocolFactory = protocolFactory; + this.manager = manager; + this.transport = transport; + } + + public TProtocolFactory getProtocolFactory() { + return protocolFactory; + } + + /** + * Is the client in an error state? + * @return + */ + public boolean hasError() { + return error != null; + } + + /** + * Get the client's error - returns null if no error + * @return + */ + public Throwable getError() { + return error; + } + + protected void checkReady() { + // Ensure we are not currently executing a method + if (currentMethod != null) { + throw new IllegalStateException("Client is currently executing another method: " + currentMethod.getClass().getName()); + } + + // Ensure we're not in an error state + if (error != null) { + throw new IllegalStateException("Client has an error!", error); + } + } + + /** + * Called by delegate method when finished + */ + protected void onComplete() { + currentMethod = null; + } + + /** + * Called by delegate method on error + */ + protected void onError(Throwable throwable) { + transport.close(); + currentMethod = null; + error = throwable; + } +} diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java b/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java new file mode 100644 index 00000000..28feb73d --- /dev/null +++ b/lib/java/src/org/apache/thrift/async/TAsyncClientFactory.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +import org.apache.thrift.transport.TNonblockingTransport; + +public interface TAsyncClientFactory { + public T getAsyncClient(TNonblockingTransport transport); +} diff --git a/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java b/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java new file mode 100644 index 00000000..8636bc8f --- /dev/null +++ b/lib/java/src/org/apache/thrift/async/TAsyncClientManager.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +import java.io.IOException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.Iterator; +import java.util.concurrent.ConcurrentLinkedQueue; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Contains selector thread which transitions method call objects + */ +public class TAsyncClientManager { + private static final Logger LOGGER = LoggerFactory.getLogger(TAsyncClientManager.class.getName()); + + private final SelectThread selectThread; + private final ConcurrentLinkedQueue pendingCalls = new ConcurrentLinkedQueue(); + + public TAsyncClientManager() throws IOException { + this.selectThread = new SelectThread(); + selectThread.start(); + } + + public void call(TAsyncMethodCall method) { + pendingCalls.add(method); + selectThread.getSelector().wakeup(); + } + + public void stop() { + selectThread.finish(); + } + + private class SelectThread extends Thread { + private final Selector selector; + private volatile boolean running; + + public SelectThread() throws IOException { + this.selector = SelectorProvider.provider().openSelector(); + this.running = true; + // We don't want to hold up the JVM when shutting down + setDaemon(true); + } + + public Selector getSelector() { + return selector; + } + + public void finish() { + running = false; + selector.wakeup(); + } + + public void run() { + while (running) { + try { + selector.select(); + } catch (IOException e) { + LOGGER.error("Caught IOException in TAsyncClientManager!", e); + } + + // Handle any ready channels calls + Iterator keys = selector.selectedKeys().iterator(); + while (keys.hasNext()) { + SelectionKey key = keys.next(); + keys.remove(); + if (!key.isValid()) { + // this should only have happened if the method call experienced an + // error and the key was cancelled. just skip it. + continue; + } + TAsyncMethodCall method = (TAsyncMethodCall)key.attachment(); + method.transition(key); + } + + // Start any new calls + TAsyncMethodCall methodCall; + while ((methodCall = pendingCalls.poll()) != null) { + try { + SelectionKey key = methodCall.registerWithSelector(selector); + methodCall.transition(key); + } catch (IOException e) { + LOGGER.warn("Caught IOException in TAsyncClientManager!", e); + } + } + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java b/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java new file mode 100644 index 00000000..e1300878 --- /dev/null +++ b/lib/java/src/org/apache/thrift/async/TAsyncMethodCall.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; + +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TMemoryBuffer; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransportException; + +/** + * Encapsulates an async method call + * Need to generate: + * - private void write_args(TProtocol protocol) + * - public T getResult() throws , , ... + * @param + */ +public abstract class TAsyncMethodCall { + public static enum State { + WRITING_REQUEST_SIZE, + WRITING_REQUEST_BODY, + READING_RESPONSE_SIZE, + READING_RESPONSE_BODY, + RESPONSE_READ, + ERROR; + } + + private static final int INITIAL_MEMORY_BUFFER_SIZE = 128; + + protected final TNonblockingTransport transport; + private final TProtocolFactory protocolFactory; + protected final TAsyncClient client; + private final AsyncMethodCallback callback; + private final boolean isOneway; + + private ByteBuffer sizeBuffer; + private final byte[] sizeBufferArray = new byte[4]; + + private ByteBuffer frameBuffer; + private State state; + + protected TAsyncMethodCall(TAsyncClient client, TProtocolFactory protocolFactory, TNonblockingTransport transport, AsyncMethodCallback callback, boolean isOneway) throws TException { + this.transport = transport; + this.callback = callback; + this.protocolFactory = protocolFactory; + this.client = client; + this.isOneway = isOneway; + + this.state = State.WRITING_REQUEST_SIZE; + prepareMethodCall(); + } + + protected State getState() { + return state; + } + + protected abstract void write_args(TProtocol protocol) throws TException; + + private void prepareMethodCall() throws TException { + TMemoryBuffer memoryBuffer = new TMemoryBuffer(INITIAL_MEMORY_BUFFER_SIZE); + TProtocol protocol = protocolFactory.getProtocol(memoryBuffer); + write_args(protocol); + + int length = memoryBuffer.length(); + frameBuffer = ByteBuffer.wrap(memoryBuffer.getArray(), 0, length); + + TFramedTransport.encodeFrameSize(length, sizeBufferArray); + sizeBuffer = ByteBuffer.wrap(sizeBufferArray); + } + + SelectionKey registerWithSelector(Selector sel) throws IOException { + SelectionKey key = transport.registerSelector(sel, SelectionKey.OP_WRITE); + key.attach(this); + return key; + } + + protected ByteBuffer getFrameBuffer() { + return frameBuffer; + } + + /** + * Transition to next state, doing whatever work is required. Since this + * method is only called by the selector thread, we can make changes to our + * select interests without worrying about concurrency. + * @param key + */ + protected void transition(SelectionKey key) { + // Ensure key is valid + if (!key.isValid()) { + key.cancel(); + Exception e = new TTransportException("Selection key not valid!"); + client.onError(e); + callback.onError(e); + return; + } + + // Transition function + try { + switch (state) { + case WRITING_REQUEST_SIZE: + doWritingRequestSize(); + break; + case WRITING_REQUEST_BODY: + doWritingRequestBody(key); + break; + case READING_RESPONSE_SIZE: + doReadingResponseSize(); + break; + case READING_RESPONSE_BODY: + doReadingResponseBody(key); + break; + case RESPONSE_READ: + case ERROR: + throw new IllegalStateException("Method call in state " + state + + " but selector called transition method. Seems like a bug..."); + } + } catch (Throwable e) { + state = State.ERROR; + key.cancel(); + key.attach(null); + client.onError(e); + callback.onError(e); + } + } + + private void doReadingResponseBody(SelectionKey key) throws IOException { + if (transport.read(frameBuffer) < 0) { + throw new IOException("Read call frame failed"); + } + if (frameBuffer.remaining() == 0) { + cleanUpAndFireCallback(key); + } + } + + private void cleanUpAndFireCallback(SelectionKey key) { + state = State.RESPONSE_READ; + key.interestOps(0); + // this ensures that the TAsyncMethod instance doesn't hang around + key.attach(null); + key.cancel(); + client.onComplete(); + callback.onComplete((T)this); + } + + private void doReadingResponseSize() throws IOException { + if (transport.read(sizeBuffer) < 0) { + throw new IOException("Read call frame size failed"); + } + if (sizeBuffer.remaining() == 0) { + state = State.READING_RESPONSE_BODY; + frameBuffer = ByteBuffer.allocate(TFramedTransport.decodeFrameSize(sizeBufferArray)); + } + } + + private void doWritingRequestBody(SelectionKey key) throws IOException { + if (transport.write(frameBuffer) < 0) { + throw new IOException("Write call frame failed"); + } + if (frameBuffer.remaining() == 0) { + if (isOneway) { + cleanUpAndFireCallback(key); + } else { + state = State.READING_RESPONSE_SIZE; + sizeBuffer.rewind(); // Prepare to read incoming frame size + key.interestOps(SelectionKey.OP_READ); + } + } + } + + private void doWritingRequestSize() throws IOException { + if (transport.write(sizeBuffer) < 0) { + throw new IOException("Write call frame size failed"); + } + if (sizeBuffer.remaining() == 0) { + state = State.WRITING_REQUEST_BODY; + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java index fab9c9b6..32483ee1 100644 --- a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java +++ b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java @@ -23,7 +23,7 @@ import org.apache.thrift.TByteArrayOutputStream; /** * TFramedTransport is a buffered TTransport that ensures a fully read message - * every time by preceeding messages with a 4-byte frame size. + * every time by preceeding messages with a 4-byte frame size. */ public class TFramedTransport extends TTransport { @@ -58,6 +58,7 @@ public class TFramedTransport extends TTransport { maxLength_ = maxLength; } + @Override public TTransport getTransport(TTransport base) { return new TFramedTransport(base, maxLength_); } @@ -122,14 +123,11 @@ public class TFramedTransport extends TTransport { readBuffer_.consumeBuffer(len); } - private final byte[] i32rd = new byte[4]; + private final byte[] i32buf = new byte[4]; + private void readFrame() throws TTransportException { - transport_.readAll(i32rd, 0, 4); - int size = - ((i32rd[0] & 0xff) << 24) | - ((i32rd[1] & 0xff) << 16) | - ((i32rd[2] & 0xff) << 8) | - ((i32rd[3] & 0xff)); + transport_.readAll(i32buf, 0, 4); + int size = decodeFrameSize(i32buf); if (size < 0) { throw new TTransportException("Read a negative frame size (" + size + ")!"); @@ -148,18 +146,30 @@ public class TFramedTransport extends TTransport { writeBuffer_.write(buf, off, len); } + @Override public void flush() throws TTransportException { byte[] buf = writeBuffer_.get(); int len = writeBuffer_.len(); writeBuffer_.reset(); - byte[] i32out = new byte[4]; - i32out[0] = (byte)(0xff & (len >> 24)); - i32out[1] = (byte)(0xff & (len >> 16)); - i32out[2] = (byte)(0xff & (len >> 8)); - i32out[3] = (byte)(0xff & (len)); - transport_.write(i32out, 0, 4); + encodeFrameSize(len, i32buf); + transport_.write(i32buf, 0, 4); transport_.write(buf, 0, len); transport_.flush(); } + + public static final void encodeFrameSize(final int frameSize, final byte[] buf) { + buf[0] = (byte)(0xff & (frameSize >> 24)); + buf[1] = (byte)(0xff & (frameSize >> 16)); + buf[2] = (byte)(0xff & (frameSize >> 8)); + buf[3] = (byte)(0xff & (frameSize)); + } + + public static final int decodeFrameSize(final byte[] buf) { + return + ((buf[0] & 0xff) << 24) | + ((buf[1] & 0xff) << 16) | + ((buf[2] & 0xff) << 8) | + ((buf[3] & 0xff)); + } } diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java index 886fcbf6..9b906db3 100644 --- a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java +++ b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java @@ -24,12 +24,12 @@ import java.io.UnsupportedEncodingException; /** * Memory buffer-based implementation of the TTransport interface. - * */ public class TMemoryBuffer extends TTransport { - /** - * + * Create a TMemoryBuffer with an initial buffer size of size. The + * internal buffer will grow as necessary to accomodate the size of the data + * being written to it. */ public TMemoryBuffer(int size) { arr_ = new TByteArrayOutputStream(size); @@ -90,9 +90,13 @@ public class TMemoryBuffer extends TTransport { // Position to read next byte from private int pos_; - + public int length() { return arr_.size(); } + + public byte[] getArray() { + return arr_.get(); + } } diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java index bc2d5396..313ef85a 100644 --- a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java @@ -21,6 +21,7 @@ package org.apache.thrift.transport; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.Socket; import java.net.SocketException; import java.nio.ByteBuffer; @@ -41,19 +42,20 @@ public class TNonblockingSocket extends TNonblockingTransport { private Socket socket_ = null; /** - * Remote host - */ - private String host_ = null; - - /** - * Remote port + * Socket timeout */ - private int port_ = 0; + private int timeout_ = 0; /** - * Socket timeout + * Create a new nonblocking socket transport connected to host:port. + * @param host + * @param port + * @throws TTransportException + * @throws IOException */ - private int timeout_ = 0; + public TNonblockingSocket(String host, int port) throws TTransportException, IOException { + this(SocketChannel.open(new InetSocketAddress(host, port))); + } /** * Constructor that takes an already created socket. diff --git a/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java b/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java new file mode 100644 index 00000000..5c8ff76a --- /dev/null +++ b/lib/java/test/org/apache/thrift/async/TestTAsyncClientManager.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.thrift.async; + +import java.util.concurrent.atomic.AtomicBoolean; + +import junit.framework.TestCase; + +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.server.TNonblockingServer; +import org.apache.thrift.transport.TNonblockingServerSocket; +import org.apache.thrift.transport.TNonblockingSocket; + +import thrift.test.CompactProtoTestStruct; +import thrift.test.Srv; +import thrift.test.Srv.Iface; +import thrift.test.Srv.AsyncClient.Janky_call; +import thrift.test.Srv.AsyncClient.onewayMethod_call; +import thrift.test.Srv.AsyncClient.voidMethod_call; + +public class TestTAsyncClientManager extends TestCase { + private static abstract class FailureLessCallback implements AsyncMethodCallback { + @Override + public void onError(Throwable throwable) { + throwable.printStackTrace(); + fail("unexpected error " + throwable); + } + } + + public class SrvHandler implements Iface { + @Override + public int Janky(int arg) throws TException { + return 0; + } + + @Override + public void methodWithDefaultArgs(int something) throws TException { + } + + @Override + public int primitiveMethod() throws TException { + return 0; + } + + @Override + public CompactProtoTestStruct structMethod() throws TException { + return null; + } + + @Override + public void voidMethod() throws TException { + } + + @Override + public void onewayMethod() throws TException { + } + } + + public void testIt() throws Exception { + // put up a server + final TNonblockingServer s = new TNonblockingServer(new Srv.Processor(new SrvHandler()), new TNonblockingServerSocket(12345)); + new Thread(new Runnable() { + @Override + public void run() { + s.serve(); + } + }).start(); + Thread.sleep(1000); + + // set up async client manager + TAsyncClientManager acm = new TAsyncClientManager(); + + // connect an async client + TNonblockingSocket clientSock = new TNonblockingSocket("localhost", 12345); + Srv.AsyncClient client = new Srv.AsyncClient(new TBinaryProtocol.Factory(), acm, clientSock); + + final Object o = new Object(); + + // make a standard method call + final AtomicBoolean jankyReturned = new AtomicBoolean(false); + client.Janky(1, new FailureLessCallback() { + @Override + public void onComplete(Janky_call response) { + try { + assertEquals(0, response.getResult()); + jankyReturned.set(true); + } catch (TException e) { + fail("unexpected exception: " + e); + } + synchronized(o) { + o.notifyAll(); + } + } + }); + + synchronized(o) { + o.wait(100000); + } + assertTrue(jankyReturned.get()); + + // make a void method call + final AtomicBoolean voidMethodReturned = new AtomicBoolean(false); + client.voidMethod(new FailureLessCallback() { + @Override + public void onComplete(voidMethod_call response) { + try { + response.getResult(); + voidMethodReturned.set(true); + } catch (TException e) { + fail("unexpected exception " + e); + } + synchronized (o) { + o.notifyAll(); + } + } + }); + + synchronized(o) { + o.wait(1000); + } + assertTrue(voidMethodReturned.get()); + + // make a oneway method call + final AtomicBoolean onewayReturned = new AtomicBoolean(false); + client.onewayMethod(new FailureLessCallback() { + @Override + public void onComplete(onewayMethod_call response) { + try { + response.getResult(); + onewayReturned.set(true); + } catch (TException e) { + fail("unexpected exception " + e); + } + synchronized(o) { + o.notifyAll(); + } + } + }); + synchronized(o) { + o.wait(1000); + } + + assertTrue(onewayReturned.get()); + + // make another standard method call + final AtomicBoolean voidAfterOnewayReturned = new AtomicBoolean(false); + client.voidMethod(new FailureLessCallback() { + @Override + public void onComplete(voidMethod_call response) { + try { + response.getResult(); + voidAfterOnewayReturned.set(true); + } catch (TException e) { + fail("unexpected exception " + e); + } + synchronized(o) { + o.notifyAll(); + } + } + }); + synchronized(o) { + o.wait(1000); + } + + assertTrue(voidAfterOnewayReturned.get()); + } +} diff --git a/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java b/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java index 365cef7f..da0de057 100644 --- a/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java +++ b/lib/java/test/org/apache/thrift/protocol/ProtocolTestBase.java @@ -305,6 +305,10 @@ public abstract class ProtocolTestBase extends TestCase { public void methodWithDefaultArgs(int something) throws TException { } + + @Override + public void onewayMethod() throws TException { + } }; Srv.Processor testProcessor = new Srv.Processor(handler); diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift index dbce93ed..5e361d21 100644 --- a/test/DebugProtoTest.thrift +++ b/test/DebugProtoTest.thrift @@ -228,14 +228,16 @@ service ServiceForExceptionWithAMap { service Srv { i32 Janky(1: i32 arg); - + // return type only methods - + void voidMethod(); i32 primitiveMethod(); CompactProtoTestStruct structMethod(); - + void methodWithDefaultArgs(1: i32 something = MYCONST); + + oneway void onewayMethod(); } service Inherited extends Srv { -- 2.17.1