THRIFT-710. java: TBinaryProtocol should access buffers directly when possible
authorBryan Duxbury <bryanduxbury@apache.org>
Tue, 2 Mar 2010 18:49:02 +0000 (18:49 +0000)
committerBryan Duxbury <bryanduxbury@apache.org>
Tue, 2 Mar 2010 18:49:02 +0000 (18:49 +0000)
This patch makes TBinaryProtocol use direct buffer access in the relevant methods. Performance testing indicates as much as 2x speed boost, though your mileage may vary.

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@918147 13f79535-47bb-0310-9956-ffa450edef68

lib/java/build.xml
lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java
lib/java/test/org/apache/thrift/test/ProtocolTestBase.java [new file with mode: 0644]
lib/java/test/org/apache/thrift/test/SerializationBenchmark.java
lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java [new file with mode: 0644]
lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java

index afe5afb..e8e5d78 100644 (file)
       classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.TCompactProtocolTest"
       classpathref="test.classpath" failonerror="true" />
+    <java classname="org.apache.thrift.test.TBinaryProtocolTest"
+      classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.IdentityTest"
       classpathref="test.classpath" failonerror="true" />
     <java classname="org.apache.thrift.test.EqualityTest"
index e9bd8b7..83c85e1 100644 (file)
@@ -244,41 +244,75 @@ public class TBinaryProtocol extends TProtocol {
 
   private byte[] bin = new byte[1];
   public byte readByte() throws TException {
+    if (trans_.getBytesRemainingInBuffer() >= 1) {
+      byte b = trans_.getBuffer()[trans_.getBufferPosition()];
+      trans_.consumeBuffer(1);
+      return b;
+    }
     readAll(bin, 0, 1);
     return bin[0];
   }
 
   private byte[] i16rd = new byte[2];
   public short readI16() throws TException {
-    readAll(i16rd, 0, 2);
+    byte[] buf = i16rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 2) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(2);
+    } else {
+      readAll(i16rd, 0, 2);
+    }
+
     return
       (short)
-      (((i16rd[0] & 0xff) << 8) |
-       ((i16rd[1] & 0xff)));
+      (((buf[off] & 0xff) << 8) |
+       ((buf[off+1] & 0xff)));
   }
 
   private byte[] i32rd = new byte[4];
   public int readI32() throws TException {
-    readAll(i32rd, 0, 4);
+    byte[] buf = i32rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 4) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(4);
+    } else {
+      readAll(i32rd, 0, 4);
+    }
     return
-      ((i32rd[0] & 0xff) << 24) |
-      ((i32rd[1] & 0xff) << 16) |
-      ((i32rd[2] & 0xff) <<  8) |
-      ((i32rd[3] & 0xff));
+      ((buf[off] & 0xff) << 24) |
+      ((buf[off+1] & 0xff) << 16) |
+      ((buf[off+2] & 0xff) <<  8) |
+      ((buf[off+3] & 0xff));
   }
 
   private byte[] i64rd = new byte[8];
   public long readI64() throws TException {
-    readAll(i64rd, 0, 8);
+    byte[] buf = i64rd;
+    int off = 0;
+
+    if (trans_.getBytesRemainingInBuffer() >= 8) {
+      buf = trans_.getBuffer();
+      off = trans_.getBufferPosition();
+      trans_.consumeBuffer(8);
+    } else {
+      readAll(i64rd, 0, 8);
+    }
+
     return
-      ((long)(i64rd[0] & 0xff) << 56) |
-      ((long)(i64rd[1] & 0xff) << 48) |
-      ((long)(i64rd[2] & 0xff) << 40) |
-      ((long)(i64rd[3] & 0xff) << 32) |
-      ((long)(i64rd[4] & 0xff) << 24) |
-      ((long)(i64rd[5] & 0xff) << 16) |
-      ((long)(i64rd[6] & 0xff) <<  8) |
-      ((long)(i64rd[7] & 0xff));
+      ((long)(buf[off]   & 0xff) << 56) |
+      ((long)(buf[off+1] & 0xff) << 48) |
+      ((long)(buf[off+2] & 0xff) << 40) |
+      ((long)(buf[off+3] & 0xff) << 32) |
+      ((long)(buf[off+4] & 0xff) << 24) |
+      ((long)(buf[off+5] & 0xff) << 16) |
+      ((long)(buf[off+6] & 0xff) <<  8) |
+      ((long)(buf[off+7] & 0xff));
   }
 
   public double readDouble() throws TException {
@@ -287,6 +321,17 @@ public class TBinaryProtocol extends TProtocol {
 
   public String readString() throws TException {
     int size = readI32();
+
+    if (trans_.getBytesRemainingInBuffer() >= size) {
+      try {
+        String s = new String(trans_.getBuffer(), trans_.getBufferPosition(), size, "UTF-8");
+        trans_.consumeBuffer(size);
+        return s;
+      } catch (UnsupportedEncodingException e) {
+        throw new TException("JVM DOES NOT SUPPORT UTF-8");
+      }
+    }
+
     return readStringBody(size);
   }
 
diff --git a/lib/java/test/org/apache/thrift/test/ProtocolTestBase.java b/lib/java/test/org/apache/thrift/test/ProtocolTestBase.java
new file mode 100644 (file)
index 0000000..205f4fe
--- /dev/null
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.thrift.test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.thrift.TBase;
+import org.apache.thrift.TDeserializer;
+import org.apache.thrift.TException;
+import org.apache.thrift.TSerializer;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TField;
+import org.apache.thrift.protocol.TMessage;
+import org.apache.thrift.protocol.TMessageType;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.protocol.TStruct;
+import org.apache.thrift.protocol.TType;
+import org.apache.thrift.transport.TMemoryBuffer;
+
+import thrift.test.CompactProtoTestStruct;
+import thrift.test.HolyMoley;
+import thrift.test.Nesting;
+import thrift.test.OneOfEach;
+import thrift.test.Srv;
+
+public abstract class ProtocolTestBase {
+  
+  protected abstract TProtocolFactory getFactory();
+
+  public void main() throws Exception {
+    testNakedByte();
+    for (int i = 0; i < 128; i++) {
+      testByteField((byte)i);
+      testByteField((byte)-i);
+    }
+    
+    for (int s : Arrays.asList(0, 1, 7, 150, 15000, 0x7fff, -1, -7, -150, -15000, -0x7fff)) {
+      testNakedI16((short)s);
+      testI16Field((short)s);
+    }
+
+    for (int i : Arrays.asList(0, 1, 7, 150, 15000, 31337, 0xffff, 0xffffff, -1, -7, -150, -15000, -0xffff, -0xffffff)) {
+      testNakedI32(i);
+      testI32Field(i);
+    }
+
+    testNakedI64(0);
+    testI64Field(0);
+    for (int i = 0; i < 62; i++) {
+      testNakedI64(1L << i);
+      testNakedI64(-(1L << i));
+      testI64Field(1L << i);
+      testI64Field(-(1L << i));
+    }
+
+    testDouble();
+
+    for (String s : Arrays.asList("", "short", "borderlinetiny", "a bit longer than the smallest possible")) {
+      testNakedString(s);
+      testStringField(s);
+    }
+
+    for (byte[] b : Arrays.asList(new byte[0], new byte[]{0,1,2,3,4,5,6,7,8,9,10}, new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}, new byte[128])) {
+      testNakedBinary(b);
+      testBinaryField(b);
+    }
+
+    testSerialization(OneOfEach.class, Fixtures.oneOfEach);
+    testSerialization(Nesting.class, Fixtures.nesting);
+    testSerialization(HolyMoley.class, Fixtures.holyMoley);
+    testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct);
+
+    testMessage();
+
+    testServerRequest();
+
+    testTDeserializer();
+  }
+
+  public void testNakedByte() throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeByte((byte)123);
+    byte out = proto.readByte();
+    if (out != 123) {
+      throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out);
+    }
+  }
+
+  public void testByteField(final byte b) throws Exception {
+    testStructField(new StructFieldTestCase(TType.BYTE, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeByte(b);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        byte result = proto.readByte();
+        if (result != b) {
+          throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI16(short n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI16(n);
+    // System.out.println(buf.inspect());
+    int out = proto.readI16();
+    if (out != n) {
+      throw new RuntimeException("I16 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI16Field(final short n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I16, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI16(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        short result = proto.readI16();
+        if (result != n) {
+          throw new RuntimeException("I16 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI32(int n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI32(n);
+    // System.out.println(buf.inspect());
+    int out = proto.readI32();
+    if (out != n) {
+      throw new RuntimeException("I32 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI32Field(final int n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I32, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI32(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        int result = proto.readI32();
+        if (result != n) {
+          throw new RuntimeException("I32 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedI64(long n) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeI64(n);
+    // System.out.println(buf.inspect());
+    long out = proto.readI64();
+    if (out != n) {
+      throw new RuntimeException("I64 was supposed to be " + n + " but was " + out);
+    }
+  }
+
+  public void testI64Field(final long n) throws Exception {
+    testStructField(new StructFieldTestCase(TType.I64, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeI64(n);
+      }
+
+      public void readMethod(TProtocol proto) throws TException {
+        long result = proto.readI64();
+        if (result != n) {
+          throw new RuntimeException("I64 was supposed to be " + n + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testDouble() throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(1000);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeDouble(123.456);
+    double out = proto.readDouble();
+    if (out != 123.456) {
+      throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out);
+    }
+  }
+
+  public void testNakedString(String str) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeString(str);
+    // System.out.println(buf.inspect());
+    String out = proto.readString();
+    if (!str.equals(out)) {
+      throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'");
+    }
+  }
+  
+  public void testStringField(final String str) throws Exception {
+    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeString(str);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        String result = proto.readString();
+        if (!result.equals(str)) {
+          throw new RuntimeException("String was supposed to be " + str + " but was " + result);
+        }
+      }
+    });
+  }
+
+  public void testNakedBinary(byte[] data) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+    proto.writeBinary(data);
+    // System.out.println(buf.inspect());
+    byte[] out = proto.readBinary();
+    if (!Arrays.equals(data, out)) {
+      throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'");
+    }
+  }
+
+  public void testBinaryField(final byte[] data) throws Exception {
+    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
+      public void writeMethod(TProtocol proto) throws TException {
+        proto.writeBinary(data);
+      }
+      
+      public void readMethod(TProtocol proto) throws TException {
+        byte[] result = proto.readBinary();
+        if (!Arrays.equals(data, result)) {
+          throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'");
+        }
+      }
+    });
+    
+  }
+
+  public <T extends TBase> void testSerialization(Class<T> klass, T obj) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TBinaryProtocol binproto = new TBinaryProtocol(buf);
+
+    try {
+      obj.write(binproto);
+      // System.out.println("Size in binary protocol: " + buf.length());
+
+      buf = new TMemoryBuffer(0);
+      TProtocol proto = getFactory().getProtocol(buf);
+
+      obj.write(proto);
+      System.out.println("Size in " +  proto.getClass().getSimpleName() + ": " + buf.length());
+      // System.out.println(buf.inspect());
+
+      T objRead = klass.newInstance();
+      objRead.read(proto);
+      if (!obj.equals(objRead)) {
+        System.out.println("Expected: " + obj.toString());
+        System.out.println("Actual: " + objRead.toString());
+        // System.out.println(buf.inspect());
+        throw new RuntimeException("Objects didn't match!");
+      }
+    } catch (Exception e) {
+      System.out.println(buf.inspect());
+      throw e;
+    }
+  }
+
+  public void testMessage() throws Exception {
+    List<TMessage> msgs = Arrays.asList(new TMessage[]{
+      new TMessage("short message name", TMessageType.CALL, 0),
+      new TMessage("1", TMessageType.REPLY, 12345),
+      new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
+      new TMessage("Janky", TMessageType.CALL, 0),
+    });
+
+    for (TMessage msg : msgs) {
+      TMemoryBuffer buf = new TMemoryBuffer(0);
+      TProtocol proto = getFactory().getProtocol(buf);
+      TMessage output = null;
+
+      proto.writeMessageBegin(msg);
+      proto.writeMessageEnd();
+
+      output = proto.readMessageBegin();
+
+      if (!msg.equals(output)) {
+        throw new RuntimeException("Message was supposed to be " + msg + " but was " + output);
+      }
+    }
+  }
+
+  public void testServerRequest() throws Exception {
+    Srv.Iface handler = new Srv.Iface() {
+      public int Janky(int i32arg) throws TException {
+        return i32arg * 2;
+      }
+
+      public int primitiveMethod() throws TException {
+        return 0;
+      }
+
+      public CompactProtoTestStruct structMethod() throws TException {
+        return null;
+      }
+
+      public void voidMethod() throws TException {
+      }
+
+      public void methodWithDefaultArgs(int something) throws TException {
+      }
+    };
+
+    Srv.Processor testProcessor = new Srv.Processor(handler);
+
+    TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
+    TProtocol clientOutProto = getFactory().getProtocol(clientOutTrans);
+    TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
+    TProtocol clientInProto = getFactory().getProtocol(clientInTrans);
+
+    Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);
+
+    testClient.send_Janky(1);
+    // System.out.println(clientOutTrans.inspect());
+    testProcessor.process(clientOutProto, clientInProto);
+    // System.out.println(clientInTrans.inspect());
+    int result = testClient.recv_Janky();
+    if (result != 2) {
+      throw new RuntimeException("Got an unexpected result: " + result);
+    }
+  }
+
+  private void testTDeserializer() throws TException {
+    TSerializer ser = new TSerializer(getFactory());
+    byte[] bytes = ser.serialize(Fixtures.compactProtoTestStruct);
+
+    TDeserializer deser = new TDeserializer(getFactory());
+    CompactProtoTestStruct cpts = new CompactProtoTestStruct();
+    deser.deserialize(cpts, bytes);
+
+    if (!Fixtures.compactProtoTestStruct.equals(cpts)) {
+      throw new RuntimeException(Fixtures.compactProtoTestStruct + " and " + cpts + " do not match!");
+    }
+  }
+
+  //
+  // Helper methods
+  //
+
+  private static String bytesToString(byte[] bytes) {
+    String s = "";
+    for (int i = 0; i < bytes.length; i++) {
+      s += Integer.toHexString((int)bytes[i]) + " ";
+    }
+    return s;
+  }
+
+  private void testStructField(StructFieldTestCase testCase) throws Exception {
+    TMemoryBuffer buf = new TMemoryBuffer(0);
+    TProtocol proto = getFactory().getProtocol(buf);
+
+    TField field = new TField("test_field", testCase.type_, testCase.id_);
+    proto.writeStructBegin(new TStruct("test_struct"));
+    proto.writeFieldBegin(field);
+    testCase.writeMethod(proto);
+    proto.writeFieldEnd();
+    proto.writeStructEnd();
+
+    // System.out.println(buf.inspect());
+
+    proto.readStructBegin();
+    TField readField = proto.readFieldBegin();
+    // TODO: verify the field is as expected
+    if (!field.equals(readField)) {
+      throw new RuntimeException("Expected " + field + " but got " + readField);
+    }
+    testCase.readMethod(proto);
+    proto.readStructEnd();
+  }
+
+  public static abstract class StructFieldTestCase {
+    byte type_;
+    short id_;
+    public StructFieldTestCase(byte type, short id) {
+      type_ = type;
+      id_ = id;
+    }
+
+    public abstract void writeMethod(TProtocol proto) throws TException;
+    public abstract void readMethod(TProtocol proto) throws TException;
+  }
+}
index 92c96d3..9ba7102 100644 (file)
 
 package org.apache.thrift.test;
 
-import java.io.ByteArrayInputStream;
+import org.apache.thrift.TBase;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TMemoryBuffer;
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
 
-import org.apache.thrift.*;
-import org.apache.thrift.protocol.*;
-import org.apache.thrift.transport.*;
-
-import thrift.test.*;
+import thrift.test.OneOfEach;
 
 public class SerializationBenchmark {
   private final static int HOW_MANY = 10000000;
@@ -55,7 +58,7 @@ public class SerializationBenchmark {
     }
     long endTime = System.currentTimeMillis();
     
-    System.out.println("Test time: " + (endTime - startTime) + " ms");
+    System.out.println("Serialization test time: " + (endTime - startTime) + " ms");
   }
   
   public static <T extends TBase> void testDeserialization(TProtocolFactory factory, T object, Class<T> klass) throws Exception {
@@ -63,14 +66,14 @@ public class SerializationBenchmark {
     object.write(factory.getProtocol(buf));
     byte[] serialized = new byte[100*1024];
     buf.read(serialized, 0, 100*1024);
-    
+
     long startTime = System.currentTimeMillis();
     for (int i = 0; i < HOW_MANY; i++) {
       T o2 = klass.newInstance();
-      o2.read(factory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(serialized))));
+      o2.read(factory.getProtocol(new TMemoryInputTransport(serialized)));
     }
     long endTime = System.currentTimeMillis();
-    
-    System.out.println("Test time: " + (endTime - startTime) + " ms");
+
+    System.out.println("Deserialization test time: " + (endTime - startTime) + " ms");
   }
 }
\ No newline at end of file
diff --git a/lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java b/lib/java/test/org/apache/thrift/test/TBinaryProtocolTest.java
new file mode 100644 (file)
index 0000000..71839fe
--- /dev/null
@@ -0,0 +1,17 @@
+package org.apache.thrift.test;
+
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+
+public class TBinaryProtocolTest extends ProtocolTestBase {
+
+  public static void main(String[] args) throws Exception {
+    new TBinaryProtocolTest().main();
+  }
+  
+  @Override
+  protected TProtocolFactory getFactory() {
+    return new TBinaryProtocol.Factory();
+  }
+
+}
index 86ea57c..1642c42 100755 (executable)
 
 package org.apache.thrift.test;
 
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.thrift.TBase;
-import org.apache.thrift.TDeserializer;
-import org.apache.thrift.TException;
-import org.apache.thrift.TSerializer;
-import org.apache.thrift.protocol.TBinaryProtocol;
 import org.apache.thrift.protocol.TCompactProtocol;
-import org.apache.thrift.protocol.TField;
-import org.apache.thrift.protocol.TMessage;
-import org.apache.thrift.protocol.TMessageType;
-import org.apache.thrift.protocol.TProtocol;
 import org.apache.thrift.protocol.TProtocolFactory;
-import org.apache.thrift.protocol.TStruct;
-import org.apache.thrift.protocol.TType;
-import org.apache.thrift.transport.TMemoryBuffer;
-
-import thrift.test.CompactProtoTestStruct;
-import thrift.test.HolyMoley;
-import thrift.test.Nesting;
-import thrift.test.OneOfEach;
-import thrift.test.Srv;
-
-public class TCompactProtocolTest {
-
-  static TProtocolFactory factory = new TCompactProtocol.Factory();
-
-  public static void main(String[] args) throws Exception {
-    testNakedByte();
-    for (int i = 0; i < 128; i++) {
-      testByteField((byte)i);
-      testByteField((byte)-i);
-    }
-    
-    testNakedI16((short)0);
-    testNakedI16((short)1);
-    testNakedI16((short)15000);
-    testNakedI16((short)0x7fff);
-    testNakedI16((short)-1);
-    testNakedI16((short)-15000);
-    testNakedI16((short)-0x7fff);
-    
-    testI16Field((short)0);
-    testI16Field((short)1);
-    testI16Field((short)7);
-    testI16Field((short)150);
-    testI16Field((short)15000);
-    testI16Field((short)0x7fff);
-    testI16Field((short)-1);
-    testI16Field((short)-7);
-    testI16Field((short)-150);
-    testI16Field((short)-15000);
-    testI16Field((short)-0x7fff);
-    
-    testNakedI32(0);
-    testNakedI32(1);
-    testNakedI32(15000);
-    testNakedI32(0xffff);
-    testNakedI32(-1);
-    testNakedI32(-15000);
-    testNakedI32(-0xffff);
-    
-    testI32Field(0);
-    testI32Field(1);
-    testI32Field(7);
-    testI32Field(150);
-    testI32Field(15000);
-    testI32Field(31337);
-    testI32Field(0xffff);
-    testI32Field(0xffffff);
-    testI32Field(-1);
-    testI32Field(-7);
-    testI32Field(-150);
-    testI32Field(-15000);
-    testI32Field(-0xffff);
-    testI32Field(-0xffffff);
-    
-    testNakedI64(0);
-    for (int i = 0; i < 62; i++) {
-      testNakedI64(1L << i);
-      testNakedI64(-(1L << i));
-    }
-
-    testI64Field(0);
-    for (int i = 0; i < 62; i++) {
-      testI64Field(1L << i);
-      testI64Field(-(1L << i));
-    }
-
-    testDouble();
-    
-    testNakedString("");
-    testNakedString("short");
-    testNakedString("borderlinetiny");
-    testNakedString("a bit longer than the smallest possible");
-    
-    testStringField("");
-    testStringField("short");
-    testStringField("borderlinetiny");
-    testStringField("a bit longer than the smallest possible");
-    
-    testNakedBinary(new byte[]{});
-    testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10});
-    testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14});
-    testNakedBinary(new byte[128]);
-    
-    testBinaryField(new byte[]{});
-    testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10});
-    testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14});
-    testBinaryField(new byte[128]);
-    
-    testSerialization(OneOfEach.class, Fixtures.oneOfEach);
-    testSerialization(Nesting.class, Fixtures.nesting);
-    testSerialization(HolyMoley.class, Fixtures.holyMoley);
-    testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct);
-    
-    testMessage();
-    
-    testServerRequest();
-    
-    testTDeserializer();
-  }
-  
-  public static void testNakedByte() throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeByte((byte)123);
-    byte out = proto.readByte();
-    if (out != 123) {
-      throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out);
-    }
-  }
-  
-  public static void testByteField(final byte b) throws Exception {
-    testStructField(new StructFieldTestCase(TType.BYTE, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeByte(b);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        byte result = proto.readByte();
-        if (result != b) {
-          throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result);
-        }
-      }
-    });
-  }
 
-  public static void testNakedI16(short n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI16(n);
-    // System.out.println(buf.inspect());
-    int out = proto.readI16();
-    if (out != n) {
-      throw new RuntimeException("I16 was supposed to be " + n + " but was " + out);
-    }
-  }
-
-  public static void testI16Field(final short n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I16, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI16(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        short result = proto.readI16();
-        if (result != n) {
-          throw new RuntimeException("I16 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-  }
-  
-  public static void testNakedI32(int n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI32(n);
-    // System.out.println(buf.inspect());
-    int out = proto.readI32();
-    if (out != n) {
-      throw new RuntimeException("I32 was supposed to be " + n + " but was " + out);
-    }
-  }
-  
-  public static void testI32Field(final int n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I32, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI32(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        int result = proto.readI32();
-        if (result != n) {
-          throw new RuntimeException("I32 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-    
-  }
-
-  public static void testNakedI64(long n) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeI64(n);
-    // System.out.println(buf.inspect());
-    long out = proto.readI64();
-    if (out != n) {
-      throw new RuntimeException("I64 was supposed to be " + n + " but was " + out);
-    }
-  }
-  
-  public static void testI64Field(final long n) throws Exception {
-    testStructField(new StructFieldTestCase(TType.I64, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeI64(n);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        long result = proto.readI64();
-        if (result != n) {
-          throw new RuntimeException("I64 was supposed to be " + n + " but was " + result);
-        }
-      }
-    });
-  }
-    
-  public static void testDouble() throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(1000);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeDouble(123.456);
-    double out = proto.readDouble();
-    if (out != 123.456) {
-      throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out);
-    }
-  }
-    
-  public static void testNakedString(String str) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeString(str);
-    // System.out.println(buf.inspect());
-    String out = proto.readString();
-    if (!str.equals(out)) {
-      throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'");
-    }
-  }
+public class TCompactProtocolTest extends ProtocolTestBase {
   
-  public static void testStringField(final String str) throws Exception {
-    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeString(str);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        String result = proto.readString();
-        if (!result.equals(str)) {
-          throw new RuntimeException("String was supposed to be " + str + " but was " + result);
-        }
-      }
-    });
-  }
-
-  public static void testNakedBinary(byte[] data) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    proto.writeBinary(data);
-    // System.out.println(buf.inspect());
-    byte[] out = proto.readBinary();
-    if (!Arrays.equals(data, out)) {
-      throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'");
-    }
-  }
-
-  public static void testBinaryField(final byte[] data) throws Exception {
-    testStructField(new StructFieldTestCase(TType.STRING, (short)15) {
-      public void writeMethod(TProtocol proto) throws TException {
-        proto.writeBinary(data);
-      }
-      
-      public void readMethod(TProtocol proto) throws TException {
-        byte[] result = proto.readBinary();
-        if (!Arrays.equals(data, result)) {
-          throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'");
-        }
-      }
-    });
-    
-  }
-
-  public static <T extends TBase> void testSerialization(Class<T> klass, T obj) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TBinaryProtocol binproto = new TBinaryProtocol(buf);
-    
-    try {
-      obj.write(binproto);
-      // System.out.println("Size in binary protocol: " + buf.length());
-    
-      buf = new TMemoryBuffer(0);
-      TProtocol proto = factory.getProtocol(buf);
-    
-      obj.write(proto);
-      System.out.println("Size in compact protocol: " + buf.length());
-      // System.out.println(buf.inspect());
-    
-      T objRead = klass.newInstance();
-      objRead.read(proto);
-      if (!obj.equals(objRead)) {
-        System.out.println("Expected: " + obj.toString());
-        System.out.println("Actual: " + objRead.toString());
-        // System.out.println(buf.inspect());
-        throw new RuntimeException("Objects didn't match!");
-      }
-    } catch (Exception e) {
-      System.out.println(buf.inspect());
-      throw e;
-    }
-  }
-
-  public static void testMessage() throws Exception {
-    List<TMessage> msgs = Arrays.asList(new TMessage[]{
-      new TMessage("short message name", TMessageType.CALL, 0),
-      new TMessage("1", TMessageType.REPLY, 12345),
-      new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
-      new TMessage("Janky", TMessageType.CALL, 0),
-    });
-    
-    for (TMessage msg : msgs) {
-      TMemoryBuffer buf = new TMemoryBuffer(0);
-      TProtocol proto = factory.getProtocol(buf);
-      TMessage output = null;
-      
-      proto.writeMessageBegin(msg);
-      proto.writeMessageEnd();
-
-      output = proto.readMessageBegin();
-
-      if (!msg.equals(output)) {
-        throw new RuntimeException("Message was supposed to be " + msg + " but was " + output);
-      }
-    }
-  }
-
-  public static void testServerRequest() throws Exception {
-    Srv.Iface handler = new Srv.Iface() {
-      public int Janky(int i32arg) throws TException {
-        return i32arg * 2;
-      }
-
-      public int primitiveMethod() throws TException {
-        return 0;
-      }
-
-      public CompactProtoTestStruct structMethod() throws TException {
-        return null;
-      }
-
-      public void voidMethod() throws TException {
-      }
-
-      public void methodWithDefaultArgs(int something) throws TException {
-      }
-    };
-    
-    Srv.Processor testProcessor = new Srv.Processor(handler);
-
-    TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
-    TProtocol clientOutProto = factory.getProtocol(clientOutTrans);
-    TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
-    TProtocol clientInProto = factory.getProtocol(clientInTrans);
-    
-    Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);
-    
-    testClient.send_Janky(1);
-    // System.out.println(clientOutTrans.inspect());
-    testProcessor.process(clientOutProto, clientInProto);
-    // System.out.println(clientInTrans.inspect());
-    int result = testClient.recv_Janky();
-    if (result != 2) {
-      throw new RuntimeException("Got an unexpected result: " + result);
-    }
-  }
-
-  //
-  // Helper methods
-  //
-  
-  private static String bytesToString(byte[] bytes) {
-    String s = "";
-    for (int i = 0; i < bytes.length; i++) {
-      s += Integer.toHexString((int)bytes[i]) + " ";
-    }
-    return s;
+  public static void main(String[] args) throws Exception {
+    new TCompactProtocolTest().main();
   }
 
-  private static void testStructField(StructFieldTestCase testCase) throws Exception {
-    TMemoryBuffer buf = new TMemoryBuffer(0);
-    TProtocol proto = factory.getProtocol(buf);
-    
-    TField field = new TField("test_field", testCase.type_, testCase.id_);
-    proto.writeStructBegin(new TStruct("test_struct"));
-    proto.writeFieldBegin(field);
-    testCase.writeMethod(proto);
-    proto.writeFieldEnd();
-    proto.writeStructEnd();
-    
-    // System.out.println(buf.inspect());
-
-    proto.readStructBegin();
-    TField readField = proto.readFieldBegin();
-    // TODO: verify the field is as expected
-    if (!field.equals(readField)) {
-      throw new RuntimeException("Expected " + field + " but got " + readField);
-    }
-    testCase.readMethod(proto);
-    proto.readStructEnd();
-  }
-  
-  public static abstract class StructFieldTestCase {
-    byte type_;
-    short id_;
-    public StructFieldTestCase(byte type, short id) {
-      type_ = type;
-      id_ = id;
-    }
-    
-    public abstract void writeMethod(TProtocol proto) throws TException;
-    public abstract void readMethod(TProtocol proto) throws TException;
-  }
   
-  private static void testTDeserializer() throws TException {
-    TSerializer ser = new TSerializer(new TCompactProtocol.Factory());
-    byte[] bytes = ser.serialize(Fixtures.compactProtoTestStruct);
-    
-    TDeserializer deser = new TDeserializer(new TCompactProtocol.Factory());
-    CompactProtoTestStruct cpts = new CompactProtoTestStruct();
-    deser.deserialize(cpts, bytes);
-    
-    if (!Fixtures.compactProtoTestStruct.equals(cpts)) {
-      throw new RuntimeException(Fixtures.compactProtoTestStruct + " and " + cpts + " do not match!");
-    }
+  @Override
+  protected TProtocolFactory getFactory() {
+    return new TCompactProtocol.Factory();
   }
 }
\ No newline at end of file