THRIFT-2421: Tree/Recursive struct support in thrift
Client: cpp
Patch:  Dave Watson

Github Pull Request: This closes #84
----
commit b6134cedf292845e5ed01052919894df6b561bf2
Date:   2014-03-20T18:12:04Z

    Recursive structs support in parser

    A common complaint is that you can't express trees or other recursive structures in thrift easily - unlike protobufs. This diff loosens up the parser to allow using structs before they are defined (and uses typedef as a forward declaration).
    This diff is actually enough to make recursive types work for some dyamic languages (I tried php, works out of the box!)

    Other languages will need forward declarations, or ways to box types, to make this work (i.e. C++ needs both forward decls and a way to express structs as pointers)
diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am
index 47725bd..0013121 100644
--- a/compiler/cpp/Makefile.am
+++ b/compiler/cpp/Makefile.am
@@ -49,6 +49,7 @@
                  src/parse/t_enum.h \
                  src/parse/t_enum_value.h \
                  src/parse/t_typedef.h \
+                 src/parse/t_typedef.cc \
                  src/parse/t_container.h \
                  src/parse/t_list.h \
                  src/parse/t_set.h \
diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc
index 298096d..f6cee8d 100755
--- a/compiler/cpp/src/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/generate/t_cpp_generator.cc
@@ -99,6 +99,7 @@
 
   void generate_typedef(t_typedef* ttypedef);
   void generate_enum(t_enum* tenum);
+  void generate_forward_declaration(t_struct* tstruct);
   void generate_struct(t_struct* tstruct) {
     generate_cpp_struct(tstruct, false);
   }
@@ -112,12 +113,15 @@
   void print_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value);
   std::string render_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value);
 
-  void generate_struct_definition    (std::ofstream& out, t_struct* tstruct,
+  void generate_struct_declaration    (std::ofstream& out, t_struct* tstruct,
                                       bool is_exception=false,
                                       bool pointers=false,
                                       bool read=true,
                                       bool write=true,
                                       bool swap=false);
+  void generate_struct_definition   (std::ofstream& out, t_struct* tstruct);
+  void generate_copy_constructor     (std::ofstream& out, t_struct* tstruct);
+  void generate_assignment_operator  (std::ofstream& out, t_struct* tstruct);
   void generate_struct_fingerprint   (std::ofstream& out, t_struct* tstruct, bool is_definition);
   void generate_struct_reader        (std::ofstream& out, t_struct* tstruct, bool pointers=false);
   void generate_struct_writer        (std::ofstream& out, t_struct* tstruct, bool pointers=false);
@@ -152,8 +156,9 @@
 
   void generate_deserialize_struct       (std::ofstream& out,
                                           t_struct*   tstruct,
-                                          std::string prefix="");
-
+                                          std::string prefix="",
+					  bool pointer=false);
+  
   void generate_deserialize_container    (std::ofstream& out,
                                           t_type*     ttype,
                                           std::string prefix="");
@@ -179,7 +184,8 @@
 
   void generate_serialize_struct         (std::ofstream& out,
                                           t_struct*   tstruct,
-                                          std::string prefix="");
+                                          std::string prefix="",
+					  bool pointer=false);
 
   void generate_serialize_container      (std::ofstream& out,
                                           t_type*     ttype,
@@ -228,6 +234,10 @@
   void generate_local_reflection(std::ofstream& out, t_type* ttype, bool is_definition);
   void generate_local_reflection_pointer(std::ofstream& out, t_type* ttype);
 
+  bool is_reference(t_field* tfield) {
+    return tfield->annotations_.count("cpp.ref") != 0;
+  }
+
   bool is_complex_type(t_type* ttype) {
     ttype = get_true_type(ttype);
 
@@ -778,6 +788,13 @@
   return render.str();
 }
 
+void t_cpp_generator::generate_forward_declaration(t_struct* tstruct) {
+  // Forward declare struct def
+  f_types_ <<
+    indent() << "class " << tstruct->get_name() << ";" << endl <<
+    endl;
+}
+
 /**
  * Generates a struct definition for a thrift data type. This is a class
  * with data members and a read/write() function, plus a mirroring isset
@@ -786,8 +803,9 @@
  * @param tstruct The struct definition
  */
 void t_cpp_generator::generate_cpp_struct(t_struct* tstruct, bool is_exception) {
-  generate_struct_definition(f_types_, tstruct, is_exception,
+  generate_struct_declaration(f_types_, tstruct, is_exception,
                              false, true, true, true);
+  generate_struct_definition(f_types_impl_, tstruct);
   generate_struct_fingerprint(f_types_impl_, tstruct, true);
   generate_local_reflection(f_types_, tstruct, false);
   generate_local_reflection(f_types_impl_, tstruct, true);
@@ -797,6 +815,69 @@
   generate_struct_reader(out, tstruct);
   generate_struct_writer(out, tstruct);
   generate_struct_swap(f_types_impl_, tstruct);
+  generate_copy_constructor(f_types_impl_, tstruct);
+  generate_assignment_operator(f_types_impl_, tstruct);
+}
+
+void t_cpp_generator::generate_copy_constructor(
+  ofstream& out,
+  t_struct* tstruct) {
+  std::string tmp_name = tmp("other");
+
+  indent(out) << tstruct->get_name() << "::" <<
+    tstruct->get_name() << "(const " << tstruct->get_name() <<
+    "& " << tmp_name << ") {" << endl;
+  indent_up();
+
+  const vector<t_field*>& members = tstruct->get_members();
+  vector<t_field*>::const_iterator f_iter;
+  for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
+    if (is_reference(*f_iter)) {
+      std::string type = type_name((*f_iter)->get_type());
+      indent(out) << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
+        (*f_iter)->get_name() << ");" << endl;
+    } else {
+      indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+        (*f_iter)->get_name() << ";" << endl;
+    }
+  }
+
+  indent_down();
+  indent(out) << "}" << endl;
+}
+
+void t_cpp_generator::generate_assignment_operator(
+  ofstream& out,
+  t_struct* tstruct) {
+  std::string tmp_name = tmp("other");
+
+  indent(out) << tstruct->get_name() << "& " << tstruct->get_name() << "::" 
+    "operator=(const " << tstruct->get_name() <<
+    "& " << tmp_name << ") {" << endl;
+  indent_up();
+
+  const vector<t_field*>& members = tstruct->get_members();
+  vector<t_field*>::const_iterator f_iter;
+  for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
+    if (is_reference(*f_iter)) {
+      std::string type = type_name((*f_iter)->get_type());
+      indent(out) << "if (this == &" << tmp_name << ") return *this;" << endl;
+      indent(out) << "if (" << (*f_iter)->get_name() << ") {" << endl;
+      indent(out) << "  *" << (*f_iter)->get_name() << " = *" << tmp_name << "." << 
+	(*f_iter)->get_name() << ";" << endl;
+      indent(out) << "} else {" << endl;
+      indent(out) << "  " << (*f_iter)->get_name() << " = new " << type << "(*" << tmp_name << "." <<
+        (*f_iter)->get_name() << ");" << endl;
+      indent(out) << "}" << endl;
+    } else {
+      indent(out) << (*f_iter)->get_name() << " = " << tmp_name << "." <<
+        (*f_iter)->get_name() << ";" << endl;
+    }
+  }
+
+  indent(out) << "return *this;" << endl;
+  indent_down();
+  indent(out) << "}" << endl;
 }
 
 /**
@@ -805,7 +886,7 @@
  * @param out Output stream
  * @param tstruct The struct
  */
-void t_cpp_generator::generate_struct_definition(ofstream& out,
+void t_cpp_generator::generate_struct_declaration(ofstream& out,
                                                  t_struct* tstruct,
                                                  bool is_exception,
                                                  bool pointers,
@@ -882,15 +963,23 @@
   generate_struct_fingerprint(out, tstruct, false);
 
   if (!pointers) {
+    // Copy constructor
+    indent(out) << 
+      tstruct->get_name() << "(const " << tstruct->get_name() << "&);" << endl;
+
+    // Assignment Operator
+    indent(out) << tstruct->get_name() << "& operator=(const " << tstruct->get_name() << "&);" << endl;
+
     // Default constructor
     indent(out) <<
       tstruct->get_name() << "()";
+    
 
     bool init_ctor = false;
 
     for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
       t_type* t = get_true_type((*m_iter)->get_type());
-      if (t->is_base_type() || t->is_enum()) {
+      if (t->is_base_type() || t->is_enum() || is_reference(*m_iter)) {
         string dval;
         if (t->is_enum()) {
           dval += "(" + type_name(t) + ")";
@@ -907,7 +996,7 @@
         } else {
           out << ", " << (*m_iter)->get_name() << "(" << dval << ")";
         }
-      }
+      } 
     }
     out << " {" << endl;
     indent_up();
@@ -929,7 +1018,7 @@
   if (tstruct->annotations_.find("final") == tstruct->annotations_.end()) {
     out <<
       endl <<
-      indent() << "virtual ~" << tstruct->get_name() << "() throw() {}" << endl << endl;
+      indent() << "virtual ~" << tstruct->get_name() << "() throw();" << endl;
   }
 
   // Pointer to this structure's reflection local typespec.
@@ -942,7 +1031,7 @@
   // Declare all fields
   for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
     indent(out) <<
-      declare_field(*m_iter, false, pointers && !(*m_iter)->get_type()->is_xception(), !read) << endl;
+      declare_field(*m_iter, false, (pointers && !(*m_iter)->get_type()->is_xception()) || is_reference(*m_iter), !read) << endl;
   }
 
   // Add the __isset data member if we need it, using the definition from above
@@ -961,19 +1050,7 @@
       endl <<
       indent() << "void __set_" << (*m_iter)->get_name() <<
         "(" << type_name((*m_iter)->get_type(), false, true);
-    out << " val) {" << endl << indent() <<
-      indent() << (*m_iter)->get_name() << " = val;" << endl;
-
-    // assume all fields are required except optional fields.
-    // for optional fields change __isset.name to true
-    bool is_optional = (*m_iter)->get_req() == t_field::T_OPTIONAL;
-    if (is_optional) {
-      out <<
-        indent() <<
-        indent() << "__isset." << (*m_iter)->get_name() << " = true;" << endl;
-    }
-    out <<
-      indent()<< "}" << endl;
+    out << " val);" << endl;
   }
   out << endl;
 
@@ -1059,6 +1136,64 @@
   }
 }
 
+void t_cpp_generator::generate_struct_definition(ofstream& out,
+						  t_struct* tstruct) {
+  // Get members
+  vector<t_field*>::const_iterator m_iter;
+  const vector<t_field*>& members = tstruct->get_members();
+
+
+  // Destructor
+  if (tstruct->annotations_.find("final") == tstruct->annotations_.end()) {
+    out <<
+      endl <<
+      indent() << tstruct->get_name() << "::~" << tstruct->get_name() << "() throw() {" << endl;
+    indent_up();
+
+    for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+      if (is_reference(*m_iter)) {
+	out << indent() <<
+	  "delete " << (*m_iter)->get_name() << ";" << endl;
+      }
+    }    
+
+    indent_down();
+    out << indent() << "}" << endl << endl;
+  }
+
+  // Create a setter function for each field
+  for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+    out <<
+      endl <<
+      indent() << "void " << tstruct->get_name() << "::__set_" << (*m_iter)->get_name() <<
+        "(" << type_name((*m_iter)->get_type(), false, true);
+    out << " val) {" << endl;
+    indent_up();
+    if (is_reference((*m_iter))) {
+      std::string type = type_name((*m_iter)->get_type());
+      indent(out) << "if (" << (*m_iter)->get_name() << ") {" << endl;
+      indent(out) << "  *" << (*m_iter)->get_name() << " = val;" << endl;
+      indent(out) << "} else {" << endl;
+      indent(out) << "  " << (*m_iter)->get_name() << " = new " << type << "(val);" << endl;
+      indent(out) << "}" << endl;
+    } else {
+      out << indent() << (*m_iter)->get_name() << " = val;" << endl;
+    }
+    indent_down();
+
+    // assume all fields are required except optional fields.
+    // for optional fields change __isset.name to true
+    bool is_optional = (*m_iter)->get_req() == t_field::T_OPTIONAL;
+    if (is_optional) {
+      out <<
+        indent() <<
+        indent() << "__isset." << (*m_iter)->get_name() << " = true;" << endl;
+    }
+    out <<
+      indent()<< "}" << endl;
+  }
+  out << endl;
+}
 /**
  * Writes the fingerprint of a struct to either the header or implementation.
  *
@@ -1751,11 +1886,11 @@
 
     // TODO(dreiss): Why is this stuff not in generate_function_helpers?
     ts->set_name(tservice->get_name() + "_" + (*f_iter)->get_name() + "_args");
-    generate_struct_definition(f_header_, ts, false);
+    generate_struct_declaration(f_header_, ts, false);
     generate_struct_reader(out, ts);
     generate_struct_writer(out, ts);
     ts->set_name(tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs");
-    generate_struct_definition(f_header_, ts, false, true, false, true);
+    generate_struct_declaration(f_header_, ts, false, true, false, true);
     generate_struct_writer(out, ts, true);
     ts->set_name(name_orig);
 
@@ -3206,12 +3341,12 @@
     result.append(*f_iter);
   }
 
-  generate_struct_definition(f_header_, &result, false);
+  generate_struct_declaration(f_header_, &result, false);
   generate_struct_reader(out, &result);
   generate_struct_result_writer(out, &result);
 
   result.set_name(tservice->get_name() + "_" + tfunction->get_name() + "_presult");
-  generate_struct_definition(f_header_, &result, false, true, true, gen_cob_style_);
+  generate_struct_declaration(f_header_, &result, false, true, true, gen_cob_style_);
   generate_struct_reader(out, &result, true);
   if (gen_cob_style_) {
     generate_struct_writer(out, &result, true);
@@ -3870,7 +4005,7 @@
   string name = prefix + tfield->get_name() + suffix;
 
   if (type->is_struct() || type->is_xception()) {
-    generate_deserialize_struct(out, (t_struct*)type, name);
+    generate_deserialize_struct(out, (t_struct*)type, name, is_reference(tfield));
   } else if (type->is_container()) {
     generate_deserialize_container(out, type, name);
   } else if (type->is_base_type()) {
@@ -3932,10 +4067,26 @@
  */
 void t_cpp_generator::generate_deserialize_struct(ofstream& out,
                                                   t_struct* tstruct,
-                                                  string prefix) {
-  (void) tstruct;
-  indent(out) <<
-    "xfer += " << prefix << ".read(iprot);" << endl;
+                                                  string prefix,
+						  bool pointer) {
+  if (pointer) {
+    indent(out) << "if (!" << prefix << ") { " << endl;
+    indent(out) << "  " << prefix << " = new " << type_name(tstruct) << ";" << endl;
+    indent(out) << "}" << endl;
+    indent(out) <<
+      "xfer += " << prefix << "->read(iprot);" << endl;
+    indent(out) << "bool wasSet = false;" << endl;
+    const vector<t_field*>& members = tstruct->get_members();
+    vector<t_field*>::const_iterator f_iter;
+    for (f_iter = members.begin(); f_iter != members.end(); ++f_iter) {
+
+      indent(out) << "if (" << prefix << "->__isset." << (*f_iter)->get_name() << ") { wasSet = true; }" << endl;
+    }
+    indent(out) << "if (!wasSet) { " << prefix << " = NULL; }" << endl;
+  } else {
+    indent(out) <<
+      "xfer += " << prefix << ".read(iprot);" << endl;
+  }
 }
 
 void t_cpp_generator::generate_deserialize_container(ofstream& out,
@@ -4088,7 +4239,8 @@
   if (type->is_struct() || type->is_xception()) {
     generate_serialize_struct(out,
                               (t_struct*)type,
-                              name);
+                              name, 
+			      is_reference(tfield));
   } else if (type->is_container()) {
     generate_serialize_container(out, type, name);
   } else if (type->is_base_type() || type->is_enum()) {
@@ -4151,10 +4303,20 @@
  */
 void t_cpp_generator::generate_serialize_struct(ofstream& out,
                                                 t_struct* tstruct,
-                                                string prefix) {
-  (void) tstruct;
-  indent(out) <<
-    "xfer += " << prefix << ".write(oprot);" << endl;
+                                                string prefix,
+						bool pointer) {
+  if (pointer) {
+    indent(out) << "if (" << prefix << ") {" << endl;
+    indent(out) << "  xfer += " << prefix << "->write(oprot); " << endl;
+    indent(out)  << "} else {" << "oprot->writeStructBegin(\"" <<
+      tstruct->get_name() << "\"); " << endl;
+    indent(out) << "  oprot->writeStructEnd();" << endl;
+    indent(out) << "  oprot->writeFieldStop();" << endl;
+    indent(out) << "}" << endl;
+  } else {
+    indent(out) <<
+      "xfer += " << prefix << ".write(oprot);" << endl;
+  }
 }
 
 void t_cpp_generator::generate_serialize_container(ofstream& out,
diff --git a/compiler/cpp/src/generate/t_generator.cc b/compiler/cpp/src/generate/t_generator.cc
index de33fd4..f04c65a 100644
--- a/compiler/cpp/src/generate/t_generator.cc
+++ b/compiler/cpp/src/generate/t_generator.cc
@@ -47,8 +47,12 @@
 
   // Generate structs, exceptions, and unions in declared order
   vector<t_struct*> objects = program_->get_objects();
+
   vector<t_struct*>::iterator o_iter;
   for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) {
+    generate_forward_declaration(*o_iter);
+  }
+  for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) {
     if ((*o_iter)->is_xception()) {
       generate_xception(*o_iter);
     } else {
diff --git a/compiler/cpp/src/generate/t_generator.h b/compiler/cpp/src/generate/t_generator.h
index e33263a..d5cf835 100644
--- a/compiler/cpp/src/generate/t_generator.h
+++ b/compiler/cpp/src/generate/t_generator.h
@@ -110,6 +110,7 @@
   }
   virtual void generate_struct   (t_struct*   tstruct)   = 0;
   virtual void generate_service  (t_service*  tservice)  = 0;
+  virtual void generate_forward_declaration (t_struct*) {}
   virtual void generate_xception (t_struct*   txception) {
     // By default exceptions are the same as structs
     generate_struct(txception);
diff --git a/compiler/cpp/src/parse/t_program.h b/compiler/cpp/src/parse/t_program.h
index 96a8a5c..5cf2738 100644
--- a/compiler/cpp/src/parse/t_program.h
+++ b/compiler/cpp/src/parse/t_program.h
@@ -235,7 +235,7 @@
   }
 
   // Scope accessor
-  t_scope* scope() {
+  t_scope* scope() const {
     return scope_;
   }
 
diff --git a/compiler/cpp/src/parse/t_type.h b/compiler/cpp/src/parse/t_type.h
index 74686b0..b85f2da 100644
--- a/compiler/cpp/src/parse/t_type.h
+++ b/compiler/cpp/src/parse/t_type.h
@@ -66,6 +66,10 @@
     return program_;
   }
 
+  const t_program* get_program() const {
+    return program_;
+  }
+
   t_type* get_true_type();
 
   // Return a string that uniquely identifies this type
diff --git a/compiler/cpp/src/parse/t_typedef.cc b/compiler/cpp/src/parse/t_typedef.cc
new file mode 100644
index 0000000..ddbe749
--- /dev/null
+++ b/compiler/cpp/src/parse/t_typedef.cc
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <cstdio>
+
+#include "t_typedef.h"
+#include "t_program.h"
+
+t_type* t_typedef::get_type() const {
+  if (type_ == NULL) {
+    t_type* type = get_program()->scope()->get_type(symbolic_);
+    if (type == NULL) {
+      printf("Type \"%s\" not defined\n", symbolic_.c_str());
+      exit(1);
+    }
+    return type;
+  }
+  return type_;
+}
diff --git a/compiler/cpp/src/parse/t_typedef.h b/compiler/cpp/src/parse/t_typedef.h
index 4c77d97..1bea4c9 100644
--- a/compiler/cpp/src/parse/t_typedef.h
+++ b/compiler/cpp/src/parse/t_typedef.h
@@ -32,16 +32,26 @@
  */
 class t_typedef : public t_type {
  public:
-  t_typedef(t_program* program, t_type* type, std::string symbolic) :
+  t_typedef(t_program* program, t_type* type, const std::string& symbolic) :
     t_type(program, symbolic),
     type_(type),
-    symbolic_(symbolic) {}
+    symbolic_(symbolic),
+    seen_(false) {}
+
+  /**
+   * This constructor is used to refer to a type that is lazily
+   * resolved at a later time, like for forward declarations or
+   * recursive types.
+   */
+  t_typedef(t_program* program, const std::string& symbolic) :
+    t_type(program, symbolic),
+    type_(NULL),
+    symbolic_(symbolic),
+    seen_(false) {}
 
   ~t_typedef() {}
 
-  t_type* get_type() const {
-    return type_;
-  }
+  t_type* get_type() const;
 
   const std::string& get_symbolic() const {
     return symbolic_;
@@ -52,19 +62,26 @@
   }
 
   virtual std::string get_fingerprint_material() const {
-    return type_->get_fingerprint_material();
+    if (!seen_) {
+      seen_ = true;
+      std::string ret = get_type()->get_fingerprint_material();
+      seen_ = false;
+      return ret;
+    } 
+    return "";
   }
 
   virtual void generate_fingerprint() {
     t_type::generate_fingerprint();
-    if (!type_->has_fingerprint()) {
-      type_->generate_fingerprint();
+    if (!get_type()->has_fingerprint()) {
+      get_type()->generate_fingerprint();
     }
   }
 
  private:
   t_type* type_;
   std::string symbolic_;
+  mutable bool seen_;
 };
 
 #endif
diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy
index 7c4a90b..e2467df 100644
--- a/compiler/cpp/src/thrifty.yy
+++ b/compiler/cpp/src/thrifty.yy
@@ -1087,8 +1087,12 @@
         // Lookup the identifier in the current scope
         $$ = g_scope->get_type($1);
         if ($$ == NULL) {
-          yyerror("Type \"%s\" has not been defined.", $1);
-          exit(1);
+          /*
+           * Either this type isn't yet declared, or it's never
+             declared.  Either way allow it and we'll figure it out
+             during generation.
+           */
+          $$ = new t_typedef(g_program, $1);
         }
       }
     }