From ab3666e6caad79315fddf0f8f38c13c7a10cc23a Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Tue, 1 Sep 2009 23:03:47 +0000 Subject: [PATCH] THRIFT-409. java: Add "union" to Thrift This patch introduces new IDL syntax for creating Unions, explicityly single-valued structs. While the parser changes are portable, this patch only includes the actual generated code changes for the Java library. Other libraries can continue to generate a struct with the same fields and remain compatible until they are able to implement the full shebang. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@810300 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_generator.cc | 2 +- compiler/cpp/src/generate/t_java_generator.cc | 456 ++++++++++++++++-- compiler/cpp/src/parse/t_struct.h | 9 + compiler/cpp/src/thriftl.ll | 2 +- compiler/cpp/src/thrifty.yy | 21 +- lib/java/build.xml | 2 + lib/java/src/org/apache/thrift/TBase.java | 6 +- lib/java/src/org/apache/thrift/TUnion.java | 192 ++++++++ .../org/apache/thrift/test/UnionTest.java | 131 +++++ test/DebugProtoTest.thrift | 15 + 10 files changed, 796 insertions(+), 40 deletions(-) create mode 100644 lib/java/src/org/apache/thrift/TUnion.java create mode 100644 lib/java/test/org/apache/thrift/test/UnionTest.java diff --git a/compiler/cpp/src/generate/t_generator.cc b/compiler/cpp/src/generate/t_generator.cc index 38c053c5..3a2c002c 100644 --- a/compiler/cpp/src/generate/t_generator.cc +++ b/compiler/cpp/src/generate/t_generator.cc @@ -49,7 +49,7 @@ void t_generator::generate_program() { vector consts = program_->get_consts(); generate_consts(consts); - // Generate structs and exceptions in declared order + // Generate structs, exceptions, and unions in declared order vector objects = program_->get_objects(); vector::iterator o_iter; for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) { diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index b7641b4b..427e6546 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -74,6 +74,7 @@ class t_java_generator : public t_oop_generator { void generate_typedef (t_typedef* ttypedef); void generate_enum (t_enum* tenum); void generate_struct (t_struct* tstruct); + void generate_union (t_struct* tunion); void generate_xception(t_struct* txception); void generate_service (t_service* tservice); @@ -116,6 +117,20 @@ class t_java_generator : public t_oop_generator { void generate_service_server (t_service* tservice); void generate_process_function (t_service* tservice, t_function* tfunction); + void generate_java_union(t_struct* tstruct); + void generate_union_constructor(ofstream& out, t_struct* tstruct); + void generate_union_getters_and_setters(ofstream& out, t_struct* tstruct); + void generate_union_abstract_methods(ofstream& out, t_struct* tstruct); + void generate_check_type(ofstream& out, t_struct* tstruct); + void generate_read_value(ofstream& out, t_struct* tstruct); + void generate_write_value(ofstream& out, t_struct* tstruct); + void generate_get_field_desc(ofstream& out, t_struct* tstruct); + void generate_get_struct_desc(ofstream& out, t_struct* tstruct); + void generate_get_field_name(ofstream& out, t_struct* tstruct); + + void generate_union_comparisons(ofstream& out, t_struct* tstruct); + void generate_union_hashcode(ofstream& out, t_struct* tstruct); + /** * Serialization constructs */ @@ -196,14 +211,17 @@ class t_java_generator : public t_oop_generator { std::string java_package(); std::string java_type_imports(); std::string java_thrift_imports(); - std::string type_name(t_type* ttype, bool in_container=false, bool in_init=false); + std::string type_name(t_type* ttype, bool in_container=false, bool in_init=false, bool skip_generic=false); std::string base_type_name(t_base_type* tbase, bool in_container=false); std::string declare_field(t_field* tfield, bool init=false); std::string function_signature(t_function* tfunction, std::string prefix=""); std::string argument_list(t_struct* tstruct); std::string type_to_enum(t_type* ttype); std::string get_enum_class_name(t_type* type); - + void generate_struct_desc(ofstream& out, t_struct* tstruct); + void generate_field_descs(ofstream& out, t_struct* tstruct); + void generate_field_name_constants(ofstream& out, t_struct* tstruct); + bool type_can_be_null(t_type* ttype) { ttype = get_true_type(ttype); @@ -294,6 +312,7 @@ string t_java_generator::java_type_imports() { "import java.util.HashSet;\n" + "import java.util.Collections;\n" + "import java.util.BitSet;\n" + + "import java.util.Arrays;\n" + "import org.slf4j.Logger;\n" + "import org.slf4j.LoggerFactory;\n\n"; } @@ -581,13 +600,17 @@ string t_java_generator::render_const_value(ofstream& out, string name, t_type* } /** - * Generates a struct definition for a thrift data type. This is a class - * with data members, read(), write(), and an inner Isset class. + * Generates a struct definition for a thrift data type. This will be a TBase + * implementor. * * @param tstruct The struct definition */ void t_java_generator::generate_struct(t_struct* tstruct) { - generate_java_struct(tstruct, false); + if (tstruct->is_union()) { + generate_java_union(tstruct); + } else { + generate_java_struct(tstruct, false); + } } /** @@ -624,6 +647,345 @@ void t_java_generator::generate_java_struct(t_struct* tstruct, f_struct.close(); } +/** + * Java union definition. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_union(t_struct* tstruct) { + // Make output file + string f_struct_name = package_dir_+"/"+(tstruct->get_name())+".java"; + ofstream f_struct; + f_struct.open(f_struct_name.c_str()); + + f_struct << + autogen_comment() << + java_package() << + java_type_imports() << + java_thrift_imports(); + + generate_java_doc(f_struct, tstruct); + + bool is_final = (tstruct->annotations_.find("final") != tstruct->annotations_.end()); + + indent(f_struct) << + "public " << (is_final ? "final " : "") << "class " << tstruct->get_name() + << " extends TUnion "; + + if (is_comparable(tstruct)) { + f_struct << "implements Comparable<" << type_name(tstruct) << "> "; + } + + scope_up(f_struct); + + generate_struct_desc(f_struct, tstruct); + generate_field_descs(f_struct, tstruct); + + f_struct << endl; + + generate_field_name_constants(f_struct, tstruct); + + f_struct << endl; + + generate_java_meta_data_map(f_struct, tstruct); + + generate_union_constructor(f_struct, tstruct); + + f_struct << endl; + + generate_union_abstract_methods(f_struct, tstruct); + + f_struct << endl; + + generate_union_getters_and_setters(f_struct, tstruct); + + f_struct << endl; + + generate_union_comparisons(f_struct, tstruct); + + f_struct << endl; + + generate_union_hashcode(f_struct, tstruct); + + f_struct << endl; + + scope_down(f_struct); + + f_struct.close(); +} + +void t_java_generator::generate_union_constructor(ofstream& out, t_struct* tstruct) { + indent(out) << "public " << type_name(tstruct) << "() {" << endl; + indent(out) << " super();" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "public " << type_name(tstruct) << "(int setField, Object value) {" << endl; + indent(out) << " super(setField, value);" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "public " << type_name(tstruct) << "(" << type_name(tstruct) << " other) {" << endl; + indent(out) << " super(other);" << endl; + indent(out) << "}" << endl; + + indent(out) << "public " << tstruct->get_name() << " deepCopy() {" << endl; + indent(out) << " return new " << tstruct->get_name() << "(this);" << endl; + indent(out) << "}" << endl << endl; + + // generate "constructors" for each field + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << "public static " << type_name(tstruct) << " " << (*m_iter)->get_name() << "(" << type_name((*m_iter)->get_type()) << " value) {" << endl; + indent(out) << " " << type_name(tstruct) << " x = new " << type_name(tstruct) << "();" << endl; + indent(out) << " x.set" << get_cap_name((*m_iter)->get_name()) << "(value);" << endl; + indent(out) << " return x;" << endl; + indent(out) << "}" << endl << endl; + } +} + +void t_java_generator::generate_union_getters_and_setters(ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + bool first = true; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (first) { + first = false; + } else { + out << endl; + } + + t_field* field = (*m_iter); + + generate_java_doc(out, field); + indent(out) << "public " << type_name(field->get_type()) << " get" << get_cap_name(field->get_name()) << "() {" << endl; + indent(out) << " if (getSetField() == " << upcase_string(field->get_name()) << ") {" << endl; + indent(out) << " return (" << type_name(field->get_type(), true) << ")getFieldValue();" << endl; + indent(out) << " } else {" << endl; + indent(out) << " throw new RuntimeException(\"Cannot get field '" << field->get_name() + << "' because union is currently set to \" + getFieldDesc(getSetField()).name);" << endl; + indent(out) << " }" << endl; + indent(out) << "}" << endl; + + out << endl; + + generate_java_doc(out, field); + indent(out) << "public void set" << get_cap_name(field->get_name()) << "(" << type_name(field->get_type()) << " value) {" << endl; + if (type_can_be_null(field->get_type())) { + indent(out) << " if (value == null) throw new NullPointerException();" << endl; + } + indent(out) << " setField_ = " << upcase_string(field->get_name()) << ";" << endl; + indent(out) << " value_ = value;" << endl; + indent(out) << "}" << endl; + } +} + +void t_java_generator::generate_union_abstract_methods(ofstream& out, t_struct* tstruct) { + generate_check_type(out, tstruct); + out << endl; + generate_read_value(out, tstruct); + out << endl; + generate_write_value(out, tstruct); + out << endl; + generate_get_field_desc(out, tstruct); + out << endl; + generate_get_struct_desc(out, tstruct); +} + +void t_java_generator::generate_check_type(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected void checkType(short setField, Object value) throws ClassCastException {" << endl; + indent_up(); + + indent(out) << "switch (setField) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + + indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << " if (value instanceof " << type_name(field->get_type(), true, false, true) << ") {" << endl; + indent(out) << " break;" << endl; + indent(out) << " }" << endl; + indent(out) << " throw new ClassCastException(\"Was expecting value of type " + << type_name(field->get_type(), true, false) << " for field '" << field->get_name() + << "', but got \" + value.getClass().getSimpleName());" << endl; + // do the real check here + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalArgumentException(\"Unknown field id \" + setField);" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "}" << endl; +} + +void t_java_generator::generate_read_value(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected Object readValue(TProtocol iprot, TField field) throws TException {" << endl; + + indent_up(); + + indent(out) << "switch (field.id) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + + indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent_up(); + indent(out) << "if (field.type == " << upcase_string(field->get_name()) << "_FIELD_DESC.type) {" << endl; + indent_up(); + indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() << ";" << endl; + generate_deserialize_field(out, field, ""); + indent(out) << "return " << field->get_name() << ";" << endl; + indent_down(); + indent(out) << "} else {" << endl; + indent(out) << " TProtocolUtil.skip(iprot, field.type);" << endl; + indent(out) << " return null;" << endl; + indent(out) << "}" << endl; + indent_down(); + } + + indent(out) << "default:" << endl; + indent(out) << " TProtocolUtil.skip(iprot, field.type);" << endl; + indent(out) << " return null;" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "}" << endl; +} + +void t_java_generator::generate_write_value(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected void writeValue(TProtocol oprot, short setField, Object value) throws TException {" << endl; + + indent_up(); + + indent(out) << "switch (setField) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + + indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent_up(); + indent(out) << type_name(field->get_type(), true, false) << " " << field->get_name() + << " = (" << type_name(field->get_type(), true, false) << ")getFieldValue();" << endl; + generate_serialize_field(out, field, ""); + indent(out) << "return;" << endl; + indent_down(); + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalStateException(\"Cannot write union with unknown field \" + setField);" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + + + + indent(out) << "}" << endl; +} + +void t_java_generator::generate_get_field_desc(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected TField getFieldDesc(int setField) {" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + indent(out) << "switch (setField) {" << endl; + indent_up(); + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent(out) << " return " << upcase_string(field->get_name()) << "_FIELD_DESC;" << endl; + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalArgumentException(\"Unknown field id \" + setField);" << endl; + + indent_down(); + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "}" << endl; +} + +void t_java_generator::generate_get_struct_desc(ofstream& out, t_struct* tstruct) { + indent(out) << "@Override" << endl; + indent(out) << "protected TStruct getStructDesc() {" << endl; + indent(out) << " return STRUCT_DESC;" << endl; + indent(out) << "}" << endl; +} + +void t_java_generator::generate_union_comparisons(ofstream& out, t_struct* tstruct) { + // equality + indent(out) << "public boolean equals(Object other) {" << endl; + indent(out) << " if (other instanceof " << tstruct->get_name() << ") {" << endl; + indent(out) << " return equals((" << tstruct->get_name() << ")other);" << endl; + indent(out) << " } else {" << endl; + indent(out) << " return false;" << endl; + indent(out) << " }" << endl; + indent(out) << "}" << endl; + + out << endl; + + indent(out) << "public boolean equals(" << tstruct->get_name() << " other) {" << endl; + indent(out) << " return getSetField() == other.getSetField() && ((value_ instanceof byte[]) ? " << endl; + indent(out) << " Arrays.equals((byte[])getFieldValue(), (byte[])other.getFieldValue()) : getFieldValue().equals(other.getFieldValue()));" << endl; + indent(out) << "}" << endl; + out << endl; + + if (is_comparable(tstruct)) { + indent(out) << "@Override" << endl; + indent(out) << "public int compareTo(" << type_name(tstruct) << " other) {" << endl; + indent(out) << " int lastComparison = TBaseHelper.compareTo(getSetField(), other.getSetField());" << endl; + indent(out) << " if (lastComparison != 0) {" << endl; + indent(out) << " return TBaseHelper.compareTo((Comparable)getFieldValue(), (Comparable)other.getFieldValue());" << endl; + indent(out) << " }" << endl; + indent(out) << " return lastComparison;" << endl; + indent(out) << "}" << endl; + out << endl; + } +} + +void t_java_generator::generate_union_hashcode(ofstream& out, t_struct* tstruct) { + if (gen_hash_code_) { + indent(out) << "@Override" << endl; + indent(out) << "public int hashCode() {" << endl; + indent(out) << " return new HashCodeBuilder().append(getSetField()).append(getFieldValue()).toHashCode();" << endl; + indent(out) << "}"; + } else { + indent(out) << "/**" << endl; + indent(out) << " * If you'd like this to perform more respectably, use the hashcode generator option." << endl; + indent(out) << " */" << endl; + indent(out) << "@Override" << endl; + indent(out) << "public int hashCode() {" << endl; + indent(out) << " return 0;" << endl; + indent(out) << "}" << endl; + } +} + /** * Java struct definition. This has various parameters, as it could be * generated standalone or inside another class as a helper. If it @@ -660,21 +1022,14 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, scope_up(out); - indent(out) << - "private static final TStruct STRUCT_DESC = new TStruct(\"" << tstruct->get_name() << "\");" << endl; + generate_struct_desc(out, tstruct); // Members are public for -java, private for -javabean const vector& members = tstruct->get_members(); vector::const_iterator m_iter; - for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { - indent(out) << - "private static final TField " << constant_name((*m_iter)->get_name()) << - "_FIELD_DESC = new TField(\"" << (*m_iter)->get_name() << "\", " << - type_to_enum((*m_iter)->get_type()) << ", " << - "(short)" << (*m_iter)->get_key() << ");" << endl; - } - + generate_field_descs(out, tstruct); + out << endl; for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { @@ -685,9 +1040,9 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, indent(out) << "public "; } out << declare_field(*m_iter, false) << endl; - - indent(out) << "public static final int " << upcase_string((*m_iter)->get_name()) << " = " << (*m_iter)->get_key() << ";" << endl; } + + generate_field_name_constants(out, tstruct); // isset data if (members.size() > 0) { @@ -712,13 +1067,6 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, } generate_java_meta_data_map(out, tstruct); - - // Static initializer to populate global class to struct metadata map - indent(out) << "static {" << endl; - indent_up(); - indent(out) << "FieldMetaData.addStructMetaDataMap(" << type_name(tstruct) << ".class, metaDataMap);" << endl; - indent_down(); - indent(out) << "}" << endl << endl; // Default constructor indent(out) << @@ -803,7 +1151,11 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, indent(out) << "}" << endl << endl; // clone method, so that you can deep copy an object when you don't know its class. - indent(out) << "@Override" << endl; + indent(out) << "public " << tstruct->get_name() << " deepCopy() {" << endl; + indent(out) << " return new " << tstruct->get_name() << "(this);" << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "@Deprecated" << endl; indent(out) << "public " << tstruct->get_name() << " clone() {" << endl; indent(out) << " return new " << tstruct->get_name() << "(this);" << endl; indent(out) << "}" << endl << endl; @@ -1562,6 +1914,13 @@ void t_java_generator::generate_java_meta_data_map(ofstream& out, } indent_down(); indent(out) << "}});" << endl << endl; + + // Static initializer to populate global class to struct metadata map + indent(out) << "static {" << endl; + indent_up(); + indent(out) << "FieldMetaData.addStructMetaDataMap(" << type_name(tstruct) << ".class, metaDataMap);" << endl; + indent_down(); + indent(out) << "}" << endl << endl; } /** @@ -2588,7 +2947,7 @@ void t_java_generator::generate_serialize_list_element(ofstream& out, * @param container Is the type going inside a container? * @return Java type name, i.e. HashMap */ -string t_java_generator::type_name(t_type* ttype, bool in_container, bool in_init) { +string t_java_generator::type_name(t_type* ttype, bool in_container, bool in_init, bool skip_generic) { // In Java typedefs are just resolved to their real type ttype = get_true_type(ttype); string prefix; @@ -2604,25 +2963,25 @@ string t_java_generator::type_name(t_type* ttype, bool in_container, bool in_ini } else { prefix = "Map"; } - return prefix + "<" + + return prefix + (skip_generic ? "" : "<" + type_name(tmap->get_key_type(), true) + "," + - type_name(tmap->get_val_type(), true) + ">"; + type_name(tmap->get_val_type(), true) + ">"); } else if (ttype->is_set()) { t_set* tset = (t_set*) ttype; if (in_init) { - prefix = "HashSet<"; + prefix = "HashSet"; } else { - prefix = "Set<"; + prefix = "Set"; } - return prefix + type_name(tset->get_elem_type(), true) + ">"; + return prefix + (skip_generic ? "" : "<" + type_name(tset->get_elem_type(), true) + ">"); } else if (ttype->is_list()) { t_list* tlist = (t_list*) ttype; if (in_init) { - prefix = "ArrayList<"; + prefix = "ArrayList"; } else { - prefix = "List<"; + prefix = "List"; } - return prefix + type_name(tlist->get_elem_type(), true) + ">"; + return prefix + (skip_generic ? "" : "<" + type_name(tlist->get_elem_type(), true) + ">"); } // Check for namespacing @@ -3023,6 +3382,34 @@ std::string t_java_generator::get_enum_class_name(t_type* type) { return package + type->get_name(); } +void t_java_generator::generate_struct_desc(ofstream& out, t_struct* tstruct) { + indent(out) << + "private static final TStruct STRUCT_DESC = new TStruct(\"" << tstruct->get_name() << "\");" << endl; +} + +void t_java_generator::generate_field_descs(ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + "private static final TField " << constant_name((*m_iter)->get_name()) << + "_FIELD_DESC = new TField(\"" << (*m_iter)->get_name() << "\", " << + type_to_enum((*m_iter)->get_type()) << ", " << + "(short)" << (*m_iter)->get_key() << ");" << endl; + } +} + +void t_java_generator::generate_field_name_constants(ofstream& out, t_struct* tstruct) { + // Members are public for -java, private for -javabean + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << "public static final int " << upcase_string((*m_iter)->get_name()) << " = " << (*m_iter)->get_key() << ";" << endl; + } +} + bool t_java_generator::is_comparable(t_struct* tstruct) { const vector& members = tstruct->get_members(); vector::const_iterator m_iter; @@ -3066,3 +3453,4 @@ THRIFT_REGISTER_GENERATOR(java, "Java", " nocamel: Do not use CamelCase field accessors with beans.\n" " hashcode: Generate quality hashCode methods.\n" ); + diff --git a/compiler/cpp/src/parse/t_struct.h b/compiler/cpp/src/parse/t_struct.h index 7980f803..76c24f23 100644 --- a/compiler/cpp/src/parse/t_struct.h +++ b/compiler/cpp/src/parse/t_struct.h @@ -58,6 +58,10 @@ class t_struct : public t_type { is_xception_ = is_xception; } + void set_union(bool is_union) { + is_union_ = is_union; + } + void set_xsd_all(bool xsd_all) { xsd_all_ = xsd_all; } @@ -95,6 +99,10 @@ class t_struct : public t_type { bool is_xception() const { return is_xception_; } + + bool is_union() const { + return is_union_; + } virtual std::string get_fingerprint_material() const { std::string rv = "{"; @@ -120,6 +128,7 @@ class t_struct : public t_type { members_type members_; members_type members_in_id_order_; bool is_xception_; + bool is_union_; bool xsd_all_; }; diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll index c563f75d..2a824d13 100644 --- a/compiler/cpp/src/thriftl.ll +++ b/compiler/cpp/src/thriftl.ll @@ -119,6 +119,7 @@ literal_begin (['\"]) "oneway" { return tok_oneway; } "typedef" { return tok_typedef; } "struct" { return tok_struct; } +"union" { return tok_union; } "exception" { return tok_xception; } "extends" { return tok_extends; } "throws" { return tok_throws; } @@ -197,7 +198,6 @@ literal_begin (['\"]) "volatile" { thrift_reserved_keyword(yytext); } "while" { thrift_reserved_keyword(yytext); } "with" { thrift_reserved_keyword(yytext); } -"union" { thrift_reserved_keyword(yytext); } "yield" { thrift_reserved_keyword(yytext); } {intconstant} { diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy index 026e5c3a..8f6e167c 100644 --- a/compiler/cpp/src/thrifty.yy +++ b/compiler/cpp/src/thrifty.yy @@ -42,6 +42,8 @@ */ int y_field_val = -1; int g_arglist = 0; +const int struct_is_struct = 0; +const int struct_is_union = 1; %} @@ -148,6 +150,7 @@ int g_arglist = 0; %token tok_const %token tok_required %token tok_optional +%token tok_union /** * Grammar nodes @@ -193,6 +196,7 @@ int g_arglist = 0; %type ConstMap %type ConstMapContents +%type StructHead %type Struct %type Xception %type Service @@ -679,11 +683,22 @@ ConstMapContents: $$->set_map(); } +StructHead: + tok_struct + { + $$ = struct_is_struct; + } +| tok_union + { + $$ = struct_is_union; + } + Struct: - tok_struct tok_identifier XsdAll '{' FieldList '}' TypeAnnotations + StructHead tok_identifier XsdAll '{' FieldList '}' TypeAnnotations { pdebug("Struct -> tok_struct tok_identifier { FieldList }"); $5->set_xsd_all($3); + $5->set_union($1 == struct_is_union); $$ = $5; $$->set_name($2); if ($7 != NULL) { @@ -691,7 +706,7 @@ Struct: delete $7; } } - + XsdAll: tok_xsd_all { @@ -1137,4 +1152,4 @@ TypeAnnotation: $$->val = $3; } -%% +%% \ No newline at end of file diff --git a/lib/java/build.xml b/lib/java/build.xml index de9b018a..dbbaf6b5 100644 --- a/lib/java/build.xml +++ b/lib/java/build.xml @@ -180,6 +180,8 @@ classpathref="test.classpath" failonerror="true" /> + diff --git a/lib/java/src/org/apache/thrift/TBase.java b/lib/java/src/org/apache/thrift/TBase.java index 7c8978a2..3c3b12fe 100644 --- a/lib/java/src/org/apache/thrift/TBase.java +++ b/lib/java/src/org/apache/thrift/TBase.java @@ -19,13 +19,15 @@ package org.apache.thrift; +import java.io.Serializable; + import org.apache.thrift.protocol.TProtocol; /** * Generic base interface for generated Thrift objects. * */ -public interface TBase extends Cloneable { +public interface TBase extends Serializable { /** * Reads the TObject from the given input protocol. @@ -63,4 +65,6 @@ public interface TBase extends Cloneable { * @param fieldId The field's id tag as found in the IDL. */ public void setFieldValue(int fieldId, Object value); + + public TBase deepCopy(); } diff --git a/lib/java/src/org/apache/thrift/TUnion.java b/lib/java/src/org/apache/thrift/TUnion.java new file mode 100644 index 00000000..9375475f --- /dev/null +++ b/lib/java/src/org/apache/thrift/TUnion.java @@ -0,0 +1,192 @@ +package org.apache.thrift; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolException; +import org.apache.thrift.protocol.TStruct; + +public abstract class TUnion implements TBase { + + protected Object value_; + protected int setField_; + + protected TUnion() { + setField_ = 0; + value_ = null; + } + + protected TUnion(int setField, Object value) { + setFieldValue(setField, value); + } + + protected TUnion(TUnion other) { + if (!other.getClass().equals(this.getClass())) { + throw new ClassCastException(); + } + setField_ = other.setField_; + value_ = deepCopyObject(other.value_); + } + + private static Object deepCopyObject(Object o) { + if (o instanceof TBase) { + return ((TBase)o).deepCopy(); + } else if (o instanceof byte[]) { + byte[] other_val = (byte[])o; + byte[] this_val = new byte[other_val.length]; + System.arraycopy(other_val, 0, this_val, 0, other_val.length); + return this_val; + } else if (o instanceof List) { + return deepCopyList((List)o); + } else if (o instanceof Set) { + return deepCopySet((Set)o); + } else if (o instanceof Map) { + return deepCopyMap((Map)o); + } else { + return o; + } + } + + private static Map deepCopyMap(Map map) { + Map copy = new HashMap(); + for (Map.Entry entry : map.entrySet()) { + copy.put(deepCopyObject(entry.getKey()), deepCopyObject(entry.getValue())); + } + return copy; + } + + private static Set deepCopySet(Set set) { + Set copy = new HashSet(); + for (Object o : set) { + copy.add(deepCopyObject(o)); + } + return copy; + } + + private static List deepCopyList(List list) { + List copy = new ArrayList(list.size()); + for (Object o : list) { + copy.add(deepCopyObject(o)); + } + return copy; + } + + public int getSetField() { + return setField_; + } + + public Object getFieldValue() { + return value_; + } + + public Object getFieldValue(int fieldId) { + if (fieldId != setField_) { + throw new IllegalArgumentException("Cannot get the value of field " + fieldId + " because union's set field is " + setField_); + } + + return getFieldValue(); + } + + public boolean isSet() { + return setField_ != 0; + } + + public boolean isSet(int fieldId) { + return setField_ == fieldId; + } + + public void read(TProtocol iprot) throws TException { + setField_ = 0; + value_ = null; + + iprot.readStructBegin(); + + TField field = iprot.readFieldBegin(); + + value_ = readValue(iprot, field); + if (value_ != null) { + setField_ = field.id; + } + + iprot.readFieldEnd(); + // this is so that we will eat the stop byte. we could put a check here to + // make sure that it actually *is* the stop byte, but it's faster to do it + // this way. + iprot.readFieldBegin(); + iprot.readStructEnd(); + } + + public void setFieldValue(int fieldId, Object value) { + checkType((short)fieldId, value); + setField_ = (short)fieldId; + value_ = value; + } + + public void write(TProtocol oprot) throws TException { + if (getSetField() == 0 || getFieldValue() == null) { + throw new TProtocolException("Cannot write a TUnion with no set value!"); + } + oprot.writeStructBegin(getStructDesc()); + oprot.writeFieldBegin(getFieldDesc(setField_)); + writeValue(oprot, (short)setField_, value_); + oprot.writeFieldEnd(); + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } + + /** + * Implementation should be generated so that we can efficiently type check + * various values. + * @param setField + * @param value + */ + protected abstract void checkType(short setField, Object value) throws ClassCastException; + + /** + * Implementation should be generated to read the right stuff from the wire + * based on the field header. + * @param field + * @return + */ + protected abstract Object readValue(TProtocol iprot, TField field) throws TException; + + protected abstract void writeValue(TProtocol oprot, short setField, Object value) throws TException; + + protected abstract TStruct getStructDesc(); + + protected abstract TField getFieldDesc(int setField); + + @Override + public String toString() { + Object v = getFieldValue(); + String vStr = null; + if (v instanceof byte[]) { + vStr = bytesToStr((byte[])v); + } else { + vStr = v.toString(); + } + return "<" + this.getClass().getSimpleName() + " " + getFieldDesc(getSetField()).name + ":" + vStr + ">"; + } + + private static String bytesToStr(byte[] bytes) { + StringBuilder sb = new StringBuilder(); + int size = Math.min(bytes.length, 128); + for (int i = 0; i < size; i++) { + if (i != 0) { + sb.append(" "); + } + String digit = Integer.toHexString(bytes[i]); + sb.append(digit.length() > 1 ? digit : "0" + digit); + } + if (bytes.length > 128) { + sb.append(" ..."); + } + return sb.toString(); + } +} diff --git a/lib/java/test/org/apache/thrift/test/UnionTest.java b/lib/java/test/org/apache/thrift/test/UnionTest.java new file mode 100644 index 00000000..85be6993 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/UnionTest.java @@ -0,0 +1,131 @@ +package org.apache.thrift.test; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.transport.TMemoryBuffer; + +import thrift.test.Empty; +import thrift.test.StructWithAUnion; +import thrift.test.TestUnion; + +public class UnionTest { + + /** + * @param args + */ + public static void main(String[] args) throws Exception { + testBasic(); + testEquality(); + testSerialization(); + } + + public static void testBasic() throws Exception { + TestUnion union = new TestUnion(); + + if (union.isSet()) { + throw new RuntimeException("new union with default constructor counts as set!"); + } + + if (union.getFieldValue() != null) { + throw new RuntimeException("unset union didn't return null for value"); + } + + union = new TestUnion(TestUnion.I32_FIELD, 25); + + if ((Integer)union.getFieldValue() != 25) { + throw new RuntimeException("set i32 field didn't come out as planned"); + } + + if ((Integer)union.getFieldValue(TestUnion.I32_FIELD) != 25) { + throw new RuntimeException("set i32 field didn't come out of TBase getFieldValue"); + } + + try { + union.getFieldValue(TestUnion.STRING_FIELD); + throw new RuntimeException("was expecting an exception around wrong set field"); + } catch (IllegalArgumentException e) { + // cool! + } + + System.out.println(union); + + union = new TestUnion(); + union.setI32_field(1); + if (union.getI32_field() != 1) { + throw new RuntimeException("didn't get the right value for i32 field!"); + } + + try { + union.getString_field(); + throw new RuntimeException("should have gotten an exception"); + } catch (Exception e) { + // sweet + } + + } + + + public static void testEquality() throws Exception { + TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25); + + TestUnion otherUnion = new TestUnion(TestUnion.STRING_FIELD, "blah!!!"); + + if (union.equals(otherUnion)) { + throw new RuntimeException("shouldn't be equal"); + } + + otherUnion = new TestUnion(TestUnion.I32_FIELD, 400); + + if (union.equals(otherUnion)) { + throw new RuntimeException("shouldn't be equal"); + } + + otherUnion = new TestUnion(TestUnion.OTHER_I32_FIELD, 25); + + if (union.equals(otherUnion)) { + throw new RuntimeException("shouldn't be equal"); + } + } + + + public static void testSerialization() throws Exception { + TestUnion union = new TestUnion(TestUnion.I32_FIELD, 25); + + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = new TBinaryProtocol(buf); + + union.write(proto); + + TestUnion u2 = new TestUnion(); + + u2.read(proto); + + if (!u2.equals(union)) { + throw new RuntimeException("serialization fails!"); + } + + StructWithAUnion swau = new StructWithAUnion(u2); + + buf = new TMemoryBuffer(0); + proto = new TBinaryProtocol(buf); + + swau.write(proto); + + StructWithAUnion swau2 = new StructWithAUnion(); + if (swau2.equals(swau)) { + throw new RuntimeException("objects match before they are supposed to!"); + } + swau2.read(proto); + if (!swau2.equals(swau)) { + throw new RuntimeException("objects don't match when they are supposed to!"); + } + + // this should NOT throw an exception. + buf = new TMemoryBuffer(0); + proto = new TBinaryProtocol(buf); + + swau.write(proto); + new Empty().read(proto); + + } +} diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift index d3d25802..9ba60a26 100644 --- a/test/DebugProtoTest.thrift +++ b/test/DebugProtoTest.thrift @@ -250,4 +250,19 @@ struct ReverseOrderStruct { service ReverseOrderService { void myMethod(4: string first, 3: i16 second, 2: i32 third, 1: i64 fourth); +} + +union TestUnion { + /** + * A doc string + */ + 1: string string_field; + 2: i32 i32_field; + 3: OneOfEach struct_field; + 4: list struct_list; + 5: i32 other_i32_field; +} + +struct StructWithAUnion { + 1: TestUnion test_union; } \ No newline at end of file -- 2.17.1