Added read/write I16, U16 and Bool methods to TProtocol

Modified code generation to define structs and io methods for function argument lists and server process call implementations


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664749 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/src/cpp_generator.py b/compiler/src/cpp_generator.py
index 834877e..f67ca00 100644
--- a/compiler/src/cpp_generator.py
+++ b/compiler/src/cpp_generator.py
@@ -154,7 +154,7 @@
         return ttype.name;
 
     elif isinstance(ttype, Function):
-        return typeToCTypeDeclaration(ttype.resultType)+ " "+ttype.name+"("+string.join([typeToCTypeDeclaration(arg) for arg in ttype.argFieldList], ", ")+")"
+        return typeToCTypeDeclaration(ttype.resultType)+ " "+ttype.name+"("+string.join([typeToCTypeDeclaration(arg) for arg in ttype.argsStruct.fieldList], ", ")+")"
 
     elif isinstance(ttype, Field):
         return typeToCTypeDeclaration(ttype.type)+ " "+ttype.name
@@ -162,15 +162,13 @@
     else:
         raise Exception, "Unknown type "+str(ttype)
 
-def writeTypeDefDefinition(cfile, typedef):
+def toTypeDefDefinition(typedef):
 
-    cfile.writeln("typedef "+typeToCTypeDeclaration(typedef.definitionType)+" "+typedef.name+";")
+    return "typedef "+typeToCTypeDeclaration(typedef.definitionType)+" "+typedef.name+";"
 
-def writeEnumDefinition(cfile, enum):
+def toEnumDefinition(enum):
 
-    cfile.write("enum "+enum.name+" ");
-
-    cfile.beginBlock();
+    result = "enum "+enum.name+" {\n"
 
     first = True
 
@@ -178,40 +176,44 @@
         if first:
             first = False
         else:
-            cfile.writeln(",")
-        cfile.write(ed.name+" = "+str(ed.id))
+            result+= ",\n"
+        result+= "    "+ed.name+" = "+str(ed.id)
 
-    cfile.writeln()
-    cfile.endBlock(";");
+    result+= "\n};\n"
 
-def writeStructDefinition(cfile, struct):
+    return result
 
-    cfile.write("struct "+struct.name+" ");
 
-    cfile.beginBlock()
+def toStructDefinition(struct):
+
+    result = "struct "+struct.name+" {\n"
 
     for field in struct.fieldList:
-        cfile.writeln(typeToCTypeDeclaration(field)+";")
+        result += "    "+typeToCTypeDeclaration(field)+";\n"
 
-    cfile.endBlock(";")
+    result+= "};\n"
 
+    return result
 
-CPP_DEFINITION_WRITER_MAP = {
-    TypeDef : writeTypeDefDefinition,
-    Enum : writeEnumDefinition,
-    Struct : writeStructDefinition,
+CPP_DEFINITION_MAP = {
+    TypeDef : toTypeDefDefinition,
+    Enum : toEnumDefinition,
+    Struct : toStructDefinition,
     Service : None
     }
     
-def writeDefinitions(cfile, definitions):
+def toDefinitions(definitions):
+
+    result = ""
+
     for definition in definitions:
 
-        writer = CPP_DEFINITION_WRITER_MAP[type(definition)]
+        writer = CPP_DEFINITION_MAP[type(definition)]
 
         if writer:
-            writer(cfile, definition)
+            result+= writer(definition)+"\n"
 
-        cfile.writeln()
+    return result
 
 CPP_THRIFT_NS = "facebook::thrift"
 
@@ -225,11 +227,11 @@
 ${functionDeclarations}};
 """)
 
-def writeServiceInterfaceDeclaration(cfile, service, debugp=None):
+def toServiceInterfaceDeclaration(service, debugp=None):
 
     functionDeclarations = string.join([CPP_INTERFACE_FUNCTION_DECLARATION.substitute(service=service.name, functionDeclaration=typeToCTypeDeclaration(function)) for function in service.functionList], "")
 
-    cfile.write(CPP_INTERFACE_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations))
+    return CPP_INTERFACE_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations)
 
 CPP_SP = Template("boost::shared_ptr<${klass}> ")
 
@@ -245,7 +247,22 @@
 CPP_TRANSPORT = CPP_TRANSPORT_NS+"::TTransport"
 CPP_TRANSPORTP = CPP_SP.substitute(klass=CPP_TRANSPORT)
 
-CPP_SERVER_FUNCTION_DECLARATION = Template("""    void process_${function}("""+CPP_TRANSPORTP+""" _itrans, """+CPP_TRANSPORTP+""" _otrans);
+CPP_SERVER_FUNCTION_DECLARATION = Template("""    void process_${function}("""+CPP_TRANSPORTP+""" itrans, """+CPP_TRANSPORTP+""" otrans);
+""")
+
+CPP_SERVER_FUNCTION_DEFINITION = Template("""
+void ${service}ServerIf::process_${function}("""+CPP_TRANSPORTP+""" itrans, """+CPP_TRANSPORTP+""" otrans) {
+
+    uint32_t xfer = 0;
+
+${argsStructDeclaration}
+${argsStructReader}    
+${resultDeclaration}
+${functionCall}
+${resultWriter}
+
+    otrans->flush();
+}
 """)
 
 CPP_PROTOCOL_TSTOP = CPP_PROTOCOL_NS+"::T_STOP"
@@ -280,6 +297,12 @@
     if isinstance(ttype, PrimitiveType):
 	return CPP_TTYPE_MAP[ttype]
 
+    elif isinstance(ttype, Enum):
+	return CPP_TTYPE_MAP[I32_TYPE]
+
+    elif isinstance(ttype, TypeDef):
+	return toWireType(toCanonicalType(ttype))
+
     elif isinstance(ttype, Struct) or isinstance(ttype, CollectionType):
 	return CPP_TTYPE_MAP[type(ttype)]
 
@@ -300,11 +323,11 @@
 ${functionDeclarations}};
 """)
 
-def writeServerDeclaration(cfile, service, debugp=None):
+def toServerDeclaration(service, debugp=None):
 
     functionDeclarations = string.join([CPP_SERVER_FUNCTION_DECLARATION.substitute(function=function.name) for function in service.functionList], "")
 
-    cfile.write(CPP_SERVER_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations))
+    return CPP_SERVER_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations)
     
 CPP_CLIENT_FUNCTION_DECLARATION = Template("""    ${functionDeclaration};
 """)
@@ -326,16 +349,51 @@
     """+CPP_PROTOCOLP+""" _oprot;
 };""")
 
-def writeClientDeclaration(cfile, service, debugp=None):
+def toServerServiceDefinition(service, debugp=None):
+
+    result = ""
+
+    for function in service.functionList:
+
+	if len(function.argsStruct.fieldList) > 0:
+	    argsStructDeclaration = "    "+typeToCTypeDeclaration(function.argsStruct)+" __args;\n"
+	    argsStructReader = "    "+toReaderCall("__args", function.argsStruct, "_iprot")+";\n"
+	else:
+	    argsStructDeclaration = ""
+	    argsStructReader = ""
+
+	functionCall = "    "
+	resultDeclaration = ""
+	resultWriter = ""
+
+	if toCanonicalType(function.resultType) != VOID_TYPE:
+	    resultDeclaration = "    "+typeToCTypeDeclaration(function.resultType)+" __result;\n"
+	    functionCall+= "__result = "
+
+	functionCall+= function.name+"("+string.join(["__args."+arg.name for arg in function.argsStruct.fieldList], ", ")+");\n"
+
+	if toCanonicalType(function.resultType) != VOID_TYPE:
+	    resultWriter = "    "+toWriterCall("__result", function.resultType, "_oprot")+";"
+
+	result+= CPP_SERVER_FUNCTION_DEFINITION.substitute(service=service.name, function=function.name,
+							   argsStructDeclaration=argsStructDeclaration, argsStructReader=argsStructReader, 
+							   functionCall=functionCall,
+							   resultDeclaration=resultDeclaration, resultWriter=resultWriter)
+
+    return result
+
+def toServerDefinition(program, debugp=None):
+
+    return string.join([toServerServiceDefinition(service) for service in program.serviceMap.values()], "\n")
+
+def toClientDeclaration(service, debugp=None):
 
     functionDeclarations = string.join([CPP_CLIENT_FUNCTION_DECLARATION.substitute(functionDeclaration=typeToCTypeDeclaration(function)) for function in service.functionList], "")
 
-    cfile.writeln(CPP_CLIENT_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations))
+    return CPP_CLIENT_DECLARATION.substitute(service=service.name, functionDeclarations=functionDeclarations)+"\n"
 
-def writeServiceDeclaration(cfile, service, debugp=None):
-    writeServiceInterfaceDeclaration(cfile, service, debugp)
-    writeServerDeclaration(cfile, service, debugp)
-    writeClientDeclaration(cfile, service, debugp)
+def toServiceDeclaration(service, debugp=None):
+    return toServiceInterfaceDeclaration(service, debugp) + toServerDeclaration(service, debugp) + toClientDeclaration(service, debugp)
 
 def toGenDir(filename, suffix="cpp-gen", debugp=None):
 
@@ -388,7 +446,7 @@
 
     cfile.writeln(CPP_TYPES_HEADER.substitute(source=basename, date=time.ctime()))
 
-    writeDefinitions(cfile, program.definitions)
+    cfile.write(toDefinitions(program.definitions))
 
     cfile.writeln(CPP_TYPES_FOOTER.substitute(source=basename))
 
@@ -432,7 +490,7 @@
 
     for service in services:
 
-        writeServiceDeclaration(cfile, service)
+        cfile.write(toServiceDeclaration(service))
 
     cfile.writeln(CPP_SERVICES_FOOTER.substitute(source=basename))
 
@@ -514,46 +572,46 @@
     else:
         raise Exception, "Unknown type "+str(ttype)
 
-def toReaderCall(value, ttype):
+def toReaderCall(value, ttype, reader="iprot"):
 
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, PrimitiveType):
-        return "xfer += iprot->read"+suffix+"(itrans, "+value+")"
+        return "xfer += "+reader+"->read"+suffix+"(itrans, "+value+")"
 
     elif isinstance(ttype, CollectionType):
-        return "xfer+= read_"+suffix+"(iprot, itrans, "+value+")"
+        return "xfer+= read_"+suffix+"("+reader+", itrans, "+value+")"
 
     elif isinstance(ttype, Struct):
-        return "xfer+= read_"+suffix+"(iprot, itrans, "+value+")"
+        return "xfer+= read_"+suffix+"("+reader+", itrans, "+value+")"
 
     elif isinstance(ttype, TypeDef):
-        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType)
+        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType, reader)
 
     elif isinstance(ttype, Enum):
-        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE)
+        return toReaderCall("reinterpret_cast<"+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE, reader)
 
     else:
         raise Exception, "Unknown type "+str(ttype)
 
-def toWriterCall(value, ttype):
+def toWriterCall(value, ttype, writer="oprot"):
 
     suffix = typeToIOMethodSuffix(ttype)
 
     if isinstance(ttype, PrimitiveType):
-        return "xfer+= oprot->write"+suffix+"(otrans, "+value+")"
+        return "xfer+= "+writer+"->write"+suffix+"(otrans, "+value+")"
 
     elif isinstance(ttype, CollectionType):
-        return "xfer+= write_"+suffix+"(oprot, otrans, "+value+")"
+        return "xfer+= write_"+suffix+"("+writer+", otrans, "+value+")"
 
     elif isinstance(ttype, Struct):
-        return "xfer+= write_"+suffix+"(oprot, otrans, "+value+")"
+        return "xfer+= write_"+suffix+"("+writer+", otrans, "+value+")"
 
     elif isinstance(ttype, TypeDef):
-        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType)
+        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(ttype.definitionType)+"&>("+value+")", ttype.definitionType, writer)
 
     elif isinstance(ttype, Enum):
-        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE)
+        return toWriterCall("reinterpret_cast<const "+typeToCTypeDeclaration(I32_TYPE)+"&>("+value+")", I32_TYPE, writer)
 
     else:
         raise Exception, "Unknown type "+str(ttype)
@@ -681,7 +739,7 @@
 
     std::string name;
     """+CPP_PROTOCOL_TTYPE+""" type;
-    uint16_t id;
+    int16_t id;
     uint32_t xfer = 0;
 
     while(true) {
@@ -830,8 +888,12 @@
     elif isinstance(ttype, Function):
 	result = toOrderedIOList(ttype.resultType, result)
 
-	for arg in ttype.argFieldList:
-	    result = toOrderedIOList(arg.type, result)
+	# skip the args struct itself and just order the arguments themselves
+	# we don't want the arg struct to be referred to until laters, since we need to
+	# inline those struct definitions with the implementation, not in the types header
+	
+	for field in ttype.argsStruct.fieldList:
+	    result = toOrderedIOList(field.type, result)
 
     else:
 	raise Exception, "Unsupported thrift type: "+str(ttype)
@@ -840,7 +902,7 @@
 
 def toIOMethodImplementations(program):
     
-    # get orderede list of all types that need marshallers:
+    # get ordered list of all types that need marshallers:
 
     iolist = toOrderedIOList(program)
 
@@ -850,6 +912,20 @@
 	result+= toReaderDefinition(ttype)
 	result+= toWriterDefinition(ttype)
 
+    # for all function argument lists we need to create both struct definitions
+    # and io methods.  We keep the struct definitions local, since they aren't part of the service
+    # API
+    # Note that we don't need to do a depth-first traverse of arg structs since they can only include fields
+    # we've already seen
+
+    for service in program.serviceMap.values():
+	for function in service.functionList:
+	    if len(function.argsStruct.fieldList) == 0:
+		continue
+	    result+= toStructDefinition(function.argsStruct)
+	    result+=toReaderDefinition(function.argsStruct)
+	    result+=toWriterDefinition(function.argsStruct)
+
     return result;
 
 def toImplementationSourceName(filename, genDir=None, debugp=None):
@@ -881,6 +957,8 @@
 
     cfile.write(toIOMethodImplementations(program))
 
+    cfile.write(toServerDefinition(program))
+
     cfile.writeln(CPP_IMPL_FOOTER.substitute(source=basename))
 
     cfile.close()
diff --git a/compiler/src/parser.py b/compiler/src/parser.py
index 1f45f86..d2fb0fd 100644
--- a/compiler/src/parser.py
+++ b/compiler/src/parser.py
@@ -300,10 +300,10 @@
 
     def assignId(field, currentId, ids):
 	'Finds the next available id number for a field'
-	id= currentId - 1
+	id = currentId - 1
 
 	while id in ids:
-	    id -= 1
+	    id-= 1
 	    
 	field.id = id
 
@@ -315,7 +315,7 @@
 	
     currentId = 0
 	
-    for fields in fieldList:
+    for field in fieldList:
 	if not field.id:
 	    currentId = assignId(field, currentId, ids)
 	
@@ -333,16 +333,16 @@
 
 class Function(Definition):
 
-    def __init__(self, symbols, name, resultType, argFieldList):
+    def __init__(self, symbols, name, resultType, argsStruct):
 	Definition.__init__(self, symbols, name)
 	self.resultType = resultType
-	self.argFieldList = argFieldList
+	self.argsStruct = argsStruct
 
     def validate(self):
-	validateFieldList(self.argFieldList)
+	validateFieldList(self.argsStruct.fieldList)
     
     def __str__(self):
-	return self.name+"("+string.join(map(lambda a: str(a), self.argFieldList), ", ")+") => "+str(self.resultType)
+	return self.name+"("+string.join(map(lambda a: str(a), self.argsStruct), ", ")+") => "+str(self.resultType)
 
 class Service(Definition):
 
@@ -499,7 +499,7 @@
 		except ErrorException, e:
 		    errors+= e.errors
 
-		for field in function.argFieldList:
+		for field in function.argsStruct.fieldList:
 		    try:
 			field.type = self.getType(function, field)
 		    except ErrorException, e:
@@ -662,10 +662,6 @@
 	except ErrorException, e:
 	    self.errors+= e.errors
 
-#    def p_definition_or_referencye_type_1(self, p):
-#       XXX need to all typedef struct foo foo_t by allowing references
-#	pass
-	    
     def p_enum(self, p):
 	'enum : ENUM ID LBRACE enumdeflist RBRACE'
 	self.pdebug("p_enum", p)
@@ -728,7 +724,7 @@
     def p_function(self, p):
 	'function : functiontype functionmodifiers ID LPAREN fieldlist RPAREN'
 	self.pdebug("p_function", p)
-	p[0] = Function(p, p[3], p[1], p[5])
+	p[0] = Function(p, p[3], p[1], Struct(p, p[3]+"_args", p[5]))
 	try:
 	    p[0].validate()
 	except ErrorException, e: