import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.spi.SelectorProvider;
-import java.util.HashSet;
+import java.util.Comparator;
import java.util.Iterator;
-import java.util.Set;
+import java.util.TreeSet;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeoutException;
+import org.apache.commons.lang.ObjectUtils;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
*/
public class TAsyncClientManager {
private static final Logger LOGGER = LoggerFactory.getLogger(TAsyncClientManager.class.getName());
-
+
private final SelectThread selectThread;
private final ConcurrentLinkedQueue<TAsyncMethodCall> pendingCalls = new ConcurrentLinkedQueue<TAsyncMethodCall>();
}
public void call(TAsyncMethodCall method) throws TException {
+ if (!isRunning()) {
+ throw new TException("SelectThread is not running");
+ }
method.prepareMethodCall();
pendingCalls.add(method);
selectThread.getSelector().wakeup();
public void stop() {
selectThread.finish();
}
-
+
+ public boolean isRunning() {
+ return selectThread.isAlive();
+ }
+
private class SelectThread extends Thread {
- // Selector waits at most SELECT_TIME milliseconds before waking
- private static final long SELECT_TIME = 5;
-
private final Selector selector;
private volatile boolean running;
- private final Set<TAsyncMethodCall> timeoutWatchSet = new HashSet<TAsyncMethodCall>();
+ private final TreeSet<TAsyncMethodCall> timeoutWatchSet = new TreeSet<TAsyncMethodCall>(new TAsyncMethodCallTimeoutComparator());
public SelectThread() throws IOException {
this.selector = SelectorProvider.provider().openSelector();
this.running = true;
+ this.setName("TAsyncClientManager#SelectorThread " + this.getId());
+
// We don't want to hold up the JVM when shutting down
setDaemon(true);
}
while (running) {
try {
try {
- selector.select(SELECT_TIME);
+ if (timeoutWatchSet.size() == 0) {
+ // No timeouts, so select indefinitely
+ selector.select();
+ } else {
+ // We have a timeout pending, so calculate the time until then and select appropriately
+ long nextTimeout = timeoutWatchSet.first().getTimeoutTimestamp();
+ long selectTime = nextTimeout - System.currentTimeMillis();
+ if (selectTime > 0) {
+ // Next timeout is in the future, select and wake up then
+ selector.select(selectTime);
+ } else {
+ // Next timeout is now or in past, select immediately so we can time out
+ selector.selectNow();
+ }
+ }
} catch (IOException e) {
LOGGER.error("Caught IOException in TAsyncClientManager!", e);
}
-
transitionMethods();
- timeoutIdleMethods();
+ timeoutMethods();
startPendingMethods();
- } catch (Throwable throwable) {
- LOGGER.error("Ignoring uncaught exception in SelectThread", throwable);
+ } catch (Exception exception) {
+ LOGGER.error("Ignoring uncaught exception in SelectThread", exception);
}
}
}
}
// Timeout any existing method calls
- private void timeoutIdleMethods() {
+ private void timeoutMethods() {
Iterator<TAsyncMethodCall> iterator = timeoutWatchSet.iterator();
+ long currentTime = System.currentTimeMillis();
while (iterator.hasNext()) {
TAsyncMethodCall methodCall = iterator.next();
- long clientTimeout = methodCall.getClient().getTimeout();
- long timeElapsed = System.currentTimeMillis() - methodCall.getLastTransitionTime();
-
- if (timeElapsed > clientTimeout) {
+ if (currentTime >= methodCall.getTimeoutTimestamp()) {
iterator.remove();
- methodCall.onError(new TimeoutException("Operation " +
- methodCall.getClass() + " timed out after " + timeElapsed +
- " milliseconds."));
+ methodCall.onError(new TimeoutException("Operation " + methodCall.getClass() + " timed out after " + (currentTime - methodCall.getStartTime()) + " ms."));
+ } else {
+ break;
}
}
}
// Catch registration errors. method will catch transition errors and cleanup.
try {
methodCall.start(selector);
-
+
// If timeout specified and first transition went smoothly, add to timeout watch set
TAsyncClient client = methodCall.getClient();
if (client.hasTimeout() && !client.hasError()) {
timeoutWatchSet.add(methodCall);
}
- } catch (Throwable e) {
- LOGGER.warn("Caught throwable in TAsyncClientManager!", e);
- methodCall.onError(e);
+ } catch (Exception exception) {
+ LOGGER.warn("Caught exception in TAsyncClientManager!", exception);
+ methodCall.onError(exception);
}
}
}
}
+
+ // Comparator used in TreeSet
+ private static class TAsyncMethodCallTimeoutComparator implements Comparator<TAsyncMethodCall> {
+ @Override
+ public int compare(TAsyncMethodCall left, TAsyncMethodCall right) {
+ if (left.getTimeoutTimestamp() == right.getTimeoutTimestamp()) {
+ return (int)(left.getSequenceId() - right.getSequenceId());
+ } else {
+ return (int)(left.getTimeoutTimestamp() - right.getTimeoutTimestamp());
+ }
+ }
+ }
+
}
*/
package org.apache.thrift.async;
+import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import thrift.test.Srv.AsyncClient.voidMethod_call;
public class TestTAsyncClientManager extends TestCase {
- private static void fail(Throwable throwable) {
- StringWriter sink = new StringWriter();
- throwable.printStackTrace(new PrintWriter(sink, true));
- fail("unexpected error " + sink.toString());
+
+ private THsHaServer server_;
+ private Thread serverThread_;
+ private TAsyncClientManager clientManager_;
+
+ public void setUp() throws Exception {
+ server_ = new THsHaServer(new Srv.Processor(new SrvHandler()), new TNonblockingServerSocket(ServerTestBase.PORT));
+ serverThread_ = new Thread(new Runnable() {
+ public void run() {
+ server_.serve();
+ }
+ });
+ serverThread_.start();
+ clientManager_ = new TAsyncClientManager();
+ Thread.sleep(500);
}
-
- private static abstract class FailureLessCallback<T extends TAsyncMethodCall> implements AsyncMethodCallback<T> {
- @Override
- public void onError(Throwable throwable) {
- fail(throwable);
+
+ public void tearDown() throws Exception {
+ server_.stop();
+ clientManager_.stop();
+ serverThread_.join();
+ }
+
+ public void testBasicCall() throws Exception {
+ Srv.AsyncClient client = getClient();
+ basicCall(client);
+ }
+
+ public void testBasicCallWithTimeout() throws Exception {
+ Srv.AsyncClient client = getClient();
+ client.setTimeout(5000);
+ basicCall(client);
+ }
+
+ public void testTimeoutCall() throws Exception {
+ final CountDownLatch latch = new CountDownLatch(1);
+ Srv.AsyncClient client = getClient();
+ client.setTimeout(100);
+ client.primitiveMethod(new AsyncMethodCallback<primitiveMethod_call>() {
+ @Override
+ public void onError(Exception exception) {
+ try {
+ if (!(exception instanceof TimeoutException)) {
+ StringWriter sink = new StringWriter();
+ exception.printStackTrace(new PrintWriter(sink, true));
+ fail("expected TimeoutException but got " + sink.toString());
+ }
+ } finally {
+ latch.countDown();
+ }
+ }
+
+ @Override
+ public void onComplete(primitiveMethod_call response) {
+ try {
+ fail("Should not have finished timed out call.");
+ } finally {
+ latch.countDown();
+ }
+ }
+ });
+ latch.await(2, TimeUnit.SECONDS);
+ assertTrue(client.hasError());
+ assertTrue(client.getError() instanceof TimeoutException);
+ }
+
+ public void testVoidCall() throws Exception {
+ final CountDownLatch latch = new CountDownLatch(1);
+ final AtomicBoolean returned = new AtomicBoolean(false);
+ Srv.AsyncClient client = getClient();
+ client.voidMethod(new FailureLessCallback<Srv.AsyncClient.voidMethod_call>() {
+ @Override
+ public void onComplete(voidMethod_call response) {
+ try {
+ response.getResult();
+ returned.set(true);
+ } catch (TException e) {
+ fail(e);
+ } finally {
+ latch.countDown();
+ }
+ }
+ });
+ latch.await(1, TimeUnit.SECONDS);
+ assertTrue(returned.get());
+ }
+
+ public void testOnewayCall() throws Exception {
+ final CountDownLatch latch = new CountDownLatch(1);
+ final AtomicBoolean returned = new AtomicBoolean(false);
+ Srv.AsyncClient client = getClient();
+ client.onewayMethod(new FailureLessCallback<onewayMethod_call>() {
+ @Override
+ public void onComplete(onewayMethod_call response) {
+ try {
+ response.getResult();
+ returned.set(true);
+ } catch (TException e) {
+ fail(e);
+ } finally {
+ latch.countDown();
+ }
+ }
+ });
+ latch.await(1, TimeUnit.SECONDS);
+ assertTrue(returned.get());
+ }
+
+ public void testParallelCalls() throws Exception {
+ // make multiple calls with deserialization in the selector thread (repro Eric's issue)
+ int numThreads = 50;
+ int numCallsPerThread = 100;
+ List<JankyRunnable> runnables = new ArrayList<JankyRunnable>();
+ List<Thread> threads = new ArrayList<Thread>();
+ for (int i = 0; i < numThreads; i++) {
+ JankyRunnable runnable = new JankyRunnable(numCallsPerThread);
+ Thread thread = new Thread(runnable);
+ thread.start();
+ threads.add(thread);
+ runnables.add(runnable);
}
+ for (Thread thread : threads) {
+ thread.join();
+ }
+ int numSuccesses = 0;
+ for (JankyRunnable runnable : runnables) {
+ numSuccesses += runnable.getNumSuccesses();
+ }
+ assertEquals(numThreads * numCallsPerThread, numSuccesses);
+ }
+
+ private Srv.AsyncClient getClient() throws IOException {
+ TNonblockingSocket clientSocket = new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT);
+ return new Srv.AsyncClient(new TBinaryProtocol.Factory(), clientManager_, clientSocket);
}
-
+
+ private void basicCall(Srv.AsyncClient client) throws Exception {
+ final CountDownLatch latch = new CountDownLatch(1);
+ final AtomicBoolean returned = new AtomicBoolean(false);
+ client.Janky(1, new FailureLessCallback<Srv.AsyncClient.Janky_call>() {
+ @Override
+ public void onComplete(Janky_call response) {
+ try {
+ assertEquals(3, response.getResult());
+ returned.set(true);
+ } catch (TException e) {
+ fail(e);
+ } finally {
+ latch.countDown();
+ }
+ }
+
+ @Override
+ public void onError(Exception exception) {
+ try {
+ StringWriter sink = new StringWriter();
+ exception.printStackTrace(new PrintWriter(sink, true));
+ fail("unexpected onError with exception " + sink.toString());
+ } finally {
+ latch.countDown();
+ }
+ }
+ });
+ latch.await(100, TimeUnit.SECONDS);
+ assertTrue(returned.get());
+ }
+
public class SrvHandler implements Iface {
+ // Use this method for a standard call testing
@Override
public int Janky(int arg) throws TException {
assertEquals(1, arg);
return 3;
}
- @Override
- public void methodWithDefaultArgs(int something) throws TException {
- }
-
- // Using this method for timeout testing
+ // Using this method for timeout testing - sleeps for 1 second before returning
@Override
public int primitiveMethod() throws TException {
try {
}
return 0;
}
+
+ @Override
+ public void methodWithDefaultArgs(int something) throws TException { }
@Override
public CompactProtoTestStruct structMethod() throws TException {
public void onewayMethod() throws TException {
}
}
-
- public class JankyRunnable implements Runnable {
- private TAsyncClientManager acm_;
+
+ private static abstract class FailureLessCallback<T extends TAsyncMethodCall> implements AsyncMethodCallback<T> {
+ @Override
+ public void onError(Exception exception) {
+ fail(exception);
+ }
+ }
+
+ private static void fail(Exception exception) {
+ StringWriter sink = new StringWriter();
+ exception.printStackTrace(new PrintWriter(sink, true));
+ fail("unexpected error " + sink.toString());
+ }
+
+ private class JankyRunnable implements Runnable {
private int numCalls_;
private int numSuccesses_ = 0;
private Srv.AsyncClient client_;
- private TNonblockingSocket clientSocket_;
- public JankyRunnable(TAsyncClientManager acm, int numCalls) throws Exception {
- this.acm_ = acm;
- this.numCalls_ = numCalls;
- this.clientSocket_ = new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT);
- this.client_ = new Srv.AsyncClient(new TBinaryProtocol.Factory(), acm_, clientSocket_);
- this.client_.setTimeout(20000);
+ public JankyRunnable(int numCalls) throws Exception {
+ numCalls_ = numCalls;
+ client_ = getClient();
+ client_.setTimeout(20000);
}
public int getNumSuccesses() {
try {
// connect an async client
final CountDownLatch latch = new CountDownLatch(1);
- final AtomicBoolean jankyReturned = new AtomicBoolean(false);
+ final AtomicBoolean returned = new AtomicBoolean(false);
client_.Janky(1, new AsyncMethodCallback<Srv.AsyncClient.Janky_call>() {
-
+
@Override
public void onComplete(Janky_call response) {
try {
assertEquals(3, response.getResult());
- jankyReturned.set(true);
+ returned.set(true);
latch.countDown();
} catch (TException e) {
latch.countDown();
}
@Override
- public void onError(Throwable throwable) {
+ public void onError(Exception exception) {
try {
StringWriter sink = new StringWriter();
- throwable.printStackTrace(new PrintWriter(sink, true));
+ exception.printStackTrace(new PrintWriter(sink, true));
fail("unexpected onError on iteration " + iteration + ": " + sink.toString());
} finally {
latch.countDown();
boolean calledBack = latch.await(30, TimeUnit.SECONDS);
assertTrue("wasn't called back in time on iteration " + iteration, calledBack);
- assertTrue("onComplete not called on iteration " + iteration, jankyReturned.get());
+ assertTrue("onComplete not called on iteration " + iteration, returned.get());
this.numSuccesses_++;
} catch (Exception e) {
fail(e);
}
}
}
-
- public void standardCallTest(Srv.AsyncClient client) throws Exception {
- final CountDownLatch latch = new CountDownLatch(1);
- final AtomicBoolean jankyReturned = new AtomicBoolean(false);
- client.Janky(1, new FailureLessCallback<Srv.AsyncClient.Janky_call>() {
- @Override
- public void onComplete(Janky_call response) {
- try {
- assertEquals(3, response.getResult());
- jankyReturned.set(true);
- } catch (TException e) {
- fail(e);
- } finally {
- latch.countDown();
- }
- }
- });
-
- latch.await(100, TimeUnit.SECONDS);
- assertTrue(jankyReturned.get());
- }
-
- public void testIt() throws Exception {
- // put up a server
- final THsHaServer s = new THsHaServer(new Srv.Processor(new SrvHandler()),
- new TNonblockingServerSocket(ServerTestBase.PORT));
- new Thread(new Runnable() {
- @Override
- public void run() {
- s.serve();
- }
- }).start();
- Thread.sleep(1000);
-
- // set up async client manager
- TAsyncClientManager acm = new TAsyncClientManager();
-
- // connect an async client
- TNonblockingSocket clientSock = new TNonblockingSocket(
- ServerTestBase.HOST, ServerTestBase.PORT);
- Srv.AsyncClient client = new Srv.AsyncClient(new TBinaryProtocol.Factory(), acm, clientSock);
-
- // make a standard method call
- standardCallTest(client);
-
- // make a standard method call that succeeds within timeout
- assertFalse(s.isStopped());
- client.setTimeout(5000);
- standardCallTest(client);
-
- // make a void method call
- assertFalse(s.isStopped());
- final CountDownLatch voidLatch = new CountDownLatch(1);
- final AtomicBoolean voidMethodReturned = new AtomicBoolean(false);
- client.voidMethod(new FailureLessCallback<Srv.AsyncClient.voidMethod_call>() {
- @Override
- public void onComplete(voidMethod_call response) {
- try {
- response.getResult();
- voidMethodReturned.set(true);
- } catch (TException e) {
- fail(e);
- } finally {
- voidLatch.countDown();
- }
- }
- });
- voidLatch.await(1, TimeUnit.SECONDS);
- assertTrue(voidMethodReturned.get());
-
- // make a oneway method call
- assertFalse(s.isStopped());
- final CountDownLatch onewayLatch = new CountDownLatch(1);
- final AtomicBoolean onewayReturned = new AtomicBoolean(false);
- client.onewayMethod(new FailureLessCallback<onewayMethod_call>() {
- @Override
- public void onComplete(onewayMethod_call response) {
- try {
- response.getResult();
- onewayReturned.set(true);
- } catch (TException e) {
- fail(e);
- } finally {
- onewayLatch.countDown();
- }
- }
- });
- onewayLatch.await(1, TimeUnit.SECONDS);
- assertTrue(onewayReturned.get());
-
- // make another standard method call
- assertFalse(s.isStopped());
- final CountDownLatch voidAfterOnewayLatch = new CountDownLatch(1);
- final AtomicBoolean voidAfterOnewayReturned = new AtomicBoolean(false);
- client.voidMethod(new FailureLessCallback<voidMethod_call>() {
- @Override
- public void onComplete(voidMethod_call response) {
- try {
- response.getResult();
- voidAfterOnewayReturned.set(true);
- } catch (TException e) {
- fail(e);
- } finally {
- voidAfterOnewayLatch.countDown();
- }
- }
- });
- voidAfterOnewayLatch.await(1, TimeUnit.SECONDS);
- assertTrue(voidAfterOnewayReturned.get());
-
- // make multiple calls with deserialization in the selector thread (repro Eric's issue)
- assertFalse(s.isStopped());
- int numThreads = 50;
- int numCallsPerThread = 100;
- List<JankyRunnable> runnables = new ArrayList<JankyRunnable>();
- List<Thread> threads = new ArrayList<Thread>();
- for (int i = 0; i < numThreads; i++) {
- JankyRunnable runnable = new JankyRunnable(acm, numCallsPerThread);
- Thread thread = new Thread(runnable);
- thread.start();
- threads.add(thread);
- runnables.add(runnable);
- }
- for (Thread thread : threads) {
- thread.join();
- }
- int numSuccesses = 0;
- for (JankyRunnable runnable : runnables) {
- numSuccesses += runnable.getNumSuccesses();
- }
- assertEquals(numThreads * numCallsPerThread, numSuccesses);
-
- // check that timeouts work
- assertFalse(s.isStopped());
- assertTrue(clientSock.isOpen());
- final CountDownLatch timeoutLatch = new CountDownLatch(1);
- client.setTimeout(100);
- client.primitiveMethod(new AsyncMethodCallback<primitiveMethod_call>() {
-
- @Override
- public void onError(Throwable throwable) {
- try {
- if (!(throwable instanceof TimeoutException)) {
- StringWriter sink = new StringWriter();
- throwable.printStackTrace(new PrintWriter(sink, true));
- fail("expected TimeoutException but got " + sink.toString());
- }
- } finally {
- timeoutLatch.countDown();
- }
- }
-
- @Override
- public void onComplete(primitiveMethod_call response) {
- try {
- fail("should not have finished timed out call.");
- } finally {
- timeoutLatch.countDown();
- }
- }
-
- });
- timeoutLatch.await(2, TimeUnit.SECONDS);
- assertTrue(client.hasError());
- assertTrue(client.getError() instanceof TimeoutException);
-
- // error closes socket and make sure isOpen reflects that
- assertFalse(clientSock.isOpen());
- }
-}
+}
\ No newline at end of file