From 15522724f623b98f5e97b442baa47fa28975881e Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Sun, 12 May 2013 20:45:13 +0200 Subject: [PATCH] THRIFT-1034 Java: Add a TNonblockingMultiFetchClient Patch: Xing Jin --- .../thrift/TNonblockingMultiFetchClient.java | 396 ++++++++++++++++++ .../thrift/TNonblockingMultiFetchStats.java | 80 ++++ 2 files changed, 476 insertions(+) create mode 100755 lib/java/src/org/apache/thrift/TNonblockingMultiFetchClient.java create mode 100755 lib/java/src/org/apache/thrift/TNonblockingMultiFetchStats.java diff --git a/lib/java/src/org/apache/thrift/TNonblockingMultiFetchClient.java b/lib/java/src/org/apache/thrift/TNonblockingMultiFetchClient.java new file mode 100755 index 00000000..efa846c2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TNonblockingMultiFetchClient.java @@ -0,0 +1,396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.thrift; + +import org.apache.log4j.Logger; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.FutureTask; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + + +/** + * This class uses a single thread to set up non-blocking sockets to a set + * of remote servers (hostname and port pairs), and sends a same request to + * all these servers. It then fetches responses from servers. + * + * Parameters: + * int maxRecvBufBytesPerServer - an upper limit for receive buffer size + * per server (in byte). If a response from a server exceeds this limit, the + * client will not allocate memory or read response data for it. + * + * int fetchTimeoutSeconds - time limit for fetching responses from all + * servers (in second). After the timeout, the fetch job is stopped and + * available responses are returned. + * + * ByteBuffer requestBuf - request message that is sent to all servers. + * + * Output: + * Responses are stored in an array of ByteBuffers. Index of elements in + * this array corresponds to index of servers in the server list. Content in + * a ByteBuffer may be in one of the following forms: + * 1. First 4 bytes form an integer indicating length of following data, + * then followed by the data. + * 2. First 4 bytes form an integer indicating length of following data, + * then followed by nothing - this happens when the response data size + * exceeds maxRecvBufBytesPerServer, and the client will not read any + * response data. + * 3. No data in the ByteBuffer - this happens when the server does not + * return any response within fetchTimeoutSeconds. + * + * In some special cases (no servers are given, fetchTimeoutSeconds less + * than or equal to 0, requestBuf is null), the return is null. + * + * Note: + * It assumes all remote servers are TNonblockingServers and use + * TFramedTransport. + * + */ +public class TNonblockingMultiFetchClient { + private static final Logger LOG = Logger.getLogger( + TNonblockingMultiFetchClient.class); + + // if the size of the response msg exceeds this limit (in byte), we will + // not read the msg + private int maxRecvBufBytesPerServer; + + // time limit for fetching data from all servers (in second) + private int fetchTimeoutSeconds; + + // store request that will be sent to servers + private ByteBuffer requestBuf; + private ByteBuffer requestBufDuplication; + + // a list of remote servers + private List servers; + + // store fetch results + private TNonblockingMultiFetchStats stats; + private ByteBuffer[] recvBuf; + + public TNonblockingMultiFetchClient(int maxRecvBufBytesPerServer, + int fetchTimeoutSeconds, ByteBuffer requestBuf, + List servers) { + this.maxRecvBufBytesPerServer = maxRecvBufBytesPerServer; + this.fetchTimeoutSeconds = fetchTimeoutSeconds; + this.requestBuf = requestBuf; + this.servers = servers; + + stats = new TNonblockingMultiFetchStats(); + recvBuf = null; + } + + public synchronized int getMaxRecvBufBytesPerServer() { + return maxRecvBufBytesPerServer; + } + + public synchronized int getFetchTimeoutSeconds() { + return fetchTimeoutSeconds; + } + + /** + * return a duplication of requestBuf, so that requestBuf will not + * be modified by others. + */ + public synchronized ByteBuffer getRequestBuf() { + if (requestBuf == null) { + return null; + } else { + if (requestBufDuplication == null) { + requestBufDuplication = requestBuf.duplicate(); + } + return requestBufDuplication; + } + } + + public synchronized List getServerList() { + if (servers == null) { + return null; + } + return Collections.unmodifiableList(servers); + } + + public synchronized TNonblockingMultiFetchStats getFetchStats() { + return stats; + } + + /** + * main entry function for fetching from servers + */ + public synchronized ByteBuffer[] fetch() { + // clear previous results + recvBuf = null; + stats.clear(); + + if (servers == null || servers.size() == 0 || + requestBuf == null || fetchTimeoutSeconds <= 0) { + return recvBuf; + } + + ExecutorService executor = Executors.newSingleThreadExecutor(); + MultiFetch multiFetch = new MultiFetch(); + FutureTask task = new FutureTask(multiFetch, null); + executor.execute(task); + try { + task.get(fetchTimeoutSeconds, TimeUnit.SECONDS); + } catch(InterruptedException ie) { + // attempt to cancel execution of the task. + task.cancel(true); + LOG.error("interrupted during fetch: "+ie.toString()); + } catch(ExecutionException ee) { + // attempt to cancel execution of the task. + task.cancel(true); + LOG.error("exception during fetch: "+ee.toString()); + } catch(TimeoutException te) { + // attempt to cancel execution of the task. + task.cancel(true); + LOG.error("timeout for fetch: "+te.toString()); + } + + executor.shutdownNow(); + multiFetch.close(); + return recvBuf; + } + + /** + * Private class that does real fetch job. + * Users are not allowed to directly use this class, as its run() + * function may run forever. + */ + private class MultiFetch implements Runnable { + private Selector selector; + + /** + * main entry function for fetching. + * + * Server responses are stored in TNonblocingMultiFetchClient.recvBuf, + * and fetch statistics is in TNonblockingMultiFetchClient.stats. + * + * Sanity check for parameters has been done in + * TNonblockingMultiFetchClient before calling this function. + */ + public void run() { + long t1 = System.currentTimeMillis(); + + int numTotalServers = servers.size(); + stats.setNumTotalServers(numTotalServers); + + // buffer for receiving response from servers + recvBuf = new ByteBuffer[numTotalServers]; + // buffer for sending request + ByteBuffer sendBuf[] = new ByteBuffer[numTotalServers]; + long numBytesRead[] = new long[numTotalServers]; + int frameSize[] = new int[numTotalServers]; + boolean hasReadFrameSize[] = new boolean[numTotalServers]; + + try { + selector = Selector.open(); + } catch (IOException e) { + LOG.error("selector opens error: "+e.toString()); + return; + } + + for (int i = 0; i < numTotalServers; i++) { + // create buffer to send request to server. + sendBuf[i] = requestBuf.duplicate(); + // create buffer to read response's frame size from server + recvBuf[i] = ByteBuffer.allocate(4); + stats.incTotalRecvBufBytes(4); + + InetSocketAddress server = servers.get(i); + SocketChannel s = null; + SelectionKey key = null; + try { + s = SocketChannel.open(); + s.configureBlocking(false); + // now this method is non-blocking + s.connect(server); + key = s.register(selector, s.validOps()); + // attach index of the key + key.attach(i); + } catch (Exception e) { + stats.incNumConnectErrorServers(); + String err = String.format("set up socket to server %s error: %s", + server.toString(), e.toString()); + LOG.error(err); + // free resource + if (s != null) { + try {s.close();} catch (Exception ex) {} + } + if (key != null) { + key.cancel(); + } + } + } + + // wait for events + while (stats.getNumReadCompletedServers() + + stats.getNumConnectErrorServers() < stats.getNumTotalServers()) { + // if the thread is interrupted (e.g., task is cancelled) + if (Thread.currentThread().isInterrupted()) { + return; + } + + try{ + selector.select(); + } catch (Exception e) { + LOG.error("selector selects error: "+e.toString()); + continue; + } + + Iterator it = selector.selectedKeys().iterator(); + while (it.hasNext()) { + SelectionKey selKey = it.next(); + it.remove(); + + // get previously attached index + int index = (Integer)selKey.attachment(); + + if (selKey.isValid() && selKey.isConnectable()) { + // if this socket throws an exception (e.g., connection refused), + // print error msg and skip it. + try { + SocketChannel sChannel = (SocketChannel)selKey.channel(); + sChannel.finishConnect(); + } catch (Exception e) { + stats.incNumConnectErrorServers(); + String err = String.format("socket %d connects to server %s " + + "error: %s", + index, servers.get(index).toString(), e.toString()); + LOG.error(err); + } + } + + if (selKey.isValid() && selKey.isWritable()) { + if (sendBuf[index].hasRemaining()) { + // if this socket throws an exception, print error msg and + // skip it. + try { + SocketChannel sChannel = (SocketChannel)selKey.channel(); + sChannel.write(sendBuf[index]); + } catch (Exception e) { + String err = String.format("socket %d writes to server %s " + + "error: %s", + index, servers.get(index).toString(), e.toString()); + LOG.error(err); + } + } + } + + if (selKey.isValid() && selKey.isReadable()) { + // if this socket throws an exception, print error msg and + // skip it. + try { + SocketChannel sChannel = (SocketChannel)selKey.channel(); + int bytesRead = sChannel.read(recvBuf[index]); + + if (bytesRead > 0) { + numBytesRead[index] += bytesRead; + + if (!hasReadFrameSize[index] && + recvBuf[index].remaining()==0) { + // if the frame size has been read completely, then prepare + // to read the actual frame. + frameSize[index] = recvBuf[index].getInt(0); + + if (frameSize[index] <= 0) { + stats.incNumInvalidFrameSize(); + String err = String.format("Read an invalid frame size %d" + + " from %s. Does the server use TFramedTransport? ", + frameSize[index], servers.get(index).toString()); + LOG.error(err); + sChannel.close(); + continue; + } + + if (frameSize[index] + 4 > stats.getMaxResponseBytes()) { + stats.setMaxResponseBytes(frameSize[index]+4); + } + + if (frameSize[index] + 4 > maxRecvBufBytesPerServer) { + stats.incNumOverflowedRecvBuf(); + String err = String.format("Read frame size %d from %s," + + " total buffer size would exceed limit %d", + frameSize[index], servers.get(index).toString(), + maxRecvBufBytesPerServer); + LOG.error(err); + sChannel.close(); + continue; + } + + // reallocate buffer for actual frame data + recvBuf[index] = ByteBuffer.allocate(frameSize[index] + 4); + recvBuf[index].putInt(frameSize[index]); + + stats.incTotalRecvBufBytes(frameSize[index]); + hasReadFrameSize[index] = true; + } + + if (hasReadFrameSize[index] && + numBytesRead[index] >= frameSize[index]+4) { + // has read all data + sChannel.close(); + stats.incNumReadCompletedServers(); + long t2 = System.currentTimeMillis(); + stats.setReadTime(t2-t1); + } + } + } catch (Exception e) { + String err = String.format("socket %d reads from server %s " + + "error: %s", + index, servers.get(index).toString(), e.toString()); + LOG.error(err); + } + } + } + } + } + + /** + * dispose any resource allocated + */ + public void close() { + try { + if (selector.isOpen()) { + Iterator it = selector.keys().iterator(); + while (it.hasNext()) { + SelectionKey selKey = it.next(); + SocketChannel sChannel = (SocketChannel)selKey.channel(); + sChannel.close(); + } + + selector.close(); + } + } catch (IOException e) { + LOG.error("free resource error: "+e.toString()); + } + } + } +} \ No newline at end of file diff --git a/lib/java/src/org/apache/thrift/TNonblockingMultiFetchStats.java b/lib/java/src/org/apache/thrift/TNonblockingMultiFetchStats.java new file mode 100755 index 00000000..90b86208 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TNonblockingMultiFetchStats.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.thrift; + +/** + * This class keeps track of statistics for TNonblockinMultiFetchClient. + */ +public class TNonblockingMultiFetchStats { + private int numTotalServers; + private int numReadCompletedServers; + private int numConnectErrorServers; + private int totalRecvBufBytes; + private int maxResponseBytes; + private int numOverflowedRecvBuf; + private int numInvalidFrameSize; + // time from the beginning of fetch() function to the reading finish + // time of the last socket (in milli-second) + private long readTime; + + public TNonblockingMultiFetchStats() { + clear(); + } + + public void clear() { + numTotalServers = 0; + numReadCompletedServers = 0; + numConnectErrorServers = 0; + totalRecvBufBytes = 0; + maxResponseBytes = 0; + numOverflowedRecvBuf = 0; + numInvalidFrameSize = 0; + readTime = 0; + } + + public String toString() { + String stats = String.format("numTotalServers=%d, " + + "numReadCompletedServers=%d, numConnectErrorServers=%d, " + + "numUnresponsiveServers=%d, totalRecvBufBytes=%fM, " + + "maxResponseBytes=%d, numOverflowedRecvBuf=%d, " + + "numInvalidFrameSize=%d, readTime=%dms", + numTotalServers, numReadCompletedServers, numConnectErrorServers, + (numTotalServers-numReadCompletedServers-numConnectErrorServers), + totalRecvBufBytes/1024.0/1024, maxResponseBytes, numOverflowedRecvBuf, + numInvalidFrameSize, readTime); + return stats; + } + + public void setNumTotalServers(int val) { numTotalServers = val; } + public void setMaxResponseBytes(int val) { maxResponseBytes = val; } + public void setReadTime(long val) { readTime = val; } + public void incNumReadCompletedServers() { numReadCompletedServers++; } + public void incNumConnectErrorServers() { numConnectErrorServers++; } + public void incNumOverflowedRecvBuf() { numOverflowedRecvBuf++; } + public void incTotalRecvBufBytes(int val) { totalRecvBufBytes += val; } + public void incNumInvalidFrameSize() { numInvalidFrameSize++; } + + public int getMaxResponseBytes() { return maxResponseBytes; } + public int getNumReadCompletedServers() { return numReadCompletedServers; } + public int getNumConnectErrorServers() { return numConnectErrorServers; } + public int getNumTotalServers() { return numTotalServers; } + public int getNumOverflowedRecvBuf() { return numOverflowedRecvBuf;} + public int getTotalRecvBufBytes() { return totalRecvBufBytes;} + public int getNumInvalidFrameSize() { return numInvalidFrameSize; } + public long getReadTime() { return readTime; } +} -- 2.17.1