From: Bryan Duxbury Date: Tue, 20 Sep 2011 22:53:31 +0000 (+0000) Subject: THRIFT-1339. java: Extend Tuple Protocol to TUnions X-Git-Tag: 0.8.0~78 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=18784d7ccc323b960a301109c926bffc8616cd33;p=common%2Fthrift.git THRIFT-1339. java: Extend Tuple Protocol to TUnions This patch implements TupleProtocol (and general Scheme support) to TUnion descendants. Patch: Armaan Sarkar git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1173418 13f79535-47bb-0310-9956-ffa450edef68 --- diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index db33e934..85090b83 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -142,8 +142,10 @@ public: void generate_union_is_set_methods(ofstream& out, t_struct* tstruct); void generate_union_abstract_methods(ofstream& out, t_struct* tstruct); void generate_check_type(ofstream& out, t_struct* tstruct); - void generate_read_value(ofstream& out, t_struct* tstruct); - void generate_write_value(ofstream& out, t_struct* tstruct); + void generate_standard_scheme_read_value(ofstream& out, t_struct* tstruct); + void generate_standard_scheme_write_value(ofstream& out, t_struct* tstruct); + void generate_tuple_scheme_read_value(ofstream& out, t_struct* tstruct); + void generate_tuple_scheme_write_value(ofstream& out, t_struct* tstruct); void generate_get_field_desc(ofstream& out, t_struct* tstruct); void generate_get_struct_desc(ofstream& out, t_struct* tstruct); void generate_get_field_name(ofstream& out, t_struct* tstruct); @@ -906,9 +908,13 @@ void t_java_generator::generate_union_is_set_methods(ofstream& out, t_struct* ts void t_java_generator::generate_union_abstract_methods(ofstream& out, t_struct* tstruct) { generate_check_type(out, tstruct); out << endl; - generate_read_value(out, tstruct); + generate_standard_scheme_read_value(out, tstruct); out << endl; - generate_write_value(out, tstruct); + generate_standard_scheme_write_value(out, tstruct); + out << endl; + generate_tuple_scheme_read_value(out, tstruct); + out << endl; + generate_tuple_scheme_write_value(out, tstruct); out << endl; generate_get_field_desc(out, tstruct); out << endl; @@ -954,9 +960,9 @@ void t_java_generator::generate_check_type(ofstream& out, t_struct* tstruct) { indent(out) << "}" << endl; } -void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { +void t_java_generator::generate_standard_scheme_read_value(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; - indent(out) << "protected Object readValue(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TField field) throws org.apache.thrift.TException {" << endl; + indent(out) << "protected Object standardSchemeReadValue(org.apache.thrift.protocol.TProtocol iprot, org.apache.thrift.protocol.TField field) throws org.apache.thrift.TException {" << endl; indent_up(); @@ -995,8 +1001,7 @@ void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { indent_down(); indent(out) << "} else {" << endl; - indent_up(); - indent(out) << "org.apache.thrift.protocol.TProtocolUtil.skip(iprot, field.type);" << endl; + indent_up(); indent(out) << "return null;" << endl; indent_down(); indent(out) << "}" << endl; @@ -1005,9 +1010,9 @@ void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { indent(out) << "}" << endl; } -void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { +void t_java_generator::generate_standard_scheme_write_value(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; - indent(out) << "protected void writeValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl; + indent(out) << "protected void standardSchemeWriteValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl; indent_up(); @@ -1040,6 +1045,83 @@ void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { indent(out) << "}" << endl; } +void t_java_generator::generate_tuple_scheme_read_value(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected Object tupleSchemeReadValue(org.apache.thrift.protocol.TProtocol iprot, short fieldID) throws org.apache.thrift.TException {" << endl; + + indent_up(); + + indent(out) << "_Fields setField = _Fields.findByThriftId(fieldID);" << endl; + indent(out) << "if (setField != null) {" << endl; + indent_up(); + indent(out) << "switch (setField) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; + indent_up(); + indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << ";" << endl; + generate_deserialize_field(out, field, ""); + indent(out) << "return " << field->get_name() << ";" << endl; + indent_down(); + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalStateException(\"setField wasn't null, but didn't match any of the case statements!\");" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "} else {" << endl; + indent_up(); + indent(out) << "return null;" << endl; + indent_down(); + indent(out) << "}" << endl; + indent_down(); + indent(out) << "}" << endl; +} + +void t_java_generator::generate_tuple_scheme_write_value(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected void tupleSchemeWriteValue(org.apache.thrift.protocol.TProtocol oprot) throws org.apache.thrift.TException {" << endl; + + indent_up(); + + indent(out) << "switch (setField_) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; + indent_up(); + indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() + << " = (" << type_name(field->get_type(), true, false) << ")value_;" << endl; + generate_serialize_field(out, field, ""); + indent(out) << "return;" << endl; + indent_down(); + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalStateException(\"Cannot write union with unknown field \" + setField_);" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + + indent(out) << "}" << endl; +} + void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; indent(out) << "protected org.apache.thrift.protocol.TField getFieldDesc(_Fields setField) {" << endl; diff --git a/lib/java/src/org/apache/thrift/TUnion.java b/lib/java/src/org/apache/thrift/TUnion.java index 240163f2..0173f9b9 100644 --- a/lib/java/src/org/apache/thrift/TUnion.java +++ b/lib/java/src/org/apache/thrift/TUnion.java @@ -25,10 +25,15 @@ import java.util.Map; import java.util.Set; import java.nio.ByteBuffer; +import org.apache.thrift.TUnion.TUnionStandardScheme; import org.apache.thrift.protocol.TField; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolException; import org.apache.thrift.protocol.TStruct; +import org.apache.thrift.scheme.IScheme; +import org.apache.thrift.scheme.SchemeFactory; +import org.apache.thrift.scheme.StandardScheme; +import org.apache.thrift.scheme.TupleScheme; public abstract class TUnion, F extends TFieldIdEnum> implements TBase { @@ -39,6 +44,12 @@ public abstract class TUnion, F extends TFieldIdEnum> impl setField_ = null; value_ = null; } + + private static final Map, SchemeFactory> schemes = new HashMap, SchemeFactory>(); + static { + schemes.put(StandardScheme.class, new TUnionStandardSchemeFactory()); + schemes.put(TupleScheme.class, new TUnionTupleSchemeFactory()); + } protected TUnion(F setField, Object value) { setFieldValue(setField, value); @@ -125,24 +136,7 @@ public abstract class TUnion, F extends TFieldIdEnum> impl } public void read(TProtocol iprot) throws TException { - setField_ = null; - value_ = null; - - iprot.readStructBegin(); - - TField field = iprot.readFieldBegin(); - - value_ = readValue(iprot, field); - if (value_ != null) { - setField_ = enumForId(field.id); - } - - iprot.readFieldEnd(); - // this is so that we will eat the stop byte. we could put a check here to - // make sure that it actually *is* the stop byte, but it's faster to do it - // this way. - iprot.readFieldBegin(); - iprot.readStructEnd(); + schemes.get(iprot.getScheme()).getScheme().read(iprot, this); } public void setFieldValue(F fieldId, Object value) { @@ -156,15 +150,7 @@ public abstract class TUnion, F extends TFieldIdEnum> impl } public void write(TProtocol oprot) throws TException { - if (getSetField() == null || getFieldValue() == null) { - throw new TProtocolException("Cannot write a TUnion with no set value!"); - } - oprot.writeStructBegin(getStructDesc()); - oprot.writeFieldBegin(getFieldDesc(setField_)); - writeValue(oprot); - oprot.writeFieldEnd(); - oprot.writeFieldStop(); - oprot.writeStructEnd(); + schemes.get(oprot.getScheme()).getScheme().write(oprot, this); } /** @@ -181,9 +167,11 @@ public abstract class TUnion, F extends TFieldIdEnum> impl * @param field * @return read Object based on the field header, as specified by the argument. */ - protected abstract Object readValue(TProtocol iprot, TField field) throws TException; - - protected abstract void writeValue(TProtocol oprot) throws TException; + protected abstract Object standardSchemeReadValue(TProtocol iprot, TField field) throws TException; + protected abstract void standardSchemeWriteValue(TProtocol oprot) throws TException; + + protected abstract Object tupleSchemeReadValue(TProtocol iprot, short fieldID) throws TException; + protected abstract void tupleSchemeWriteValue(TProtocol oprot) throws TException; protected abstract TStruct getStructDesc(); @@ -216,4 +204,77 @@ public abstract class TUnion, F extends TFieldIdEnum> impl this.setField_ = null; this.value_ = null; } + + private static class TUnionStandardSchemeFactory implements SchemeFactory { + public TUnionStandardScheme getScheme() { + return new TUnionStandardScheme(); + } + } + + public static class TUnionStandardScheme extends StandardScheme { + + @Override + public void read(TProtocol iprot, TUnion struct) throws TException { + struct.setField_ = null; + struct.value_ = null; + + iprot.readStructBegin(); + + TField field = iprot.readFieldBegin(); + + struct.value_ = struct.standardSchemeReadValue(iprot, field); + if (struct.value_ != null) { + struct.setField_ = struct.enumForId(field.id); + } + + iprot.readFieldEnd(); + // this is so that we will eat the stop byte. we could put a check here to + // make sure that it actually *is* the stop byte, but it's faster to do it + // this way. + iprot.readFieldBegin(); + iprot.readStructEnd(); + } + + @Override + public void write(TProtocol oprot, TUnion struct) throws TException { + if (struct.getSetField() == null || struct.getFieldValue() == null) { + throw new TProtocolException("Cannot write a TUnion with no set value!"); + } + oprot.writeStructBegin(struct.getStructDesc()); + oprot.writeFieldBegin(struct.getFieldDesc(struct.setField_)); + struct.standardSchemeWriteValue(oprot); + oprot.writeFieldEnd(); + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + } + + private static class TUnionTupleSchemeFactory implements SchemeFactory { + public TUnionStandardScheme getScheme() { + return new TUnionStandardScheme(); + } + } + + public static class TUnionTupleScheme extends TupleScheme { + + @Override + public void read(TProtocol iprot, TUnion struct) throws TException { + struct.setField_ = null; + struct.value_ = null; + short fieldID = iprot.readI16(); + struct.value_ = struct.tupleSchemeReadValue(iprot, fieldID); + if (struct.value_ != null) { + struct.setField_ = struct.enumForId(fieldID); + } + } + + @Override + public void write(TProtocol oprot, TUnion struct) throws TException { + if (struct.getSetField() == null || struct.getFieldValue() == null) { + throw new TProtocolException("Cannot write a TUnion with no set value!"); + } + oprot.writeI16(struct.setField_.getThriftFieldId()); + struct.tupleSchemeWriteValue(oprot); + } + } } diff --git a/lib/java/test/org/apache/thrift/TestTUnion.java b/lib/java/test/org/apache/thrift/TestTUnion.java index e9d9825c..f1e6f0e1 100644 --- a/lib/java/test/org/apache/thrift/TestTUnion.java +++ b/lib/java/test/org/apache/thrift/TestTUnion.java @@ -34,6 +34,7 @@ import junit.framework.TestCase; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TTupleProtocol; import org.apache.thrift.transport.TMemoryBuffer; import thrift.test.ComparableUnion; @@ -185,6 +186,41 @@ public class TestTUnion extends TestCase { swau.write(proto); new Empty().read(proto); } + + public void testTupleProtocolSerialization () throws Exception { + TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25); + union.setI32_set(Collections.singleton(42)); + + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = new TTupleProtocol(buf); + + union.write(proto); + + TestUnion u2 = new TestUnion(); + + u2.read(proto); + + assertEquals(u2, union); + + StructWithAUnion swau = new StructWithAUnion(u2); + + buf = new TMemoryBuffer(0); + proto = new TBinaryProtocol(buf); + + swau.write(proto); + + StructWithAUnion swau2 = new StructWithAUnion(); + assertFalse(swau2.equals(swau)); + swau2.read(proto); + assertEquals(swau2, swau); + + // this should NOT throw an exception. + buf = new TMemoryBuffer(0); + proto = new TTupleProtocol(buf); + + swau.write(proto); + new Empty().read(proto); + } public void testSkip() throws Exception { TestUnion tu = TestUnion.string_field("string");