From: Bryan Duxbury Date: Thu, 12 Nov 2009 20:52:25 +0000 (+0000) Subject: THRIFT-623. java: Use a Java enum to represent field ids in generated structs X-Git-Tag: 0.2.0~8 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=aa9fb5dc9de5f3cfbe086e6df8e7c6d3640c272c;p=common%2Fthrift.git THRIFT-623. java: Use a Java enum to represent field ids in generated structs git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@835538 13f79535-47bb-0310-9956-ffa450edef68 --- diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index 230d0d04..e3f54496 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -96,7 +96,6 @@ class t_java_generator : public t_oop_generator { void generate_java_struct_writer(std::ofstream& out, t_struct* tstruct); void generate_java_struct_tostring(std::ofstream& out, t_struct* tstruct); void generate_java_meta_data_map(std::ofstream& out, t_struct* tstruct); - void generate_java_field_name_map(std::ofstream& out, t_struct* tstruct); void generate_field_value_meta_data(std::ofstream& out, t_type* type); std::string get_java_type_string(t_type* type); void generate_reflection_setters(std::ostringstream& out, t_type* type, std::string field_name, std::string cap_name); @@ -309,8 +308,10 @@ string t_java_generator::java_type_imports() { "import java.util.ArrayList;\n" + "import java.util.Map;\n" + "import java.util.HashMap;\n" + + "import java.util.EnumMap;\n" + "import java.util.Set;\n" + "import java.util.HashSet;\n" + + "import java.util.EnumSet;\n" + "import java.util.Collections;\n" + "import java.util.BitSet;\n" + "import java.util.Arrays;\n" + @@ -671,7 +672,7 @@ void t_java_generator::generate_java_union(t_struct* tstruct) { indent(f_struct) << "public " << (is_final ? "final " : "") << "class " << tstruct->get_name() - << " extends TUnion "; + << " extends TUnion<" << tstruct->get_name() << "._Fields> "; if (is_comparable(tstruct)) { f_struct << "implements Comparable<" << type_name(tstruct) << "> "; @@ -690,8 +691,6 @@ void t_java_generator::generate_java_union(t_struct* tstruct) { generate_java_meta_data_map(f_struct, tstruct); - generate_java_field_name_map(f_struct, tstruct); - generate_union_constructor(f_struct, tstruct); f_struct << endl; @@ -722,7 +721,7 @@ void t_java_generator::generate_union_constructor(ofstream& out, t_struct* tstru indent(out) << " super();" << endl; indent(out) << "}" << endl << endl; - indent(out) << "public " << type_name(tstruct) << "(int setField, Object value) {" << endl; + indent(out) << "public " << type_name(tstruct) << "(_Fields setField, Object value) {" << endl; indent(out) << " super(setField, value);" << endl; indent(out) << "}" << endl << endl; @@ -762,7 +761,7 @@ void t_java_generator::generate_union_getters_and_setters(ofstream& out, t_struc generate_java_doc(out, field); indent(out) << "public " << type_name(field->get_type()) << " get" << get_cap_name(field->get_name()) << "() {" << endl; - indent(out) << " if (getSetField() == " << upcase_string(field->get_name()) << ") {" << endl; + indent(out) << " if (getSetField() == _Fields." << constant_name(field->get_name()) << ") {" << endl; indent(out) << " return (" << type_name(field->get_type(), true) << ")getFieldValue();" << endl; indent(out) << " } else {" << endl; indent(out) << " throw new RuntimeException(\"Cannot get field '" << field->get_name() @@ -777,7 +776,7 @@ void t_java_generator::generate_union_getters_and_setters(ofstream& out, t_struc if (type_can_be_null(field->get_type())) { indent(out) << " if (value == null) throw new NullPointerException();" << endl; } - indent(out) << " setField_ = " << upcase_string(field->get_name()) << ";" << endl; + indent(out) << " setField_ = _Fields." << constant_name(field->get_name()) << ";" << endl; indent(out) << " value_ = value;" << endl; indent(out) << "}" << endl; } @@ -793,11 +792,16 @@ void t_java_generator::generate_union_abstract_methods(ofstream& out, t_struct* generate_get_field_desc(out, tstruct); out << endl; generate_get_struct_desc(out, tstruct); + out << endl; + indent(out) << "@Override" << endl; + indent(out) << "protected _Fields enumForId(short id) {" << endl; + indent(out) << " return _Fields.findByThriftIdOrThrow(id);" << endl; + indent(out) << "}" << endl; } void t_java_generator::generate_check_type(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; - indent(out) << "protected void checkType(short setField, Object value) throws ClassCastException {" << endl; + indent(out) << "protected void checkType(_Fields setField, Object value) throws ClassCastException {" << endl; indent_up(); indent(out) << "switch (setField) {" << endl; @@ -809,7 +813,7 @@ void t_java_generator::generate_check_type(ofstream& out, t_struct* tstruct) { for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { t_field* field = (*m_iter); - indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; indent(out) << " if (value instanceof " << type_name(field->get_type(), true, false, true) << ") {" << endl; indent(out) << " break;" << endl; indent(out) << " }" << endl; @@ -835,7 +839,7 @@ void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { indent_up(); - indent(out) << "switch (field.id) {" << endl; + indent(out) << "switch (_Fields.findByThriftId(field.id)) {" << endl; indent_up(); const vector& members = tstruct->get_members(); @@ -844,9 +848,9 @@ void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { t_field* field = (*m_iter); - indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; indent_up(); - indent(out) << "if (field.type == " << upcase_string(field->get_name()) << "_FIELD_DESC.type) {" << endl; + indent(out) << "if (field.type == " << constant_name(field->get_name()) << "_FIELD_DESC.type) {" << endl; indent_up(); indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << ";" << endl; generate_deserialize_field(out, field, ""); @@ -872,7 +876,7 @@ void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; - indent(out) << "protected void writeValue(TProtocol oprot, short setField, Object value) throws TException {" << endl; + indent(out) << "protected void writeValue(TProtocol oprot, _Fields setField, Object value) throws TException {" << endl; indent_up(); @@ -885,7 +889,7 @@ void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { t_field* field = (*m_iter); - indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; indent_up(); indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << " = (" << type_name(field->get_type(), true, false) << ")getFieldValue();" << endl; @@ -909,7 +913,7 @@ void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) { indent(out) << "@Override" << endl; - indent(out) << "protected TField getFieldDesc(int setField) {" << endl; + indent(out) << "protected TField getFieldDesc(_Fields setField) {" << endl; indent_up(); const vector& members = tstruct->get_members(); @@ -920,8 +924,8 @@ void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { t_field* field = (*m_iter); - indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; - indent(out) << " return " << upcase_string(field->get_name()) << "_FIELD_DESC;" << endl; + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; + indent(out) << " return " << constant_name(field->get_name()) << "_FIELD_DESC;" << endl; } indent(out) << "default:" << endl; @@ -1020,7 +1024,7 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, if (is_exception) { out << "extends Exception "; } - out << "implements TBase, java.io.Serializable, Cloneable"; + out << "implements TBase<" << tstruct->get_name() << "._Fields>, java.io.Serializable, Cloneable"; if (is_comparable(tstruct)) { out << ", Comparable<" << type_name(tstruct) << ">"; @@ -1036,8 +1040,10 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, const vector& members = tstruct->get_members(); vector::const_iterator m_iter; + out << endl; + generate_field_descs(out, tstruct); - + out << endl; for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { @@ -1050,8 +1056,10 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, out << declare_field(*m_iter, false) << endl; } + out << endl; + generate_field_name_constants(out, tstruct); - + // isset data if (members.size() > 0) { out << endl; @@ -1078,8 +1086,6 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, bool all_optional_members = true; - generate_java_field_name_map(out, tstruct); - // Default constructor indent(out) << "public " << tstruct->get_name() << "() {" << endl; @@ -1389,38 +1395,37 @@ void t_java_generator::generate_java_struct_reader(ofstream& out, "}" << endl; // Switch statement on the field we are reading - indent(out) << - "switch (field.id)" << endl; + indent(out) << "_Fields fieldId = _Fields.findByThriftId(field.id);" << endl; + indent(out) << "if (fieldId == null) {" << endl; + indent(out) << " TProtocolUtil.skip(iprot, field.type);" << endl; + indent(out) << "} else {" << endl; + indent_up(); - scope_up(out); + indent(out) << "switch (fieldId)" << endl; - // Generate deserialization code for known cases - for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { - indent(out) << - "case " << upcase_string((*f_iter)->get_name()) << ":" << endl; - indent_up(); - indent(out) << - "if (field.type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; - indent_up(); + scope_up(out); - generate_deserialize_field(out, *f_iter, "this."); - generate_isset_set(out, *f_iter); - indent_down(); - out << - indent() << "} else { " << endl << - indent() << " TProtocolUtil.skip(iprot, field.type);" << endl << - indent() << "}" << endl << - indent() << "break;" << endl; - indent_down(); - } + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << constant_name((*f_iter)->get_name()) << ":" << endl; + indent_up(); + indent(out) << + "if (field.type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); - // In the default case we skip the field + generate_deserialize_field(out, *f_iter, "this."); + generate_isset_set(out, *f_iter); + indent_down(); out << - indent() << "default:" << endl << + indent() << "} else { " << endl << indent() << " TProtocolUtil.skip(iprot, field.type);" << endl << - indent() << " break;" << endl; + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } - scope_down(out); + scope_down(out); // Read field end marker indent(out) << @@ -1428,8 +1433,11 @@ void t_java_generator::generate_java_struct_reader(ofstream& out, scope_down(out); + indent_down(); + indent(out) << "}" << endl; + out << - indent() << "iprot.readStructEnd();" << endl << endl; + indent() << "iprot.readStructEnd();" << endl; // in non-beans style, check for required fields of primitive type // (which can be checked here but not in the general validate method) @@ -1625,7 +1633,7 @@ void t_java_generator::generate_java_struct_result_writer(ofstream& out, } void t_java_generator::generate_reflection_getters(ostringstream& out, t_type* type, string field_name, string cap_name) { - indent(out) << "case " << upcase_string(field_name) << ":" << endl; + indent(out) << "case " << constant_name(field_name) << ":" << endl; indent_up(); if (type->is_base_type() && !type->is_string()) { @@ -1640,7 +1648,7 @@ void t_java_generator::generate_reflection_getters(ostringstream& out, t_type* t } void t_java_generator::generate_reflection_setters(ostringstream& out, t_type* type, string field_name, string cap_name) { - indent(out) << "case " << upcase_string(field_name) << ":" << endl; + indent(out) << "case " << constant_name(field_name) << ":" << endl; indent_up(); indent(out) << "if (value == null) {" << endl; indent(out) << " unset" << get_cap_name(field_name) << "();" << endl; @@ -1674,36 +1682,29 @@ void t_java_generator::generate_generic_field_getters_setters(std::ofstream& out // create the setter - indent(out) << "public void setFieldValue(int fieldID, Object value) {" << endl; - indent_up(); - - indent(out) << "switch (fieldID) {" << endl; - + + indent(out) << "public void setFieldValue(_Fields field, Object value) {" << endl; + indent(out) << " switch (field) {" << endl; out << setter_stream.str(); + indent(out) << " }" << endl; + indent(out) << "}" << endl << endl; - indent(out) << "default:" << endl; - indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; - - indent(out) << "}" << endl; - - indent_down(); + indent(out) << "public void setFieldValue(int fieldID, Object value) {" << endl; + indent(out) << " setFieldValue(_Fields.findByThriftIdOrThrow(fieldID), value);" << endl; indent(out) << "}" << endl << endl; // create the getter - indent(out) << "public Object getFieldValue(int fieldID) {" << endl; + indent(out) << "public Object getFieldValue(_Fields field) {" << endl; indent_up(); - - indent(out) << "switch (fieldID) {" << endl; - + indent(out) << "switch (field) {" << endl; out << getter_stream.str(); - - indent(out) << "default:" << endl; - indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; - indent(out) << "}" << endl; - + indent(out) << "throw new IllegalStateException();" << endl; indent_down(); + indent(out) << "}" << endl << endl; + indent(out) << "public Object getFieldValue(int fieldId) {" << endl; + indent(out) << " return getFieldValue(_Fields.findByThriftIdOrThrow(fieldId));" << endl; indent(out) << "}" << endl << endl; } @@ -1713,26 +1714,27 @@ void t_java_generator::generate_generic_isset_method(std::ofstream& out, t_struc vector::const_iterator f_iter; // create the isSet method - indent(out) << "// Returns true if field corresponding to fieldID is set (has been asigned a value) and false otherwise" << endl; - indent(out) << "public boolean isSet(int fieldID) {" << endl; + indent(out) << "/** Returns true if field corresponding to fieldID is set (has been asigned a value) and false otherwise */" << endl; + indent(out) << "public boolean isSet(_Fields field) {" << endl; indent_up(); - indent(out) << "switch (fieldID) {" << endl; + indent(out) << "switch (field) {" << endl; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { t_field* field = *f_iter; - indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << "case " << constant_name(field->get_name()) << ":" << endl; indent_up(); indent(out) << "return " << generate_isset_check(field) << ";" << endl; indent_down(); } - indent(out) << "default:" << endl; - indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; - indent(out) << "}" << endl; - + indent(out) << "throw new IllegalStateException();" << endl; indent_down(); indent(out) << "}" << endl << endl; + + indent(out) << "public boolean isSet(int fieldID) {" << endl; + indent(out) << " return isSet(_Fields.findByThriftIdOrThrow(fieldID));" << endl; + indent(out) << "}" << endl << endl; } /** @@ -1861,7 +1863,7 @@ void t_java_generator::generate_java_bean_boilerplate(ofstream& out, indent(out) << "}" << endl << endl; // isSet method - indent(out) << "// Returns true if field " << field_name << " is set (has been asigned a value) and false otherwise" << endl; + indent(out) << "/** Returns true if field " << field_name << " is set (has been asigned a value) and false otherwise */" << endl; indent(out) << "public boolean is" << get_cap_name("set") << cap_name << "() {" << endl; indent_up(); if (type_can_be_null(type)) { @@ -1979,14 +1981,14 @@ void t_java_generator::generate_java_meta_data_map(ofstream& out, vector::const_iterator f_iter; // Static Map with fieldID -> FieldMetaData mappings - indent(out) << "public static final Map metaDataMap = Collections.unmodifiableMap(new HashMap() {{" << endl; + indent(out) << "public static final Map<_Fields, FieldMetaData> metaDataMap = Collections.unmodifiableMap(new EnumMap<_Fields, FieldMetaData>(_Fields.class) {{" << endl; // Populate map indent_up(); for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { t_field* field = *f_iter; std::string field_name = field->get_name(); - indent(out) << "put(" << upcase_string(field_name) << ", new FieldMetaData(\"" << field_name << "\", "; + indent(out) << "put(_Fields." << constant_name(field_name) << ", new FieldMetaData(\"" << field_name << "\", "; // Set field requirement type (required, optional, etc.) if (field->get_req() == t_field::T_REQUIRED) { @@ -1997,7 +1999,7 @@ void t_java_generator::generate_java_meta_data_map(ofstream& out, out << "TFieldRequirementType.DEFAULT, "; } - // Create value meta data + // Create value meta data generate_field_value_meta_data(out, field->get_type()); out << "));" << endl; } @@ -2012,30 +2014,6 @@ void t_java_generator::generate_java_meta_data_map(ofstream& out, indent(out) << "}" << endl << endl; } -/** - * Generates a static map from field names to field IDs - * - * @param tstruct The struct definition - */ -void t_java_generator::generate_java_field_name_map(ofstream& out, - t_struct* tstruct) { - const vector& fields = tstruct->get_members(); - vector::const_iterator f_iter; - - // Static Map with fieldName -> fieldID - indent(out) << "public static final Map fieldNameMap = Collections.unmodifiableMap(new HashMap() {{" << endl; - - // Populate map - indent_up(); - for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { - t_field* field = *f_iter; - std::string field_name = field->get_name(); - indent(out) << "put(\"" << field->get_name() << "\", new Integer(" << upcase_string(field->get_name()) << "));" << endl; - } - indent_down(); - indent(out) << "}});" << endl << endl; -} - /** * Returns a string with the java representation of the given thrift type * (e.g. for the type struct it returns "TType.STRUCT") @@ -3514,13 +3492,78 @@ void t_java_generator::generate_field_descs(ofstream& out, t_struct* tstruct) { } void t_java_generator::generate_field_name_constants(ofstream& out, t_struct* tstruct) { - // Members are public for -java, private for -javabean + indent(out) << "/** The set of fields this struct contains, along with convenience methods for finding and manipulating them. */" << endl; + indent(out) << "public enum _Fields implements TFieldIdEnum {" << endl; + + indent_up(); + bool first = true; const vector& members = tstruct->get_members(); vector::const_iterator m_iter; - for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { - indent(out) << "public static final int " << upcase_string((*m_iter)->get_name()) << " = " << (*m_iter)->get_key() << ";" << endl; + if (!first) { + out << "," << endl; + } + first = false; + generate_java_doc(out, *m_iter); + indent(out) << constant_name((*m_iter)->get_name()) << "((short)" << (*m_iter)->get_key() << ", \"" << (*m_iter)->get_name() << "\")"; } + + out << ";" << endl << endl; + + indent(out) << "private static final Map byId = new HashMap();" << endl; + indent(out) << "private static final Map byName = new HashMap();" << endl; + out << endl; + + indent(out) << "static {" << endl; + indent(out) << " for (_Fields field : EnumSet.allOf(_Fields.class)) {" << endl; + indent(out) << " byId.put((int)field._thriftId, field);" << endl; + indent(out) << " byName.put(field.getFieldName(), field);" << endl; + indent(out) << " }" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "/**" << endl; + indent(out) << " * Find the _Fields constant that matches fieldId, or null if its not found." << endl; + indent(out) << " */" << endl; + indent(out) << "public static _Fields findByThriftId(int fieldId) {" << endl; + indent(out) << " return byId.get(fieldId);" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "/**" << endl; + indent(out) << " * Find the _Fields constant that matches fieldId, throwing an exception" << endl; + indent(out) << " * if it is not found." << endl; + indent(out) << " */" << endl; + indent(out) << "public static _Fields findByThriftIdOrThrow(int fieldId) {" << endl; + indent(out) << " _Fields fields = findByThriftId(fieldId);" << endl; + indent(out) << " if (fields == null) throw new IllegalArgumentException(\"Field \" + fieldId + \" doesn't exist!\");" << endl; + indent(out) << " return fields;" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "/**" << endl; + indent(out) << " * Find the _Fields constant that matches name, or null if its not found." << endl; + indent(out) << " */" << endl; + indent(out) << "public static _Fields findByName(String name) {" << endl; + indent(out) << " return byName.get(name);" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "private final short _thriftId;" << endl; + indent(out) << "private final String _fieldName;" << endl << endl; + + indent(out) << "_Fields(short thriftId, String fieldName) {" << endl; + indent(out) << " _thriftId = thriftId;" << endl; + indent(out) << " _fieldName = fieldName;" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "public short getThriftFieldId() {" << endl; + indent(out) << " return _thriftId;" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "public String getFieldName() {" << endl; + indent(out) << " return _fieldName;" << endl; + indent(out) << "}" << endl; + + indent_down(); + + indent(out) << "}" << endl; } bool t_java_generator::is_comparable(t_struct* tstruct) { diff --git a/lib/java/src/org/apache/thrift/TBase.java b/lib/java/src/org/apache/thrift/TBase.java index 3c3b12fe..bfa0abea 100644 --- a/lib/java/src/org/apache/thrift/TBase.java +++ b/lib/java/src/org/apache/thrift/TBase.java @@ -27,7 +27,7 @@ import org.apache.thrift.protocol.TProtocol; * Generic base interface for generated Thrift objects. * */ -public interface TBase extends Serializable { +public interface TBase extends Serializable { /** * Reads the TObject from the given input protocol. @@ -48,23 +48,49 @@ public interface TBase extends Serializable { * * @param fieldId The field's id tag as found in the IDL. */ + @Deprecated public boolean isSet(int fieldId); + /** + * Check if a field is currently set or unset. + * + * @param field + */ + public boolean isSet(F field); + /** * Get a field's value by id. Primitive types will be wrapped in the * appropriate "boxed" types. * * @param fieldId The field's id tag as found in the IDL. */ + @Deprecated public Object getFieldValue(int fieldId); + /** + * Get a field's value by field variable. Primitive types will be wrapped in + * the appropriate "boxed" types. + * + * @param field + */ + public Object getFieldValue(F field); + /** * Set a field's value by id. Primitive types must be "boxed" in the * appropriate object wrapper type. * * @param fieldId The field's id tag as found in the IDL. */ + @Deprecated public void setFieldValue(int fieldId, Object value); - public TBase deepCopy(); + /** + * Set a field's value by field variable. Primitive types must be "boxed" in + * the appropriate object wrapper type. + * + * @param field + */ + public void setFieldValue(F field, Object value); + + public TBase deepCopy(); } diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java index 7b7d51dc..750ea48a 100644 --- a/lib/java/src/org/apache/thrift/TDeserializer.java +++ b/lib/java/src/org/apache/thrift/TDeserializer.java @@ -29,6 +29,7 @@ import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.protocol.TProtocolUtil; import org.apache.thrift.protocol.TType; import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.TFieldIdEnum; /** * Generic utility for easily deserializing objects from a byte array or Java @@ -92,7 +93,7 @@ public class TDeserializer { * @param fieldIdPath The FieldId's that define a path tb * @throws TException */ - public void partialDeserialize(TBase tb, byte[] bytes, int ... fieldIdPath) throws TException { + public void partialDeserialize(TBase tb, byte[] bytes, TFieldIdEnum ... fieldIdPath) throws TException { // if there are no elements in the path, then the user is looking for the // regular deserialize method // TODO: it might be nice not to have to do this check every time to save @@ -116,11 +117,11 @@ public class TDeserializer { // we can stop searching if we either see a stop or we go past the field // id we're looking for (since fields should now be serialized in asc // order). - if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex]) { + if (field.type == TType.STOP || field.id > fieldIdPath[curPathIndex].getThriftFieldId()) { return; } - if (field.id != fieldIdPath[curPathIndex]) { + if (field.id != fieldIdPath[curPathIndex].getThriftFieldId()) { // Not the field we're looking for. Skip field. TProtocolUtil.skip(iprot, field.type); iprot.readFieldEnd(); diff --git a/lib/java/src/org/apache/thrift/TFieldIdEnum.java b/lib/java/src/org/apache/thrift/TFieldIdEnum.java new file mode 100644 index 00000000..6bcc9f23 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TFieldIdEnum.java @@ -0,0 +1,16 @@ +package org.apache.thrift; + +/** + * Interface for all generated struct Fields objects. + */ +public interface TFieldIdEnum { + /** + * Get the Thrift field id for the named field. + */ + public short getThriftFieldId(); + + /** + * Get the field's name, exactly as in the IDL. + */ + public String getFieldName(); +} diff --git a/lib/java/src/org/apache/thrift/TUnion.java b/lib/java/src/org/apache/thrift/TUnion.java index 9375475f..219669f0 100644 --- a/lib/java/src/org/apache/thrift/TUnion.java +++ b/lib/java/src/org/apache/thrift/TUnion.java @@ -12,28 +12,28 @@ import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolException; import org.apache.thrift.protocol.TStruct; -public abstract class TUnion implements TBase { +public abstract class TUnion implements TBase { protected Object value_; - protected int setField_; - + protected F setField_; + protected TUnion() { - setField_ = 0; + setField_ = null; value_ = null; } - protected TUnion(int setField, Object value) { + protected TUnion(F setField, Object value) { setFieldValue(setField, value); } - protected TUnion(TUnion other) { + protected TUnion(TUnion other) { if (!other.getClass().equals(this.getClass())) { throw new ClassCastException(); } setField_ = other.setField_; value_ = deepCopyObject(other.value_); } - + private static Object deepCopyObject(Object o) { if (o instanceof TBase) { return ((TBase)o).deepCopy(); @@ -52,7 +52,7 @@ public abstract class TUnion implements TBase { return o; } } - + private static Map deepCopyMap(Map map) { Map copy = new HashMap(); for (Map.Entry entry : map.entrySet()) { @@ -77,15 +77,15 @@ public abstract class TUnion implements TBase { return copy; } - public int getSetField() { + public F getSetField() { return setField_; } - + public Object getFieldValue() { return value_; } - - public Object getFieldValue(int fieldId) { + + public Object getFieldValue(F fieldId) { if (fieldId != setField_) { throw new IllegalArgumentException("Cannot get the value of field " + fieldId + " because union's set field is " + setField_); } @@ -93,16 +93,24 @@ public abstract class TUnion implements TBase { return getFieldValue(); } + public Object getFieldValue(int fieldId) { + return getFieldValue(enumForId((short)fieldId)); + } + public boolean isSet() { - return setField_ != 0; + return setField_ != null; } - public boolean isSet(int fieldId) { + public boolean isSet(F fieldId) { return setField_ == fieldId; } + public boolean isSet(int fieldId) { + return isSet(enumForId((short)fieldId)); + } + public void read(TProtocol iprot) throws TException { - setField_ = 0; + setField_ = null; value_ = null; iprot.readStructBegin(); @@ -111,7 +119,7 @@ public abstract class TUnion implements TBase { value_ = readValue(iprot, field); if (value_ != null) { - setField_ = field.id; + setField_ = enumForId(field.id); } iprot.readFieldEnd(); @@ -122,19 +130,23 @@ public abstract class TUnion implements TBase { iprot.readStructEnd(); } - public void setFieldValue(int fieldId, Object value) { - checkType((short)fieldId, value); - setField_ = (short)fieldId; + public void setFieldValue(F fieldId, Object value) { + checkType(fieldId, value); + setField_ = fieldId; value_ = value; } + public void setFieldValue(int fieldId, Object value) { + setFieldValue(enumForId((short)fieldId), value); + } + public void write(TProtocol oprot) throws TException { - if (getSetField() == 0 || getFieldValue() == null) { + if (getSetField() == null || getFieldValue() == null) { throw new TProtocolException("Cannot write a TUnion with no set value!"); } oprot.writeStructBegin(getStructDesc()); oprot.writeFieldBegin(getFieldDesc(setField_)); - writeValue(oprot, (short)setField_, value_); + writeValue(oprot, setField_, value_); oprot.writeFieldEnd(); oprot.writeFieldStop(); oprot.writeStructEnd(); @@ -146,7 +158,7 @@ public abstract class TUnion implements TBase { * @param setField * @param value */ - protected abstract void checkType(short setField, Object value) throws ClassCastException; + protected abstract void checkType(F setField, Object value) throws ClassCastException; /** * Implementation should be generated to read the right stuff from the wire @@ -156,11 +168,13 @@ public abstract class TUnion implements TBase { */ protected abstract Object readValue(TProtocol iprot, TField field) throws TException; - protected abstract void writeValue(TProtocol oprot, short setField, Object value) throws TException; + protected abstract void writeValue(TProtocol oprot, F setField, Object value) throws TException; protected abstract TStruct getStructDesc(); - protected abstract TField getFieldDesc(int setField); + protected abstract TField getFieldDesc(F setField); + + protected abstract F enumForId(short id); @Override public String toString() { diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java index 3e90a8b9..b634291b 100644 --- a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java +++ b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java @@ -22,6 +22,7 @@ package org.apache.thrift.meta_data; import java.util.HashMap; import java.util.Map; import org.apache.thrift.TBase; +import org.apache.thrift.TFieldIdEnum; /** * This class is used to store meta data about thrift fields. Every field in a @@ -32,10 +33,10 @@ public class FieldMetaData implements java.io.Serializable { public final String fieldName; public final byte requirementType; public final FieldValueMetaData valueMetaData; - private static Map, Map> structMap; + private static Map, Map> structMap; static { - structMap = new HashMap, Map>(); + structMap = new HashMap, Map>(); } public FieldMetaData(String name, byte req, FieldValueMetaData vMetaData){ @@ -44,7 +45,7 @@ public class FieldMetaData implements java.io.Serializable { this.valueMetaData = vMetaData; } - public static void addStructMetaDataMap(Class sClass, Map map){ + public static void addStructMetaDataMap(Class sClass, Map map){ structMap.put(sClass, map); } @@ -54,7 +55,7 @@ public class FieldMetaData implements java.io.Serializable { * * @param sClass The TBase class for which the metadata map is requested */ - public static Map getStructMetaDataMap(Class sClass){ + public static Map getStructMetaDataMap(Class sClass){ if (!structMap.containsKey(sClass)){ // Load class if it hasn't been loaded try{ sClass.newInstance(); diff --git a/lib/java/test/org/apache/thrift/test/MetaDataTest.java b/lib/java/test/org/apache/thrift/test/MetaDataTest.java index 386ec9b8..8bb9f2ca 100644 --- a/lib/java/test/org/apache/thrift/test/MetaDataTest.java +++ b/lib/java/test/org/apache/thrift/test/MetaDataTest.java @@ -28,43 +28,43 @@ import org.apache.thrift.meta_data.MapMetaData; import org.apache.thrift.meta_data.SetMetaData; import org.apache.thrift.meta_data.StructMetaData; import org.apache.thrift.protocol.TType; +import org.apache.thrift.TFieldIdEnum; import thrift.test.*; public class MetaDataTest { - public static void main(String[] args) throws Exception { - Map mdMap = CrazyNesting.metaDataMap; - + Map mdMap = CrazyNesting.metaDataMap; + // Check for struct fields existence if (mdMap.size() != 3) throw new RuntimeException("metadata map contains wrong number of entries!"); - if (!mdMap.containsKey(CrazyNesting.SET_FIELD) || !mdMap.containsKey(CrazyNesting.LIST_FIELD) || !mdMap.containsKey(CrazyNesting.STRING_FIELD)) + if (!mdMap.containsKey(CrazyNesting._Fields.SET_FIELD) || !mdMap.containsKey(CrazyNesting._Fields.LIST_FIELD) || !mdMap.containsKey(CrazyNesting._Fields.STRING_FIELD)) throw new RuntimeException("metadata map doesn't contain entry for a struct field!"); - + // Check for struct fields contents - if (!mdMap.get(CrazyNesting.STRING_FIELD).fieldName.equals("string_field") || - !mdMap.get(CrazyNesting.LIST_FIELD).fieldName.equals("list_field") || - !mdMap.get(CrazyNesting.SET_FIELD).fieldName.equals("set_field")) + if (!mdMap.get(CrazyNesting._Fields.STRING_FIELD).fieldName.equals("string_field") || + !mdMap.get(CrazyNesting._Fields.LIST_FIELD).fieldName.equals("list_field") || + !mdMap.get(CrazyNesting._Fields.SET_FIELD).fieldName.equals("set_field")) throw new RuntimeException("metadata map contains a wrong fieldname"); - if (mdMap.get(CrazyNesting.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT || - mdMap.get(CrazyNesting.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED || - mdMap.get(CrazyNesting.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL) + if (mdMap.get(CrazyNesting._Fields.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT || + mdMap.get(CrazyNesting._Fields.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED || + mdMap.get(CrazyNesting._Fields.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL) throw new RuntimeException("metadata map contains the wrong requirement type for a field"); - if (mdMap.get(CrazyNesting.STRING_FIELD).valueMetaData.type != TType.STRING || - mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.type != TType.LIST || - mdMap.get(CrazyNesting.SET_FIELD).valueMetaData.type != TType.SET) + if (mdMap.get(CrazyNesting._Fields.STRING_FIELD).valueMetaData.type != TType.STRING || + mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.type != TType.LIST || + mdMap.get(CrazyNesting._Fields.SET_FIELD).valueMetaData.type != TType.SET) throw new RuntimeException("metadata map contains the wrong requirement type for a field"); - + // Check nested structures - if (!mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isContainer()) + if (!mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.isContainer()) throw new RuntimeException("value metadata for a list is stored as non-container!"); - if (mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isStruct()) + if (mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData.isStruct()) throw new RuntimeException("value metadata for a list is stored as a struct!"); - if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT) + if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT) throw new RuntimeException("metadata map contains wrong type for a value in a deeply nested structure"); - if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class) + if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting._Fields.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class) throw new RuntimeException("metadata map contains wrong class for a struct in a deeply nested structure"); - + // Check that FieldMetaData contains a map with metadata for all generated struct classes if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) == null || FieldMetaData.getStructMetaDataMap(Insanity.class) == null || @@ -74,12 +74,8 @@ public class MetaDataTest { FieldMetaData.getStructMetaDataMap(Insanity.class) != Insanity.metaDataMap) throw new RuntimeException("global metadata map contains wrong entry for a loaded struct"); - Map fnMap = CrazyNesting.fieldNameMap; - if (fnMap.size() != 3) { - throw new RuntimeException("Field Name Map contains wrong number of entries!"); - } - for (Map.Entry mdEntry : mdMap.entrySet()) { - if (!fnMap.get(mdEntry.getValue().fieldName).equals(mdEntry.getKey())) { + for (Map.Entry mdEntry : mdMap.entrySet()) { + if (!CrazyNesting._Fields.findByName(mdEntry.getValue().fieldName).equals(mdEntry.getKey())) { throw new RuntimeException("Field name map contained invalid Name <-> ID mapping"); } } diff --git a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java index d88a686d..a7fa59b0 100644 --- a/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java +++ b/lib/java/test/org/apache/thrift/test/PartialDeserializeTest.java @@ -24,6 +24,7 @@ import org.apache.thrift.TBase; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.TSerializer; +import org.apache.thrift.TFieldIdEnum; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TJSONProtocol; @@ -48,7 +49,7 @@ public class PartialDeserializeTest { // 1:Union // 1.3:OneOfEach OneOfEach Level3OneOfEach = Fixtures.oneOfEach; - TestUnion Level2TestUnion = new TestUnion(TestUnion.STRUCT_FIELD, Level3OneOfEach); + TestUnion Level2TestUnion = new TestUnion(TestUnion._Fields.STRUCT_FIELD, Level3OneOfEach); StructWithAUnion Level1SWU = new StructWithAUnion(Level2TestUnion); Backwards bw = new Backwards(2, 1); @@ -59,20 +60,20 @@ public class PartialDeserializeTest { testPartialDeserialize(factory, Level1SWU, new StructWithAUnion(), Level1SWU); //Level 2 test - testPartialDeserialize(factory, Level1SWU, new TestUnion(), Level2TestUnion, StructWithAUnion.TEST_UNION); + testPartialDeserialize(factory, Level1SWU, new TestUnion(), Level2TestUnion, StructWithAUnion._Fields.TEST_UNION); //Level 3 on 3rd field test - testPartialDeserialize(factory, Level1SWU, new OneOfEach(), Level3OneOfEach, StructWithAUnion.TEST_UNION, TestUnion.STRUCT_FIELD); + testPartialDeserialize(factory, Level1SWU, new OneOfEach(), Level3OneOfEach, StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.STRUCT_FIELD); //Test early termination when traversed path Field.id exceeds the one being searched for - testPartialDeserialize(factory, Level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion.TEST_UNION, TestUnion.I32_FIELD); - + testPartialDeserialize(factory, Level1SWU, new OneOfEach(), new OneOfEach(), StructWithAUnion._Fields.TEST_UNION, TestUnion._Fields.I32_FIELD); + //Test that readStructBegin isn't called on primitive - testPartialDeserialize(factory, pts, new Backwards(), bw, PrimitiveThenStruct.BW); + testPartialDeserialize(factory, pts, new Backwards(), bw, PrimitiveThenStruct._Fields.BW); } } - public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, int ... fieldIdPath) throws TException { + public static void testPartialDeserialize(TProtocolFactory protocolFactory, TBase input, TBase output, TBase expected, TFieldIdEnum ... fieldIdPath) throws TException { byte[] record = new TSerializer(protocolFactory).serialize(input); new TDeserializer(protocolFactory).partialDeserialize(output, record, fieldIdPath); if(!output.equals(expected)) diff --git a/lib/java/test/org/apache/thrift/test/ReadStruct.java b/lib/java/test/org/apache/thrift/test/ReadStruct.java index 2dc042c5..ef36f4d0 100644 --- a/lib/java/test/org/apache/thrift/test/ReadStruct.java +++ b/lib/java/test/org/apache/thrift/test/ReadStruct.java @@ -35,28 +35,27 @@ public class ReadStruct { System.out.println("usage: java -cp build/classes org.apache.thrift.test.ReadStruct filename proto_factory_class"); System.out.println("Read in an instance of CompactProtocolTestStruct from 'file', making sure that it is equivalent to Fixtures.compactProtoTestStruct. Use a protocol from 'proto_factory_class'."); } - + TTransport trans = new TIOStreamTransport(new BufferedInputStream(new FileInputStream(args[0]))); - + TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance(); - + TProtocol proto = factory.getProtocol(trans); - + CompactProtoTestStruct cpts = new CompactProtoTestStruct(); - - for (Integer fid : CompactProtoTestStruct.metaDataMap.keySet()) { + + for (CompactProtoTestStruct._Fields fid : CompactProtoTestStruct.metaDataMap.keySet()) { cpts.setFieldValue(fid, null); } - + cpts.read(proto); - + if (cpts.equals(Fixtures.compactProtoTestStruct)) { System.out.println("Object verified successfully!"); } else { System.out.println("Object failed verification!"); System.out.println("Expected: " + Fixtures.compactProtoTestStruct + " but got " + cpts); } - } } diff --git a/lib/java/test/org/apache/thrift/test/UnionTest.java b/lib/java/test/org/apache/thrift/test/UnionTest.java index 04716c64..cb69063a 100644 --- a/lib/java/test/org/apache/thrift/test/UnionTest.java +++ b/lib/java/test/org/apache/thrift/test/UnionTest.java @@ -32,18 +32,18 @@ public class UnionTest { throw new RuntimeException("unset union didn't return null for value"); } - union = new TestUnion(TestUnion.I32_FIELD, 25); + union = new TestUnion(TestUnion._Fields.I32_FIELD, 25); if ((Integer)union.getFieldValue() != 25) { throw new RuntimeException("set i32 field didn't come out as planned"); } - if ((Integer)union.getFieldValue(TestUnion.I32_FIELD) != 25) { + if ((Integer)union.getFieldValue(TestUnion._Fields.I32_FIELD) != 25) { throw new RuntimeException("set i32 field didn't come out of TBase getFieldValue"); } try { - union.getFieldValue(TestUnion.STRING_FIELD); + union.getFieldValue(TestUnion._Fields.STRING_FIELD); throw new RuntimeException("was expecting an exception around wrong set field"); } catch (IllegalArgumentException e) { // cool! @@ -73,21 +73,21 @@ public class UnionTest { public static void testEquality() throws Exception { - TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25); + TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25); - TestUnion otherUnion = new TestUnion(TestUnion.STRING_FIELD, "blah!!!"); + TestUnion otherUnion = new TestUnion(TestUnion._Fields.STRING_FIELD, "blah!!!"); if (union.equals(otherUnion)) { throw new RuntimeException("shouldn't be equal"); } - otherUnion = new TestUnion(TestUnion.I32_FIELD, 400); + otherUnion = new TestUnion(TestUnion._Fields.I32_FIELD, 400); if (union.equals(otherUnion)) { throw new RuntimeException("shouldn't be equal"); } - otherUnion = new TestUnion(TestUnion.OTHER_I32_FIELD, 25); + otherUnion = new TestUnion(TestUnion._Fields.OTHER_I32_FIELD, 25); if (union.equals(otherUnion)) { throw new RuntimeException("shouldn't be equal"); @@ -96,7 +96,7 @@ public class UnionTest { public static void testSerialization() throws Exception { - TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25); + TestUnion union = new TestUnion(TestUnion._Fields.I32_FIELD, 25); TMemoryBuffer buf = new TMemoryBuffer(0); TProtocol proto = new TBinaryProtocol(buf);