From 33e190cd150c326ac833c435f975c2e737cff74f Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Tue, 16 Feb 2010 21:19:01 +0000 Subject: [PATCH] THRIFT-697. Union support in Ruby git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_rb_generator.cc | 115 +++++++++-- lib/rb/ext/struct.c | 185 +++++++++++++----- lib/rb/ext/struct.h | 5 +- lib/rb/ext/thrift_native.c | 6 +- lib/rb/lib/thrift.rb | 2 + .../protocol/binary_protocol_accelerated.rb | 6 +- lib/rb/lib/thrift/struct.rb | 108 +--------- lib/rb/lib/thrift/struct_union.rb | 126 ++++++++++++ lib/rb/lib/thrift/types.rb | 2 +- lib/rb/lib/thrift/union.rb | 128 ++++++++++++ lib/rb/spec/ThriftSpec.thrift | 32 +++ .../spec/binary_protocol_accelerated_spec.rb | 31 +-- lib/rb/spec/binary_protocol_spec_shared.rb | 4 +- lib/rb/spec/union_spec.rb | 145 ++++++++++++++ 14 files changed, 713 insertions(+), 182 deletions(-) create mode 100644 lib/rb/lib/thrift/struct_union.rb create mode 100644 lib/rb/lib/thrift/union.rb create mode 100644 lib/rb/spec/union_spec.rb diff --git a/compiler/cpp/src/generate/t_rb_generator.cc b/compiler/cpp/src/generate/t_rb_generator.cc index 55db0f7e..47fd8ec5 100644 --- a/compiler/cpp/src/generate/t_rb_generator.cc +++ b/compiler/cpp/src/generate/t_rb_generator.cc @@ -68,6 +68,7 @@ class t_rb_generator : public t_oop_generator { void generate_enum (t_enum* tenum); void generate_const (t_const* tconst); void generate_struct (t_struct* tstruct); + void generate_union (t_struct* tunion); void generate_xception (t_struct* txception); void generate_service (t_service* tservice); @@ -79,11 +80,15 @@ class t_rb_generator : public t_oop_generator { void generate_rb_struct(std::ofstream& out, t_struct* tstruct, bool is_exception); void generate_rb_struct_required_validator(std::ofstream& out, t_struct* tstruct); + void generate_rb_union(std::ofstream& out, t_struct* tstruct, bool is_exception); + void generate_rb_union_validator(std::ofstream& out, t_struct* tstruct); void generate_rb_function_helpers(t_function* tfunction); void generate_rb_simple_constructor(std::ofstream& out, t_struct* tstruct); void generate_rb_simple_exception_constructor(std::ofstream& out, t_struct* tstruct); void generate_field_constants (std::ofstream& out, t_struct* tstruct); - void generate_accessors (std::ofstream& out, t_struct* tstruct); + void generate_field_constructors (std::ofstream& out, t_struct* tstruct); + void generate_struct_accessors (std::ofstream& out, t_struct* tstruct); + void generate_union_accessors (std::ofstream& out, t_struct* tstruct); void generate_field_defns (std::ofstream& out, t_struct* tstruct); void generate_field_data (std::ofstream& out, t_type* field_type, const std::string& field_name, t_const_value* field_value, bool optional); @@ -461,7 +466,11 @@ string t_rb_generator::render_const_value(t_type* type, t_const_value* value) { * Generates a ruby struct */ void t_rb_generator::generate_struct(t_struct* tstruct) { - generate_rb_struct(f_types_, tstruct, false); + if (tstruct->is_union()) { + generate_rb_union(f_types_, tstruct, false); + } else { + generate_rb_struct(f_types_, tstruct, false); + } } /** @@ -486,14 +495,14 @@ void t_rb_generator::generate_rb_struct(std::ofstream& out, t_struct* tstruct, b out << endl; indent_up(); - indent(out) << "include ::Thrift::Struct" << endl; + indent(out) << "include ::Thrift::Struct, ::Thrift::Struct_Union" << endl; if (is_exception) { generate_rb_simple_exception_constructor(out, tstruct); } generate_field_constants(out, tstruct); - generate_accessors(out, tstruct); + generate_struct_accessors(out, tstruct); generate_field_defns(out, tstruct); generate_rb_struct_required_validator(out, tstruct); @@ -501,6 +510,53 @@ void t_rb_generator::generate_rb_struct(std::ofstream& out, t_struct* tstruct, b indent(out) << "end" << endl << endl; } + +/** + * Generates a ruby union + */ +void t_rb_generator::generate_rb_union(std::ofstream& out, t_struct* tstruct, bool is_exception = false) { + generate_rdoc(out, tstruct); + indent(out) << "class " << type_name(tstruct) << " < ::Thrift::Union" << endl; + + indent_up(); + indent(out) << "include ::Thrift::Struct_Union" << endl; + + generate_field_constructors(out, tstruct); + + generate_field_constants(out, tstruct); + generate_union_accessors(out, tstruct); + generate_field_defns(out, tstruct); + generate_rb_union_validator(out, tstruct); + + indent_down(); + indent(out) << "end" << endl << endl; +} + +void t_rb_generator::generate_field_constructors(std::ofstream& out, t_struct* tstruct) { + + indent(out) << "class << self" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (f_iter != fields.begin()) { + out << endl; + } + std::string field_name = (*f_iter)->get_name(); + + indent(out) << "def " << field_name << "(val)" << endl; + indent(out) << " " << tstruct->get_name() << ".new(:" << field_name << ", val)" << endl; + indent(out) << "end" << endl; + } + + indent_down(); + indent(out) << "end" << endl; + + out << endl; +} + void t_rb_generator::generate_rb_simple_exception_constructor(std::ofstream& out, t_struct* tstruct) { const vector& members = tstruct->get_members(); @@ -537,7 +593,7 @@ void t_rb_generator::generate_field_constants(std::ofstream& out, t_struct* tstr out << endl; } -void t_rb_generator::generate_accessors(std::ofstream& out, t_struct* tstruct) { +void t_rb_generator::generate_struct_accessors(std::ofstream& out, t_struct* tstruct) { const vector& members = tstruct->get_members(); vector::const_iterator m_iter; @@ -550,6 +606,19 @@ void t_rb_generator::generate_accessors(std::ofstream& out, t_struct* tstruct) { } } +void t_rb_generator::generate_union_accessors(std::ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + if (members.size() > 0) { + indent(out) << "field_accessor self"; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << ", :" << (*m_iter)->get_name(); + } + out << endl; + } +} + void t_rb_generator::generate_field_defns(std::ofstream& out, t_struct* tstruct) { const vector& fields = tstruct->get_members(); vector::const_iterator f_iter; @@ -1080,7 +1149,7 @@ void t_rb_generator::generate_rb_struct_required_validator(std::ofstream& out, t_struct* tstruct) { indent(out) << "def validate" << endl; indent_up(); - + const vector& fields = tstruct->get_members(); vector::const_iterator f_iter; @@ -1096,11 +1165,11 @@ void t_rb_generator::generate_rb_struct_required_validator(std::ofstream& out, out << endl; } } - + // if field is an enum, check that its value is valid for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { t_field* field = (*f_iter); - + if (field->get_type()->is_enum()){ indent(out) << "unless @" << field->get_name() << ".nil? || " << full_type_name(field->get_type()) << "::VALID_VALUES.include?(@" << field->get_name() << ")" << endl; indent_up(); @@ -1108,11 +1177,35 @@ void t_rb_generator::generate_rb_struct_required_validator(std::ofstream& out, indent_down(); indent(out) << "end" << endl; } - } - + } + + indent_down(); + indent(out) << "end" << endl << endl; +} + +void t_rb_generator::generate_rb_union_validator(std::ofstream& out, + t_struct* tstruct) { + indent(out) << "def validate" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + indent(out) << "raise(StandardError, 'Union fields are not set.') if get_set_field.nil? || get_value.nil?" << endl; + + // if field is an enum, check that its value is valid + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + const t_field* field = (*f_iter); + + if (field->get_type()->is_enum()){ + indent(out) << "if get_set_field == :" << field->get_name() << endl; + indent(out) << " raise ::Thrift::ProtocolException.new(::Thrift::ProtocolException::UNKNOWN, 'Invalid value of field " << field->get_name() << "!') unless " << full_type_name(field->get_type()) << "::VALID_VALUES.include?(get_value)" << endl; + indent(out) << "end" << endl; + } + } + indent_down(); indent(out) << "end" << endl << endl; - } THRIFT_REGISTER_GENERATOR(rb, "Ruby", ""); diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c index 7429fb1a..d459ddb7 100644 --- a/lib/rb/ext/struct.c +++ b/lib/rb/ext/struct.c @@ -45,29 +45,18 @@ strlcpy (char *dst, const char *src, size_t dst_sz) static native_proto_method_table *mt; static native_proto_method_table *default_mt; -// static VALUE last_proto_class = Qnil; + +VALUE thrift_union_class; + +ID setfield_id; +ID setvalue_id; + +ID to_s_method_id; +ID name_to_id_method_id; #define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET) #define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id) -// static void set_native_proto_function_pointers(VALUE protocol) { -// VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table")); -// // TODO: check nil? -// Data_Get_Struct(method_table_object, native_proto_method_table, mt); -// } - -// static void check_native_proto_method_table(VALUE protocol) { -// VALUE protoclass = CLASS_OF(protocol); -// if (protoclass != last_proto_class) { -// last_proto_class = protoclass; -// if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) { -// set_native_proto_function_pointers(protocol); -// } else { -// mt = default_mt; -// } -// } -// } - //------------------------------------------- // Writing section //------------------------------------------- @@ -275,62 +264,62 @@ static void set_default_proto_function_pointers() { // end default protocol methods - +static VALUE rb_thrift_union_write (VALUE self, VALUE protocol); static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol); static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info); VALUE get_field_value(VALUE obj, VALUE field_name) { char name_buf[RSTRING_LEN(field_name) + 1]; - + name_buf[0] = '@'; strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf)); VALUE value = rb_ivar_get(obj, rb_intern(name_buf)); - + return value; } static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) { int sz, i; - + if (ttype == TTYPE_MAP) { VALUE keys; VALUE key; VALUE val; Check_Type(value, T_HASH); - + VALUE key_info = rb_hash_aref(field_info, key_sym); VALUE keytype_value = rb_hash_aref(key_info, type_sym); int keytype = FIX2INT(keytype_value); - + VALUE value_info = rb_hash_aref(field_info, value_sym); VALUE valuetype_value = rb_hash_aref(value_info, type_sym); int valuetype = FIX2INT(valuetype_value); - + keys = rb_funcall(value, keys_method_id, 0); - + sz = RARRAY_LEN(keys); - + mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz)); - + for (i = 0; i < sz; i++) { key = rb_ary_entry(keys, i); val = rb_hash_aref(value, key); - + if (IS_CONTAINER(keytype)) { write_container(keytype, key_info, key, protocol); } else { write_anything(keytype, key, protocol, key_info); } - + if (IS_CONTAINER(valuetype)) { write_container(valuetype, value_info, val, protocol); } else { write_anything(valuetype, val, protocol, value_info); } } - + mt->write_map_end(protocol); } else if (ttype == TTYPE_LIST) { Check_Type(value, T_ARRAY); @@ -340,7 +329,7 @@ static void write_container(int ttype, VALUE field_info, VALUE value, VALUE prot VALUE element_type_info = rb_hash_aref(field_info, element_sym); VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); int element_type = FIX2INT(element_type_value); - + mt->write_list_begin(protocol, element_type_value, INT2FIX(sz)); for (i = 0; i < sz; ++i) { VALUE val = rb_ary_entry(value, i); @@ -370,9 +359,9 @@ static void write_container(int ttype, VALUE field_info, VALUE value, VALUE prot VALUE element_type_info = rb_hash_aref(field_info, element_sym); VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); int element_type = FIX2INT(element_type_value); - + mt->write_set_begin(protocol, element_type_value, INT2FIX(sz)); - + for (i = 0; i < sz; i++) { VALUE val = rb_ary_entry(items, i); if (IS_CONTAINER(element_type)) { @@ -381,7 +370,7 @@ static void write_container(int ttype, VALUE field_info, VALUE value, VALUE prot write_anything(element_type, val, protocol, element_type_info); } } - + mt->write_set_end(protocol); } else { rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype); @@ -406,7 +395,11 @@ static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_i } else if (IS_CONTAINER(ttype)) { write_container(ttype, field_info, value, protocol); } else if (ttype == TTYPE_STRUCT) { - rb_thrift_struct_write(value, protocol); + if (rb_obj_is_kind_of(value, thrift_union_class)) { + rb_thrift_union_write(value, protocol); + } else { + rb_thrift_struct_write(value, protocol); + } } else { rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype); } @@ -423,24 +416,27 @@ static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol) { // iterate through all the fields here VALUE struct_fields = STRUCT_FIELDS(self); + VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0); VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0); int i = 0; for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) { VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i); + VALUE field_info = rb_hash_aref(struct_fields, field_id); VALUE ttype_value = rb_hash_aref(field_info, type_sym); int ttype = FIX2INT(ttype_value); VALUE field_name = rb_hash_aref(field_info, name_sym); + VALUE field_value = get_field_value(self, field_name); if (!NIL_P(field_value)) { mt->write_field_begin(protocol, field_name, ttype_value, field_id); - + write_anything(ttype, field_value, protocol, field_info); - + mt->write_field_end(protocol); } } @@ -457,6 +453,7 @@ static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol) { // Reading section //------------------------------------------- +static VALUE rb_thrift_union_read(VALUE self, VALUE protocol); static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol); static void set_field_value(VALUE obj, VALUE field_name, VALUE value) { @@ -488,7 +485,12 @@ static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) { } else if (ttype == TTYPE_STRUCT) { VALUE klass = rb_hash_aref(field_info, class_sym); result = rb_class_new_instance(0, NULL, klass); - rb_thrift_struct_read(result, protocol); + + if (rb_obj_is_kind_of(result, thrift_union_class)) { + rb_thrift_union_read(result, protocol); + } else { + rb_thrift_struct_read(result, protocol); + } } else if (ttype == TTYPE_MAP) { int i; @@ -524,7 +526,6 @@ static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) { rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); } - mt->read_list_end(protocol); } else if (ttype == TTYPE_SET) { VALUE items; @@ -539,7 +540,6 @@ static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) { rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); } - mt->read_set_end(protocol); result = rb_class_new_instance(1, &items, rb_cSet); @@ -597,13 +597,110 @@ static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol) { return Qnil; } + +// -------------------------------- +// Union section +// -------------------------------- + +static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) { + // read struct begin + mt->read_struct_begin(protocol); + + VALUE struct_fields = STRUCT_FIELDS(self); + + VALUE field_header = mt->read_field_begin(protocol); + VALUE field_type_value = rb_ary_entry(field_header, 1); + int field_type = FIX2INT(field_type_value); + + // make sure we got a type we expected + VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2)); + + if (!NIL_P(field_info)) { + int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym)); + if (field_type == specified_type) { + // read the value + VALUE name = rb_hash_aref(field_info, name_sym); + rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name)))); + rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info)); + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + + // read field end + mt->read_field_end(protocol); + + field_header = mt->read_field_begin(protocol); + field_type_value = rb_ary_entry(field_header, 1); + field_type = FIX2INT(field_type_value); + + if (field_type != TTYPE_STOP) { + rb_raise(rb_eRuntimeError, "too many fields in union!"); + } + + // read field end + mt->read_field_end(protocol); + + // read struct end + mt->read_struct_end(protocol); + + // call validate + rb_funcall(self, validate_method_id, 0); + + return Qnil; +} + +static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) { + // call validate + rb_funcall(self, validate_method_id, 0); + + // write struct begin + mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self))); + + VALUE struct_fields = STRUCT_FIELDS(self); + + VALUE setfield = rb_ivar_get(self, setfield_id); + VALUE setvalue = rb_ivar_get(self, setvalue_id); + VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0)); + + VALUE field_info = rb_hash_aref(struct_fields, field_id); + + VALUE ttype_value = rb_hash_aref(field_info, type_sym); + int ttype = FIX2INT(ttype_value); + + mt->write_field_begin(protocol, setfield, ttype_value, field_id); + + write_anything(ttype, setvalue, protocol, field_info); + + mt->write_field_end(protocol); + + mt->write_field_stop(protocol); + + // write struct end + mt->write_struct_end(protocol); + + return Qnil; +} + void Init_struct() { VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct")); rb_define_method(struct_module, "write", rb_thrift_struct_write, 1); rb_define_method(struct_module, "read", rb_thrift_struct_read, 1); + thrift_union_class = rb_const_get(thrift_module, rb_intern("Union")); + + rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1); + rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1); + + setfield_id = rb_intern("@setfield"); + setvalue_id = rb_intern("@value"); + + to_s_method_id = rb_intern("to_s"); + name_to_id_method_id = rb_intern("name_to_id"); + set_default_proto_function_pointers(); mt = default_mt; -} - +} \ No newline at end of file diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h index 37b1b35b..48ccef8b 100644 --- a/lib/rb/ext/struct.h +++ b/lib/rb/ext/struct.h @@ -17,6 +17,7 @@ * under the License. */ + #include #include @@ -41,7 +42,7 @@ typedef struct native_proto_method_table { VALUE (*write_field_stop)(VALUE); VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE); VALUE (*write_message_end)(VALUE); - + VALUE (*read_message_begin)(VALUE); VALUE (*read_message_end)(VALUE); VALUE (*read_field_begin)(VALUE); @@ -61,7 +62,7 @@ typedef struct native_proto_method_table { VALUE (*read_string)(VALUE); VALUE (*read_struct_begin)(VALUE); VALUE (*read_struct_end)(VALUE); - } native_proto_method_table; void Init_struct(); +void Init_union(); diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c index effa202c..09b9fe49 100644 --- a/lib/rb/ext/thrift_native.c +++ b/lib/rb/ext/thrift_native.c @@ -111,7 +111,7 @@ void Init_thrift_native() { thrift_types_module = rb_const_get(thrift_module, rb_intern("Types")); rb_cSet = rb_const_get(rb_cObject, rb_intern("Set")); protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException")); - + // Init ttype constants TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL"))); TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE"))); @@ -171,13 +171,13 @@ void Init_thrift_native() { write_method_id = rb_intern("write"); read_all_method_id = rb_intern("read_all"); native_qmark_method_id = rb_intern("native?"); - + // constant ids fields_const_id = rb_intern("FIELDS"); transport_ivar_id = rb_intern("@trans"); strict_read_ivar_id = rb_intern("@strict_read"); strict_write_ivar_id = rb_intern("@strict_write"); - + // cached symbols type_sym = ID2SYM(rb_intern("type")); name_sym = ID2SYM(rb_intern("name")); diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb index 4d4e130a..02d67b8b 100644 --- a/lib/rb/lib/thrift.rb +++ b/lib/rb/lib/thrift.rb @@ -28,6 +28,8 @@ require 'thrift/types' require 'thrift/processor' require 'thrift/client' require 'thrift/struct' +require 'thrift/union' +require 'thrift/struct_union' # serializer require 'thrift/serializer/serializer' diff --git a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb index eaf64f6b..70ea652c 100644 --- a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb +++ b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb @@ -29,7 +29,11 @@ See MemoryBuffer and BufferedTransport for examples. module Thrift class BinaryProtocolAcceleratedFactory < BaseProtocolFactory def get_protocol(trans) - BinaryProtocolAccelerated.new(trans) + if (defined? BinaryProtocolAccelerated) + BinaryProtocolAccelerated.new(trans) + else + BinaryProtocol.new(trans) + end end end end diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb index dfc8a2fe..9e52073b 100644 --- a/lib/rb/lib/thrift/struct.rb +++ b/lib/rb/lib/thrift/struct.rb @@ -65,26 +65,7 @@ module Thrift end fields_with_default_values end - - def name_to_id(name) - names_to_ids = self.class.instance_variable_get("@names_to_ids") - unless names_to_ids - names_to_ids = {} - struct_fields.each do |fid, field_def| - names_to_ids[field_def[:name]] = fid - end - self.class.instance_variable_set("@names_to_ids", names_to_ids) - end - names_to_ids[name] - end - - def each_field - struct_fields.keys.sort.each do |fid| - data = struct_fields[fid] - yield fid, data - end - end - + def inspect(skip_optional_nulls = true) fields = [] each_field do |fid, field_info| @@ -115,7 +96,8 @@ module Thrift each_field do |fid, field_info| name = field_info[:name] type = field_info[:type] - if (value = instance_variable_get("@#{name}")) + value = instance_variable_get("@#{name}") + unless value.nil? if is_container? type oprot.write_field_begin(name, type, fid) write_container(oprot, value, field_info) @@ -210,89 +192,5 @@ module Thrift iprot.skip(ftype) end end - - def read_field(iprot, field = {}) - case field[:type] - when Types::STRUCT - value = field[:class].new - value.read(iprot) - when Types::MAP - key_type, val_type, size = iprot.read_map_begin - value = {} - size.times do - k = read_field(iprot, field_info(field[:key])) - v = read_field(iprot, field_info(field[:value])) - value[k] = v - end - iprot.read_map_end - when Types::LIST - e_type, size = iprot.read_list_begin - value = Array.new(size) do |n| - read_field(iprot, field_info(field[:element])) - end - iprot.read_list_end - when Types::SET - e_type, size = iprot.read_set_begin - value = Set.new - size.times do - element = read_field(iprot, field_info(field[:element])) - value << element - end - iprot.read_set_end - else - value = iprot.read_type(field[:type]) - end - value - end - - def write_data(oprot, value, field) - if is_container? field[:type] - write_container(oprot, value, field) - else - oprot.write_type(field[:type], value) - end - end - - def write_container(oprot, value, field = {}) - case field[:type] - when Types::MAP - oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size) - value.each do |k, v| - write_data(oprot, k, field[:key]) - write_data(oprot, v, field[:value]) - end - oprot.write_map_end - when Types::LIST - oprot.write_list_begin(field[:element][:type], value.size) - value.each do |elem| - write_data(oprot, elem, field[:element]) - end - oprot.write_list_end - when Types::SET - oprot.write_set_begin(field[:element][:type], value.size) - value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets - write_data(oprot, v, field[:element]) - end - oprot.write_set_end - else - raise "Not a container type: #{field[:type]}" - end - end - - CONTAINER_TYPES = [] - CONTAINER_TYPES[Types::LIST] = true - CONTAINER_TYPES[Types::MAP] = true - CONTAINER_TYPES[Types::SET] = true - def is_container?(type) - CONTAINER_TYPES[type] - end - - def field_info(field) - { :type => field[:type], - :class => field[:class], - :key => field[:key], - :value => field[:value], - :element => field[:element] } - end end end diff --git a/lib/rb/lib/thrift/struct_union.rb b/lib/rb/lib/thrift/struct_union.rb new file mode 100644 index 00000000..9a5903f1 --- /dev/null +++ b/lib/rb/lib/thrift/struct_union.rb @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +require 'set' + +module Thrift + module Struct_Union + def name_to_id(name) + names_to_ids = self.class.instance_variable_get("@names_to_ids") + unless names_to_ids + names_to_ids = {} + struct_fields.each do |fid, field_def| + names_to_ids[field_def[:name]] = fid + end + self.class.instance_variable_set("@names_to_ids", names_to_ids) + end + names_to_ids[name] + end + + def each_field + struct_fields.keys.sort.each do |fid| + data = struct_fields[fid] + yield fid, data + end + end + + def read_field(iprot, field = {}) + case field[:type] + when Types::STRUCT + value = field[:class].new + value.read(iprot) + when Types::MAP + key_type, val_type, size = iprot.read_map_begin + value = {} + size.times do + k = read_field(iprot, field_info(field[:key])) + v = read_field(iprot, field_info(field[:value])) + value[k] = v + end + iprot.read_map_end + when Types::LIST + e_type, size = iprot.read_list_begin + value = Array.new(size) do |n| + read_field(iprot, field_info(field[:element])) + end + iprot.read_list_end + when Types::SET + e_type, size = iprot.read_set_begin + value = Set.new + size.times do + element = read_field(iprot, field_info(field[:element])) + value << element + end + iprot.read_set_end + else + value = iprot.read_type(field[:type]) + end + value + end + + def write_data(oprot, value, field) + if is_container? field[:type] + write_container(oprot, value, field) + else + oprot.write_type(field[:type], value) + end + end + + def write_container(oprot, value, field = {}) + case field[:type] + when Types::MAP + oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size) + value.each do |k, v| + write_data(oprot, k, field[:key]) + write_data(oprot, v, field[:value]) + end + oprot.write_map_end + when Types::LIST + oprot.write_list_begin(field[:element][:type], value.size) + value.each do |elem| + write_data(oprot, elem, field[:element]) + end + oprot.write_list_end + when Types::SET + oprot.write_set_begin(field[:element][:type], value.size) + value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets + write_data(oprot, v, field[:element]) + end + oprot.write_set_end + else + raise "Not a container type: #{field[:type]}" + end + end + + CONTAINER_TYPES = [] + CONTAINER_TYPES[Types::LIST] = true + CONTAINER_TYPES[Types::MAP] = true + CONTAINER_TYPES[Types::SET] = true + def is_container?(type) + CONTAINER_TYPES[type] + end + + def field_info(field) + { :type => field[:type], + :class => field[:class], + :key => field[:key], + :value => field[:value], + :element => field[:element] } + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/types.rb b/lib/rb/lib/thrift/types.rb index 20e4ca2c..cac52691 100644 --- a/lib/rb/lib/thrift/types.rb +++ b/lib/rb/lib/thrift/types.rb @@ -57,7 +57,7 @@ module Thrift when Types::STRING String when Types::STRUCT - Struct + [Struct, Union] when Types::MAP Hash when Types::SET diff --git a/lib/rb/lib/thrift/union.rb b/lib/rb/lib/thrift/union.rb new file mode 100644 index 00000000..0b41ed49 --- /dev/null +++ b/lib/rb/lib/thrift/union.rb @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Union + def initialize(name=nil, value=nil) + if name + if value.nil? + raise Exception, "Union #{self.class} cannot be instantiated with setfield and nil value!" + end + + Thrift.check_type(value, struct_fields[name_to_id(name.to_s)], name) if Thrift.type_checking + elsif !value.nil? + raise Exception, "Value provided, but no name!" + end + @setfield = name + @value = value + end + + def inspect + "<#{self.class} #{@setfield}: #{@value}>" + end + + def read(iprot) + iprot.read_struct_begin + fname, ftype, fid = iprot.read_field_begin + handle_message(iprot, fid, ftype) + iprot.read_field_end + + fname, ftype, fid = iprot.read_field_begin + raise "Too many fields for union" unless (ftype == Types::STOP) + + iprot.read_struct_end + validate + end + + def write(oprot) + validate + oprot.write_struct_begin(self.class.name) + + fid = self.name_to_id(@setfield.to_s) + + field_info = struct_fields[fid] + type = field_info[:type] + if is_container? type + oprot.write_field_begin(@setfield, type, fid) + write_container(oprot, @value, field_info) + oprot.write_field_end + else + oprot.write_field(@setfield, type, fid, @value) + end + + oprot.write_field_stop + oprot.write_struct_end + end + + def ==(other) + other != nil && @setfield == other.get_set_field && @value == other.get_value + end + + def eql?(other) + self.class == other.class && self == other + end + + def hash + [self.class.name, @setfield, @value].hash + end + + def self.field_accessor(klass, *fields) + fields.each do |field| + klass.send :define_method, "#{field}" do + if field == @setfield + @value + else + raise RuntimeError, "#{field} is not union's set field." + end + end + + klass.send :define_method, "#{field}=" do |value| + Thrift.check_type(value, klass::FIELDS.values.find {|f| f[:name].to_s == field.to_s }, field) if Thrift.type_checking + @setfield = field + @value = value + end + end + end + + # get the symbol that indicates what the currently set field type is. + def get_set_field + @setfield + end + + # get the current value of this union, regardless of what the set field is. + # generally, you should only use this method when you don't know in advance + # what field to expect. + def get_value + @value + end + + protected + + def handle_message(iprot, fid, ftype) + field = struct_fields[fid] + if field and field[:type] == ftype + @value = read_field(iprot, field) + name = field[:name].to_sym + @setfield = name + else + iprot.skip(ftype) + end + end + end +end \ No newline at end of file diff --git a/lib/rb/spec/ThriftSpec.thrift b/lib/rb/spec/ThriftSpec.thrift index fe5a8aae..f5c8c09a 100644 --- a/lib/rb/spec/ThriftSpec.thrift +++ b/lib/rb/spec/ThriftSpec.thrift @@ -42,6 +42,38 @@ struct Hello { 1: string greeting = "hello world" } +union My_union { + 1: bool im_true, + 2: byte a_bite, + 3: i16 integer16, + 4: i32 integer32, + 5: i64 integer64, + 6: double double_precision, + 7: string some_characters, + 8: i32 other_i32 +} + +struct Struct_with_union { + 1: My_union fun_union + 2: i32 integer32 + 3: string some_characters +} + +enum SomeEnum { + ONE + TWO +} + +union TestUnion { + /** + * A doc string + */ + 1: string string_field; + 2: i32 i32_field; + 3: i32 other_i32_field; + 4: SomeEnum enum_field; +} + struct Foo { 1: i32 simple = 53, 2: string words = "words", diff --git a/lib/rb/spec/binary_protocol_accelerated_spec.rb b/lib/rb/spec/binary_protocol_accelerated_spec.rb index 48c22e48..b8518c88 100644 --- a/lib/rb/spec/binary_protocol_accelerated_spec.rb +++ b/lib/rb/spec/binary_protocol_accelerated_spec.rb @@ -20,22 +20,27 @@ require File.dirname(__FILE__) + '/spec_helper' require File.dirname(__FILE__) + '/binary_protocol_spec_shared' -class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup - include Thrift +if defined? Thrift::BinaryProtocolAccelerated - describe Thrift::BinaryProtocolAccelerated do - # since BinaryProtocolAccelerated should be directly equivalent to - # BinaryProtocol, we don't need any custom specs! - it_should_behave_like 'a binary protocol' + class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup + include Thrift - def protocol_class - BinaryProtocolAccelerated + describe Thrift::BinaryProtocolAccelerated do + # since BinaryProtocolAccelerated should be directly equivalent to + # BinaryProtocol, we don't need any custom specs! + it_should_behave_like 'a binary protocol' + + def protocol_class + BinaryProtocolAccelerated + end end - end - describe BinaryProtocolAcceleratedFactory do - it "should create a BinaryProtocolAccelerated" do - BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated) + describe BinaryProtocolAcceleratedFactory do + it "should create a BinaryProtocolAccelerated" do + BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated) + end end end -end +else + puts "skipping BinaryProtocolAccelerated spec because it is not defined." +end \ No newline at end of file diff --git a/lib/rb/spec/binary_protocol_spec_shared.rb b/lib/rb/spec/binary_protocol_spec_shared.rb index 84f59206..28da7608 100644 --- a/lib/rb/spec/binary_protocol_spec_shared.rb +++ b/lib/rb/spec/binary_protocol_spec_shared.rb @@ -349,9 +349,9 @@ shared_examples_for 'a binary protocol' do # first block firstblock.call(client) - + processor.process(serverproto, serverproto) - + # second block secondblock.call(client) ensure diff --git a/lib/rb/spec/union_spec.rb b/lib/rb/spec/union_spec.rb new file mode 100644 index 00000000..48352884 --- /dev/null +++ b/lib/rb/spec/union_spec.rb @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftUnionSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + describe Union do + it "should return nil value in unset union" do + union = My_union.new + union.get_set_field.should == nil + union.get_value.should == nil + end + + it "should set a field and be accessible through get_value and the named field accessor" do + union = My_union.new + union.integer32 = 25 + union.get_set_field.should == :integer32 + union.get_value.should == 25 + union.integer32.should == 25 + end + + it "should work correctly when instantiated with static field constructors" do + union = My_union.integer32(5) + union.get_set_field.should == :integer32 + union.integer32.should == 5 + end + + it "should raise for wrong set field" do + union = My_union.new + union.integer32 = 25 + lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.") + end + + it "should not be equal to nil" do + union = My_union.new + union.should_not == nil + end + + it "should not equate two different unions, i32 vs. string" do + union = My_union.new(:integer32, 25) + other_union = My_union.new(:some_characters, "blah!") + union.should_not == other_union + end + + it "should properly reset setfield and setvalue" do + union = My_union.new(:integer32, 25) + union.get_set_field.should == :integer32 + union.some_characters = "blah!" + union.get_set_field.should == :some_characters + union.get_value.should == "blah!" + lambda { union.integer32 }.should raise_error(RuntimeError, "integer32 is not union's set field.") + end + + it "should not equate two different unions with different values" do + union = My_union.new(:integer32, 25) + other_union = My_union.new(:integer32, 400) + union.should_not == other_union + end + + it "should not equate two different unions with different fields" do + union = My_union.new(:integer32, 25) + other_union = My_union.new(:other_i32, 25) + union.should_not == other_union + end + + it "should inspect properly" do + union = My_union.new(:integer32, 25) + union.inspect.should == "" + end + + it "should not allow setting with instance_variable_set" do + union = My_union.new(:integer32, 27) + union.instance_variable_set(:@some_characters, "hallo!") + union.get_set_field.should == :integer32 + union.get_value.should == 27 + lambda { union.some_characters }.should raise_error(RuntimeError, "some_characters is not union's set field.") + end + + it "should serialize correctly" do + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::BinaryProtocol.new(trans) + + union = My_union.new(:integer32, 25) + union.write(proto) + + other_union = My_union.new(:integer32, 25) + other_union.read(proto) + other_union.should == union + end + + it "should raise when validating unset union" do + union = My_union.new + lambda { union.validate }.should raise_error(StandardError, "Union fields are not set.") + + other_union = My_union.new(:integer32, 1) + lambda { other_union.validate }.should_not raise_error(StandardError, "Union fields are not set.") + end + + it "should validate an enum field properly" do + union = TestUnion.new(:enum_field, 3) + union.get_set_field.should == :enum_field + lambda { union.validate }.should raise_error(ProtocolException, "Invalid value of field enum_field!") + + other_union = TestUnion.new(:enum_field, 1) + lambda { other_union.validate }.should_not raise_error(ProtocolException, "Invalid value of field enum_field!") + end + + it "should properly serialize and match structs with a union" do + union = My_union.new(:integer32, 26) + swu = Struct_with_union.new(:fun_union => union) + + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + swu.write(proto) + + other_union = My_union.new(:some_characters, "hello there") + swu2 = Struct_with_union.new(:fun_union => other_union) + + swu2.should_not == swu + + swu2.read(proto) + swu2.should == swu + end + end +end -- 2.17.1