THRIFT-697. Union support in Ruby

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@910700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c
index 7429fb1..d459ddb 100644
--- a/lib/rb/ext/struct.c
+++ b/lib/rb/ext/struct.c
@@ -45,29 +45,18 @@
 
 static native_proto_method_table *mt;
 static native_proto_method_table *default_mt;
-// static VALUE last_proto_class = Qnil;
+
+VALUE thrift_union_class;
+
+ID setfield_id;
+ID setvalue_id;
+
+ID to_s_method_id;
+ID name_to_id_method_id;
 
 #define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET)
 #define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id)
 
-// static void set_native_proto_function_pointers(VALUE protocol) {
-//   VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table"));
-//   // TODO: check nil?
-//   Data_Get_Struct(method_table_object, native_proto_method_table, mt);
-// }
-
-// static void check_native_proto_method_table(VALUE protocol) {
-//   VALUE protoclass = CLASS_OF(protocol);
-//   if (protoclass != last_proto_class) {
-//     last_proto_class = protoclass;
-//     if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) {
-//       set_native_proto_function_pointers(protocol);
-//     } else {
-//       mt = default_mt;
-//     }
-//   }
-// }
-
 //-------------------------------------------
 // Writing section
 //-------------------------------------------
@@ -275,62 +264,62 @@
 
 // end default protocol methods
 
-
+static VALUE rb_thrift_union_write (VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol);
 static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info);
 
 VALUE get_field_value(VALUE obj, VALUE field_name) {
   char name_buf[RSTRING_LEN(field_name) + 1];
-  
+
   name_buf[0] = '@';
   strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf));
 
   VALUE value = rb_ivar_get(obj, rb_intern(name_buf));
-  
+
   return value;
 }
 
 static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) {
   int sz, i;
-  
+
   if (ttype == TTYPE_MAP) {
     VALUE keys;
     VALUE key;
     VALUE val;
 
     Check_Type(value, T_HASH);
-    
+
     VALUE key_info = rb_hash_aref(field_info, key_sym);
     VALUE keytype_value = rb_hash_aref(key_info, type_sym);
     int keytype = FIX2INT(keytype_value);
-    
+
     VALUE value_info = rb_hash_aref(field_info, value_sym);
     VALUE valuetype_value = rb_hash_aref(value_info, type_sym);
     int valuetype = FIX2INT(valuetype_value);
-    
+
     keys = rb_funcall(value, keys_method_id, 0);
-    
+
     sz = RARRAY_LEN(keys);
-    
+
     mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       key = rb_ary_entry(keys, i);
       val = rb_hash_aref(value, key);
-      
+
       if (IS_CONTAINER(keytype)) {
         write_container(keytype, key_info, key, protocol);
       } else {
         write_anything(keytype, key, protocol, key_info);
       }
-      
+
       if (IS_CONTAINER(valuetype)) {
         write_container(valuetype, value_info, val, protocol);
       } else {
         write_anything(valuetype, val, protocol, value_info);
       }
     }
-    
+
     mt->write_map_end(protocol);
   } else if (ttype == TTYPE_LIST) {
     Check_Type(value, T_ARRAY);
@@ -340,7 +329,7 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_list_begin(protocol, element_type_value, INT2FIX(sz));
     for (i = 0; i < sz; ++i) {
       VALUE val = rb_ary_entry(value, i);
@@ -370,9 +359,9 @@
     VALUE element_type_info = rb_hash_aref(field_info, element_sym);
     VALUE element_type_value = rb_hash_aref(element_type_info, type_sym);
     int element_type = FIX2INT(element_type_value);
-    
+
     mt->write_set_begin(protocol, element_type_value, INT2FIX(sz));
-    
+
     for (i = 0; i < sz; i++) {
       VALUE val = rb_ary_entry(items, i);
       if (IS_CONTAINER(element_type)) {
@@ -381,7 +370,7 @@
         write_anything(element_type, val, protocol, element_type_info);
       }
     }
-    
+
     mt->write_set_end(protocol);
   } else {
     rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype);
@@ -406,7 +395,11 @@
   } else if (IS_CONTAINER(ttype)) {
     write_container(ttype, field_info, value, protocol);
   } else if (ttype == TTYPE_STRUCT) {
-    rb_thrift_struct_write(value, protocol);
+    if (rb_obj_is_kind_of(value, thrift_union_class)) {
+      rb_thrift_union_write(value, protocol);
+    } else {
+      rb_thrift_struct_write(value, protocol);
+    }
   } else {
     rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype);
   }
@@ -423,24 +416,27 @@
 
   // iterate through all the fields here
   VALUE struct_fields = STRUCT_FIELDS(self);
+
   VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0);
   VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0);
 
   int i = 0;
   for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) {
     VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i);
+
     VALUE field_info = rb_hash_aref(struct_fields, field_id);
 
     VALUE ttype_value = rb_hash_aref(field_info, type_sym);
     int ttype = FIX2INT(ttype_value);
     VALUE field_name = rb_hash_aref(field_info, name_sym);
+
     VALUE field_value = get_field_value(self, field_name);
 
     if (!NIL_P(field_value)) {
       mt->write_field_begin(protocol, field_name, ttype_value, field_id);
-      
+
       write_anything(ttype, field_value, protocol, field_info);
-      
+
       mt->write_field_end(protocol);
     }
   }
@@ -457,6 +453,7 @@
 // Reading section
 //-------------------------------------------
 
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol);
 static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol);
 
 static void set_field_value(VALUE obj, VALUE field_name, VALUE value) {
@@ -488,7 +485,12 @@
   } else if (ttype == TTYPE_STRUCT) {
     VALUE klass = rb_hash_aref(field_info, class_sym);
     result = rb_class_new_instance(0, NULL, klass);
-    rb_thrift_struct_read(result, protocol);
+
+    if (rb_obj_is_kind_of(result, thrift_union_class)) {
+      rb_thrift_union_read(result, protocol);
+    } else {
+      rb_thrift_struct_read(result, protocol);
+    }
   } else if (ttype == TTYPE_MAP) {
     int i;
 
@@ -524,7 +526,6 @@
       rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_list_end(protocol);
   } else if (ttype == TTYPE_SET) {
     VALUE items;
@@ -539,7 +540,6 @@
       rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym)));
     }
 
-
     mt->read_set_end(protocol);
 
     result = rb_class_new_instance(1, &items, rb_cSet);
@@ -597,13 +597,110 @@
   return Qnil;
 }
 
+
+// --------------------------------
+// Union section
+// --------------------------------
+
+static VALUE rb_thrift_union_read(VALUE self, VALUE protocol) {
+  // read struct begin
+  mt->read_struct_begin(protocol);
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE field_header = mt->read_field_begin(protocol);
+  VALUE field_type_value = rb_ary_entry(field_header, 1);
+  int field_type = FIX2INT(field_type_value);
+
+  // make sure we got a type we expected
+  VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2));
+
+  if (!NIL_P(field_info)) {
+    int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym));
+    if (field_type == specified_type) {
+      // read the value
+      VALUE name = rb_hash_aref(field_info, name_sym);
+      rb_iv_set(self, "@setfield", ID2SYM(rb_intern(RSTRING_PTR(name))));
+      rb_iv_set(self, "@value", read_anything(protocol, field_type, field_info));
+    } else {
+      rb_funcall(protocol, skip_method_id, 1, field_type_value);
+    }
+  } else {
+    rb_funcall(protocol, skip_method_id, 1, field_type_value);
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  field_header = mt->read_field_begin(protocol);
+  field_type_value = rb_ary_entry(field_header, 1);
+  field_type = FIX2INT(field_type_value);
+
+  if (field_type != TTYPE_STOP) {
+    rb_raise(rb_eRuntimeError, "too many fields in union!");
+  }
+
+  // read field end
+  mt->read_field_end(protocol);
+
+  // read struct end
+  mt->read_struct_end(protocol);
+
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  return Qnil;
+}
+
+static VALUE rb_thrift_union_write(VALUE self, VALUE protocol) {
+  // call validate
+  rb_funcall(self, validate_method_id, 0);
+
+  // write struct begin
+  mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self)));
+
+  VALUE struct_fields = STRUCT_FIELDS(self);
+
+  VALUE setfield = rb_ivar_get(self, setfield_id);
+  VALUE setvalue = rb_ivar_get(self, setvalue_id);
+  VALUE field_id = rb_funcall(self, name_to_id_method_id, 1, rb_funcall(setfield, to_s_method_id, 0));
+
+  VALUE field_info = rb_hash_aref(struct_fields, field_id);
+
+  VALUE ttype_value = rb_hash_aref(field_info, type_sym);
+  int ttype = FIX2INT(ttype_value);
+
+  mt->write_field_begin(protocol, setfield, ttype_value, field_id);
+
+  write_anything(ttype, setvalue, protocol, field_info);
+
+  mt->write_field_end(protocol);
+
+  mt->write_field_stop(protocol);
+
+  // write struct end
+  mt->write_struct_end(protocol);
+
+  return Qnil;
+}
+
 void Init_struct() {
   VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct"));
 
   rb_define_method(struct_module, "write", rb_thrift_struct_write, 1);
   rb_define_method(struct_module, "read", rb_thrift_struct_read, 1);
 
+  thrift_union_class = rb_const_get(thrift_module, rb_intern("Union"));
+
+  rb_define_method(thrift_union_class, "write", rb_thrift_union_write, 1);
+  rb_define_method(thrift_union_class, "read", rb_thrift_union_read, 1);
+
+  setfield_id = rb_intern("@setfield");
+  setvalue_id = rb_intern("@value");
+
+  to_s_method_id = rb_intern("to_s");
+  name_to_id_method_id = rb_intern("name_to_id");
+
   set_default_proto_function_pointers();
   mt = default_mt;
-}
-
+}
\ No newline at end of file
diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h
index 37b1b35..48ccef8 100644
--- a/lib/rb/ext/struct.h
+++ b/lib/rb/ext/struct.h
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+
 #include <stdbool.h>
 #include <ruby.h>
 
@@ -41,7 +42,7 @@
   VALUE (*write_field_stop)(VALUE);
   VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE);
   VALUE (*write_message_end)(VALUE);
-  
+
   VALUE (*read_message_begin)(VALUE);
   VALUE (*read_message_end)(VALUE);
   VALUE (*read_field_begin)(VALUE);
@@ -61,7 +62,7 @@
   VALUE (*read_string)(VALUE);
   VALUE (*read_struct_begin)(VALUE);
   VALUE (*read_struct_end)(VALUE);
-  
 } native_proto_method_table;
 
 void Init_struct();
+void Init_union();
diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c
index effa202..09b9fe4 100644
--- a/lib/rb/ext/thrift_native.c
+++ b/lib/rb/ext/thrift_native.c
@@ -111,7 +111,7 @@
   thrift_types_module = rb_const_get(thrift_module, rb_intern("Types"));
   rb_cSet = rb_const_get(rb_cObject, rb_intern("Set"));
   protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException"));
-  
+
   // Init ttype constants
   TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL")));
   TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE")));
@@ -171,13 +171,13 @@
   write_method_id = rb_intern("write");
   read_all_method_id = rb_intern("read_all");
   native_qmark_method_id = rb_intern("native?");
-  
+
   // constant ids
   fields_const_id = rb_intern("FIELDS");
   transport_ivar_id = rb_intern("@trans");
   strict_read_ivar_id = rb_intern("@strict_read");
   strict_write_ivar_id = rb_intern("@strict_write");  
-  
+
   // cached symbols
   type_sym = ID2SYM(rb_intern("type"));
   name_sym = ID2SYM(rb_intern("name"));