From bf4fd1996323e104b79532587b5dd392d92a83fa Mon Sep 17 00:00:00 2001 From: Marc Slemko Date: Tue, 15 Aug 2006 21:29:39 +0000 Subject: [PATCH] Modified C++ code-gen to create default constructors for all non-string primitives so that auto variable instances of structs aren't populated with garbage. This matters because, given thrift's loosey-goosey argument and result lists, structs may only be sparsely filled. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@664757 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/src/cpp_generator.py | 92 +++++++++++++++++++++++------------ compiler/src/parser.py | 4 +- test/ThriftTest.thrift | 8 +-- test/cpp/Makefile | 12 ++--- test/cpp/src/TestClient.cc | 43 +++++++++++----- test/cpp/src/TestServer.cc | 54 +++++++++++++++++--- 6 files changed, 147 insertions(+), 66 deletions(-) diff --git a/compiler/src/cpp_generator.py b/compiler/src/cpp_generator.py index 8003f1e7..fd25e17c 100644 --- a/compiler/src/cpp_generator.py +++ b/compiler/src/cpp_generator.py @@ -217,15 +217,28 @@ def toStructDefinition(struct): result = "struct "+struct.name+" {\n" - for field in struct.fieldList: - if toCanonicalType(field.type) != VOID_TYPE: - result += " "+toCTypeDeclaration(field)+";\n" + # Create constructor defaults for primitive types - result+= " struct {\n" + ctorValues = string.join([field.name+"(0)" for field in struct.fieldList if isinstance(toCanonicalType(field.type), PrimitiveType) and toCanonicalType(field.type) not in [STRING_TYPE, UTF8_TYPE, UTF16_TYPE, VOID_TYPE]], ", ") - for field in struct.fieldList: - result+= " bool "+field.name+";\n" - result+= " } __isset;\n" + if len(ctorValues) > 0: + result+= " "+struct.name+"() : "+ctorValues+ " {}\n" + + # Field declarations + + result+= string.join([" "+toCTypeDeclaration(field)+";\n" for field in struct.fieldList if toCanonicalType(field.type) != VOID_TYPE], "") + + # is-field-set struct and ctor + + ctorValues = string.join([field.name+"(false)" for field in struct.fieldList if toCanonicalType(field.type) != VOID_TYPE], ", ") + + if len(ctorValues) > 0: + result+= " struct __isset {\n" + result+= " __isset() : "+ctorValues+" {}\n" + result+= string.join([" bool "+field.name+";\n" for field in struct.fieldList if toCanonicalType(field.type) != VOID_TYPE], "") + result+= " } __isset;\n" + + # bring it on home result+= "};\n" @@ -337,7 +350,7 @@ void ${service}ServerIf::process_${function}(uint32_t seqid, """+CPP_TRANSPORTP+ ${resultStructDeclaration}; - ${functionCall}; + ${functionCall} _oprot->writeMessageBegin(otrans, \"${function}\", """+CPP_PROTOCOL_REPLY+""", seqid); @@ -446,11 +459,8 @@ ${argsToStruct}; _iprot->readMessageEnd(_itrans); - if(__result.__isset.success) { - ${success} - } ${error} else { - throw """+CPP_EXCEPTION+"""(\"${function} failed: unknown result"); - } +${returnResult} + throw """+CPP_EXCEPTION+"""(\"${function} failed: unknown result"); } """) @@ -483,11 +493,9 @@ def toServerFunctionDefinition(servicePrefix, function, debugp=None): resultStructWriter = toWriterCall("__result", function.resultStruct, "_oprot") - - - if function.returnType() != VOID_TYPE: + if toCanonicalType(function.returnType()) != VOID_TYPE: functionCallPrefix= "__result.success = " - functionCallSuffix = "__result.__isset.success = true;" + functionCallSuffix = "\n __result.__isset.success = true;" else: functionCallPrefix = "" functionCallSuffix = "" @@ -497,9 +505,10 @@ def toServerFunctionDefinition(servicePrefix, function, debugp=None): exceptions = function.exceptions() if len(exceptions) > 0: - functionCallPrefix= "try {"+functionCallPrefix + functionCallPrefix= "try {\n "+functionCallPrefix - functionCallSuffix = functionCallSuffix+"} "+string.join(["catch("+toCTypeDeclaration(exceptions[ix].type)+"& e"+str(ix)+") {__result."+exceptions[ix].name+" = e"+str(ix)+"; __result.__isset."+exceptions[ix].name+" = true;}" for ix in range(len(exceptions))], "") + functionCallSuffix+= "\n }"+string.join([" catch("+toCTypeDeclaration(exceptions[ix].type)+"& e"+str(ix)+") {\n __result."+exceptions[ix].name+" = e"+str(ix)+";\n __result.__isset."+exceptions[ix].name+" = true;\n }" + for ix in range(len(exceptions))], "") functionCall = functionCallPrefix+functionCall+functionCallSuffix @@ -520,7 +529,7 @@ def toServerServiceDefinition(service, debugp=None): result+= toServerFunctionDefinition(service.name, function, debugp) - callProcessSwitch = " if"+string.join(["(name.compare(\""+function.name+"\") == 0) { process_"+function.name+"(seqid, itrans, otrans);}" for function in service.functionList], "\n else if")+"\n else {throw "+CPP_EXCEPTION+"(\"Unknown function name \\\"\"+name+\"\\\"\");}" + callProcessSwitch = " if"+string.join(["(name.compare(\""+function.name+"\") == 0) { process_"+function.name+"(seqid, itrans, otrans);\n}" for function in service.functionList], "\n else if")+" else {throw "+CPP_EXCEPTION+"(\"Unknown function name \\\"\"+name+\"\\\"\");}" result+= CPP_SERVER_PROCESS_DEFINITION.substitute(service=service.name, callProcessSwitch=callProcessSwitch) @@ -539,6 +548,8 @@ def toClientDeclaration(service, debugp=None): def toClientFunctionDefinition(servicePrefix, function, debugp=None): """Converts a thrift service method declaration to a client stub implementation""" + isVoid = toCanonicalType(function.returnType()) == VOID_TYPE + returnDeclaration = toCTypeDeclaration(function.returnType()) argsDeclaration = string.join([toCTypeDeclaration(function.args()[ix].type)+" __arg"+str(ix) for ix in range(len(function.args()))], ", ") @@ -558,18 +569,27 @@ def toClientFunctionDefinition(servicePrefix, function, debugp=None): resultStructReader = toReaderCall("__result", function.resultStruct, "_iprot", "_itrans") - if(toCanonicalType(function.returnType()) != VOID_TYPE): - - success = "return __result.success;" - else: - success = "" - exceptions = function.exceptions() + """ void return and non-void returns require very different arrangments. For void returns, we don't actually expect + anything to have been sent from the remote side, therefore we need to check for explicit exception returns first, + then, if none were encountered, return. For void we test for success first - since this is the most likey result - + and then check for exceptions. In both cases, we need to handle the case where there are no specified exceptions. """ + if len(exceptions) > 0: - error = "else if "+string.join(["(__result.__isset."+exception.name+") { throw __result."+exception.name+";}" for exception in exceptions], "else if") + errors= ["if(__result.__isset."+exception.name+") {\n throw __result."+exception.name+";\n }" for exception in exceptions] + else: + errors = [] + + if not isVoid: + returnResult = " if(__result.__isset.success) {\n return __result.success;\n}" + if len(errors) > 0: + returnResult+= " else "+string.join(errors, " else ") else: - error = "" + if len(errors) > 0: + returnResult= " "+string.join(errors, " else ")+" else {\n return;\n }\n" + else: + returnResult=" return;\n" return CPP_CLIENT_FUNCTION_DEFINITION.substitute(service=servicePrefix, function=function.name, @@ -581,8 +601,7 @@ def toClientFunctionDefinition(servicePrefix, function, debugp=None): argsToStruct=argsToStruct, resultStructDeclaration=resultStructDeclaration, resultStructReader=resultStructReader, - success=success, - error=error) + returnResult=returnResult) def toClientServiceDefinition(service, debugp=None): """Converts a thrift service definition to a client stub implementation""" @@ -1025,6 +1044,8 @@ def toStructReaderDefinition(struct): fieldSwitch="" for field in fieldList: + if toCanonicalType(field.type) == VOID_TYPE: + continue; fieldSwitch+= " case "+str(field.id)+": " fieldSwitch+= toReaderCall("value."+field.name, field.type)+"; value.__isset."+field.name+" = true; break;\n" @@ -1042,7 +1063,8 @@ def toStructWriterDefinition(struct): type=toWireType(toCanonicalType(field.type)), id=field.id, fieldWriterCall=toWriterCall("value."+field.name, field.type))+";" - for field in struct.fieldList] + for field in struct.fieldList if toCanonicalType(field.type) != VOID_TYPE + ] fieldWriterCalls = " "+string.join(fieldWriterCalls, "\n ") @@ -1064,6 +1086,7 @@ ${fieldWriterCalls} def toResultStructReaderDefinition(struct): """Converts internal results struct to a reader function definition""" + return toStructReaderDefinition(struct) def toResultStructWriterDefinition(struct): @@ -1071,7 +1094,12 @@ def toResultStructWriterDefinition(struct): suffix = typeToIOMethodSuffix(struct) - fieldWriterCalls = ["if(value.__isset."+field.name+") { "+toWriterCall("value."+field.name, field.type)+";}" for field in struct.fieldList] + fieldWriterCalls = ["if(value.__isset."+field.name+") { "+ + CPP_WRITE_FIELD_DEFINITION.substitute(name=field.name, + type=toWireType(toCanonicalType(field.type)), + id=field.id, + fieldWriterCall=toWriterCall("value."+field.name, field.type))+";}" + for field in struct.fieldList if toCanonicalType(field.type) != VOID_TYPE] fieldWriterCalls = " "+string.join(fieldWriterCalls, "\n else ") diff --git a/compiler/src/parser.py b/compiler/src/parser.py index 8bd96f2e..3b0bdc17 100644 --- a/compiler/src/parser.py +++ b/compiler/src/parser.py @@ -122,8 +122,8 @@ class PrimitiveType(Type): STOP_TYPE = PrimitiveType("stop") VOID_TYPE = PrimitiveType("void") BOOL_TYPE = PrimitiveType("bool") -STRING_TYPE =PrimitiveType("utf7") -UTF7_TYPE = PrimitiveType("utf7") +STRING_TYPE = PrimitiveType("utf7") +UTF7_TYPE = STRING_TYPE UTF8_TYPE = PrimitiveType("utf8") UTF16_TYPE = PrimitiveType("utf16") BYTE_TYPE = PrimitiveType("u08") diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift index ac0d95c0..50021ba3 100644 --- a/test/ThriftTest.thrift +++ b/test/ThriftTest.thrift @@ -33,12 +33,12 @@ struct Insanity list xtructs = 1 } -exception Exception { +exception Xception { u32 errorCode, string message } -exception Xception { +exception Xception2 { u32 errorCode, Xtruct struct_thing } @@ -74,9 +74,9 @@ service ThriftTest /* Exception specifier */ - Xtruct testException(string arg) throws(Exception err1), + void testException(string arg) throws(Xception err1), /* Multiple exceptions specifier */ - Xtruct testMultiException(string arg0, string arg1) throws(Exception err1=1, Xception err2) + Xtruct testMultiException(string arg0, string arg1) throws(Xception err1=1, Xception2 err2) } diff --git a/test/cpp/Makefile b/test/cpp/Makefile index daaa15ea..cec9699b 100644 --- a/test/cpp/Makefile +++ b/test/cpp/Makefile @@ -22,18 +22,12 @@ include_paths = $(thrift_home)/include/thrift \ include_flags = $(patsubst %,-I%, $(include_paths)) # Tools -THRIFT = python ../../compiler/src/thrift.py ~/ws/thrift/dev/test/ThriftTest.thrift --cpp +THRIFT = python ../../compiler/src/thrift.py ~/ws/thrift/dev/test/ThriftTest.thrift CC = g++ LD = g++ # Compiler flags -LIBS = ../../lib/cpp/src/server/TSimpleServer.cc \ - ../../lib/cpp/src/protocol/TBinaryProtocol.cc \ - ../../lib/cpp/src/transport/TBufferedTransport.cc \ - ../../lib/cpp/src/transport/TChunkedTransport.cc \ - ../../lib/cpp/src/transport/TServerSocket.cc \ - ../../lib/cpp/src/transport/TSocket.cc -DCFL = -Wall -O3 -g -I../cpp-gen $(include_flags) $(LIBS) +DCFL = -Wall -O3 -g -I../cpp-gen $(include_flags) -L$(thrift_home)/lib -lthrift CFL = -Wall -O3 -I../cpp-gen $(include_flags) -L$(thrift_home)/lib -lthrift all: server client @@ -41,7 +35,7 @@ all: server client debug: server-debug client-debug stubs: ../ThriftTest.thrift - $(THRIFT) -cpp ../ThriftTest.thrift + $(THRIFT) --cpp ../ThriftTest.thrift server-debug: stubs g++ -o TestServer $(DCFL) src/TestServer.cc ../cpp-gen/ThriftTest.cc diff --git a/test/cpp/src/TestClient.cc b/test/cpp/src/TestClient.cc index e977442c..42272d31 100644 --- a/test/cpp/src/TestClient.cc +++ b/test/cpp/src/TestClient.cc @@ -13,6 +13,7 @@ using namespace std; using namespace facebook::thrift; using namespace facebook::thrift::protocol; using namespace facebook::thrift::transport; +using namespace thrift::test; //extern uint32_t g_socket_syscalls; @@ -96,8 +97,8 @@ int main(int argc, char** argv) { * I64 TEST */ printf("testI64(-34359738368)"); - int64_t i64 = testClient.testI64(-34359738368); - printf(" = %ld\n", i64); + int64_t i64 = testClient.testI64(-34359738368LL); + printf(" = %lld\n", i64); /** * STRUCT TEST @@ -109,7 +110,7 @@ int main(int argc, char** argv) { out.i32_thing = -3; out.i64_thing = -5; Xtruct in = testClient.testStruct(out); - printf(" = {\"%s\", %d, %d, %ld}\n", + printf(" = {\"%s\", %d, %d, %lld}\n", in.string_thing.c_str(), (int)in.byte_thing, in.i32_thing, @@ -125,7 +126,7 @@ int main(int argc, char** argv) { out2.i32_thing = 5; Xtruct2 in2 = testClient.testNest(out2); in = in2.struct_thing; - printf(" = {%d, {\"%s\", %d, %d, %ld}, %d}\n", + printf(" = {%d, {\"%s\", %d, %d, %lld}, %d}\n", in2.byte_thing, in.string_thing.c_str(), (int)in.byte_thing, @@ -256,8 +257,8 @@ int main(int argc, char** argv) { * TYPEDEF TEST */ printf("testTypedef(309858235082523)"); - UserId uid = testClient.testTypedef(309858235082523); - printf(" = %ld\n", uid); + UserId uid = testClient.testTypedef(309858235082523LL); + printf(" = %lld\n", uid); /** * NESTED MAP TEST @@ -292,7 +293,7 @@ int main(int argc, char** argv) { printf(" = {"); map >::const_iterator i_iter; for (i_iter = whoa.begin(); i_iter != whoa.end(); ++i_iter) { - printf("%ld => {", i_iter->first); + printf("%lld => {", i_iter->first); map::const_iterator i2_iter; for (i2_iter = i_iter->second.begin(); i2_iter != i_iter->second.end(); @@ -302,7 +303,7 @@ int main(int argc, char** argv) { map::const_iterator um; printf("{"); for (um = userMap.begin(); um != userMap.end(); ++um) { - printf("%d => %ld, ", um->first, um->second); + printf("%d => %lld, ", um->first, um->second); } printf("}, "); @@ -310,7 +311,7 @@ int main(int argc, char** argv) { list::const_iterator x; printf("{"); for (x = xtructs.begin(); x != xtructs.end(); ++x) { - printf("{\"%s\", %d, %d, %ld}, ", + printf("{\"%s\", %d, %d, %lld}, ", x->string_thing.c_str(), (int)x->byte_thing, x->i32_thing, @@ -324,9 +325,29 @@ int main(int argc, char** argv) { } printf("}\n"); - uint64_t stop = now(); - printf("Total time: %lu us\n", stop-start); + /* test multi exception */ + try { + Xtruct result = testClient.testMultiException("Xception", "test 1"); + + } catch(Xception& e) { + printf("testClient.testMulticException(\"Xception\", \"test 1\") => {%u, \"%s\"}\n", e.errorCode, e.message.c_str()); + } + + try { + Xtruct result = testClient.testMultiException("Xception2", "test 2"); + + } catch(Xception2& e) { + printf("testClient.testMultiException(\"Xception2\", \"test 2\") => {%u, {\"%s\"}}\n", e.errorCode, e.struct_thing.string_thing.c_str()); + } + + Xtruct result = testClient.testMultiException("success", "test 3"); + + printf("testClient.testMultiException(\"success\", \"test 3\") => {{\"%s\"}}\n", result.string_thing.c_str()); + + uint64_t stop = now(); + printf("Total time: %llu us\n", stop-start); + bufferedSocket->close(); } diff --git a/test/cpp/src/TestServer.cc b/test/cpp/src/TestServer.cc index 27a983c6..4df37af1 100644 --- a/test/cpp/src/TestServer.cc +++ b/test/cpp/src/TestServer.cc @@ -17,6 +17,8 @@ using namespace facebook::thrift::protocol; using namespace facebook::thrift::transport; using namespace facebook::thrift::server; +using namespace thrift::test; + class TestServer : public ThriftTestServerIf { public: TestServer(shared_ptr protocol) : @@ -42,12 +44,12 @@ class TestServer : public ThriftTestServerIf { } int64_t testI64(int64_t thing) { - printf("testI64(%ld)\n", thing); + printf("testI64(%lld)\n", thing); return thing; } Xtruct testStruct(Xtruct thing) { - printf("testStruct({\"%s\", %d, %d, %ld})\n", + printf("testStruct({\"%s\", %d, %d, %lld})\n", thing.string_thing.c_str(), (int)thing.byte_thing, thing.i32_thing, @@ -57,7 +59,7 @@ class TestServer : public ThriftTestServerIf { Xtruct2 testNest(Xtruct2 nest) { Xtruct thing = nest.struct_thing; - printf("testNest({%d, {\"%s\", %d, %d, %ld}, %d})\n", + printf("testNest({%d, {\"%s\", %d, %d, %lld}, %d})\n", (int)nest.byte_thing, thing.string_thing.c_str(), (int)thing.byte_thing, @@ -121,7 +123,7 @@ class TestServer : public ThriftTestServerIf { } UserId testTypedef(UserId thing) { - printf("testTypedef(%ld)\n", thing); + printf("testTypedef(%lld)\n", thing); return thing; } @@ -181,7 +183,7 @@ class TestServer : public ThriftTestServerIf { printf(" = {"); map >::const_iterator i_iter; for (i_iter = insane.begin(); i_iter != insane.end(); ++i_iter) { - printf("%ld => {", i_iter->first); + printf("%lld => {", i_iter->first); map::const_iterator i2_iter; for (i2_iter = i_iter->second.begin(); i2_iter != i_iter->second.end(); @@ -191,7 +193,7 @@ class TestServer : public ThriftTestServerIf { map::const_iterator um; printf("{"); for (um = userMap.begin(); um != userMap.end(); ++um) { - printf("%d => %ld, ", um->first, um->second); + printf("%d => %lld, ", um->first, um->second); } printf("}, "); @@ -199,7 +201,7 @@ class TestServer : public ThriftTestServerIf { list::const_iterator x; printf("{"); for (x = xtructs.begin(); x != xtructs.end(); ++x) { - printf("{\"%s\", %d, %d, %ld}, ", + printf("{\"%s\", %d, %d, %lld}, ", x->string_thing.c_str(), (int)x->byte_thing, x->i32_thing, @@ -227,6 +229,42 @@ class TestServer : public ThriftTestServerIf { return hello; } + + virtual void testException(std::string arg) throw(struct thrift::test::Xception) { + printf("testException(%s)\n", arg.c_str()); + if(arg.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = "This is an Xception"; + throw e; + } else { + Xtruct result; + result.string_thing = arg; + return; + } + } + + virtual struct Xtruct testMultiException(std::string arg0, std::string arg1) throw(struct Xception, struct Xception2) { + + printf("testException(%s, %s)\n", arg0.c_str(), arg1.c_str()); + + if(arg0.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = "This is an Xception"; + throw e; + + } else if(arg0.compare("Xception2") == 0) { + Xception2 e; + e.errorCode = 2002; + e.struct_thing.string_thing = "This is an Xception2"; + throw e; + } else { + Xtruct result; + result.string_thing = arg1; + return result; + } + } }; int main(int argc, char **argv) { @@ -241,7 +279,7 @@ int main(int argc, char **argv) { usage << argv[0] << " [--port=] [--server-type=] [--protocol-type=] [--workers=]" << endl << - "\t\tserver-type\t\ttype of server, \"simple-server\" or \"thread-pool\". Default is " << serverType << endl << + "\t\tserver-type\t\ttype of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl << "\t\tprotocol-type\t\ttype of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << -- 2.17.1