From: Gavin McDonald Date: Thu, 28 Oct 2010 02:12:01 +0000 (+0000) Subject: Thrift now a TLP - INFRA-3116 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=0b75e1ac7643787e201fd62628823e6d51ca6353;p=common%2Fthrift.git Thrift now a TLP - INFRA-3116 git-svn-id: https://svn.apache.org/repos/asf/thrift/branches/0.1.x@1028168 13f79535-47bb-0310-9956-ffa450edef68 --- 0b75e1ac7643787e201fd62628823e6d51ca6353 diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..793d695b --- /dev/null +++ b/.gitignore @@ -0,0 +1,88 @@ +/Makefile +/Makefile.in +/aclocal.m4 +/autom4te.cache +/autoscan.log +/compiler/cpp/.deps +/compiler/cpp/Makefile +/compiler/cpp/Makefile.in +/compiler/cpp/thrift +/compiler/cpp/*.o +/compiler/cpp/thriftl.cc +/compiler/cpp/thrifty.cc +/compiler/cpp/thrifty.h +/compiler/cpp/version.h +/config.* +/configure +/configure.lineno +/configure.scan +/depcomp +/if/Makefile +/if/Makefile.in +/install-sh +/lib/Makefile +/lib/Makefile.in +/lib/erl/ebin +/lib/cpp/.deps +/lib/cpp/.libs +/lib/cpp/Makefile +/lib/cpp/Makefile.in +/lib/cpp/concurrency_test +/lib/cpp/*.o +/lib/cpp/*.la +/lib/cpp/*.lo +/lib/cpp/*.pc +/lib/csharp/Makefile +/lib/csharp/Makefile.in +/lib/java/Makefile +/lib/java/Makefile.in +/lib/java/build +/lib/java/gen-java +/lib/java/gen-javabean +/lib/java/libthrift.jar +/lib/perl/MANIFEST +/lib/perl/Makefile +/lib/perl/Makefile.in +/lib/perl/Makefile-perl.mk +/lib/perl/blib +/lib/perl/pm_to_blib +/lib/perl/test/Makefile +/lib/perl/test/Makefile.in +/lib/perl/test/gen-perl +/lib/py/Makefile +/lib/py/Makefile.in +/lib/py/build +/lib/rb/Makefile +/lib/rb/Makefile.in +/libtool +/ltmain.sh +/missing +/stamp-h1 +/test/.deps +/test/.libs +/test/*.o +/test/*.la +/test/*.lo +/test/Benchmark +/test/DebugProtoTest +/test/JSONProtoTest +/test/TFDTransportTest +/test/TPipedTransportTest +/test/UnitTests +/test/Makefile +/test/Makefile.in +/test/OptionalRequiredTest +/test/ReflectionTest +/test/gen-cpp +/test/java/Makefile +/test/java/Makefile.am +/test/java/Makefile.in +/test/java/build +/test/java/gen-java +/test/java/thrifttest.jar +/test/py/Makefile +/test/py/Makefile.in +/test/py/gen-py +/test/rb/Makefile +/test/rb/Makefile.in +/ylwrap diff --git a/CHANGES b/CHANGES new file mode 100644 index 00000000..e4d81933 --- /dev/null +++ b/CHANGES @@ -0,0 +1,35 @@ +Thrift Changelog + +Version 0.1.0 RC1 / Unreleased + +Compatibility Breaking Changes: + C++: + * It's quite possible that regenerating code and rebuilding will be + required. Make sure your headers match your libs! + + Java: + + Python: + + Ruby: + * Generated files now have underscored names [THRIFT-421] + * The library has been rearranged to be more Ruby-like [THRIFT-276] + + Erlang: + * Generated code will have to be regenerated, and the new code will + have to be deployed atomically with the new library code [THRIFT-136] + + +New Features and Bug Fixes: + C++: + * Support for TCompactProtocol [THRIFT-333] + + Java: + * Support for TCompactProtocol [THRIFT-110] + + Python: + * Support for Twisted [THRIFT-148] + + Ruby: + * Support for TCompactProtocol [THRIFT-332] + diff --git a/CONTRIBUTORS b/CONTRIBUTORS new file mode 100644 index 00000000..fd954f8d --- /dev/null +++ b/CONTRIBUTORS @@ -0,0 +1,77 @@ +Chad Walters +-TJSONProtocol for C++ and Java + +Nitay +-Support for "make check" + +William Morgan +-Miscellaneous Ruby improvements + +Ben Maurer +-Restructuring the way Autoconf is used + +Patrick Collison +-Smalltalk bindings + +Dave Simpson +-Better support for connection tracking in the C++ server +-Miscellaneous fixes + +Igor Afanasyev +-Perl HttpClient and bugfixes + +Todd Berman +-MinGW port of the compiler +-C# bindings +-MS build task + +---------------- +Release 20070917 +---------------- + +Dave Engberg +-JavaBean/JavaDoc enhancements + +Andrew McGeachie +-Cocoa/Objective-C support + +Ben Maurer +-Python performance enhancements, fastbinary support + +Andrew Lutomirski +-Added optional/required keywords for C++ objects +-Added comparison operators for C++ structs + +Johan Oskarsson +-Java findbugs compliance fixes + +Paul Saab +-IPv6 support for TSocket in C++/Python + +Kevin Clark +-Significant overhaul of Ruby code generation and servers + +Simon Forman +-TProcessorFactory abstraction for Java servers + +Jake Luciani +-Perl code generation, libraries, test code + +David Reiss +-strings.h include fix for bzero +-endianness fixes on TBinaryProtocol double serialization +-improved ntohll,htonll implementation + +Dan Li +-Java TestServer and Tutorial Fixes + +Kevin Ko +-Fix for unnecessary std::string copy construction in Protocol/Exception + +Paul Querna +-Autoconf error message fix for libevent detection +-clock_gettime implementation for OSX + +---------------- +Release 20070401 +---------------- diff --git a/DISCLAIMER b/DISCLAIMER new file mode 100644 index 00000000..de7e7ea2 --- /dev/null +++ b/DISCLAIMER @@ -0,0 +1,6 @@ +Apache Thrift is an effort undergoing incubation at The Apache Software Foundation (ASF), +sponsored by the Incubator PMC. Incubation is required of all newly accepted projects +until a further review indicates that the infrastructure, communications, and decision +making process have stabilized in a manner consistent with other successful ASF projects. +While incubation status is not necessarily a reflection of the completeness or stability +of the code, it does indicate that the project has yet to be fully endorsed by the ASF. diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile.am b/Makefile.am new file mode 100644 index 00000000..bd566010 --- /dev/null +++ b/Makefile.am @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +ACLOCAL_AMFLAGS = -I ./aclocal + +SUBDIRS = compiler/cpp lib test + +dist-hook: + find $(distdir) -name '.[a-zA-Z0-9]*' | xargs rm -rf + +EXTRA_DIST = bootstrap.sh cleanup.sh print_version.sh doc tutorial contrib \ + CONTRIBUTORS LICENSE CHANGES DISCLAIMER NOTICE diff --git a/NEWS b/NEWS new file mode 100644 index 00000000..2369f61d --- /dev/null +++ b/NEWS @@ -0,0 +1,79 @@ +Release Notes for Thrift 20080411 + +.equals and .hashCode() for Java scturcts (developed by dreiss). + +Improvments to the C++ TSocketPool (developed by akhil). + +PHP (de)serialization extension (developed by dweatherford). + +Add fb303 to contrib (developed by Facebook). + +TJSONProtocol for C++ and Java (contributed by Chad Walters of Powerset). + +Support for "make check" and better tests (contributed by Nitay). + +Smalltalk support (contributed by Patrick Collison). + +Dave Simpson +Better support for connection tracking in the C++ server (contributed by +Dave Simpson of Powerset). + +Perl HttpClient (contributed by Igor Afanasyev of Evernote). + +C# support (contributed by Todd Berman of imeem). + +MinGW port of the compiler (contributed by Todd Berman of imeem). + +Tons of small improvements and bug fixes. + + +Release Notes for Thrift 20070917 + +TBinaryProtocol now includes a protocol version number in messaged. +This is a non-backwards-compatible change. Please see the +TBinaryProtocol constructor for strictRead_ and strictWrite_. + +Add binary type to support non-text "strings" in Java. + +TSocketPool for C++ (developed by jsobel). + +Syntax highlighting for vim and emacs (developed by hzhao and martin). + +Perl support (contributed by Jake Luciani). + +Erlang support (developed by cpiro). + +Ruby API overhaul (contributed by Kevin Clark of Powerset). + +IPv6 support in C++ and Python (contributed by Paul Saab of Powerset). + +Read/Write locks (developed by boz). + +OCaml support (developed by iproctor). + +Human-readable strings from Thrift structures in C++ (developed by dreiss). + +Haskell support (developed by iproctor). + +Support for optional fields in C++ (contributed by Andy Lutomirsky). + +Support for operator== for Thrift structures (contributed by Andy Lutomirsky). + +Python/C module for fast (de)serialization (contributed by Ben Maurer). + +Limited reflection for C++ services (developed by dreiss). + +Python library installation defaults to /usr, override with PY_PREFIX. + +Support for Javabean-style Java classes (contributed by Dave Engberg). + +TDenseProtocol for C++ (experimental way to shrink structures) +(developed by dreiss). + +Cocoa/Objective C support (contributed by Andrew McGeachie). + +Thrift now builds without libevent. + +TZlibTransport for C++ (compress serialized structures) (developed by dreiss). + +Tons of small improvements and bug fixes. diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..b5186316 --- /dev/null +++ b/NOTICE @@ -0,0 +1,26 @@ +Apache Thrift +Copyright 2006-2009 The Apache Software Foundation, et al. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Some files in this distribution are distributed under different terms +from the rest of Apache Thrift. Please see individual files for +license information. + +In addition, the following unlabelled files are distributed under +specific terms. Please see the "doc" directory for the text of their +licenses. + + lib/rb/setup.rb: GNU Lesser General Public License 2.1 (lgpl-2.1.txt) + lib/ocaml/OCamlMakefile: GNU Lesser General Public License 2.1 (lgpl-2.1.txt) + lib/ocaml/README-OCamlMakefile: GNU Lesser General Public License 2.1 (lgpl-2.1.txt) + lib/erl/build/beamver: MIT License (otp-base-license.txt) + lib/erl/build/buildtargets.mk: MIT License (otp-base-license.txt) + lib/erl/build/colors.mk: MIT License (otp-base-license.txt) + lib/erl/build/docs.mk: MIT License (otp-base-license.txt) + lib/erl/build/mime.types: MIT License (otp-base-license.txt) + lib/erl/build/otp.mk: MIT License (otp-base-license.txt) + lib/erl/build/otp_subdir.mk: MIT License (otp-base-license.txt) + lib/erl/build/raw_test.mk: MIT License (otp-base-license.txt) + lib/erl/src/Makefile: MIT License (otp-base-license.txt) diff --git a/README b/README new file mode 100644 index 00000000..adf8af47 --- /dev/null +++ b/README @@ -0,0 +1,137 @@ +Apache Thrift (an Apache Incubator project) + +Last Modified: 2009-Jan-30 + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Introduction +============ + +Thrift is a lightweight, language-independent software stack with an +associated code generation mechanism for RPC. Thrift provides clean +abstractions for data transport, data serialization, and application +level processing. The code generation system takes a simple definition +language as its input and generates code across programming languages that +uses the abstracted stack to build interoperable RPC clients and servers. + +Thrift is specifically designed to support non-atomic version changes +across client and server code. + +For more details on Thrift's design and implementation, take a gander at +the Thrift whitepaper included in this distribution or at the README files +in your particular subdirectory of interest. + +Heirarchy +========= + +thrift/ + + compiler/ + Contains the Thrift compiler, implemented in C++. + + lib/ + Contains the Thrift software library implementation, subdivided by + language of implementation. + + cpp/ + java/ + php/ + py/ + rb/ + + test/ + + Contains sample Thrift files and test code across the target programming + languages. + + tutorial/ + + Contains a basic tutorial that will teach you how to develop software + using Thrift. + +Requirements +============ + +See http://wiki.apache.org/thrift/ThriftRequirements for +an up-to-date list of build requirements. + +Resources +========= + +More information about Thrift can be obtained on the Thrift webpage at: + + http://incubator.apache.org/thrift + +Acknowledgments +=============== + +Thrift was inspired by pillar, a lightweight RPC tool written by Adam D'Angelo, +and also by Google's protocol buffers. + +Installation +============ + +If you are building from the first time out of the source repository, you will +need to generate the configure scripts. (This is not necessary if you +downloaded a tarball.) From the top directory, do: + + ./bootstrap.sh + +Once the configure scripts are generated, thrift can be configured. +From the top directory, do: + + ./configure + +You may need to specify the location of the boost files explicitly. +If you installed boost in /usr/local, you would run configure as follows: + + ./configure --with-boost=/usr/local + +Note that by default the thrift C++ library is typically built with debugging +symbols included. If you want to customize these options you should use the +CXXFLAGS option in configure, as such: + + ./configure CXXFLAGS='-g -O2' + ./configure CFLAGS='-g -O2' + ./configure CPPFLAGS='-DDEBUG_MY_FEATURE' + +Run ./configure --help to see other configuration options + +Please be aware that the Python library will ignore the --prefix option +and just install wherever Python's distutils puts it (usually along +the lines of /usr/lib/pythonX.Y/site-packages/). If you need to control +where the Python modules are installed, set the PY_PREFIX variable. +(DESTDIR is respected for Python and C++.) + +Make thrift: + + make + +From the top directory, become superuser and do: + + make install + +Note that some language packages must be installed manually using build tools +better suited to those languages (at the time of this writing, this applies +to Java, Ruby, PHP). + +Look for the README file in the lib// folder for more details on the +installation of each language library package. diff --git a/aclocal/ax_boost_base.m4 b/aclocal/ax_boost_base.m4 new file mode 100644 index 00000000..e56bb738 --- /dev/null +++ b/aclocal/ax_boost_base.m4 @@ -0,0 +1,198 @@ +##### http://autoconf-archive.cryp.to/ax_boost_base.html +# +# SYNOPSIS +# +# AX_BOOST_BASE([MINIMUM-VERSION]) +# +# DESCRIPTION +# +# Test for the Boost C++ libraries of a particular version (or newer) +# +# If no path to the installed boost library is given the macro +# searchs under /usr, /usr/local, /opt and /opt/local and evaluates +# the $BOOST_ROOT environment variable. Further documentation is +# available at . +# +# This macro calls: +# +# AC_SUBST(BOOST_CPPFLAGS) / AC_SUBST(BOOST_LDFLAGS) +# +# And sets: +# +# HAVE_BOOST +# +# LAST MODIFICATION +# +# 2007-07-28 +# +# COPYLEFT +# +# Copyright (c) 2007 Thomas Porschberg +# +# Copying and distribution of this file, with or without +# modification, are permitted in any medium without royalty provided +# the copyright notice and this notice are preserved. + +AC_DEFUN([AX_BOOST_BASE], +[ +AC_ARG_WITH([boost], + AS_HELP_STRING([--with-boost@<:@=DIR@:>@], [use boost (default is yes) - it is possible to specify the root directory for boost (optional)]), + [ + if test "$withval" = "no"; then + want_boost="no" + elif test "$withval" = "yes"; then + want_boost="yes" + ac_boost_path="" + else + want_boost="yes" + ac_boost_path="$withval" + fi + ], + [want_boost="yes"]) + +if test "x$want_boost" = "xyes"; then + boost_lib_version_req=ifelse([$1], ,1.20.0,$1) + boost_lib_version_req_shorten=`expr $boost_lib_version_req : '\([[0-9]]*\.[[0-9]]*\)'` + boost_lib_version_req_major=`expr $boost_lib_version_req : '\([[0-9]]*\)'` + boost_lib_version_req_minor=`expr $boost_lib_version_req : '[[0-9]]*\.\([[0-9]]*\)'` + boost_lib_version_req_sub_minor=`expr $boost_lib_version_req : '[[0-9]]*\.[[0-9]]*\.\([[0-9]]*\)'` + if test "x$boost_lib_version_req_sub_minor" = "x" ; then + boost_lib_version_req_sub_minor="0" + fi + WANT_BOOST_VERSION=`expr $boost_lib_version_req_major \* 100000 \+ $boost_lib_version_req_minor \* 100 \+ $boost_lib_version_req_sub_minor` + AC_MSG_CHECKING(for boostlib >= $boost_lib_version_req) + succeeded=no + + dnl first we check the system location for boost libraries + dnl this location ist chosen if boost libraries are installed with the --layout=system option + dnl or if you install boost with RPM + if test "$ac_boost_path" != ""; then + BOOST_LDFLAGS="-L$ac_boost_path/lib" + BOOST_CPPFLAGS="-I$ac_boost_path/include" + else + for ac_boost_path_tmp in /usr /usr/local /opt /opt/local ; do + if test -d "$ac_boost_path_tmp/include/boost" && test -r "$ac_boost_path_tmp/include/boost"; then + BOOST_LDFLAGS="-L$ac_boost_path_tmp/lib" + BOOST_CPPFLAGS="-I$ac_boost_path_tmp/include" + break; + fi + done + fi + + CPPFLAGS_SAVED="$CPPFLAGS" + CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" + export CPPFLAGS + + LDFLAGS_SAVED="$LDFLAGS" + LDFLAGS="$LDFLAGS $BOOST_LDFLAGS" + export LDFLAGS + + AC_LANG_PUSH(C++) + AC_COMPILE_IFELSE([AC_LANG_PROGRAM([[ + @%:@include + ]], [[ + #if BOOST_VERSION >= $WANT_BOOST_VERSION + // Everything is okay + #else + # error Boost version is too old + #endif + ]])],[ + AC_MSG_RESULT(yes) + succeeded=yes + found_system=yes + ],[ + ]) + AC_LANG_POP([C++]) + + + + dnl if we found no boost with system layout we search for boost libraries + dnl built and installed without the --layout=system option or for a staged(not installed) version + if test "x$succeeded" != "xyes"; then + _version=0 + if test "$ac_boost_path" != ""; then + BOOST_LDFLAGS="-L$ac_boost_path/lib" + if test -d "$ac_boost_path" && test -r "$ac_boost_path"; then + for i in `ls -d $ac_boost_path/include/boost-* 2>/dev/null`; do + _version_tmp=`echo $i | sed "s#$ac_boost_path##" | sed 's/\/include\/boost-//' | sed 's/_/./'` + V_CHECK=`expr $_version_tmp \> $_version` + if test "$V_CHECK" = "1" ; then + _version=$_version_tmp + fi + VERSION_UNDERSCORE=`echo $_version | sed 's/\./_/'` + BOOST_CPPFLAGS="-I$ac_boost_path/include/boost-$VERSION_UNDERSCORE" + done + fi + else + for ac_boost_path in /usr /usr/local /opt /opt/local ; do + if test -d "$ac_boost_path" && test -r "$ac_boost_path"; then + for i in `ls -d $ac_boost_path/include/boost-* 2>/dev/null`; do + _version_tmp=`echo $i | sed "s#$ac_boost_path##" | sed 's/\/include\/boost-//' | sed 's/_/./'` + V_CHECK=`expr $_version_tmp \> $_version` + if test "$V_CHECK" = "1" ; then + _version=$_version_tmp + best_path=$ac_boost_path + fi + done + fi + done + + VERSION_UNDERSCORE=`echo $_version | sed 's/\./_/'` + BOOST_CPPFLAGS="-I$best_path/include/boost-$VERSION_UNDERSCORE" + BOOST_LDFLAGS="-L$best_path/lib" + + if test "x$BOOST_ROOT" != "x"; then + if test -d "$BOOST_ROOT" && test -r "$BOOST_ROOT" && test -d "$BOOST_ROOT/stage/lib" && test -r "$BOOST_ROOT/stage/lib"; then + version_dir=`expr //$BOOST_ROOT : '.*/\(.*\)'` + stage_version=`echo $version_dir | sed 's/boost_//' | sed 's/_/./g'` + stage_version_shorten=`expr $stage_version : '\([[0-9]]*\.[[0-9]]*\)'` + V_CHECK=`expr $stage_version_shorten \>\= $_version` + if test "$V_CHECK" = "1" ; then + AC_MSG_NOTICE(We will use a staged boost library from $BOOST_ROOT) + BOOST_CPPFLAGS="-I$BOOST_ROOT" + BOOST_LDFLAGS="-L$BOOST_ROOT/stage/lib" + fi + fi + fi + fi + + CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" + export CPPFLAGS + LDFLAGS="$LDFLAGS $BOOST_LDFLAGS" + export LDFLAGS + + AC_LANG_PUSH(C++) + AC_COMPILE_IFELSE([AC_LANG_PROGRAM([[ + @%:@include + ]], [[ + #if BOOST_VERSION >= $WANT_BOOST_VERSION + // Everything is okay + #else + # error Boost version is too old + #endif + ]])],[ + AC_MSG_RESULT(yes) + succeeded=yes + found_system=yes + ],[ + ]) + AC_LANG_POP([C++]) + fi + + if test "$succeeded" != "yes" ; then + if test "$_version" = "0" ; then + AC_MSG_ERROR([[We could not detect the boost libraries (version $boost_lib_version_req_shorten or higher). If you have a staged boost library (still not installed) please specify \$BOOST_ROOT in your environment and do not give a PATH to --with-boost option. If you are sure you have boost installed, then check your version number looking in . See http://randspringer.de/boost for more documentation.]]) + else + AC_MSG_NOTICE([Your boost libraries seems to old (version $_version).]) + fi + else + AC_SUBST(BOOST_CPPFLAGS) + AC_SUBST(BOOST_LDFLAGS) + AC_DEFINE(HAVE_BOOST,,[define if the Boost library is available]) + fi + + CPPFLAGS="$CPPFLAGS_SAVED" + LDFLAGS="$LDFLAGS_SAVED" +fi + +]) diff --git a/aclocal/ax_javac_and_java.m4 b/aclocal/ax_javac_and_java.m4 new file mode 100644 index 00000000..3c8577f4 --- /dev/null +++ b/aclocal/ax_javac_and_java.m4 @@ -0,0 +1,84 @@ +dnl @synopsis AX_JAVAC_AND_JAVA +dnl +dnl Test for the presence of a JDK. +dnl +dnl If "JAVA" is defined in the environment, that will be the only +dnl java command tested. Otherwise, a hard-coded list will be used. +dnl Similarly for "JAVAC". +dnl +dnl This macro does not currenly support testing for a particular +dnl Java version, the presence of a particular class, testing for +dnl only one of "java" and "javac", or compiling or running +dnl user-provided Java code. +dnl +dnl After AX_JAVAC_AND_JAVA runs, the shell variables "success" and +dnl "ax_javac_and_java" are set to "yes" or "no", and "JAVAC" and +dnl "JAVA" are set to the appropriate commands. +dnl +dnl @category Java +dnl @version 2009-02-09 +dnl @license AllPermissive +dnl +dnl Copyright (C) 2009 David Reiss +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + + +AC_DEFUN([AX_JAVAC_AND_JAVA], + [ + + dnl Hard-coded default commands to test. + JAVAC_PROGS="javac,jikes,gcj -C" + JAVA_PROGS="java,kaffe" + + dnl Allow the user to specify an alternative. + if test -n "$JAVAC" ; then + JAVAC_PROGS="$JAVAC" + fi + if test -n "$JAVA" ; then + JAVA_PROGS="$JAVA" + fi + + AC_MSG_CHECKING(for javac and java) + + echo "public class configtest_ax_javac_and_java { public static void main(String args@<:@@:>@) { } }" > configtest_ax_javac_and_java.java + success=no + oIFS="$IFS" + + IFS="," + for JAVAC in $JAVAC_PROGS ; do + IFS="$oIFS" + + echo "Running \"$JAVAC configtest_ax_javac_and_java.java\"" >&AS_MESSAGE_LOG_FD + if $JAVAC configtest_ax_javac_and_java.java >&AS_MESSAGE_LOG_FD 2>&1 ; then + + IFS="," + for JAVA in $JAVA_PROGS ; do + IFS="$oIFS" + + echo "Running \"$JAVA configtest_ax_javac_and_java\"" >&AS_MESSAGE_LOG_FD + if $JAVA configtest_ax_javac_and_java >&AS_MESSAGE_LOG_FD 2>&1 ; then + success=yes + break 2 + fi + + done + + fi + + done + + rm -f configtest_ax_javac_and_java.java configtest_ax_javac_and_java.class + + if test "$success" != "yes" ; then + AC_MSG_RESULT(no) + JAVAC="" + JAVA="" + else + AC_MSG_RESULT(yes) + fi + + ax_javac_and_java="$success" + + ]) diff --git a/aclocal/ax_lib_event.m4 b/aclocal/ax_lib_event.m4 new file mode 100644 index 00000000..3a48156a --- /dev/null +++ b/aclocal/ax_lib_event.m4 @@ -0,0 +1,173 @@ +dnl @synopsis AX_LIB_EVENT([MINIMUM-VERSION]) +dnl +dnl Test for the libevent library of a particular version (or newer). +dnl +dnl If no path to the installed libevent is given, the macro will first try +dnl using no -I or -L flags, then searches under /usr, /usr/local, /opt, +dnl and /opt/libevent. +dnl If these all fail, it will try the $LIBEVENT_ROOT environment variable. +dnl +dnl This macro requires that #include works and defines u_char. +dnl +dnl This macro calls: +dnl AC_SUBST(LIBEVENT_CPPFLAGS) +dnl AC_SUBST(LIBEVENT_LDFLAGS) +dnl AC_SUBST(LIBEVENT_LIBS) +dnl +dnl And (if libevent is found): +dnl AC_DEFINE(HAVE_LIBEVENT) +dnl +dnl It also leaves the shell variables "success" and "ax_have_libevent" +dnl set to "yes" or "no". +dnl +dnl NOTE: This macro does not currently work for cross-compiling, +dnl but it can be easily modified to allow it. (grep "cross"). +dnl +dnl @category InstalledPackages +dnl @category C +dnl @version 2007-09-12 +dnl @license AllPermissive +dnl +dnl Copyright (C) 2009 David Reiss +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +dnl Input: ax_libevent_path, WANT_LIBEVENT_VERSION +dnl Output: success=yes/no +AC_DEFUN([AX_LIB_EVENT_DO_CHECK], + [ + # Save our flags. + CPPFLAGS_SAVED="$CPPFLAGS" + LDFLAGS_SAVED="$LDFLAGS" + LIBS_SAVED="$LIBS" + LD_LIBRARY_PATH_SAVED="$LD_LIBRARY_PATH" + + # Set our flags if we are checking a specific directory. + if test -n "$ax_libevent_path" ; then + LIBEVENT_CPPFLAGS="-I$ax_libevent_path/include" + LIBEVENT_LDFLAGS="-L$ax_libevent_path/lib" + LD_LIBRARY_PATH="$ax_libevent_path/lib:$LD_LIBRARY_PATH" + else + LIBEVENT_CPPFLAGS="" + LIBEVENT_LDFLAGS="" + fi + + # Required flag for libevent. + LIBEVENT_LIBS="-levent" + + # Prepare the environment for compilation. + CPPFLAGS="$CPPFLAGS $LIBEVENT_CPPFLAGS" + LDFLAGS="$LDFLAGS $LIBEVENT_LDFLAGS" + LIBS="$LIBS $LIBEVENT_LIBS" + export CPPFLAGS + export LDFLAGS + export LIBS + export LD_LIBRARY_PATH + + success=no + + # Compile, link, and run the program. This checks: + # - event.h is available for including. + # - event_get_version() is available for linking. + # - The event version string is lexicographically greater + # than the required version. + AC_LANG_PUSH([C]) + dnl This can be changed to AC_LINK_IFELSE if you are cross-compiling, + dnl but then the version cannot be checked. + AC_RUN_IFELSE([AC_LANG_PROGRAM([[ + #include + #include + ]], [[ + const char* lib_version = event_get_version(); + const char* wnt_version = "$WANT_LIBEVENT_VERSION"; + for (;;) { + /* If we reached the end of the want version. We have it. */ + if (*wnt_version == '\0') { + return 0; + } + /* If the want version continues but the lib version does not, */ + /* we are missing a letter. We don't have it. */ + if (*lib_version == '\0') { + return 1; + } + /* If we have greater than what we want. We have it. */ + if (*lib_version > *wnt_version) { + return 0; + } + /* If we have less, we don't. */ + if (*lib_version < *wnt_version) { + return 1; + } + lib_version++; + wnt_version++; + } + return 0; + ]])], [ + success=yes + ]) + AC_LANG_POP([C]) + + # Restore flags. + CPPFLAGS="$CPPFLAGS_SAVED" + LDFLAGS="$LDFLAGS_SAVED" + LIBS="$LIBS_SAVED" + LD_LIBRARY_PATH="$LD_LIBRARY_PATH_SAVED" + ]) + + +AC_DEFUN([AX_LIB_EVENT], + [ + + dnl Allow search path to be overridden on the command line. + AC_ARG_WITH([libevent], + AS_HELP_STRING([--with-libevent@<:@=DIR@:>@], [use libevent (default is yes) - it is possible to specify an alternate root directory for libevent]), + [ + if test "x$withval" = "xno"; then + want_libevent="no" + elif test "x$withval" = "xyes"; then + want_libevent="yes" + ax_libevent_path="" + else + want_libevent="yes" + ax_libevent_path="$withval" + fi + ], + [ want_libevent="yes" ; ax_libevent_path="" ]) + + + if test "$want_libevent" = "yes"; then + WANT_LIBEVENT_VERSION=ifelse([$1], ,1.2,$1) + + AC_MSG_CHECKING(for libevent >= $WANT_LIBEVENT_VERSION) + + # Run tests. + if test -n "$ax_libevent_path"; then + AX_LIB_EVENT_DO_CHECK + else + for ax_libevent_path in "" /usr /usr/local /opt /opt/local /opt/libevent "$LIBEVENT_ROOT" ; do + AX_LIB_EVENT_DO_CHECK + if test "$success" = "yes"; then + break; + fi + done + fi + + if test "$success" != "yes" ; then + AC_MSG_RESULT(no) + LIBEVENT_CPPFLAGS="" + LIBEVENT_LDFLAGS="" + LIBEVENT_LIBS="" + else + AC_MSG_RESULT(yes) + AC_DEFINE(HAVE_LIBEVENT,,[define if libevent is available]) + fi + + ax_have_libevent="$success" + + AC_SUBST(LIBEVENT_CPPFLAGS) + AC_SUBST(LIBEVENT_LDFLAGS) + AC_SUBST(LIBEVENT_LIBS) + fi + + ]) diff --git a/aclocal/ax_lib_zlib.m4 b/aclocal/ax_lib_zlib.m4 new file mode 100644 index 00000000..8c10ab41 --- /dev/null +++ b/aclocal/ax_lib_zlib.m4 @@ -0,0 +1,173 @@ +dnl @synopsis AX_LIB_ZLIB([MINIMUM-VERSION]) +dnl +dnl Test for the libz library of a particular version (or newer). +dnl +dnl If no path to the installed zlib is given, the macro will first try +dnl using no -I or -L flags, then searches under /usr, /usr/local, /opt, +dnl and /opt/zlib. +dnl If these all fail, it will try the $ZLIB_ROOT environment variable. +dnl +dnl This macro calls: +dnl AC_SUBST(ZLIB_CPPFLAGS) +dnl AC_SUBST(ZLIB_LDFLAGS) +dnl AC_SUBST(ZLIB_LIBS) +dnl +dnl And (if zlib is found): +dnl AC_DEFINE(HAVE_ZLIB) +dnl +dnl It also leaves the shell variables "success" and "ax_have_zlib" +dnl set to "yes" or "no". +dnl +dnl NOTE: This macro does not currently work for cross-compiling, +dnl but it can be easily modified to allow it. (grep "cross"). +dnl +dnl @category InstalledPackages +dnl @category C +dnl @version 2007-09-12 +dnl @license AllPermissive +dnl +dnl Copyright (C) 2009 David Reiss +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +dnl Input: ax_zlib_path, WANT_ZLIB_VERSION +dnl Output: success=yes/no +AC_DEFUN([AX_LIB_ZLIB_DO_CHECK], + [ + # Save our flags. + CPPFLAGS_SAVED="$CPPFLAGS" + LDFLAGS_SAVED="$LDFLAGS" + LIBS_SAVED="$LIBS" + LD_LIBRARY_PATH_SAVED="$LD_LIBRARY_PATH" + + # Set our flags if we are checking a specific directory. + if test -n "$ax_zlib_path" ; then + ZLIB_CPPFLAGS="-I$ax_zlib_path/include" + ZLIB_LDFLAGS="-L$ax_zlib_path/lib" + LD_LIBRARY_PATH="$ax_zlib_path/lib:$LD_LIBRARY_PATH" + else + ZLIB_CPPFLAGS="" + ZLIB_LDFLAGS="" + fi + + # Required flag for zlib. + ZLIB_LIBS="-lz" + + # Prepare the environment for compilation. + CPPFLAGS="$CPPFLAGS $ZLIB_CPPFLAGS" + LDFLAGS="$LDFLAGS $ZLIB_LDFLAGS" + LIBS="$LIBS $ZLIB_LIBS" + export CPPFLAGS + export LDFLAGS + export LIBS + export LD_LIBRARY_PATH + + success=no + + # Compile, link, and run the program. This checks: + # - zlib.h is available for including. + # - zlibVersion() is available for linking. + # - ZLIB_VERNUM is greater than or equal to the desired version. + # - ZLIB_VERSION (defined in zlib.h) matches zlibVersion() + # (defined in the library). + AC_LANG_PUSH([C]) + dnl This can be changed to AC_LINK_IFELSE if you are cross-compiling. + AC_RUN_IFELSE([AC_LANG_PROGRAM([[ + #include + #if ZLIB_VERNUM >= 0x$WANT_ZLIB_VERSION + #else + # error zlib is too old + #endif + ]], [[ + const char* lib_version = zlibVersion(); + const char* hdr_version = ZLIB_VERSION; + for (;;) { + if (*lib_version != *hdr_version) { + /* If this happens, your zlib header doesn't match your zlib */ + /* library. That is really bad. */ + return 1; + } + if (*lib_version == '\0') { + break; + } + lib_version++; + hdr_version++; + } + return 0; + ]])], [ + success=yes + ]) + AC_LANG_POP([C]) + + # Restore flags. + CPPFLAGS="$CPPFLAGS_SAVED" + LDFLAGS="$LDFLAGS_SAVED" + LIBS="$LIBS_SAVED" + LD_LIBRARY_PATH="$LD_LIBRARY_PATH_SAVED" + ]) + + +AC_DEFUN([AX_LIB_ZLIB], + [ + + dnl Allow search path to be overridden on the command line. + AC_ARG_WITH([zlib], + AS_HELP_STRING([--with-zlib@<:@=DIR@:>@], [use zlib (default is yes) - it is possible to specify an alternate root directory for zlib]), + [ + if test "x$withval" = "xno"; then + want_zlib="no" + elif test "x$withval" = "xyes"; then + want_zlib="yes" + ax_zlib_path="" + else + want_zlib="yes" + ax_zlib_path="$withval" + fi + ], + [want_zlib="yes" ; ax_zlib_path="" ]) + + + if test "$want_zlib" = "yes"; then + # Parse out the version. + zlib_version_req=ifelse([$1], ,1.2.3,$1) + zlib_version_req_major=`expr $zlib_version_req : '\([[0-9]]*\)'` + zlib_version_req_minor=`expr $zlib_version_req : '[[0-9]]*\.\([[0-9]]*\)'` + zlib_version_req_patch=`expr $zlib_version_req : '[[0-9]]*\.[[0-9]]*\.\([[0-9]]*\)'` + if test -z "$zlib_version_req_patch" ; then + zlib_version_req_patch="0" + fi + WANT_ZLIB_VERSION=`expr $zlib_version_req_major \* 1000 \+ $zlib_version_req_minor \* 100 \+ $zlib_version_req_patch \* 10` + + AC_MSG_CHECKING(for zlib >= $zlib_version_req) + + # Run tests. + if test -n "$ax_zlib_path"; then + AX_LIB_ZLIB_DO_CHECK + else + for ax_zlib_path in "" /usr /usr/local /opt /opt/zlib "$ZLIB_ROOT" ; do + AX_LIB_ZLIB_DO_CHECK + if test "$success" = "yes"; then + break; + fi + done + fi + + if test "$success" != "yes" ; then + AC_MSG_RESULT(no) + ZLIB_CPPFLAGS="" + ZLIB_LDFLAGS="" + ZLIB_LIBS="" + else + AC_MSG_RESULT(yes) + AC_DEFINE(HAVE_ZLIB,,[define if zlib is available]) + fi + + ax_have_zlib="$success" + + AC_SUBST(ZLIB_CPPFLAGS) + AC_SUBST(ZLIB_LDFLAGS) + AC_SUBST(ZLIB_LIBS) + fi + + ]) diff --git a/aclocal/ax_signed_right_shift.m4 b/aclocal/ax_signed_right_shift.m4 new file mode 100644 index 00000000..01952338 --- /dev/null +++ b/aclocal/ax_signed_right_shift.m4 @@ -0,0 +1,127 @@ +dnl @synopsis AX_SIGNED_RIGHT_SHIFT +dnl +dnl Tests the behavior of a right shift on a negative signed int. +dnl +dnl This macro calls: +dnl AC_DEFINE(SIGNED_RIGHT_SHIFT_IS) +dnl AC_DEFINE(ARITHMETIC_RIGHT_SHIFT) +dnl AC_DEFINE(LOGICAL_RIGHT_SHIFT) +dnl AC_DEFINE(UNKNOWN_RIGHT_SHIFT) +dnl +dnl SIGNED_RIGHT_SHIFT_IS will be equal to one of the other macros. +dnl It also leaves the shell variables "ax_signed_right_shift" +dnl set to "arithmetic", "logical", or "unknown". +dnl +dnl NOTE: This macro does not work for cross-compiling. +dnl +dnl @category C +dnl @version 2009-03-25 +dnl @license AllPermissive +dnl +dnl Copyright (C) 2009 David Reiss +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +AC_DEFUN([AX_SIGNED_RIGHT_SHIFT], + [ + + AC_MSG_CHECKING(the behavior of a signed right shift) + + success_arithmetic=no + AC_RUN_IFELSE([AC_LANG_PROGRAM([[]], [[ + return + /* 0xffffffff */ + -1 >> 1 != -1 || + -1 >> 2 != -1 || + -1 >> 3 != -1 || + -1 >> 4 != -1 || + -1 >> 8 != -1 || + -1 >> 16 != -1 || + -1 >> 24 != -1 || + -1 >> 31 != -1 || + /* 0x80000000 */ + (-2147483647 - 1) >> 1 != -1073741824 || + (-2147483647 - 1) >> 2 != -536870912 || + (-2147483647 - 1) >> 3 != -268435456 || + (-2147483647 - 1) >> 4 != -134217728 || + (-2147483647 - 1) >> 8 != -8388608 || + (-2147483647 - 1) >> 16 != -32768 || + (-2147483647 - 1) >> 24 != -128 || + (-2147483647 - 1) >> 31 != -1 || + /* 0x90800000 */ + -1870659584 >> 1 != -935329792 || + -1870659584 >> 2 != -467664896 || + -1870659584 >> 3 != -233832448 || + -1870659584 >> 4 != -116916224 || + -1870659584 >> 8 != -7307264 || + -1870659584 >> 16 != -28544 || + -1870659584 >> 24 != -112 || + -1870659584 >> 31 != -1 || + 0; + ]])], [ + success_arithmetic=yes + ]) + + + success_logical=no + AC_RUN_IFELSE([AC_LANG_PROGRAM([[]], [[ + return + /* 0xffffffff */ + -1 >> 1 != (signed)((unsigned)-1 >> 1) || + -1 >> 2 != (signed)((unsigned)-1 >> 2) || + -1 >> 3 != (signed)((unsigned)-1 >> 3) || + -1 >> 4 != (signed)((unsigned)-1 >> 4) || + -1 >> 8 != (signed)((unsigned)-1 >> 8) || + -1 >> 16 != (signed)((unsigned)-1 >> 16) || + -1 >> 24 != (signed)((unsigned)-1 >> 24) || + -1 >> 31 != (signed)((unsigned)-1 >> 31) || + /* 0x80000000 */ + (-2147483647 - 1) >> 1 != (signed)((unsigned)(-2147483647 - 1) >> 1) || + (-2147483647 - 1) >> 2 != (signed)((unsigned)(-2147483647 - 1) >> 2) || + (-2147483647 - 1) >> 3 != (signed)((unsigned)(-2147483647 - 1) >> 3) || + (-2147483647 - 1) >> 4 != (signed)((unsigned)(-2147483647 - 1) >> 4) || + (-2147483647 - 1) >> 8 != (signed)((unsigned)(-2147483647 - 1) >> 8) || + (-2147483647 - 1) >> 16 != (signed)((unsigned)(-2147483647 - 1) >> 16) || + (-2147483647 - 1) >> 24 != (signed)((unsigned)(-2147483647 - 1) >> 24) || + (-2147483647 - 1) >> 31 != (signed)((unsigned)(-2147483647 - 1) >> 31) || + /* 0x90800000 */ + -1870659584 >> 1 != (signed)((unsigned)-1870659584 >> 1) || + -1870659584 >> 2 != (signed)((unsigned)-1870659584 >> 2) || + -1870659584 >> 3 != (signed)((unsigned)-1870659584 >> 3) || + -1870659584 >> 4 != (signed)((unsigned)-1870659584 >> 4) || + -1870659584 >> 8 != (signed)((unsigned)-1870659584 >> 8) || + -1870659584 >> 16 != (signed)((unsigned)-1870659584 >> 16) || + -1870659584 >> 24 != (signed)((unsigned)-1870659584 >> 24) || + -1870659584 >> 31 != (signed)((unsigned)-1870659584 >> 31) || + 0; + ]])], [ + success_logical=yes + ]) + + + AC_DEFINE([ARITHMETIC_RIGHT_SHIFT], 1, [Possible value for SIGNED_RIGHT_SHIFT_IS]) + AC_DEFINE([LOGICAL_RIGHT_SHIFT], 2, [Possible value for SIGNED_RIGHT_SHIFT_IS]) + AC_DEFINE([UNKNOWN_RIGHT_SHIFT], 3, [Possible value for SIGNED_RIGHT_SHIFT_IS]) + + if test "$success_arithmetic" = "yes" && test "$success_logica" = "yes" ; then + AC_MSG_ERROR("Right shift appears to be both arithmetic and logical!") + elif test "$success_arithmetic" = "yes" ; then + ax_signed_right_shift=arithmetic + AC_DEFINE([SIGNED_RIGHT_SHIFT_IS], 1, + [Indicates the effect of the right shift operator + on negative signed integers]) + elif test "$success_logical" = "yes" ; then + ax_signed_right_shift=logical + AC_DEFINE([SIGNED_RIGHT_SHIFT_IS], 2, + [Indicates the effect of the right shift operator + on negative signed integers]) + else + ax_signed_right_shift=unknown + AC_DEFINE([SIGNED_RIGHT_SHIFT_IS], 3, + [Indicates the effect of the right shift operator + on negative signed integers]) + fi + + AC_MSG_RESULT($ax_signed_right_shift) + ]) diff --git a/aclocal/ax_thrift_internal.m4 b/aclocal/ax_thrift_internal.m4 new file mode 100644 index 00000000..979bec6b --- /dev/null +++ b/aclocal/ax_thrift_internal.m4 @@ -0,0 +1,39 @@ +dnl @synopsis AX_THRIFT_GEN(SHORT_LANGUAGE, LONG_LANGUAGE, DEFAULT) +dnl @synopsis AX_THRIFT_LIB(SHORT_LANGUAGE, LONG_LANGUAGE, DEFAULT) +dnl +dnl Allow a particular language generator to be disabled. +dnl Allow a particular language library to be disabled. +dnl +dnl These macros have poor error handling and are poorly documented. +dnl They are intended only for internal use by the Thrift compiler. +dnl +dnl @version 2008-02-20 +dnl @license AllPermissive +dnl +dnl Copyright (C) 2009 David Reiss +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +AC_DEFUN([AX_THRIFT_GEN], + [ + AC_ARG_ENABLE([gen-$1], + AC_HELP_STRING([--enable-gen-$1], [enable the $2 compiler @<:@default=$3@:>@]), + [ax_thrift_gen_$1="$enableval"], + [ax_thrift_gen_$1=$3] + ) + dnl I'd like to run the AM_CONDITIONAL here, but automake likes + dnl all AM_CONDITIONALs to be nice and explicit in configure.ac. + dnl AM_CONDITIONAL([THRIFT_GEN_$1], [test "$ax_thrift_gen_$1" = "yes"]) + ]) + +AC_DEFUN([AX_THRIFT_LIB], + [ + AC_ARG_WITH($1, + AC_HELP_STRING([--with-$1], [build the $2 library @<:@default=$3@:>@]), + [with_$1="$withval"], + [with_$1=$3] + ) + dnl What we do here is going to vary from library to library, + dnl so we can't really generalize (yet!). + ]) diff --git a/bootstrap.sh b/bootstrap.sh new file mode 100755 index 00000000..45ac8d52 --- /dev/null +++ b/bootstrap.sh @@ -0,0 +1,35 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +./cleanup.sh + +autoscan || exit 1 +aclocal -I ./aclocal || exit 1 +autoheader || exit 1 + +if libtoolize --version 1 >/dev/null 2>/dev/null; then + libtoolize --automake || exit 1 +elif glibtoolize --version 1 >/dev/null 2>/dev/null; then + glibtoolize --automake || exit 1 +fi + +autoconf +automake -ac --add-missing --foreign || exit 1 diff --git a/cleanup.sh b/cleanup.sh new file mode 100755 index 00000000..0ea55bb9 --- /dev/null +++ b/cleanup.sh @@ -0,0 +1,58 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +topsrcdir="`dirname $0`" +cd "$topsrcdir" + +make -k clean >/dev/null 2>&1 +make -k distclean >/dev/null 2>&1 +find . -name Makefile.in -exec rm -f {} \; +rm -rf \ +AUTHORS \ +ChangeLog \ +INSTALL \ +Makefile \ +Makefile.in \ +Makefile.orig \ +aclocal.m4 \ +autom4te.cache \ +autoscan.log \ +config.guess \ +config.h \ +config.hin \ +config.hin~ \ +config.log \ +config.status \ +config.status.lineno \ +config.sub \ +configure \ +configure.lineno \ +configure.scan \ +depcomp \ +.deps \ +install-sh \ +.libs \ +libtool \ +ltmain.sh \ +missing \ +ylwrap \ +if/gen-* \ +test/gen-* diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am new file mode 100644 index 00000000..3838facf --- /dev/null +++ b/compiler/cpp/Makefile.am @@ -0,0 +1,136 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +AM_YFLAGS = -d +BUILT_SOURCES = + +bin_PROGRAMS = thrift + +thrift_OBJDIR = obj + +thrift_SOURCES = src/thrifty.yy \ + src/thriftl.ll \ + src/main.cc \ + src/md5.c \ + src/generate/t_generator.cc \ + src/globals.h \ + src/main.h \ + src/platform.h \ + src/md5.h \ + src/parse/t_doc.h \ + src/parse/t_type.h \ + src/parse/t_base_type.h \ + src/parse/t_enum.h \ + src/parse/t_enum_value.h \ + src/parse/t_typedef.h \ + src/parse/t_container.h \ + src/parse/t_list.h \ + src/parse/t_set.h \ + src/parse/t_map.h \ + src/parse/t_struct.h \ + src/parse/t_field.h \ + src/parse/t_service.h \ + src/parse/t_function.h \ + src/parse/t_program.h \ + src/parse/t_scope.h \ + src/parse/t_const.h \ + src/parse/t_const_value.h \ + src/generate/t_generator.h \ + src/generate/t_oop_generator.h + +if THRIFT_GEN_cpp +thrift_SOURCES += src/generate/t_cpp_generator.cc +endif +if THRIFT_GEN_java +thrift_SOURCES += src/generate/t_java_generator.cc +endif +if THRIFT_GEN_csharp +thrift_SOURCES += src/generate/t_csharp_generator.cc +endif +if THRIFT_GEN_py +thrift_SOURCES += src/generate/t_py_generator.cc +endif +if THRIFT_GEN_rb +thrift_SOURCES += src/generate/t_rb_generator.cc +endif +if THRIFT_GEN_perl +thrift_SOURCES += src/generate/t_perl_generator.cc +endif +if THRIFT_GEN_php +thrift_SOURCES += src/generate/t_php_generator.cc +endif +if THRIFT_GEN_erl +thrift_SOURCES += src/generate/t_erl_generator.cc +endif +if THRIFT_GEN_cocoa +thrift_SOURCES += src/generate/t_cocoa_generator.cc +endif +if THRIFT_GEN_st +thrift_SOURCES += src/generate/t_st_generator.cc +endif +if THRIFT_GEN_ocaml +thrift_SOURCES += src/generate/t_ocaml_generator.cc +endif +if THRIFT_GEN_hs +thrift_SOURCES += src/generate/t_hs_generator.cc +endif +if THRIFT_GEN_xsd +thrift_SOURCES += src/generate/t_xsd_generator.cc +endif +if THRIFT_GEN_html +thrift_SOURCES += src/generate/t_html_generator.cc +endif + +thrift_CXXFLAGS = -Wall -I$(srcdir)/src $(BOOST_CPPFLAGS) +thrift_LDFLAGS = -Wall $(BOOST_LDFLAGS) + +thrift_LDADD = @LEXLIB@ + +EXTRA_DIST = README + +clean-local: + $(RM) thriftl.cc thrifty.cc thrifty.h version.h + +src/main.cc: version.h + +# Adding this to BUILT_SOURCES will cause version.h to be +# regenerated on every "make all" or "make check", which is +# necessary because it changes whenever we "svn up" or similar. +# Ideally, we would like this to be regenerated whenever the +# compiler is rebuilt, but every way we could think of to do +# that caused unnecessary rebuilds of the compiler. +BUILT_SOURCES += regen_version_h + +THRIFT_VERSION=$(shell /bin/sh $(top_srcdir)/print_version.sh -v) +THRIFT_REVISION=$(shell /bin/sh $(top_srcdir)/print_version.sh -r) + +regen_version_h: + @printf "Regenerating version.h... " + @TMPFILE=`mktemp ./version_h.tmp_XXXXXX` ; \ + echo "// AUTOGENERATED, DO NOT EDIT" > $$TMPFILE ; \ + echo '#define THRIFT_VERSION "$(THRIFT_VERSION)"' >> $$TMPFILE ; \ + echo '#define THRIFT_REVISION "$(THRIFT_REVISION)"' >> $$TMPFILE ; \ + if cmp $$TMPFILE version.h >/dev/null ; \ + then \ + rm -f $$TMPFILE ; \ + echo "No changes." ; \ + else \ + mv $$TMPFILE version.h ; \ + echo "Updated." ; \ + fi diff --git a/compiler/cpp/README b/compiler/cpp/README new file mode 100644 index 00000000..fb100a82 --- /dev/null +++ b/compiler/cpp/README @@ -0,0 +1,39 @@ +Thrift Code Compiler + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Thrift Code Compiler +==================== + +This compiler takes thrift files as input and generates output code across +various programming languages. To build and install it, do this: + + ./bootstrap.sh + ./configure + make + sudo make install + +It requires some form of LEX and YACC to be installed, which should be +picked up by autoconf. + +Not much else to report here. You'll have to look at the code to get your +questions answered. Or just run the executable after you build and take +a look at the usage message. diff --git a/compiler/cpp/src/generate/t_cocoa_generator.cc b/compiler/cpp/src/generate/t_cocoa_generator.cc new file mode 100644 index 00000000..48c853c2 --- /dev/null +++ b/compiler/cpp/src/generate/t_cocoa_generator.cc @@ -0,0 +1,2059 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * Objective-C code generator. + * + * mostly copy/pasting/tweaking from mcslee's work. + */ +class t_cocoa_generator : public t_oop_generator { + public: + t_cocoa_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-cocoa"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + void generate_consts(std::vector consts); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_struct (t_struct* tstruct); + void generate_xception(t_struct* txception); + void generate_service (t_service* tservice); + + void print_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + std::string render_const_value(std::string name, t_type* type, t_const_value* value, + bool containerize_it=false); + + void generate_cocoa_struct(t_struct* tstruct, bool is_exception); + void generate_cocoa_struct_interface(std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void generate_cocoa_struct_implementation(std::ofstream& out, t_struct* tstruct, bool is_xception=false, bool is_result=false); + void generate_cocoa_struct_initializer_signature(std::ofstream& out, + t_struct* tstruct); + void generate_cocoa_struct_field_accessor_declarations(std::ofstream& out, + t_struct* tstruct, + bool is_exception); + void generate_cocoa_struct_field_accessor_implementations(std::ofstream& out, + t_struct* tstruct, + bool is_exception); + void generate_cocoa_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_cocoa_struct_result_writer(std::ofstream& out, t_struct* tstruct); + void generate_cocoa_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_cocoa_struct_description(std::ofstream& out, t_struct* tstruct); + + std::string function_result_helper_struct_type(t_function* tfunction); + void generate_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_cocoa_service_protocol (std::ofstream& out, t_service* tservice); + void generate_cocoa_service_client_interface (std::ofstream& out, t_service* tservice); + void generate_cocoa_service_client_implementation (std::ofstream& out, t_service* tservice); + void generate_cocoa_service_helpers (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream& out, + t_field* tfield, + std::string fieldName); + + void generate_deserialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream& out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream& out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream& out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream& out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string fieldName=""); + + void generate_serialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream& out, + t_map* tmap, + std::string iter, + std::string map); + + void generate_serialize_set_element (std::ofstream& out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream& out, + t_list* tlist, + std::string index, + std::string listName); + + /** + * Helper rendering functions + */ + + std::string cocoa_prefix(); + std::string cocoa_imports(); + std::string cocoa_thrift_imports(); + std::string type_name(t_type* ttype, bool class_ref=false); + std::string base_type_name(t_base_type* tbase); + std::string declare_field(t_field* tfield); + std::string function_signature(t_function* tfunction); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string format_string_for_type(t_type* type); + std::string call_field_setter(t_field* tfield, std::string fieldName); + std::string containerize(t_type * ttype, std::string fieldName); + std::string decontainerize(t_field * tfield, std::string fieldName); + + bool type_can_be_null(t_type* ttype) { + ttype = get_true_type(ttype); + + return + ttype->is_container() || + ttype->is_struct() || + ttype->is_xception() || + ttype->is_string(); + } + + private: + + std::string cocoa_prefix_; + std::string constants_declarations_; + + /** + * File streams + */ + + std::ofstream f_header_; + std::ofstream f_impl_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + */ +void t_cocoa_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + cocoa_prefix_ = program_->get_namespace("cocoa"); + + // we have a .h header file... + string f_header_name = program_name_+".h"; + string f_header_fullname = get_out_dir()+f_header_name; + f_header_.open(f_header_fullname.c_str()); + + f_header_ << + autogen_comment() << + endl; + + f_header_ << + cocoa_imports() << + cocoa_thrift_imports(); + + // ...and a .m implementation file + string f_impl_name = get_out_dir()+program_name_+".m"; + f_impl_.open(f_impl_name.c_str()); + + f_impl_ << + autogen_comment() << + endl; + + f_impl_ << + cocoa_imports() << + cocoa_thrift_imports() << + "#import \"" << f_header_name << "\"" << endl << + endl; + +} + +/** + * Prints standard Cocoa imports + * + * @return List of imports for Cocoa libraries + */ +string t_cocoa_generator::cocoa_imports() { + return + string() + + "#import \n" + + "\n"; +} + +/** + * Prints thrift runtime imports + * + * @return List of imports necessary for thrift runtime + */ +string t_cocoa_generator::cocoa_thrift_imports() { + string result = string() + + "#import \n" + + "#import \n" + + "#import \n" + + "\n"; + + // Include other Thrift includes + const vector& includes = program_->get_includes(); + for (size_t i = 0; i < includes.size(); ++i) { + result += "#import \"" + includes[i]->get_name() + ".h\"" + "\n"; + } + result += "\n"; + + return result; +} + + +/** + * Finish up generation. + */ +void t_cocoa_generator::close_generator() +{ + // stick our constants declarations at the end of the header file + // since they refer to things we are defining. + f_header_ << constants_declarations_ << endl; +} + +/** + * Generates a typedef. This is just a simple 1-liner in objective-c + * + * @param ttypedef The type definition + */ +void t_cocoa_generator::generate_typedef(t_typedef* ttypedef) { + f_header_ << + indent() << "typedef " << type_name(ttypedef->get_type()) << " " << cocoa_prefix_ << ttypedef->get_symbolic() << ";" << endl << + endl; +} + +/** + * Generates code for an enumerated type. In Objective-C, this is + * essentially the same as the thrift definition itself, using the + * enum keyword in Objective-C. For namespace purposes, the name of + * the enum plus an underscore is prefixed onto each element. + * + * @param tenum The enumeration + */ +void t_cocoa_generator::generate_enum(t_enum* tenum) { + f_header_ << + indent() << "enum " << cocoa_prefix_ << tenum->get_name() << " {" << endl; + indent_up(); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + bool first = true; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if (first) { + first = false; + } else { + f_header_ << + "," << endl; + } + f_header_ << + indent() << tenum->get_name() << "_" << (*c_iter)->get_name(); + if ((*c_iter)->has_value()) { + f_header_ << + " = " << (*c_iter)->get_value(); + } + } + + indent_down(); + f_header_ << + endl << + "};" << endl << + endl; +} + +/** + * Generates a class that holds all the constants. Primitive values + * could have been placed outside this class, but I just put + * everything in for consistency. + */ +void t_cocoa_generator::generate_consts(std::vector consts) { + std::ostringstream const_interface; + string constants_class_name = cocoa_prefix_ + program_name_ + "Constants"; + + const_interface << "@interface " << constants_class_name << " "; + scope_up(const_interface); + scope_down(const_interface); + + // getter method for each constant defined. + vector::iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + string name = (*c_iter)->get_name(); + t_type* type = (*c_iter)->get_type(); + const_interface << + "+ (" << type_name(type) << ") " << name << ";" << endl; + } + + const_interface << "@end"; + + // this gets spit into the header file in ::close_generator + constants_declarations_ = const_interface.str(); + + // static variables in the .m hold all constant values + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + string name = (*c_iter)->get_name(); + t_type* type = (*c_iter)->get_type(); + f_impl_ << + "static " << type_name(type) << " " << cocoa_prefix_ << name; + if (!type->is_container() && !type->is_struct()) { + f_impl_ << " = " << render_const_value(name, type, (*c_iter)->get_value()); + } + f_impl_ << ";" << endl; + } + f_impl_ << endl; + + f_impl_ << "@implementation " << constants_class_name << endl; + + // initialize complex constants when the class is loaded + f_impl_ << "+ (void) initialize "; + scope_up(f_impl_); + + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + if ((*c_iter)->get_type()->is_container() || + (*c_iter)->get_type()->is_struct()) { + string name = (*c_iter)->get_name(); + f_impl_ << indent() << name << " = " << render_const_value(name, + (*c_iter)->get_type(), + (*c_iter)->get_value()); + f_impl_ << ";" << endl; + } + } + scope_down(f_impl_); + + // getter method for each constant + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + string name = (*c_iter)->get_name(); + t_type* type = (*c_iter)->get_type(); + f_impl_ << + "+ (" << type_name(type) << ") " << name; + scope_up(f_impl_); + indent(f_impl_) << "return " << name << ";" << endl; + scope_down(f_impl_); + } + + f_impl_ << "@end" << endl << endl; +} + + +/** + * Generates a struct definition for a thrift data type. This is a class + * with protected data members, read(), write(), and getters and setters. + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_struct(t_struct* tstruct) { + generate_cocoa_struct_interface(f_header_, tstruct, false); + generate_cocoa_struct_implementation(f_impl_, tstruct, false); +} + +/** + * Exceptions are structs, but they inherit from NSException + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_xception(t_struct* txception) { + generate_cocoa_struct_interface(f_header_, txception, true); + generate_cocoa_struct_implementation(f_impl_, txception, true); +} + + +/** + * Generate the interface for a struct + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_interface(ofstream &out, + t_struct* tstruct, + bool is_exception) { + out << "@interface " << cocoa_prefix_ << tstruct->get_name() << " : "; + + if (is_exception) { + out << "NSException "; + } else { + out << "NSObject "; + } + + scope_up(out); + + // members are protected. this is redundant, but explicit. + // f_header_ << endl << "@protected:" << endl; + + const vector& members = tstruct->get_members(); + + // member varialbes + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << indent() << declare_field(*m_iter) << endl; + } + + if (members.size() > 0) { + out << endl; + // isset fields + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + "BOOL __" << (*m_iter)->get_name() << "_isset;" << endl; + } + } + + scope_down(out); + out << endl; + + // initializer for all fields + if (!members.empty()) { + generate_cocoa_struct_initializer_signature(out, tstruct); + out << ";" << endl; + } + out << endl; + + // read and write + out << "- (void) read: (id ) inProtocol;" << endl; + out << "- (void) write: (id ) outProtocol;" << endl; + out << endl; + + // getters and setters + generate_cocoa_struct_field_accessor_declarations(out, tstruct, is_exception); + + out << "@end" << endl << endl; +} + + +/** + * Generate signature for initializer of struct with a parameter for + * each field. + */ +void t_cocoa_generator::generate_cocoa_struct_initializer_signature(ofstream &out, + t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + indent(out) << "- (id) initWith"; + for (m_iter = members.begin(); m_iter != members.end(); ) { + if (m_iter == members.begin()) { + out << capitalize((*m_iter)->get_name()); + } else { + out << (*m_iter)->get_name(); + } + out << ": (" << type_name((*m_iter)->get_type()) << ") " << + (*m_iter)->get_name(); + ++m_iter; + if (m_iter != members.end()) { + out << " "; + } + } +} + + +/** + * Generate getter and setter declarations for all fields, plus an + * IsSet getter. + */ +void t_cocoa_generator::generate_cocoa_struct_field_accessor_declarations(ofstream &out, + t_struct* tstruct, + bool is_exception) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << indent() << "- (" << type_name((*m_iter)->get_type()) << ") " << decapitalize((*m_iter)->get_name()) << ";" << endl; + out << indent() << "- (void) set" << capitalize((*m_iter)->get_name()) << + ": (" << type_name((*m_iter)->get_type()) << ") " << (*m_iter)->get_name() << ";" << endl; + out << indent() << "- (BOOL) " << (*m_iter)->get_name() << "IsSet;" << endl << endl; + } +} + + +/** + * Generate struct implementation. + * + * @param tstruct The struct definition + * @param is_exception Is this an exception? + * @param is_result If this is a result it needs a different writer + */ +void t_cocoa_generator::generate_cocoa_struct_implementation(ofstream &out, + t_struct* tstruct, + bool is_exception, + bool is_result) { + indent(out) << + "@implementation " << cocoa_prefix_ << tstruct->get_name() << endl; + + // exceptions need to call the designated initializer on NSException + if (is_exception) { + out << indent() << "- (id) init" << endl; + scope_up(out); + out << indent() << "return [super initWithName: @\"" << tstruct->get_name() << + "\" reason: @\"unknown\" userInfo: nil];" << endl; + scope_down(out); + } + + // initializer with all fields as params + const vector& members = tstruct->get_members(); + if (!members.empty()) { + generate_cocoa_struct_initializer_signature(out, tstruct); + out << endl; + scope_up(out); + if (is_exception) { + out << indent() << "self = [self init];" << endl; + } else { + out << indent() << "self = [super init];" << endl; + } + + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + out << indent() << "__" << (*m_iter)->get_name() << " = "; + if (type_can_be_null(t)) { + out << "[" << (*m_iter)->get_name() << " retain];" << endl; + } else { + out << (*m_iter)->get_name() << ";" << endl; + } + out << indent() << "__" << (*m_iter)->get_name() << "_isset = YES;" << endl; + } + + out << indent() << "return self;" << endl; + scope_down(out); + out << endl; + } + + // dealloc + if (!members.empty()) { + out << "- (void) dealloc" << endl; + scope_up(out); + + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if (type_can_be_null(t)) { + indent(out) << "[__" << (*m_iter)->get_name() << " release];" << endl; + } + } + + out << indent() << "[super dealloc];" << endl; + scope_down(out); + out << endl; + } + + // the rest of the methods + generate_cocoa_struct_field_accessor_implementations(out, tstruct, is_exception); + generate_cocoa_struct_reader(out, tstruct); + if (is_result) { + generate_cocoa_struct_result_writer(out, tstruct); + } else { + generate_cocoa_struct_writer(out, tstruct); + } + generate_cocoa_struct_description(out, tstruct); + + out << "@end" << endl << endl; +} + + +/** + * Generates a function to read all the fields of the struct. + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_reader(ofstream& out, + t_struct* tstruct) { + out << + "- (void) read: (id ) inProtocol" << endl; + scope_up(out); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // Declare stack tmp variables + indent(out) << "NSString * fieldName;" << endl; + indent(out) << "int fieldType;" << endl; + indent(out) << "int fieldID;" << endl; + out << endl; + + indent(out) << "[inProtocol readStructBeginReturningName: NULL];" << endl; + + // Loop over reading in fields + indent(out) << + "while (true)" << endl; + scope_up(out); + + // Read beginning field marker + indent(out) << + "[inProtocol readFieldBeginReturningName: &fieldName type: &fieldType fieldID: &fieldID];" << endl; + + // Check for field STOP marker and break + indent(out) << + "if (fieldType == TType_STOP) { " << endl; + indent_up(); + indent(out) << + "break;" << endl; + indent_down(); + indent(out) << + "}" << endl; + + // Switch statement on the field we are reading + indent(out) << + "switch (fieldID)" << endl; + + scope_up(out); + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << (*f_iter)->get_key() << ":" << endl; + indent_up(); + indent(out) << + "if (fieldType == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); + + generate_deserialize_field(out, *f_iter, "fieldValue"); + indent(out) << call_field_setter(*f_iter, "fieldValue") << endl; + // if this is an allocated field, release it since the struct + // is now retaining it + if (type_can_be_null((*f_iter)->get_type())) { + // deserialized strings are autorelease, so don't release them + if (!(get_true_type((*f_iter)->get_type())->is_string())) { + indent(out) << "[fieldValue release];" << endl; + } + } + + indent_down(); + out << + indent() << "} else { " << endl << + indent() << " [TProtocolUtil skipType: fieldType onProtocol: inProtocol];" << endl << + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } + + // In the default case we skip the field + out << + indent() << "default:" << endl << + indent() << " [TProtocolUtil skipType: fieldType onProtocol: inProtocol];" << endl << + indent() << " break;" << endl; + + scope_down(out); + + // Read field end marker + indent(out) << + "[inProtocol readFieldEnd];" << endl; + + scope_down(out); + + out << + indent() << "[inProtocol readStructEnd];" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates a function to write all the fields of the struct + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_writer(ofstream& out, + t_struct* tstruct) { + out << + indent() << "- (void) write: (id ) outProtocol {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + out << + indent() << "[outProtocol writeStructBeginWithName: @\"" << name << "\"];" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + out << + indent() << "if (__" << (*f_iter)->get_name() << "_isset) {" << endl; + indent_up(); + bool null_allowed = type_can_be_null((*f_iter)->get_type()); + if (null_allowed) { + out << + indent() << "if (__" << (*f_iter)->get_name() << " != nil) {" << endl; + indent_up(); + } + + indent(out) << "[outProtocol writeFieldBeginWithName: @\"" << + (*f_iter)->get_name() << "\" type: " << type_to_enum((*f_iter)->get_type()) << + " fieldID: " << (*f_iter)->get_key() << "];" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "__"+(*f_iter)->get_name()); + + // Write field closer + indent(out) << + "[outProtocol writeFieldEnd];" << endl; + + if (null_allowed) { + scope_down(out); + } + scope_down(out); + } + // Write the struct map + out << + indent() << "[outProtocol writeFieldStop];" << endl << + indent() << "[outProtocol writeStructEnd];" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates a function to write all the fields of the struct, which + * is a function result. These fields are only written if they are + * set, and only one of them can be set at a time. + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_result_writer(ofstream& out, + t_struct* tstruct) { + out << + indent() << "- (void) write: (id ) outProtocol {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + out << + indent() << "[outProtocol writeStructBeginWithName: @\"" << name << "\"];" << endl; + + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + out << + endl << + indent() << "if "; + } else { + out << + " else if "; + } + + out << + "(__" << (*f_iter)->get_name() << "_isset) {" << endl; + indent_up(); + + bool null_allowed = type_can_be_null((*f_iter)->get_type()); + if (null_allowed) { + out << + indent() << "if (__" << (*f_iter)->get_name() << " != nil) {" << endl; + indent_up(); + } + + indent(out) << "[outProtocol writeFieldBeginWithName: @\"" << + (*f_iter)->get_name() << "\" type: " << type_to_enum((*f_iter)->get_type()) << + " fieldID: " << (*f_iter)->get_key() << "];" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "__"+(*f_iter)->get_name()); + + // Write field closer + indent(out) << + "[outProtocol writeFieldEnd];" << endl; + + if (null_allowed) { + indent_down(); + indent(out) << "}" << endl; + } + + indent_down(); + indent(out) << "}"; + } + // Write the struct map + out << + endl << + indent() << "[outProtocol writeFieldStop];" << endl << + indent() << "[outProtocol writeStructEnd];" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generate property accessor methods for all fields in the struct. + * getter, setter, isset getter. + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_field_accessor_implementations(ofstream& out, + t_struct* tstruct, + bool is_exception) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = *f_iter; + t_type* type = get_true_type(field->get_type()); + std::string field_name = field->get_name(); + std::string cap_name = field_name; + cap_name[0] = toupper(cap_name[0]); + + // Simple getter + indent(out) << "- (" << type_name(type) << ") "; + out << field_name << " {" << endl; + indent_up(); + if (!type_can_be_null(type)) { + indent(out) << "return __" << field_name << ";" << endl; + } else { + indent(out) << "return [[__" << field_name << " retain] autorelease];" << endl; + } + indent_down(); + indent(out) << "}" << endl << endl; + + // Simple setter + indent(out) << "- (void) set" << cap_name << ": (" << type_name(type) << + ") " << field_name << " {" << endl; + indent_up(); + if (!type_can_be_null(type)) { + indent(out) << "__" << field_name << " = " << field_name << ";" << endl; + } else { + indent(out) << "[" << field_name << " retain];" << endl; + indent(out) << "[__" << field_name << " release];" << endl; + indent(out) << "__" << field_name << " = " << field_name << ";" << endl; + } + indent(out) << "__" << field_name << "_isset = YES;" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + // IsSet + indent(out) << "- (BOOL) " << field_name << "IsSet {" << endl; + indent_up(); + indent(out) << "return __" << field_name << "_isset;" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + // Unsetter - do we need this? + indent(out) << "- (void) unset" << cap_name << " {" << endl; + indent_up(); + if (type_can_be_null(type)) { + indent(out) << "[__" << field_name << " release];" << endl; + indent(out) << "__" << field_name << " = nil;" << endl; + } + indent(out) << "__" << field_name << "_isset = NO;" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + } +} + +/** + * Generates a description method for the given struct + * + * @param tstruct The struct definition + */ +void t_cocoa_generator::generate_cocoa_struct_description(ofstream& out, + t_struct* tstruct) { + out << + indent() << "- (NSString *) description {" << endl; + indent_up(); + + out << + indent() << "NSMutableString * ms = [NSMutableString stringWithString: @\"" << + tstruct->get_name() << "(\"];" << endl; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + indent(out) << "[ms appendString: @\"" << (*f_iter)->get_name() << ":\"];" << endl; + } else { + indent(out) << "[ms appendString: @\"," << (*f_iter)->get_name() << ":\"];" << endl; + } + t_type* ttype = (*f_iter)->get_type(); + indent(out) << "[ms appendFormat: @\"" << format_string_for_type(ttype) << "\", __" << + (*f_iter)->get_name() << "];" << endl; + } + out << + indent() << "[ms appendString: @\")\"];" << endl << + indent() << "return [ms copy];" << endl; + + indent_down(); + indent(out) << "}" << endl << + endl; +} + + +/** + * Generates a thrift service. In Objective-C this consists of a + * protocol definition, a client interface and a client implementation. + * + * @param tservice The service definition + */ +void t_cocoa_generator::generate_service(t_service* tservice) { + generate_cocoa_service_protocol(f_header_, tservice); + generate_cocoa_service_client_interface(f_header_, tservice); + generate_cocoa_service_helpers(tservice); + generate_cocoa_service_client_implementation(f_impl_, tservice); +} + + +/** + * Generates structs for all the service return types + * + * @param tservice The service + */ +void t_cocoa_generator::generate_cocoa_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_function_helpers(*f_iter); + } +} + +string t_cocoa_generator::function_result_helper_struct_type(t_function* tfunction) { + return capitalize(tfunction->get_name()) + "Result_"; +} + + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_cocoa_generator::generate_function_helpers(t_function* tfunction) { + if (tfunction->is_oneway()) { + return; + } + + // create a result struct with a success field of the return type, + // and a field for each type of exception thrown + t_struct result(program_, function_result_helper_struct_type(tfunction)); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + // generate the result struct + generate_cocoa_struct_interface(f_impl_, &result, false); + generate_cocoa_struct_implementation(f_impl_, &result, false, true); +} + +/** + * Generates a service protocol definition. + * + * @param tservice The service to generate a protocol definition for + */ +void t_cocoa_generator::generate_cocoa_service_protocol(ofstream& out, + t_service* tservice) { + out << "@protocol " << cocoa_prefix_ << tservice->get_name() << " " << endl; + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + out << "- " << function_signature(*f_iter) << ";" << + " // throws "; + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + out << type_name((*x_iter)->get_type()) + ", "; + } + out << "TException" << endl; + } + out << "@end" << endl << endl; +} + + +/** + * Generates a service client interface definition. + * + * @param tservice The service to generate a client interface definition for + */ +void t_cocoa_generator::generate_cocoa_service_client_interface(ofstream& out, + t_service* tservice) { + out << "@interface " << cocoa_prefix_ << tservice->get_name() << "Client : NSObject <" << + cocoa_prefix_ << tservice->get_name() << "> "; + + scope_up(out); + out << indent() << "id inProtocol;" << endl; + out << indent() << "id outProtocol;" << endl; + scope_down(out); + + out << "- (id) initWithProtocol: (id ) protocol;" << endl; + out << "- (id) initWithInProtocol: (id ) inProtocol outProtocol: (id ) outProtocol;" << endl; + out << "@end" << endl << endl; +} + + +/** + * Generates a service client implementation. + * + * @param tservice The service to generate an implementation for + */ +void t_cocoa_generator::generate_cocoa_service_client_implementation(ofstream& out, + t_service* tservice) { + out << "@implementation " << cocoa_prefix_ << tservice->get_name() << "Client" << endl; + + // initializers + out << "- (id) initWithProtocol: (id ) protocol" << endl; + scope_up(out); + out << indent() << "return [self initWithInProtocol: protocol outProtocol: protocol];" << endl; + scope_down(out); + out << endl; + + out << "- (id) initWithInProtocol: (id ) anInProtocol outProtocol: (id ) anOutProtocol" << endl; + scope_up(out); + out << indent() << "[super init];" << endl; + out << indent() << "inProtocol = [anInProtocol retain];" << endl; + out << indent() << "outProtocol = [anOutProtocol retain];" << endl; + out << indent() << "return self;" << endl; + scope_down(out); + out << endl; + + // dealloc + out << "- (void) dealloc" << endl; + scope_up(out); + out << indent() << "[inProtocol release];" << endl; + out << indent() << "[outProtocol release];" << endl; + out << indent() << "[super dealloc];" << endl; + scope_down(out); + out << endl; + + // generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + + t_function send_function(g_type_void, + string("send_") + (*f_iter)->get_name(), + (*f_iter)->get_arglist()); + + string argsname = (*f_iter)->get_name() + "_args"; + + // Open function + indent(out) << + "- " << function_signature(&send_function) << endl; + scope_up(out); + + // Serialize the request + out << + indent() << "[outProtocol writeMessageBeginWithName: @\"" << funname << "\"" << + " type: TMessageType_CALL" << + " sequenceID: 0];" << endl; + + out << + indent() << "[outProtocol writeStructBeginWithName: @\"" << argsname << "\"];" << endl; + + // write out function parameters + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + string fieldName = (*fld_iter)->get_name(); + if (type_can_be_null((*fld_iter)->get_type())) { + out << indent() << "if (" << fieldName << " != nil)"; + scope_up(out); + } + out << + indent() << "[outProtocol writeFieldBeginWithName: @\"" << fieldName << "\"" + " type: " << type_to_enum((*fld_iter)->get_type()) << + " fieldID: " << (*fld_iter)->get_key() << "];" << endl; + + generate_serialize_field(out, *fld_iter, fieldName); + + out << + indent() << "[outProtocol writeFieldEnd];" << endl; + + if (type_can_be_null((*fld_iter)->get_type())) { + scope_down(out); + } + } + + out << + indent() << "[outProtocol writeFieldStop];" << endl; + out << + indent() << "[outProtocol writeStructEnd];" << endl; + + out << + indent() << "[outProtocol writeMessageEnd];" << endl << + indent() << "[[outProtocol transport] flush];" << endl; + + scope_down(out); + out << endl; + + if (!(*f_iter)->is_oneway()) { + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs, + (*f_iter)->get_xceptions()); + // Open function + indent(out) << + "- " << function_signature(&recv_function) << endl; + scope_up(out); + + // TODO(mcslee): Message validation here, was the seqid etc ok? + + // check for an exception + out << + indent() << "int msgType = 0;" << endl << + indent() << "[inProtocol readMessageBeginReturningName: nil type: &msgType sequenceID: NULL];" << endl << + indent() << "if (msgType == TMessageType_EXCEPTION) {" << endl << + indent() << " TApplicationException * x = [TApplicationException read: inProtocol];" << endl << + indent() << " [inProtocol readMessageEnd];" << endl << + indent() << " @throw x;" << endl << + indent() << "}" << endl; + + // FIXME - could optimize here to reduce creation of temporary objects. + string resultname = function_result_helper_struct_type(*f_iter); + out << + indent() << cocoa_prefix_ << resultname << " * result = [[[" << cocoa_prefix_ << + resultname << " alloc] init] autorelease];" << endl; + indent(out) << "[result read: inProtocol];" << endl; + indent(out) << "[inProtocol readMessageEnd];" << endl; + + // Careful, only return _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + out << + indent() << "if ([result successIsSet]) {" << endl << + indent() << " return [result success];" << endl << + indent() << "}" << endl; + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + out << + indent() << "if ([result " << (*x_iter)->get_name() << "IsSet]) {" << endl << + indent() << " @throw [result " << (*x_iter)->get_name() << "];" << endl << + indent() << "}" << endl; + } + + // If you get here it's an exception, unless a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(out) << + "return;" << endl; + } else { + out << + indent() << "@throw [TApplicationException exceptionWithType: TApplicationException_MISSING_RESULT" << endl << + indent() << " reason: @\"" << (*f_iter)->get_name() << " failed: unknown result\"];" << endl; + } + + // Close function + scope_down(out); + out << endl; + } + + // Open function + indent(out) << + "- " << function_signature(*f_iter) << endl; + scope_up(out); + indent(out) << + "[self send_" << funname; + + // Declare the function arguments + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + out << " "; + } + out << ": " << (*fld_iter)->get_name(); + } + out << "];" << endl; + + if (!(*f_iter)->is_oneway()) { + out << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + out << "return "; + } + out << + "[self recv_" << funname << "];" << endl; + } + scope_down(out); + out << endl; + } + + indent_down(); + + out << "@end" << endl << endl; +} + + +/** + * Deserializes a field of any type. + * + * @param tfield The field + * @param fieldName The variable name for this field + */ +void t_cocoa_generator::generate_deserialize_field(ofstream& out, + t_field* tfield, + string fieldName) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type, + fieldName); + } else if (type->is_container()) { + generate_deserialize_container(out, type, fieldName); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + type_name(type) << " " << fieldName << " = [inProtocol "; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + tfield->get_name(); + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "readBinary];"; + } else { + out << "readString];"; + } + break; + case t_base_type::TYPE_BOOL: + out << "readBool];"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte];"; + break; + case t_base_type::TYPE_I16: + out << "readI16];"; + break; + case t_base_type::TYPE_I32: + out << "readI32];"; + break; + case t_base_type::TYPE_I64: + out << "readI64];"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble];"; + break; + default: + throw "compiler error: no Objective-C name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32];"; + } + out << + endl; + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type_name(type).c_str()); + } +} + +/** + * Generates an unserializer for a struct, allocates the struct and invokes read: + */ +void t_cocoa_generator::generate_deserialize_struct(ofstream& out, + t_struct* tstruct, + string fieldName) { + indent(out) << type_name(tstruct) << fieldName << " = [[" << + type_name(tstruct, true) << " alloc] init];" << endl; + indent(out) << "[" << fieldName << " read: inProtocol];" << endl; +} + +/** + * Deserializes a container by reading its size and then iterating + */ +void t_cocoa_generator::generate_deserialize_container(ofstream& out, + t_type* ttype, + string fieldName) { + string size = tmp("_size"); + indent(out) << "int " << size << ";" << endl; + + // Declare variables, read header + if (ttype->is_map()) { + indent(out) + << "[inProtocol readMapBeginReturningKeyType: NULL valueType: NULL size: &" << + size << "];" << endl; + indent(out) << "NSMutableDictionary * " << fieldName << + " = [[NSMutableDictionary alloc] initWithCapacity: " << size << "];" << endl; + } else if (ttype->is_set()) { + indent(out) + << "[inProtocol readSetBeginReturningElementType: NULL size: &" << size << "];" << endl; + indent(out) << "NSMutableSet * " << fieldName << + " = [[NSMutableSet alloc] initWithCapacity: " << size << "];" << endl; + } else if (ttype->is_list()) { + indent(out) + << "[inProtocol readListBeginReturningElementType: NULL size: &" << size << "];" << endl; + indent(out) << "NSMutableArray * " << fieldName << + " = [[NSMutableArray alloc] initWithCapacity: " << size << "];" << endl; + } + // FIXME - the code above does not verify that the element types of + // the containers being read match the element types of the + // containers we are reading into. Does that matter? + + // For loop iterates over elements + string i = tmp("_i"); + indent(out) << "int " << i << ";" << endl << + indent() << "for (" << i << " = 0; " << + i << " < " << size << "; " << + "++" << i << ")" << endl; + + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, fieldName); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, fieldName); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, fieldName); + } + + scope_down(out); + + // Read container end + if (ttype->is_map()) { + indent(out) << "[inProtocol readMapEnd];" << endl; + } else if (ttype->is_set()) { + indent(out) << "[inProtocol readSetEnd];" << endl; + } else if (ttype->is_list()) { + indent(out) << "[inProtocol readListEnd];" << endl; + } + +} + + +/** + * Take a variable of a given type and wrap it in code to make it + * suitable for putting into a container, if necessary. Basically, + * wrap scaler primitives in NSNumber objects. + */ +string t_cocoa_generator::containerize(t_type * ttype, + string fieldName) +{ + // FIXME - optimize here to avoid autorelease pool? + ttype = get_true_type(ttype); + if (ttype->is_enum()) { + return "[NSNumber numberWithInt: " + fieldName + "]"; + } else if (ttype->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "can't containerize void"; + case t_base_type::TYPE_BOOL: + return "[NSNumber numberWithBool: " + fieldName + "]"; + case t_base_type::TYPE_BYTE: + return "[NSNumber numberWithUnsignedChar: " + fieldName + "]"; + case t_base_type::TYPE_I16: + return "[NSNumber numberWithShort: " + fieldName + "]"; + case t_base_type::TYPE_I32: + return "[NSNumber numberWithLong: " + fieldName + "]"; + case t_base_type::TYPE_I64: + return "[NSNumber numberWithLongLong: " + fieldName + "]"; + case t_base_type::TYPE_DOUBLE: + return "[NSNumber numberWithDouble: " + fieldName + "]"; + default: + break; + } + } + + // do nothing + return fieldName; +} + + +/** + * Generates code to deserialize a map element + */ +void t_cocoa_generator::generate_deserialize_map_element(ofstream& out, + t_map* tmap, + string fieldName) { + string key = tmp("_key"); + string val = tmp("_val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + generate_deserialize_field(out, &fkey, key); + generate_deserialize_field(out, &fval, val); + + indent(out) << + "[" << fieldName << " setObject: " << containerize(fval.get_type(), val) << + " forKey: " << containerize(fkey.get_type(), key) << "];" << endl; +} + +/** + * Deserializes a set element + */ +void t_cocoa_generator::generate_deserialize_set_element(ofstream& out, + t_set* tset, + string fieldName) { + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + generate_deserialize_field(out, &felem, elem); + + indent(out) << + "[" << fieldName << " addObject: " << containerize(felem.get_type(), elem) << "];" << endl; +} + +/** + * Deserializes a list element + */ +void t_cocoa_generator::generate_deserialize_list_element(ofstream& out, + t_list* tlist, + string fieldName) { + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + + generate_deserialize_field(out, &felem, elem); + + indent(out) << + "[" << fieldName << " addObject: " << containerize(felem.get_type(), elem) << "];" << endl; +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param fieldName Name to of the variable holding the field + */ +void t_cocoa_generator::generate_serialize_field(ofstream& out, + t_field* tfield, + string fieldName) { + t_type* type = get_true_type(tfield->get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + fieldName); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + fieldName); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + "[outProtocol "; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + fieldName; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "writeBinary: " << fieldName << "];"; + } else { + out << "writeString: " << fieldName << "];"; + } + break; + case t_base_type::TYPE_BOOL: + out << "writeBool: " << fieldName << "];"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte: " << fieldName << "];"; + break; + case t_base_type::TYPE_I16: + out << "writeI16: " << fieldName << "];"; + break; + case t_base_type::TYPE_I32: + out << "writeI32: " << fieldName << "];"; + break; + case t_base_type::TYPE_I64: + out << "writeI64: " << fieldName << "];"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble: " << fieldName << "];"; + break; + default: + throw "compiler error: no Java name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32: " << fieldName << "];"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), + type_name(type).c_str()); + } +} + +/** + * Serialize a struct. + * + * @param tstruct The struct to serialize + * @param fieldName Name of variable holding struct + */ +void t_cocoa_generator::generate_serialize_struct(ofstream& out, + t_struct* tstruct, + string fieldName) { + out << + indent() << "[" << fieldName << " write: outProtocol];" << endl; +} + +/** + * Serializes a container by writing its size then the elements. + * + * @param ttype The type of container + * @param fieldName Name of variable holding container + */ +void t_cocoa_generator::generate_serialize_container(ofstream& out, + t_type* ttype, + string fieldName) { + scope_up(out); + + if (ttype->is_map()) { + indent(out) << + "[outProtocol writeMapBeginWithKeyType: " << + type_to_enum(((t_map*)ttype)->get_key_type()) << " valueType: " << + type_to_enum(((t_map*)ttype)->get_val_type()) << " size: [" << + fieldName << " count]];" << endl; + } else if (ttype->is_set()) { + indent(out) << + "[outProtocol writeSetBeginWithElementType: " << + type_to_enum(((t_set*)ttype)->get_elem_type()) << " size: [" << + fieldName << " count]];" << endl; + } else if (ttype->is_list()) { + indent(out) << + "[outProtocol writeListBeginWithElementType: " << + type_to_enum(((t_list*)ttype)->get_elem_type()) << " size: [" << + fieldName << " count]];" << endl; + } + + string iter = tmp("_iter"); + string key; + if (ttype->is_map()) { + key = tmp("key"); + indent(out) << "NSEnumerator * " << iter << " = [" << fieldName << " keyEnumerator];" << endl; + indent(out) << "id " << key << ";" << endl; + indent(out) << "while ((" << key << " = [" << iter << " nextObject]))" << endl; + } else if (ttype->is_set()) { + key = tmp("obj"); + indent(out) << "NSEnumerator * " << iter << " = [" << fieldName << " objectEnumerator];" << endl; + indent(out) << "id " << key << ";" << endl; + indent(out) << "while ((" << key << " = [" << iter << " nextObject]))" << endl; + } else if (ttype->is_list()) { + key = tmp("i"); + indent(out) << "int " << key << ";" << endl; + indent(out) << + "for (" << key << " = 0; " << key << " < [" << fieldName << " count]; " << key << "++)" << endl; + } + + scope_up(out); + + if (ttype->is_map()) { + generate_serialize_map_element(out, (t_map*)ttype, key, fieldName); + } else if (ttype->is_set()) { + generate_serialize_set_element(out, (t_set*)ttype, key); + } else if (ttype->is_list()) { + generate_serialize_list_element(out, (t_list*)ttype, key, fieldName); + } + + scope_down(out); + + if (ttype->is_map()) { + indent(out) << + "[outProtocol writeMapEnd];" << endl; + } else if (ttype->is_set()) { + indent(out) << + "[outProtocol writeSetEnd];" << endl; + } else if (ttype->is_list()) { + indent(out) << + "[outProtocol writeListEnd];" << endl; + } + + scope_down(out); +} + +/** + * Given a field variable name, wrap it in code that converts it to a + * primitive type, if necessary. + */ +string t_cocoa_generator::decontainerize(t_field * tfield, + string fieldName) +{ + t_type * ttype = get_true_type(tfield->get_type()); + if (ttype->is_enum()) { + return "[" + fieldName + " intValue]"; + } else if (ttype->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "can't decontainerize void"; + case t_base_type::TYPE_BOOL: + return "[" + fieldName + " boolValue]"; + case t_base_type::TYPE_BYTE: + return "[" + fieldName + " unsignedCharValue]"; + case t_base_type::TYPE_I16: + return "[" + fieldName + " shortValue]"; + case t_base_type::TYPE_I32: + return "[" + fieldName + " longValue]"; + case t_base_type::TYPE_I64: + return "[" + fieldName + " longLongValue]"; + case t_base_type::TYPE_DOUBLE: + return "[" + fieldName + " doubleValue]"; + default: + break; + } + } + + // do nothing + return fieldName; +} + + +/** + * Serializes the members of a map. + */ +void t_cocoa_generator::generate_serialize_map_element(ofstream& out, + t_map* tmap, + string key, + string mapName) { + t_field kfield(tmap->get_key_type(), key); + generate_serialize_field(out, &kfield, decontainerize(&kfield, key)); + t_field vfield(tmap->get_val_type(), "[" + mapName + " objectForKey: " + key + "]"); + generate_serialize_field(out, &vfield, decontainerize(&vfield, vfield.get_name())); +} + +/** + * Serializes the members of a set. + */ +void t_cocoa_generator::generate_serialize_set_element(ofstream& out, + t_set* tset, + string elementName) { + t_field efield(tset->get_elem_type(), elementName); + generate_serialize_field(out, &efield, decontainerize(&efield, elementName)); +} + +/** + * Serializes the members of a list. + */ +void t_cocoa_generator::generate_serialize_list_element(ofstream& out, + t_list* tlist, + string index, + string listName) { + t_field efield(tlist->get_elem_type(), "[" + listName + " objectAtIndex: " + index + "]"); + generate_serialize_field(out, &efield, decontainerize(&efield, efield.get_name())); +} + + +/** + * Returns an Objective-C name + * + * @param ttype The type + * @param class_ref Do we want a Class reference istead of a type reference? + * @return Java type name, i.e. HashMap + */ +string t_cocoa_generator::type_name(t_type* ttype, bool class_ref) { + if (ttype->is_typedef()) { + return cocoa_prefix_ + ttype->get_name(); + } + + string result; + if (ttype->is_base_type()) { + return base_type_name((t_base_type*)ttype); + } else if (ttype->is_enum()) { + return "int"; + } else if (ttype->is_map()) { + result = "NSDictionary"; + } else if (ttype->is_set()) { + result = "NSSet"; + } else if (ttype->is_list()) { + result = "NSArray"; + } else { + // Check for prefix + t_program* program = ttype->get_program(); + if (program != NULL) { + result = program->get_namespace("cocoa") + ttype->get_name(); + } else { + result = ttype->get_name(); + } + } + + if (!class_ref) { + result += " *"; + } + return result; +} + +/** + * Returns the Objective-C type that corresponds to the thrift type. + * + * @param tbase The base type + */ +string t_cocoa_generator::base_type_name(t_base_type* type) { + t_base_type::t_base tbase = type->get_base(); + + switch (tbase) { + case t_base_type::TYPE_VOID: + return "void"; + case t_base_type::TYPE_STRING: + if (type->is_binary()) { + return "NSData *"; + } else { + return "NSString *"; + } + case t_base_type::TYPE_BOOL: + return "BOOL"; + case t_base_type::TYPE_BYTE: + return "uint8_t"; + case t_base_type::TYPE_I16: + return"int16_t"; + case t_base_type::TYPE_I32: + return "int32_t"; + case t_base_type::TYPE_I64: + return"int64_t"; + case t_base_type::TYPE_DOUBLE: + return "double"; + default: + throw "compiler error: no objective-c name for base type " + t_base_type::t_base_name(tbase); + } +} + + +/** + * Spit out code that evaluates to the specified constant value. + */ +string t_cocoa_generator::render_const_value(string name, + t_type* type, + t_const_value* value, + bool containerize_it) { + type = get_true_type(type); + std::ostringstream render; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + render << "@\"" << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + render << ((value->get_integer() > 0) ? "YES" : "NO"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + render << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + render << value->get_integer(); + } else { + render << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + render << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + render << "[[" << type_name(type, true) << " alloc] initWith"; + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + bool first = true; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + if (first) { + render << capitalize(v_iter->first->get_string()); + first = false; + } else { + render << " " << v_iter->first->get_string(); + } + render << ": " << render_const_value(name, field_type, v_iter->second); + } + render << "]"; + } else if (type->is_map()) { + render << "[[NSDictionary alloc] initWithObjectsAndKeys: "; + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + bool first = true; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(name, ktype, v_iter->first, true); + string val = render_const_value(name, vtype, v_iter->second, true); + if (first) { + first = false; + } else { + render << ", "; + } + render << val << ", " << key; + } + render << ", nil]"; + } else if (type->is_list()) { + render << "[[NSArray alloc] initWithObjects: "; + t_type * etype = ((t_list*)type)->get_elem_type(); + const vector& val = value->get_list(); + bool first = true; + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (first) { + first = false; + } else { + render << ", "; + } + render << render_const_value(name, etype, *v_iter, true); + } + render << ", nil]"; + } else if (type->is_set()) { + render << "[[NSSet alloc] initWithObjects: "; + t_type * etype = ((t_set*)type)->get_elem_type(); + const vector& val = value->get_list(); + bool first = true; + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (first) { + first = false; + } else { + render << ", "; + } + render << render_const_value(name, etype, *v_iter, true); + } + render << ", nil]"; + } else { + throw "don't know how to render constant for type: " + type->get_name(); + } + + if (containerize_it) { + return containerize(type, render.str()); + } + + return render.str(); +} + + +/** + * Declares a field. + * + * @param ttype The type + */ +string t_cocoa_generator::declare_field(t_field* tfield) { + return type_name(tfield->get_type()) + " __" + tfield->get_name() + ";"; +} + +/** + * Renders a function signature + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_cocoa_generator::function_signature(t_function* tfunction) { + t_type* ttype = tfunction->get_returntype(); + std::string result = + "(" + type_name(ttype) + ") " + tfunction->get_name() + argument_list(tfunction->get_arglist()); + return result; +} + + +/** + * Renders a colon separated list of types and names, suitable for an + * objective-c parameter list + */ +string t_cocoa_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += " "; + } + result += ": (" + type_name((*f_iter)->get_type()) + ") " + (*f_iter)->get_name(); + } + return result; +} + + +/** + * Converts the parse type to an Objective-C enum string for the given type. + */ +string t_cocoa_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType_STRING"; + case t_base_type::TYPE_BOOL: + return "TType_BOOL"; + case t_base_type::TYPE_BYTE: + return "TType_BYTE"; + case t_base_type::TYPE_I16: + return "TType_I16"; + case t_base_type::TYPE_I32: + return "TType_I32"; + case t_base_type::TYPE_I64: + return "TType_I64"; + case t_base_type::TYPE_DOUBLE: + return "TType_DOUBLE"; + } + } else if (type->is_enum()) { + return "TType_I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType_STRUCT"; + } else if (type->is_map()) { + return "TType_MAP"; + } else if (type->is_set()) { + return "TType_SET"; + } else if (type->is_list()) { + return "TType_LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +/** + * Returns a format string specifier for the supplied parse type. + */ +string t_cocoa_generator::format_string_for_type(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "\\\"%@\\\""; + case t_base_type::TYPE_BOOL: + return "%i"; + case t_base_type::TYPE_BYTE: + return "%i"; + case t_base_type::TYPE_I16: + return "%hi"; + case t_base_type::TYPE_I32: + return "%i"; + case t_base_type::TYPE_I64: + return "%qi"; + case t_base_type::TYPE_DOUBLE: + return "%f"; + } + } else if (type->is_enum()) { + return "%i"; + } else if (type->is_struct() || type->is_xception()) { + return "%@"; + } else if (type->is_map()) { + return "%@"; + } else if (type->is_set()) { + return "%@"; + } else if (type->is_list()) { + return "%@"; + } + + throw "INVALID TYPE IN format_string_for_type: " + type->get_name(); +} + +/** + * Generate a call to a field's setter. + * + * @param tfield Field the setter is being called on + * @param fieldName Name of variable to pass to setter + */ + +string t_cocoa_generator::call_field_setter(t_field* tfield, string fieldName) { + return "[self set" + capitalize(tfield->get_name()) + ": " + fieldName + "];"; +} + + +THRIFT_REGISTER_GENERATOR(cocoa, "Cocoa", ""); diff --git a/compiler/cpp/src/generate/t_cpp_generator.cc b/compiler/cpp/src/generate/t_cpp_generator.cc new file mode 100644 index 00000000..67a4bd4e --- /dev/null +++ b/compiler/cpp/src/generate/t_cpp_generator.cc @@ -0,0 +1,3003 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include "platform.h" +#include "t_oop_generator.h" +using namespace std; + + +/** + * C++ code generator. This is legitimacy incarnate. + * + */ +class t_cpp_generator : public t_oop_generator { + public: + t_cpp_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + std::map::const_iterator iter; + + iter = parsed_options.find("dense"); + gen_dense_ = (iter != parsed_options.end()); + + iter = parsed_options.find("include_prefix"); + use_include_prefix_ = (iter != parsed_options.end()); + + out_dir_base_ = "gen-cpp"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + void generate_consts(std::vector consts); + + /** + * Program-level generation functions + */ + + void generate_typedef(t_typedef* ttypedef); + void generate_enum(t_enum* tenum); + void generate_struct(t_struct* tstruct) { + generate_cpp_struct(tstruct, false); + } + void generate_xception(t_struct* txception) { + generate_cpp_struct(txception, true); + } + void generate_cpp_struct(t_struct* tstruct, bool is_exception); + + void generate_service(t_service* tservice); + + void print_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + std::string render_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + + void generate_struct_definition (std::ofstream& out, t_struct* tstruct, bool is_exception=false, bool pointers=false, bool read=true, bool write=true); + void generate_struct_fingerprint (std::ofstream& out, t_struct* tstruct, bool is_definition); + void generate_struct_reader (std::ofstream& out, t_struct* tstruct, bool pointers=false); + void generate_struct_writer (std::ofstream& out, t_struct* tstruct, bool pointers=false); + void generate_struct_result_writer (std::ofstream& out, t_struct* tstruct, bool pointers=false); + + /** + * Service-level generation functions + */ + + void generate_service_interface (t_service* tservice); + void generate_service_null (t_service* tservice); + void generate_service_multiface (t_service* tservice); + void generate_service_helpers (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_processor (t_service* tservice); + void generate_service_skeleton (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + void generate_function_helpers (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream& out, + t_field* tfield, + std::string prefix="", + std::string suffix=""); + + void generate_deserialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream& out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream& out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream& out, + t_list* tlist, + std::string prefix, + bool push_back, + std::string index); + + void generate_serialize_field (std::ofstream& out, + t_field* tfield, + std::string prefix="", + std::string suffix=""); + + void generate_serialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream& out, + t_map* tmap, + std::string iter); + + void generate_serialize_set_element (std::ofstream& out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream& out, + t_list* tlist, + std::string iter); + + /** + * Helper rendering functions + */ + + std::string namespace_prefix(std::string ns); + std::string namespace_open(std::string ns); + std::string namespace_close(std::string ns); + std::string type_name(t_type* ttype, bool in_typedef=false, bool arg=false); + std::string base_type_name(t_base_type::t_base tbase); + std::string declare_field(t_field* tfield, bool init=false, bool pointer=false, bool constant=false, bool reference=false); + std::string function_signature(t_function* tfunction, std::string prefix="", bool name_params=true); + std::string argument_list(t_struct* tstruct, bool name_params=true); + std::string type_to_enum(t_type* ttype); + std::string local_reflection_name(const char*, t_type* ttype, bool external=false); + + // These handles checking gen_dense_ and checking for duplicates. + void generate_local_reflection(std::ofstream& out, t_type* ttype, bool is_definition); + void generate_local_reflection_pointer(std::ofstream& out, t_type* ttype); + + bool is_complex_type(t_type* ttype) { + ttype = get_true_type(ttype); + + return + ttype->is_container() || + ttype->is_struct() || + ttype->is_xception() || + (ttype->is_base_type() && (((t_base_type*)ttype)->get_base() == t_base_type::TYPE_STRING)); + } + + void set_use_include_prefix(bool use_include_prefix) { + use_include_prefix_ = use_include_prefix; + } + + private: + /** + * Returns the include prefix to use for a file generated by program, or the + * empty string if no include prefix should be used. + */ + std::string get_include_prefix(const t_program& program) const; + + /** + * True iff we should generate local reflection metadata for TDenseProtocol. + */ + bool gen_dense_; + + /** + * True iff we should use a path prefix in our #include statements for other + * thrift-generated header files. + */ + bool use_include_prefix_; + + /** + * Strings for namespace, computed once up front then used directly + */ + + std::string ns_open_; + std::string ns_close_; + + /** + * File streams, stored here to avoid passing them as parameters to every + * function. + */ + + std::ofstream f_types_; + std::ofstream f_types_impl_; + std::ofstream f_header_; + std::ofstream f_service_; + + /** + * When generating local reflections, make sure we don't generate duplicates. + */ + std::set reflected_fingerprints_; +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_cpp_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + string f_types_name = get_out_dir()+program_name_+"_types.h"; + f_types_.open(f_types_name.c_str()); + + string f_types_impl_name = get_out_dir()+program_name_+"_types.cpp"; + f_types_impl_.open(f_types_impl_name.c_str()); + + // Print header + f_types_ << + autogen_comment(); + f_types_impl_ << + autogen_comment(); + + // Start ifndef + f_types_ << + "#ifndef " << program_name_ << "_TYPES_H" << endl << + "#define " << program_name_ << "_TYPES_H" << endl << + endl; + + // Include base types + f_types_ << + "#include " << endl << + "#include " << endl << + "#include " << endl << + endl; + + // Include other Thrift includes + const vector& includes = program_->get_includes(); + for (size_t i = 0; i < includes.size(); ++i) { + f_types_ << + "#include \"" << get_include_prefix(*(includes[i])) << + includes[i]->get_name() << "_types.h\"" << endl; + } + f_types_ << endl; + + // Include custom headers + const vector& cpp_includes = program_->get_cpp_includes(); + for (size_t i = 0; i < cpp_includes.size(); ++i) { + if (cpp_includes[i][0] == '<') { + f_types_ << + "#include " << cpp_includes[i] << endl; + } else { + f_types_ << + "#include \"" << cpp_includes[i] << "\"" << endl; + } + } + f_types_ << + endl; + + // Include the types file + f_types_impl_ << + "#include \"" << get_include_prefix(*get_program()) << program_name_ << + "_types.h\"" << endl << + endl; + + // If we are generating local reflection metadata, we need to include + // the definition of TypeSpec. + if (gen_dense_) { + f_types_impl_ << + "#include " << endl << + endl; + } + + // Open namespace + ns_open_ = namespace_open(program_->get_namespace("cpp")); + ns_close_ = namespace_close(program_->get_namespace("cpp")); + + f_types_ << + ns_open_ << endl << + endl; + + f_types_impl_ << + ns_open_ << endl << + endl; +} + +/** + * Closes the output files. + */ +void t_cpp_generator::close_generator() { + // Close namespace + f_types_ << + ns_close_ << endl << + endl; + f_types_impl_ << + ns_close_ << endl; + + // Close ifndef + f_types_ << + "#endif" << endl; + + // Close output file + f_types_.close(); + f_types_impl_.close(); +} + +/** + * Generates a typedef. This is just a simple 1-liner in C++ + * + * @param ttypedef The type definition + */ +void t_cpp_generator::generate_typedef(t_typedef* ttypedef) { + f_types_ << + indent() << "typedef " << type_name(ttypedef->get_type(), true) << " " << ttypedef->get_symbolic() << ";" << endl << + endl; +} + +/** + * Generates code for an enumerated type. In C++, this is essentially the same + * as the thrift definition itself, using the enum keyword in C++. + * + * @param tenum The enumeration + */ +void t_cpp_generator::generate_enum(t_enum* tenum) { + f_types_ << + indent() << "enum " << tenum->get_name() << " {" << endl; + indent_up(); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + bool first = true; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if (first) { + first = false; + } else { + f_types_ << + "," << endl; + } + f_types_ << + indent() << (*c_iter)->get_name(); + if ((*c_iter)->has_value()) { + f_types_ << + " = " << (*c_iter)->get_value(); + } + } + + indent_down(); + f_types_ << + endl << + "};" << endl << + endl; + + generate_local_reflection(f_types_, tenum, false); + generate_local_reflection(f_types_impl_, tenum, true); +} + +/** + * Generates a class that holds all the constants. + */ +void t_cpp_generator::generate_consts(std::vector consts) { + string f_consts_name = get_out_dir()+program_name_+"_constants.h"; + ofstream f_consts; + f_consts.open(f_consts_name.c_str()); + + string f_consts_impl_name = get_out_dir()+program_name_+"_constants.cpp"; + ofstream f_consts_impl; + f_consts_impl.open(f_consts_impl_name.c_str()); + + // Print header + f_consts << + autogen_comment(); + f_consts_impl << + autogen_comment(); + + // Start ifndef + f_consts << + "#ifndef " << program_name_ << "_CONSTANTS_H" << endl << + "#define " << program_name_ << "_CONSTANTS_H" << endl << + endl << + "#include \"" << get_include_prefix(*get_program()) << program_name_ << + "_types.h\"" << endl << + endl << + ns_open_ << endl << + endl; + + f_consts_impl << + "#include \"" << get_include_prefix(*get_program()) << program_name_ << + "_constants.h\"" << endl << + endl << + ns_open_ << endl << + endl; + + f_consts << + "class " << program_name_ << "Constants {" << endl << + " public:" << endl << + " " << program_name_ << "Constants();" << endl << + endl; + indent_up(); + vector::iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + string name = (*c_iter)->get_name(); + t_type* type = (*c_iter)->get_type(); + f_consts << + indent() << type_name(type) << " " << name << ";" << endl; + } + indent_down(); + f_consts << + "};" << endl; + + f_consts_impl << + "const " << program_name_ << "Constants g_" << program_name_ << "_constants;" << endl << + endl << + program_name_ << "Constants::" << program_name_ << "Constants() {" << endl; + indent_up(); + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + print_const_value(f_consts_impl, + (*c_iter)->get_name(), + (*c_iter)->get_type(), + (*c_iter)->get_value()); + } + indent_down(); + indent(f_consts_impl) << + "}" << endl; + + f_consts << + endl << + "extern const " << program_name_ << "Constants g_" << program_name_ << "_constants;" << endl << + endl << + ns_close_ << endl << + endl << + "#endif" << endl; + f_consts.close(); + + f_consts_impl << + endl << + ns_close_ << endl << + endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +void t_cpp_generator::print_const_value(ofstream& out, string name, t_type* type, t_const_value* value) { + type = get_true_type(type); + if (type->is_base_type()) { + string v2 = render_const_value(out, name, type, value); + indent(out) << name << " = " << v2 << ";" << endl << + endl; + } else if (type->is_enum()) { + indent(out) << name << " = (" << type_name(type) << ")" << value->get_integer() << ";" << endl << + endl; + } else if (type->is_struct() || type->is_xception()) { + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + string val = render_const_value(out, name, field_type, v_iter->second); + indent(out) << name << "." << v_iter->first->get_string() << " = " << val << ";" << endl; + indent(out) << name << ".__isset." << v_iter->first->get_string() << " = true;" << endl; + } + out << endl; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(out, name, ktype, v_iter->first); + string val = render_const_value(out, name, vtype, v_iter->second); + indent(out) << name << ".insert(std::make_pair(" << key << ", " << val << "));" << endl; + } + out << endl; + } else if (type->is_list()) { + t_type* etype = ((t_list*)type)->get_elem_type(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string val = render_const_value(out, name, etype, *v_iter); + indent(out) << name << ".push_back(" << val << ");" << endl; + } + out << endl; + } else if (type->is_set()) { + t_type* etype = ((t_set*)type)->get_elem_type(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string val = render_const_value(out, name, etype, *v_iter); + indent(out) << name << ".insert(" << val << ");" << endl; + } + out << endl; + } else { + throw "INVALID TYPE IN print_const_value: " + type->get_name(); + } +} + +/** + * + */ +string t_cpp_generator::render_const_value(ofstream& out, string name, t_type* type, t_const_value* value) { + std::ostringstream render; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + render << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + render << ((value->get_integer() > 0) ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + render << value->get_integer(); + break; + case t_base_type::TYPE_I64: + render << value->get_integer() << "LL"; + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + render << value->get_integer(); + } else { + render << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + render << "(" << type_name(type) << ")" << value->get_integer(); + } else { + string t = tmp("tmp"); + indent(out) << type_name(type) << " " << t << ";" << endl; + print_const_value(out, t, type, value); + render << t; + } + + return render.str(); +} + +/** + * Generates a struct definition for a thrift data type. This is a class + * with data members and a read/write() function, plus a mirroring isset + * inner class. + * + * @param tstruct The struct definition + */ +void t_cpp_generator::generate_cpp_struct(t_struct* tstruct, bool is_exception) { + generate_struct_definition(f_types_, tstruct, is_exception); + generate_struct_fingerprint(f_types_impl_, tstruct, true); + generate_local_reflection(f_types_, tstruct, false); + generate_local_reflection(f_types_impl_, tstruct, true); + generate_local_reflection_pointer(f_types_impl_, tstruct); + generate_struct_reader(f_types_impl_, tstruct); + generate_struct_writer(f_types_impl_, tstruct); +} + +/** + * Writes the struct definition into the header file + * + * @param out Output stream + * @param tstruct The struct + */ +void t_cpp_generator::generate_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception, + bool pointers, + bool read, + bool write) { + string extends = ""; + if (is_exception) { + extends = " : public apache::thrift::TException"; + } + + // Open struct def + out << + indent() << "class " << tstruct->get_name() << extends << " {" << endl << + indent() << " public:" << endl << + endl; + indent_up(); + + // Put the fingerprint up top for all to see. + generate_struct_fingerprint(out, tstruct, false); + + // Get members + vector::const_iterator m_iter; + const vector& members = tstruct->get_members(); + + if (!pointers) { + // Default constructor + indent(out) << + tstruct->get_name() << "()"; + + bool init_ctor = false; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if (t->is_base_type()) { + string dval; + if (t->is_enum()) { + dval += "(" + type_name(t) + ")"; + } + dval += t->is_string() ? "\"\"" : "0"; + t_const_value* cv = (*m_iter)->get_value(); + if (cv != NULL) { + dval = render_const_value(out, (*m_iter)->get_name(), t, cv); + } + if (!init_ctor) { + init_ctor = true; + out << " : "; + out << (*m_iter)->get_name() << "(" << dval << ")"; + } else { + out << ", " << (*m_iter)->get_name() << "(" << dval << ")"; + } + } + } + out << " {" << endl; + indent_up(); + // TODO(dreiss): When everything else in Thrift is perfect, + // do more of these in the initializer list. + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + + if (!t->is_base_type()) { + t_const_value* cv = (*m_iter)->get_value(); + if (cv != NULL) { + print_const_value(out, (*m_iter)->get_name(), t, cv); + } + } + } + scope_down(out); + } + + if (tstruct->annotations_.find("final") == tstruct->annotations_.end()) { + out << + endl << + indent() << "virtual ~" << tstruct->get_name() << "() throw() {}" << endl << endl; + } + + // Pointer to this structure's reflection local typespec. + if (gen_dense_) { + indent(out) << + "static apache::thrift::reflection::local::TypeSpec* local_reflection;" << + endl << endl; + } + + // Declare all fields + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + declare_field(*m_iter, false, pointers && !(*m_iter)->get_type()->is_xception(), !read) << endl; + } + + // Isset struct has boolean fields, but only for non-required fields. + bool has_nonrequired_fields = false; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if ((*m_iter)->get_req() != t_field::T_REQUIRED) + has_nonrequired_fields = true; + } + + if (has_nonrequired_fields && (!pointers || read)) { + out << + endl << + indent() << "struct __isset {" << endl; + indent_up(); + + indent(out) << + "__isset() : "; + bool first = true; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if ((*m_iter)->get_req() == t_field::T_REQUIRED) { + continue; + } + if (first) { + first = false; + out << + (*m_iter)->get_name() << "(false)"; + } else { + out << + ", " << (*m_iter)->get_name() << "(false)"; + } + } + out << " {}" << endl; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if ((*m_iter)->get_req() != t_field::T_REQUIRED) { + indent(out) << + "bool " << (*m_iter)->get_name() << ";" << endl; + } + } + + indent_down(); + indent(out) << + "} __isset;" << endl; + } + + out << endl; + + if (!pointers) { + // Generate an equality testing operator. Make it inline since the compiler + // will do a better job than we would when deciding whether to inline it. + out << + indent() << "bool operator == (const " << tstruct->get_name() << " & " << + (members.size() > 0 ? "rhs" : "/* rhs */") << ") const" << endl; + scope_up(out); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + // Most existing Thrift code does not use isset or optional/required, + // so we treat "default" fields as required. + if ((*m_iter)->get_req() != t_field::T_OPTIONAL) { + out << + indent() << "if (!(" << (*m_iter)->get_name() + << " == rhs." << (*m_iter)->get_name() << "))" << endl << + indent() << " return false;" << endl; + } else { + out << + indent() << "if (__isset." << (*m_iter)->get_name() + << " != rhs.__isset." << (*m_iter)->get_name() << ")" << endl << + indent() << " return false;" << endl << + indent() << "else if (__isset." << (*m_iter)->get_name() << " && !(" + << (*m_iter)->get_name() << " == rhs." << (*m_iter)->get_name() + << "))" << endl << + indent() << " return false;" << endl; + } + } + indent(out) << "return true;" << endl; + scope_down(out); + out << + indent() << "bool operator != (const " << tstruct->get_name() << " &rhs) const {" << endl << + indent() << " return !(*this == rhs);" << endl << + indent() << "}" << endl << endl; + + // Generate the declaration of a less-than operator. This must be + // implemented by the application developer if they wish to use it. (They + // will get a link error if they try to use it without an implementation.) + out << + indent() << "bool operator < (const " + << tstruct->get_name() << " & ) const;" << endl << endl; + } + if (read) { + out << + indent() << "uint32_t read(apache::thrift::protocol::TProtocol* iprot);" << endl; + } + if (write) { + out << + indent() << "uint32_t write(apache::thrift::protocol::TProtocol* oprot) const;" << endl; + } + out << endl; + + indent_down(); + indent(out) << + "};" << endl << + endl; +} + +/** + * Writes the fingerprint of a struct to either the header or implementation. + * + * @param out Output stream + * @param tstruct The struct + */ +void t_cpp_generator::generate_struct_fingerprint(ofstream& out, + t_struct* tstruct, + bool is_definition) { + string stat, nspace, comment; + if (is_definition) { + stat = ""; + nspace = tstruct->get_name() + "::"; + comment = " "; + } else { + stat = "static "; + nspace = ""; + comment = "; // "; + } + + if (tstruct->has_fingerprint()) { + out << + indent() << stat << "const char* " << nspace + << "ascii_fingerprint" << comment << "= \"" << + tstruct->get_ascii_fingerprint() << "\";" << endl << + indent() << stat << "const uint8_t " << nspace << + "binary_fingerprint[" << t_type::fingerprint_len << "]" << comment << "= {"; + const char* comma = ""; + for (int i = 0; i < t_type::fingerprint_len; i++) { + out << comma << "0x" << t_struct::byte_to_hex(tstruct->get_binary_fingerprint()[i]); + comma = ","; + } + out << "};" << endl << endl; + } +} + +/** + * Writes the local reflection of a type (either declaration or definition). + */ +void t_cpp_generator::generate_local_reflection(std::ofstream& out, + t_type* ttype, + bool is_definition) { + if (!gen_dense_) { + return; + } + ttype = get_true_type(ttype); + assert(ttype->has_fingerprint()); + string key = ttype->get_ascii_fingerprint() + (is_definition ? "-defn" : "-decl"); + // Note that we have generated this fingerprint. If we already did, bail out. + if (!reflected_fingerprints_.insert(key).second) { + return; + } + // Let each program handle its own structures. + if (ttype->get_program() != NULL && ttype->get_program() != program_) { + return; + } + + // Do dependencies. + if (ttype->is_list()) { + generate_local_reflection(out, ((t_list*)ttype)->get_elem_type(), is_definition); + } else if (ttype->is_set()) { + generate_local_reflection(out, ((t_set*)ttype)->get_elem_type(), is_definition); + } else if (ttype->is_map()) { + generate_local_reflection(out, ((t_map*)ttype)->get_key_type(), is_definition); + generate_local_reflection(out, ((t_map*)ttype)->get_val_type(), is_definition); + } else if (ttype->is_struct() || ttype->is_xception()) { + // Hacky hacky. For efficiency and convenience, we need a dummy "T_STOP" + // type at the end of our typespec array. Unfortunately, there is no + // T_STOP type, so we use the global void type, and special case it when + // generating its typespec. + + const vector& members = ((t_struct*)ttype)->get_sorted_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + generate_local_reflection(out, (**m_iter).get_type(), is_definition); + } + generate_local_reflection(out, g_type_void, is_definition); + + // For definitions of structures, do the arrays of metas and field specs also. + if (is_definition) { + out << + indent() << "apache::thrift::reflection::local::FieldMeta" << endl << + indent() << local_reflection_name("metas", ttype) <<"[] = {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << "{ " << (*m_iter)->get_key() << ", " << + (((*m_iter)->get_req() == t_field::T_OPTIONAL) ? "true" : "false") << + " }," << endl; + } + // Zero for the T_STOP marker. + indent(out) << "{ 0, false }" << endl << "};" << endl; + indent_down(); + + out << + indent() << "apache::thrift::reflection::local::TypeSpec*" << endl << + indent() << local_reflection_name("specs", ttype) <<"[] = {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << "&" << + local_reflection_name("typespec", (*m_iter)->get_type(), true) << "," << endl; + } + indent(out) << "&" << + local_reflection_name("typespec", g_type_void) << "," << endl; + indent_down(); + indent(out) << "};" << endl; + } + } + + out << + indent() << "// " << ttype->get_fingerprint_material() << endl << + indent() << (is_definition ? "" : "extern ") << + "apache::thrift::reflection::local::TypeSpec" << endl << + local_reflection_name("typespec", ttype) << + (is_definition ? "(" : ";") << endl; + + if (!is_definition) { + out << endl; + return; + } + + indent_up(); + + if (ttype->is_void()) { + indent(out) << "apache::thrift::protocol::T_STOP"; + } else { + indent(out) << type_to_enum(ttype); + } + + if (ttype->is_struct()) { + out << "," << endl << + indent() << type_name(ttype) << "::binary_fingerprint," << endl << + indent() << local_reflection_name("metas", ttype) << "," << endl << + indent() << local_reflection_name("specs", ttype); + } else if (ttype->is_list()) { + out << "," << endl << + indent() << "&" << local_reflection_name("typespec", ((t_list*)ttype)->get_elem_type()) << "," << endl << + indent() << "NULL"; + } else if (ttype->is_set()) { + out << "," << endl << + indent() << "&" << local_reflection_name("typespec", ((t_set*)ttype)->get_elem_type()) << "," << endl << + indent() << "NULL"; + } else if (ttype->is_map()) { + out << "," << endl << + indent() << "&" << local_reflection_name("typespec", ((t_map*)ttype)->get_key_type()) << "," << endl << + indent() << "&" << local_reflection_name("typespec", ((t_map*)ttype)->get_val_type()); + } + + out << ");" << endl << endl; + + indent_down(); +} + +/** + * Writes the structure's static pointer to its local reflection typespec + * into the implementation file. + */ +void t_cpp_generator::generate_local_reflection_pointer(std::ofstream& out, + t_type* ttype) { + if (!gen_dense_) { + return; + } + indent(out) << + "apache::thrift::reflection::local::TypeSpec* " << + ttype->get_name() << "::local_reflection = " << endl << + indent() << " &" << local_reflection_name("typespec", ttype) << ";" << + endl << endl; +} + +/** + * Makes a helper function to gen a struct reader. + * + * @param out Stream to write to + * @param tstruct The struct + */ +void t_cpp_generator::generate_struct_reader(ofstream& out, + t_struct* tstruct, + bool pointers) { + indent(out) << + "uint32_t " << tstruct->get_name() << "::read(apache::thrift::protocol::TProtocol* iprot) {" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // Declare stack tmp variables + out << + endl << + indent() << "uint32_t xfer = 0;" << endl << + indent() << "std::string fname;" << endl << + indent() << "apache::thrift::protocol::TType ftype;" << endl << + indent() << "int16_t fid;" << endl << + endl << + indent() << "xfer += iprot->readStructBegin(fname);" << endl << + endl << + indent() << "using apache::thrift::protocol::TProtocolException;" << endl << + endl; + + // Required variables aren't in __isset, so we need tmp vars to check them. + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_req() == t_field::T_REQUIRED) + indent(out) << "bool isset_" << (*f_iter)->get_name() << " = false;" << endl; + } + out << endl; + + + // Loop over reading in fields + indent(out) << + "while (true)" << endl; + scope_up(out); + + // Read beginning field marker + indent(out) << + "xfer += iprot->readFieldBegin(fname, ftype, fid);" << endl; + + // Check for field STOP marker + out << + indent() << "if (ftype == apache::thrift::protocol::T_STOP) {" << endl << + indent() << " break;" << endl << + indent() << "}" << endl; + + // Switch statement on the field we are reading + indent(out) << + "switch (fid)" << endl; + + scope_up(out); + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << (*f_iter)->get_key() << ":" << endl; + indent_up(); + indent(out) << + "if (ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); + + const char *isset_prefix = + ((*f_iter)->get_req() != t_field::T_REQUIRED) ? "this->__isset." : "isset_"; + +#if 0 + // This code throws an exception if the same field is encountered twice. + // We've decided to leave it out for performance reasons. + // TODO(dreiss): Generate this code and "if" it out to make it easier + // for people recompiling thrift to include it. + out << + indent() << "if (" << isset_prefix << (*f_iter)->get_name() << ")" << endl << + indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl; +#endif + + if (pointers && !(*f_iter)->get_type()->is_xception()) { + generate_deserialize_field(out, *f_iter, "(*(this->", "))"); + } else { + generate_deserialize_field(out, *f_iter, "this->"); + } + out << + indent() << isset_prefix << (*f_iter)->get_name() << " = true;" << endl; + indent_down(); + out << + indent() << "} else {" << endl << + indent() << " xfer += iprot->skip(ftype);" << endl << + // TODO(dreiss): Make this an option when thrift structs + // have a common base class. + // indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl << + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } + + // In the default case we skip the field + out << + indent() << "default:" << endl << + indent() << " xfer += iprot->skip(ftype);" << endl << + indent() << " break;" << endl; + + scope_down(out); + + // Read field end marker + indent(out) << + "xfer += iprot->readFieldEnd();" << endl; + + scope_down(out); + + out << + endl << + indent() << "xfer += iprot->readStructEnd();" << endl; + + // Throw if any required fields are missing. + // We do this after reading the struct end so that + // there might possibly be a chance of continuing. + out << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_req() == t_field::T_REQUIRED) + out << + indent() << "if (!isset_" << (*f_iter)->get_name() << ')' << endl << + indent() << " throw TProtocolException(TProtocolException::INVALID_DATA);" << endl; + } + + indent(out) << "return xfer;" << endl; + + indent_down(); + indent(out) << + "}" << endl << endl; +} + +/** + * Generates the write function. + * + * @param out Stream to write to + * @param tstruct The struct + */ +void t_cpp_generator::generate_struct_writer(ofstream& out, + t_struct* tstruct, + bool pointers) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << + "uint32_t " << tstruct->get_name() << "::write(apache::thrift::protocol::TProtocol* oprot) const {" << endl; + indent_up(); + + out << + indent() << "uint32_t xfer = 0;" << endl; + + indent(out) << + "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_req() == t_field::T_OPTIONAL) { + indent(out) << "if (this->__isset." << (*f_iter)->get_name() << ") {" << endl; + indent_up(); + } + // Write field header + out << + indent() << "xfer += oprot->writeFieldBegin(" << + "\"" << (*f_iter)->get_name() << "\", " << + type_to_enum((*f_iter)->get_type()) << ", " << + (*f_iter)->get_key() << ");" << endl; + // Write field contents + if (pointers) { + generate_serialize_field(out, *f_iter, "(*(this->", "))"); + } else { + generate_serialize_field(out, *f_iter, "this->"); + } + // Write field closer + indent(out) << + "xfer += oprot->writeFieldEnd();" << endl; + if ((*f_iter)->get_req() == t_field::T_OPTIONAL) { + indent_down(); + indent(out) << '}' << endl; + } + } + + // Write the struct map + out << + indent() << "xfer += oprot->writeFieldStop();" << endl << + indent() << "xfer += oprot->writeStructEnd();" << endl << + indent() << "return xfer;" << endl; + + indent_down(); + indent(out) << + "}" << endl << + endl; +} + +/** + * Struct writer for result of a function, which can have only one of its + * fields set and does a conditional if else look up into the __isset field + * of the struct. + * + * @param out Output stream + * @param tstruct The result struct + */ +void t_cpp_generator::generate_struct_result_writer(ofstream& out, + t_struct* tstruct, + bool pointers) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << + "uint32_t " << tstruct->get_name() << "::write(apache::thrift::protocol::TProtocol* oprot) const {" << endl; + indent_up(); + + out << + endl << + indent() << "uint32_t xfer = 0;" << endl << + endl; + + indent(out) << + "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl; + + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + out << + endl << + indent() << "if "; + } else { + out << + " else if "; + } + + out << "(this->__isset." << (*f_iter)->get_name() << ") {" << endl; + + indent_up(); + + // Write field header + out << + indent() << "xfer += oprot->writeFieldBegin(" << + "\"" << (*f_iter)->get_name() << "\", " << + type_to_enum((*f_iter)->get_type()) << ", " << + (*f_iter)->get_key() << ");" << endl; + // Write field contents + if (pointers) { + generate_serialize_field(out, *f_iter, "(*(this->", "))"); + } else { + generate_serialize_field(out, *f_iter, "this->"); + } + // Write field closer + indent(out) << "xfer += oprot->writeFieldEnd();" << endl; + + indent_down(); + indent(out) << "}"; + } + + // Write the struct map + out << + endl << + indent() << "xfer += oprot->writeFieldStop();" << endl << + indent() << "xfer += oprot->writeStructEnd();" << endl << + indent() << "return xfer;" << endl; + + indent_down(); + indent(out) << + "}" << endl << + endl; +} + +/** + * Generates a thrift service. In C++, this comprises an entirely separate + * header and source file. The header file defines the methods and includes + * the data types defined in the main header file, and the implementation + * file contains implementations of the basic printer and default interfaces. + * + * @param tservice The service definition + */ +void t_cpp_generator::generate_service(t_service* tservice) { + string svcname = tservice->get_name(); + + // Make output files + string f_header_name = get_out_dir()+svcname+".h"; + f_header_.open(f_header_name.c_str()); + + // Print header file includes + f_header_ << + autogen_comment(); + f_header_ << + "#ifndef " << svcname << "_H" << endl << + "#define " << svcname << "_H" << endl << + endl << + "#include " << endl << + "#include \"" << get_include_prefix(*get_program()) << program_name_ << + "_types.h\"" << endl; + + t_service* extends_service = tservice->get_extends(); + if (extends_service != NULL) { + f_header_ << + "#include \"" << get_include_prefix(*(extends_service->get_program())) << + extends_service->get_name() << ".h\"" << endl; + } + + f_header_ << + endl << + ns_open_ << endl << + endl; + + // Service implementation file includes + string f_service_name = get_out_dir()+svcname+".cpp"; + f_service_.open(f_service_name.c_str()); + f_service_ << + autogen_comment(); + f_service_ << + "#include \"" << get_include_prefix(*get_program()) << svcname << ".h\"" << + endl << + endl << + ns_open_ << endl << + endl; + + // Generate all the components + generate_service_interface(tservice); + generate_service_null(tservice); + generate_service_helpers(tservice); + generate_service_client(tservice); + generate_service_processor(tservice); + generate_service_multiface(tservice); + generate_service_skeleton(tservice); + + // Close the namespace + f_service_ << + ns_close_ << endl << + endl; + f_header_ << + ns_close_ << endl << + endl; + f_header_ << + "#endif" << endl; + + // Close the files + f_service_.close(); + f_header_.close(); +} + +/** + * Generates helper functions for a service. Basically, this generates types + * for all the arguments and results to functions. + * + * @param tservice The service to generate a header definition for + */ +void t_cpp_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + string name_orig = ts->get_name(); + + ts->set_name(tservice->get_name() + "_" + (*f_iter)->get_name() + "_args"); + generate_struct_definition(f_header_, ts, false); + generate_struct_reader(f_service_, ts); + generate_struct_writer(f_service_, ts); + ts->set_name(tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs"); + generate_struct_definition(f_header_, ts, false, true, false, true); + generate_struct_writer(f_service_, ts, true); + ts->set_name(name_orig); + + generate_function_helpers(tservice, *f_iter); + } +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_cpp_generator::generate_service_interface(t_service* tservice) { + string extends = ""; + if (tservice->get_extends() != NULL) { + extends = " : virtual public " + type_name(tservice->get_extends()) + "If"; + } + f_header_ << + "class " << service_name_ << "If" << extends << " {" << endl << + " public:" << endl; + indent_up(); + f_header_ << + indent() << "virtual ~" << service_name_ << "If() {}" << endl; + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_header_ << + indent() << "virtual " << function_signature(*f_iter) << " = 0;" << endl; + } + indent_down(); + f_header_ << + "};" << endl << endl; +} + +/** + * Generates a null implementation of the service. + * + * @param tservice The service to generate a header definition for + */ +void t_cpp_generator::generate_service_null(t_service* tservice) { + string extends = ""; + if (tservice->get_extends() != NULL) { + extends = " , virtual public " + type_name(tservice->get_extends()) + "Null"; + } + f_header_ << + "class " << service_name_ << "Null : virtual public " << service_name_ << "If" << extends << " {" << endl << + " public:" << endl; + indent_up(); + f_header_ << + indent() << "virtual ~" << service_name_ << "Null() {}" << endl; + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_header_ << + indent() << function_signature(*f_iter, "", false) << " {" << endl; + indent_up(); + t_type* returntype = (*f_iter)->get_returntype(); + if (returntype->is_void()) { + f_header_ << + indent() << "return;" << endl; + } else if (is_complex_type(returntype)) { + f_header_ << + indent() << "return;" << endl; + } else { + t_field returnfield(returntype, "_return"); + f_header_ << + indent() << declare_field(&returnfield, true) << endl << + indent() << "return _return;" << endl; + } + indent_down(); + f_header_ << + indent() << "}" << endl; + } + indent_down(); + f_header_ << + "};" << endl << endl; +} + + +/** + * Generates a multiface, which is a single server that just takes a set + * of objects implementing the interface and calls them all, returning the + * value of the last one to be called. + * + * @param tservice The service to generate a multiserver for. + */ +void t_cpp_generator::generate_service_multiface(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_multiface = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_multiface = ", public " + extends + "Multiface"; + } + + string list_type = string("std::vector >"; + + // Generate the header portion + f_header_ << + "class " << service_name_ << "Multiface : " << + "virtual public " << service_name_ << "If" << + extends_multiface << " {" << endl << + " public:" << endl; + indent_up(); + f_header_ << + indent() << service_name_ << "Multiface(" << list_type << "& ifaces) : ifaces_(ifaces) {" << endl; + if (!extends.empty()) { + f_header_ << + indent() << " std::vector >::iterator iter;" << endl << + indent() << " for (iter = ifaces.begin(); iter != ifaces.end(); ++iter) {" << endl << + indent() << " " << extends << "Multiface::add(*iter);" << endl << + indent() << " }" << endl; + } + f_header_ << + indent() << "}" << endl << + indent() << "virtual ~" << service_name_ << "Multiface() {}" << endl; + indent_down(); + + // Protected data members + f_header_ << + " protected:" << endl; + indent_up(); + f_header_ << + indent() << list_type << " ifaces_;" << endl << + indent() << service_name_ << "Multiface() {}" << endl << + indent() << "void add(boost::shared_ptr<" << service_name_ << "If> iface) {" << endl; + if (!extends.empty()) { + f_header_ << + indent() << " " << extends << "Multiface::add(iface);" << endl; + } + f_header_ << + indent() << " ifaces_.push_back(iface);" << endl << + indent() << "}" << endl; + indent_down(); + + f_header_ << + indent() << " public:" << endl; + indent_up(); + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arglist = (*f_iter)->get_arglist(); + const vector& args = arglist->get_members(); + vector::const_iterator a_iter; + + string call = string("ifaces_[i]->") + (*f_iter)->get_name() + "("; + bool first = true; + if (is_complex_type((*f_iter)->get_returntype())) { + call += "_return"; + first = false; + } + for (a_iter = args.begin(); a_iter != args.end(); ++a_iter) { + if (first) { + first = false; + } else { + call += ", "; + } + call += (*a_iter)->get_name(); + } + call += ")"; + + f_header_ << + indent() << function_signature(*f_iter) << " {" << endl; + indent_up(); + f_header_ << + indent() << "uint32_t sz = ifaces_.size();" << endl << + indent() << "for (uint32_t i = 0; i < sz; ++i) {" << endl; + if (!(*f_iter)->get_returntype()->is_void()) { + f_header_ << + indent() << " if (i == sz - 1) {" << endl; + if (is_complex_type((*f_iter)->get_returntype())) { + f_header_ << + indent() << " " << call << ";" << endl << + indent() << " return;" << endl; + } else { + f_header_ << + indent() << " return " << call << ";" << endl; + } + f_header_ << + indent() << " } else {" << endl << + indent() << " " << call << ";" << endl << + indent() << " }" << endl; + } else { + f_header_ << + indent() << " " << call << ";" << endl; + } + + f_header_ << + indent() << "}" << endl; + + indent_down(); + f_header_ << + indent() << "}" << endl << + endl; + } + + indent_down(); + f_header_ << + indent() << "};" << endl << + endl; +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_cpp_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_client = ", public " + extends + "Client"; + } + + // Generate the header portion + f_header_ << + "class " << service_name_ << "Client : " << + "virtual public " << service_name_ << "If" << + extends_client << " {" << endl << + " public:" << endl; + + indent_up(); + f_header_ << + indent() << service_name_ << "Client(boost::shared_ptr prot) :" << endl; + if (extends.empty()) { + f_header_ << + indent() << " piprot_(prot)," << endl << + indent() << " poprot_(prot) {" << endl << + indent() << " iprot_ = prot.get();" << endl << + indent() << " oprot_ = prot.get();" << endl << + indent() << "}" << endl; + } else { + f_header_ << + indent() << " " << extends << "Client(prot, prot) {}" << endl; + } + + f_header_ << + indent() << service_name_ << "Client(boost::shared_ptr iprot, boost::shared_ptr oprot) :" << endl; + if (extends.empty()) { + f_header_ << + indent() << " piprot_(iprot)," << endl << + indent() << " poprot_(oprot) {" << endl << + indent() << " iprot_ = iprot.get();" << endl << + indent() << " oprot_ = oprot.get();" << endl << + indent() << "}" << endl; + } else { + f_header_ << + indent() << " " << extends << "Client(iprot, oprot) {}" << endl; + } + + // Generate getters for the protocols. + f_header_ << + indent() << "boost::shared_ptr getInputProtocol() {" << endl << + indent() << " return piprot_;" << endl << + indent() << "}" << endl; + + f_header_ << + indent() << "boost::shared_ptr getOutputProtocol() {" << endl << + indent() << " return poprot_;" << endl << + indent() << "}" << endl; + + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_function send_function(g_type_void, + string("send_") + (*f_iter)->get_name(), + (*f_iter)->get_arglist()); + indent(f_header_) << function_signature(*f_iter) << ";" << endl; + indent(f_header_) << function_signature(&send_function) << ";" << endl; + if (!(*f_iter)->is_oneway()) { + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + indent(f_header_) << function_signature(&recv_function) << ";" << endl; + } + } + indent_down(); + + if (extends.empty()) { + f_header_ << + " protected:" << endl; + indent_up(); + f_header_ << + indent() << "boost::shared_ptr piprot_;" << endl << + indent() << "boost::shared_ptr poprot_;" << endl << + indent() << "apache::thrift::protocol::TProtocol* iprot_;" << endl << + indent() << "apache::thrift::protocol::TProtocol* oprot_;" << endl; + indent_down(); + } + + f_header_ << + "};" << endl << + endl; + + string scope = service_name_ + "Client::"; + + // Generate client method implementations + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + + // Open function + indent(f_service_) << + function_signature(*f_iter, scope) << endl; + scope_up(f_service_); + indent(f_service_) << + "send_" << funname << "("; + + // Get the struct of function call params + t_struct* arg_struct = (*f_iter)->get_arglist(); + + // Declare the function arguments + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << (*fld_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + if (is_complex_type((*f_iter)->get_returntype())) { + f_service_ << "recv_" << funname << "(_return);" << endl; + } else { + f_service_ << "return recv_" << funname << "();" << endl; + } + } else { + f_service_ << + "recv_" << funname << "();" << endl; + } + } + scope_down(f_service_); + f_service_ << endl; + + // Function for sending + t_function send_function(g_type_void, + string("send_") + (*f_iter)->get_name(), + (*f_iter)->get_arglist()); + + // Open the send function + indent(f_service_) << + function_signature(&send_function, scope) << endl; + scope_up(f_service_); + + // Function arguments and results + string argsname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_pargs"; + string resultname = tservice->get_name() + "_" + (*f_iter)->get_name() + "_presult"; + + // Serialize the request + f_service_ << + indent() << "int32_t cseqid = 0;" << endl << + indent() << "oprot_->writeMessageBegin(\"" << (*f_iter)->get_name() << "\", apache::thrift::protocol::T_CALL, cseqid);" << endl << + endl << + indent() << argsname << " args;" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "args." << (*fld_iter)->get_name() << " = &" << (*fld_iter)->get_name() << ";" << endl; + } + + f_service_ << + indent() << "args.write(oprot_);" << endl << + endl << + indent() << "oprot_->writeMessageEnd();" << endl << + indent() << "oprot_->getTransport()->flush();" << endl << + indent() << "oprot_->getTransport()->writeEnd();" << endl; + + scope_down(f_service_); + f_service_ << endl; + + // Generate recv function only if not an oneway function + if (!(*f_iter)->is_oneway()) { + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + // Open function + indent(f_service_) << + function_signature(&recv_function, scope) << endl; + scope_up(f_service_); + + f_service_ << + endl << + indent() << "int32_t rseqid = 0;" << endl << + indent() << "std::string fname;" << endl << + indent() << "apache::thrift::protocol::TMessageType mtype;" << endl << + endl << + indent() << "iprot_->readMessageBegin(fname, mtype, rseqid);" << endl << + indent() << "if (mtype == apache::thrift::protocol::T_EXCEPTION) {" << endl << + indent() << " apache::thrift::TApplicationException x;" << endl << + indent() << " x.read(iprot_);" << endl << + indent() << " iprot_->readMessageEnd();" << endl << + indent() << " iprot_->getTransport()->readEnd();" << endl << + indent() << " throw x;" << endl << + indent() << "}" << endl << + indent() << "if (mtype != apache::thrift::protocol::T_REPLY) {" << endl << + indent() << " iprot_->skip(apache::thrift::protocol::T_STRUCT);" << endl << + indent() << " iprot_->readMessageEnd();" << endl << + indent() << " iprot_->getTransport()->readEnd();" << endl << + indent() << " throw apache::thrift::TApplicationException(apache::thrift::TApplicationException::INVALID_MESSAGE_TYPE);" << endl << + indent() << "}" << endl << + indent() << "if (fname.compare(\"" << (*f_iter)->get_name() << "\") != 0) {" << endl << + indent() << " iprot_->skip(apache::thrift::protocol::T_STRUCT);" << endl << + indent() << " iprot_->readMessageEnd();" << endl << + indent() << " iprot_->getTransport()->readEnd();" << endl << + indent() << " throw apache::thrift::TApplicationException(apache::thrift::TApplicationException::WRONG_METHOD_NAME);" << endl << + indent() << "}" << endl; + + if (!(*f_iter)->get_returntype()->is_void() && + !is_complex_type((*f_iter)->get_returntype())) { + t_field returnfield((*f_iter)->get_returntype(), "_return"); + f_service_ << + indent() << declare_field(&returnfield) << endl; + } + + f_service_ << + indent() << resultname << " result;" << endl; + + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "result.success = &_return;" << endl; + } + + f_service_ << + indent() << "result.read(iprot_);" << endl << + indent() << "iprot_->readMessageEnd();" << endl << + indent() << "iprot_->getTransport()->readEnd();" << endl << + endl; + + // Careful, only look for _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + if (is_complex_type((*f_iter)->get_returntype())) { + f_service_ << + indent() << "if (result.__isset.success) {" << endl << + indent() << " // _return pointer has now been filled" << endl << + indent() << " return;" << endl << + indent() << "}" << endl; + } else { + f_service_ << + indent() << "if (result.__isset.success) {" << endl << + indent() << " return _return;" << endl << + indent() << "}" << endl; + } + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "if (result.__isset." << (*x_iter)->get_name() << ") {" << endl << + indent() << " throw result." << (*x_iter)->get_name() << ";" << endl << + indent() << "}" << endl; + } + + // We only get here if we are a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "return;" << endl; + } else { + f_service_ << + indent() << "throw apache::thrift::TApplicationException(apache::thrift::TApplicationException::MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; + } + + // Close function + scope_down(f_service_); + f_service_ << endl; + } + } +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_cpp_generator::generate_service_processor(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_processor = ", public " + extends + "Processor"; + } + + // Generate the header portion + f_header_ << + "class " << service_name_ << "Processor : " << + "virtual public apache::thrift::TProcessor" << + extends_processor << " {" << endl; + + // Protected data members + f_header_ << + " protected:" << endl; + indent_up(); + f_header_ << + indent() << "boost::shared_ptr<" << service_name_ << "If> iface_;" << endl; + f_header_ << + indent() << "virtual bool process_fn(apache::thrift::protocol::TProtocol* iprot, apache::thrift::protocol::TProtocol* oprot, std::string& fname, int32_t seqid);" << endl; + indent_down(); + + // Process function declarations + f_header_ << + " private:" << endl; + indent_up(); + f_header_ << + indent() << "std::map processMap_;" << endl; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + indent(f_header_) << + "void process_" << (*f_iter)->get_name() << "(int32_t seqid, apache::thrift::protocol::TProtocol* iprot, apache::thrift::protocol::TProtocol* oprot);" << endl; + } + indent_down(); + + indent_up(); + string declare_map = ""; + indent_up(); + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + declare_map += indent(); + declare_map += "processMap_[\""; + declare_map += (*f_iter)->get_name(); + declare_map += "\"] = &"; + declare_map += service_name_; + declare_map += "Processor::process_"; + declare_map += (*f_iter)->get_name(); + declare_map += ";\n"; + } + indent_down(); + + f_header_ << + " public:" << endl << + indent() << service_name_ << "Processor(boost::shared_ptr<" << service_name_ << "If> iface) :" << endl; + if (extends.empty()) { + f_header_ << + indent() << " iface_(iface) {" << endl; + } else { + f_header_ << + indent() << " " << extends << "Processor(iface)," << endl << + indent() << " iface_(iface) {" << endl; + } + f_header_ << + declare_map << + indent() << "}" << endl << + endl << + indent() << "virtual bool process(boost::shared_ptr piprot, boost::shared_ptr poprot);" << endl << + indent() << "virtual ~" << service_name_ << "Processor() {}" << endl; + indent_down(); + f_header_ << + "};" << endl << endl; + + // Generate the server implementation + f_service_ << + "bool " << service_name_ << "Processor::process(boost::shared_ptr piprot, boost::shared_ptr poprot) {" << endl; + indent_up(); + + f_service_ << + endl << + indent() << "apache::thrift::protocol::TProtocol* iprot = piprot.get();" << endl << + indent() << "apache::thrift::protocol::TProtocol* oprot = poprot.get();" << endl << + indent() << "std::string fname;" << endl << + indent() << "apache::thrift::protocol::TMessageType mtype;" << endl << + indent() << "int32_t seqid;" << endl << + endl << + indent() << "iprot->readMessageBegin(fname, mtype, seqid);" << endl << + endl << + indent() << "if (mtype != apache::thrift::protocol::T_CALL && mtype != apache::thrift::protocol::T_ONEWAY) {" << endl << + indent() << " iprot->skip(apache::thrift::protocol::T_STRUCT);" << endl << + indent() << " iprot->readMessageEnd();" << endl << + indent() << " iprot->getTransport()->readEnd();" << endl << + indent() << " apache::thrift::TApplicationException x(apache::thrift::TApplicationException::INVALID_MESSAGE_TYPE);" << endl << + indent() << " oprot->writeMessageBegin(fname, apache::thrift::protocol::T_EXCEPTION, seqid);" << endl << + indent() << " x.write(oprot);" << endl << + indent() << " oprot->writeMessageEnd();" << endl << + indent() << " oprot->getTransport()->flush();" << endl << + indent() << " oprot->getTransport()->writeEnd();" << endl << + indent() << " return true;" << endl << + indent() << "}" << endl << + endl << + indent() << "return process_fn(iprot, oprot, fname, seqid);" << + endl; + + indent_down(); + f_service_ << + indent() << "}" << endl << + endl; + + f_service_ << + "bool " << service_name_ << "Processor::process_fn(apache::thrift::protocol::TProtocol* iprot, apache::thrift::protocol::TProtocol* oprot, std::string& fname, int32_t seqid) {" << endl; + indent_up(); + + // HOT: member function pointer map + f_service_ << + indent() << "std::map::iterator pfn;" << endl << + indent() << "pfn = processMap_.find(fname);" << endl << + indent() << "if (pfn == processMap_.end()) {" << endl; + if (extends.empty()) { + f_service_ << + indent() << " iprot->skip(apache::thrift::protocol::T_STRUCT);" << endl << + indent() << " iprot->readMessageEnd();" << endl << + indent() << " iprot->getTransport()->readEnd();" << endl << + indent() << " apache::thrift::TApplicationException x(apache::thrift::TApplicationException::UNKNOWN_METHOD, \"Invalid method name: '\"+fname+\"'\");" << endl << + indent() << " oprot->writeMessageBegin(fname, apache::thrift::protocol::T_EXCEPTION, seqid);" << endl << + indent() << " x.write(oprot);" << endl << + indent() << " oprot->writeMessageEnd();" << endl << + indent() << " oprot->getTransport()->flush();" << endl << + indent() << " oprot->getTransport()->writeEnd();" << endl << + indent() << " return true;" << endl; + } else { + f_service_ << + indent() << " return " << extends << "Processor::process_fn(iprot, oprot, fname, seqid);" << endl; + } + f_service_ << + indent() << "}" << endl << + indent() << "(this->*(pfn->second))(seqid, iprot, oprot);" << endl << + indent() << "return true;" << endl; + + indent_down(); + f_service_ << + "}" << endl << + endl; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_cpp_generator::generate_function_helpers(t_service* tservice, + t_function* tfunction) { + if (tfunction->is_oneway()) { + return; + } + + t_struct result(program_, tservice->get_name() + "_" + tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + generate_struct_definition(f_header_, &result, false); + generate_struct_reader(f_service_, &result); + generate_struct_result_writer(f_service_, &result); + + result.set_name(tservice->get_name() + "_" + tfunction->get_name() + "_presult"); + generate_struct_definition(f_header_, &result, false, true, true, false); + generate_struct_reader(f_service_, &result, true); + +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_cpp_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + f_service_ << + "void " << tservice->get_name() << "Processor::" << + "process_" << tfunction->get_name() << + "(int32_t seqid, apache::thrift::protocol::TProtocol* iprot, apache::thrift::protocol::TProtocol* oprot)" << endl; + scope_up(f_service_); + + string argsname = tservice->get_name() + "_" + tfunction->get_name() + "_args"; + string resultname = tservice->get_name() + "_" + tfunction->get_name() + "_result"; + + f_service_ << + indent() << argsname << " args;" << endl << + indent() << "args.read(iprot);" << endl << + indent() << "iprot->readMessageEnd();" << endl << + indent() << "iprot->getTransport()->readEnd();" << endl << + endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result + if (!tfunction->is_oneway()) { + f_service_ << + indent() << resultname << " result;" << endl; + } + + // Try block for functions with exceptions + f_service_ << + indent() << "try {" << endl; + indent_up(); + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + bool first = true; + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + if (is_complex_type(tfunction->get_returntype())) { + first = false; + f_service_ << "iface_->" << tfunction->get_name() << "(result.success"; + } else { + f_service_ << "result.success = iface_->" << tfunction->get_name() << "("; + } + } else { + f_service_ << + "iface_->" << tfunction->get_name() << "("; + } + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << (*f_iter)->get_name(); + } + f_service_ << ");" << endl; + + // Set isset on success field + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << + indent() << "result.__isset.success = true;" << endl; + } + + indent_down(); + f_service_ << indent() << "}"; + + if (!tfunction->is_oneway()) { + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << " catch (" << type_name((*x_iter)->get_type()) << " &" << (*x_iter)->get_name() << ") {" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << ";" << endl << + indent() << "result.__isset." << (*x_iter)->get_name() << " = true;" << endl; + indent_down(); + f_service_ << indent() << "}"; + } else { + f_service_ << "}"; + } + } + } + + f_service_ << " catch (const std::exception& e) {" << endl; + + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "apache::thrift::TApplicationException x(e.what());" << endl << + indent() << "oprot->writeMessageBegin(\"" << tfunction->get_name() << "\", apache::thrift::protocol::T_EXCEPTION, seqid);" << endl << + indent() << "x.write(oprot);" << endl << + indent() << "oprot->writeMessageEnd();" << endl << + indent() << "oprot->getTransport()->flush();" << endl << + indent() << "oprot->getTransport()->writeEnd();" << endl << + indent() << "return;" << endl; + indent_down(); + } + f_service_ << indent() << "}" << endl; + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return;" << endl; + indent_down(); + f_service_ << "}" << endl << + endl; + return; + } + + // Serialize the result into a struct + f_service_ << + endl << + indent() << "oprot->writeMessageBegin(\"" << tfunction->get_name() << "\", apache::thrift::protocol::T_REPLY, seqid);" << endl << + indent() << "result.write(oprot);" << endl << + indent() << "oprot->writeMessageEnd();" << endl << + indent() << "oprot->getTransport()->flush();" << endl << + indent() << "oprot->getTransport()->writeEnd();" << endl; + + // Close function + scope_down(f_service_); + f_service_ << endl; +} + +/** + * Generates a skeleton file of a server + * + * @param tservice The service to generate a server for. + */ +void t_cpp_generator::generate_service_skeleton(t_service* tservice) { + string svcname = tservice->get_name(); + + // Service implementation file includes + string f_skeleton_name = get_out_dir()+svcname+"_server.skeleton.cpp"; + + string ns = namespace_prefix(tservice->get_program()->get_namespace("cpp")); + + ofstream f_skeleton; + f_skeleton.open(f_skeleton_name.c_str()); + f_skeleton << + "// This autogenerated skeleton file illustrates how to build a server." << endl << + "// You should copy it to another filename to avoid overwriting it." << endl << + endl << + "#include \"" << get_include_prefix(*get_program()) << svcname << ".h\"" << endl << + "#include " << endl << + "#include " << endl << + "#include " << endl << + "#include " << endl << + endl << + "using namespace apache::thrift;" << endl << + "using namespace apache::thrift::protocol;" << endl << + "using namespace apache::thrift::transport;" << endl << + "using namespace apache::thrift::server;" << endl << + endl << + "using boost::shared_ptr;" << endl << + endl; + + if (!ns.empty()) { + f_skeleton << + "using namespace " << string(ns, 0, ns.size()-2) << ";" << endl << + endl; + } + + f_skeleton << + "class " << svcname << "Handler : virtual public " << svcname << "If {" << endl << + " public:" << endl; + indent_up(); + f_skeleton << + indent() << svcname << "Handler() {" << endl << + indent() << " // Your initialization goes here" << endl << + indent() << "}" << endl << + endl; + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_skeleton << + indent() << function_signature(*f_iter) << " {" << endl << + indent() << " // Your implementation goes here" << endl << + indent() << " printf(\"" << (*f_iter)->get_name() << "\\n\");" << endl << + indent() << "}" << endl << + endl; + } + + indent_down(); + f_skeleton << + "};" << endl << + endl; + + f_skeleton << + indent() << "int main(int argc, char **argv) {" << endl; + indent_up(); + f_skeleton << + indent() << "int port = 9090;" << endl << + indent() << "shared_ptr<" << svcname << "Handler> handler(new " << svcname << "Handler());" << endl << + indent() << "shared_ptr processor(new " << svcname << "Processor(handler));" << endl << + indent() << "shared_ptr serverTransport(new TServerSocket(port));" << endl << + indent() << "shared_ptr transportFactory(new TBufferedTransportFactory());" << endl << + indent() << "shared_ptr protocolFactory(new TBinaryProtocolFactory());" << endl << + endl << + indent() << "TSimpleServer server(processor, serverTransport, transportFactory, protocolFactory);" << endl << + indent() << "server.serve();" << endl << + indent() << "return 0;" << endl; + indent_down(); + f_skeleton << + "}" << endl << + endl; + + // Close the files + f_skeleton.close(); +} + +/** + * Deserializes a field of any type. + */ +void t_cpp_generator::generate_deserialize_field(ofstream& out, + t_field* tfield, + string prefix, + string suffix) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name() + suffix; + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, (t_struct*)type, name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type()) { + indent(out) << + "xfer += iprot->"; + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "readBinary(" << name << ");"; + } + else { + out << "readString(" << name << ");"; + } + break; + case t_base_type::TYPE_BOOL: + out << "readBool(" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte(" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "readI16(" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "readI32(" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "readI64(" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble(" << name << ");"; + break; + default: + throw "compiler error: no C++ reader for base type " + t_base_type::t_base_name(tbase) + name; + } + out << + endl; + } else if (type->is_enum()) { + string t = tmp("ecast"); + out << + indent() << "int32_t " << t << ";" << endl << + indent() << "xfer += iprot->readI32(" << t << ");" << endl << + indent() << name << " = (" << type_name(type) << ")" << t << ";" << endl; + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type_name(type).c_str()); + } +} + +/** + * Generates an unserializer for a variable. This makes two key assumptions, + * first that there is a const char* variable named data that points to the + * buffer for deserialization, and that there is a variable protocol which + * is a reference to a TProtocol serialization object. + */ +void t_cpp_generator::generate_deserialize_struct(ofstream& out, + t_struct* tstruct, + string prefix) { + indent(out) << + "xfer += " << prefix << ".read(iprot);" << endl; +} + +void t_cpp_generator::generate_deserialize_container(ofstream& out, + t_type* ttype, + string prefix) { + scope_up(out); + + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + + t_container* tcontainer = (t_container*)ttype; + bool use_push = tcontainer->has_cpp_name(); + + indent(out) << + prefix << ".clear();" << endl << + indent() << "uint32_t " << size << ";" << endl; + + // Declare variables, read header + if (ttype->is_map()) { + out << + indent() << "apache::thrift::protocol::TType " << ktype << ";" << endl << + indent() << "apache::thrift::protocol::TType " << vtype << ";" << endl << + indent() << "iprot->readMapBegin(" << + ktype << ", " << vtype << ", " << size << ");" << endl; + } else if (ttype->is_set()) { + out << + indent() << "apache::thrift::protocol::TType " << etype << ";" << endl << + indent() << "iprot->readSetBegin(" << + etype << ", " << size << ");" << endl; + } else if (ttype->is_list()) { + out << + indent() << "apache::thrift::protocol::TType " << etype << ";" << endl << + indent() << "iprot->readListBegin(" << + etype << ", " << size << ");" << endl; + if (!use_push) { + indent(out) << prefix << ".resize(" << size << ");" << endl; + } + } + + + // For loop iterates over elements + string i = tmp("_i"); + out << + indent() << "uint32_t " << i << ";" << endl << + indent() << "for (" << i << " = 0; " << i << " < " << size << "; ++" << i << ")" << endl; + + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix, use_push, i); + } + + scope_down(out); + + // Read container end + if (ttype->is_map()) { + indent(out) << "iprot->readMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "iprot->readSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "iprot->readListEnd();" << endl; + } + + scope_down(out); +} + + +/** + * Generates code to deserialize a map + */ +void t_cpp_generator::generate_deserialize_map_element(ofstream& out, + t_map* tmap, + string prefix) { + string key = tmp("_key"); + string val = tmp("_val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + out << + indent() << declare_field(&fkey) << endl; + + generate_deserialize_field(out, &fkey); + indent(out) << + declare_field(&fval, false, false, false, true) << " = " << + prefix << "[" << key << "];" << endl; + + generate_deserialize_field(out, &fval); +} + +void t_cpp_generator::generate_deserialize_set_element(ofstream& out, + t_set* tset, + string prefix) { + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + indent(out) << + declare_field(&felem) << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".insert(" << elem << ");" << endl; +} + +void t_cpp_generator::generate_deserialize_list_element(ofstream& out, + t_list* tlist, + string prefix, + bool use_push, + string index) { + if (use_push) { + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + indent(out) << declare_field(&felem) << endl; + generate_deserialize_field(out, &felem); + indent(out) << prefix << ".push_back(" << elem << ");" << endl; + } else { + t_field felem(tlist->get_elem_type(), prefix + "[" + index + "]"); + generate_deserialize_field(out, &felem); + } +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param prefix Name to prepend to field name + */ +void t_cpp_generator::generate_serialize_field(ofstream& out, + t_field* tfield, + string prefix, + string suffix) { + t_type* type = get_true_type(tfield->get_type()); + + string name = prefix + tfield->get_name() + suffix; + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + name; + } + + + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_serialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + + indent(out) << + "xfer += oprot->"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "writeBinary(" << name << ");"; + } + else { + out << "writeString(" << name << ");"; + } + break; + case t_base_type::TYPE_BOOL: + out << "writeBool(" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte(" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "writeI16(" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "writeI32(" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "writeI64(" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble(" << name << ");"; + break; + default: + throw "compiler error: no C++ writer for base type " + t_base_type::t_base_name(tbase) + name; + } + } else if (type->is_enum()) { + out << "writeI32((int32_t)" << name << ");"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s' TYPE '%s'\n", + name.c_str(), + type_name(type).c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_cpp_generator::generate_serialize_struct(ofstream& out, + t_struct* tstruct, + string prefix) { + indent(out) << + "xfer += " << prefix << ".write(oprot);" << endl; +} + +void t_cpp_generator::generate_serialize_container(ofstream& out, + t_type* ttype, + string prefix) { + scope_up(out); + + if (ttype->is_map()) { + indent(out) << + "xfer += oprot->writeMapBegin(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + prefix << ".size());" << endl; + } else if (ttype->is_set()) { + indent(out) << + "xfer += oprot->writeSetBegin(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + prefix << ".size());" << endl; + } else if (ttype->is_list()) { + indent(out) << + "xfer += oprot->writeListBegin(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + prefix << ".size());" << endl; + } + + string iter = tmp("_iter"); + out << + indent() << type_name(ttype) << "::const_iterator " << iter << ";" << endl << + indent() << "for (" << iter << " = " << prefix << ".begin(); " << iter << " != " << prefix << ".end(); ++" << iter << ")" << endl; + scope_up(out); + if (ttype->is_map()) { + generate_serialize_map_element(out, (t_map*)ttype, iter); + } else if (ttype->is_set()) { + generate_serialize_set_element(out, (t_set*)ttype, iter); + } else if (ttype->is_list()) { + generate_serialize_list_element(out, (t_list*)ttype, iter); + } + scope_down(out); + + if (ttype->is_map()) { + indent(out) << + "xfer += oprot->writeMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << + "xfer += oprot->writeSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << + "xfer += oprot->writeListEnd();" << endl; + } + + scope_down(out); +} + +/** + * Serializes the members of a map. + * + */ +void t_cpp_generator::generate_serialize_map_element(ofstream& out, + t_map* tmap, + string iter) { + t_field kfield(tmap->get_key_type(), iter + "->first"); + generate_serialize_field(out, &kfield, ""); + + t_field vfield(tmap->get_val_type(), iter + "->second"); + generate_serialize_field(out, &vfield, ""); +} + +/** + * Serializes the members of a set. + */ +void t_cpp_generator::generate_serialize_set_element(ofstream& out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), "(*" + iter + ")"); + generate_serialize_field(out, &efield, ""); +} + +/** + * Serializes the members of a list. + */ +void t_cpp_generator::generate_serialize_list_element(ofstream& out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), "(*" + iter + ")"); + generate_serialize_field(out, &efield, ""); +} + +/** + * Makes a :: prefix for a namespace + * + * @param ns The namepsace, w/ periods in it + * @return Namespaces + */ +string t_cpp_generator::namespace_prefix(string ns) { + if (ns.size() == 0) { + return ""; + } + string result = ""; + string::size_type loc; + while ((loc = ns.find(".")) != string::npos) { + result += ns.substr(0, loc); + result += "::"; + ns = ns.substr(loc+1); + } + if (ns.size() > 0) { + result += ns + "::"; + } + return result; +} + +/** + * Opens namespace. + * + * @param ns The namepsace, w/ periods in it + * @return Namespaces + */ +string t_cpp_generator::namespace_open(string ns) { + if (ns.size() == 0) { + return ""; + } + string result = ""; + string separator = ""; + string::size_type loc; + while ((loc = ns.find(".")) != string::npos) { + result += separator; + result += "namespace "; + result += ns.substr(0, loc); + result += " {"; + separator = " "; + ns = ns.substr(loc+1); + } + if (ns.size() > 0) { + result += separator + "namespace " + ns + " {"; + } + return result; +} + +/** + * Closes namespace. + * + * @param ns The namepsace, w/ periods in it + * @return Namespaces + */ +string t_cpp_generator::namespace_close(string ns) { + if (ns.size() == 0) { + return ""; + } + string result = "}"; + string::size_type loc; + while ((loc = ns.find(".")) != string::npos) { + result += "}"; + ns = ns.substr(loc+1); + } + result += " // namespace"; + return result; +} + +/** + * Returns a C++ type name + * + * @param ttype The type + * @return String of the type name, i.e. std::set + */ +string t_cpp_generator::type_name(t_type* ttype, bool in_typedef, bool arg) { + if (ttype->is_base_type()) { + string bname = base_type_name(((t_base_type*)ttype)->get_base()); + if (!arg) { + return bname; + } + + if (((t_base_type*)ttype)->get_base() == t_base_type::TYPE_STRING) { + return "const " + bname + "&"; + } else { + return "const " + bname; + } + } + + // Check for a custom overloaded C++ name + if (ttype->is_container()) { + string cname; + + t_container* tcontainer = (t_container*) ttype; + if (tcontainer->has_cpp_name()) { + cname = tcontainer->get_cpp_name(); + } else if (ttype->is_map()) { + t_map* tmap = (t_map*) ttype; + cname = "std::map<" + + type_name(tmap->get_key_type(), in_typedef) + ", " + + type_name(tmap->get_val_type(), in_typedef) + "> "; + } else if (ttype->is_set()) { + t_set* tset = (t_set*) ttype; + cname = "std::set<" + type_name(tset->get_elem_type(), in_typedef) + "> "; + } else if (ttype->is_list()) { + t_list* tlist = (t_list*) ttype; + cname = "std::vector<" + type_name(tlist->get_elem_type(), in_typedef) + "> "; + } + + if (arg) { + return "const " + cname + "&"; + } else { + return cname; + } + } + + string class_prefix; + if (in_typedef && (ttype->is_struct() || ttype->is_xception())) { + class_prefix = "class "; + } + + // Check if it needs to be namespaced + string pname; + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + pname = + class_prefix + + namespace_prefix(program->get_namespace("cpp")) + + ttype->get_name(); + } else { + pname = class_prefix + ttype->get_name(); + } + + if (arg) { + if (is_complex_type(ttype)) { + return "const " + pname + "&"; + } else { + return "const " + pname; + } + } else { + return pname; + } +} + +/** + * Returns the C++ type that corresponds to the thrift type. + * + * @param tbase The base type + * @return Explicit C++ type, i.e. "int32_t" + */ +string t_cpp_generator::base_type_name(t_base_type::t_base tbase) { + switch (tbase) { + case t_base_type::TYPE_VOID: + return "void"; + case t_base_type::TYPE_STRING: + return "std::string"; + case t_base_type::TYPE_BOOL: + return "bool"; + case t_base_type::TYPE_BYTE: + return "int8_t"; + case t_base_type::TYPE_I16: + return "int16_t"; + case t_base_type::TYPE_I32: + return "int32_t"; + case t_base_type::TYPE_I64: + return "int64_t"; + case t_base_type::TYPE_DOUBLE: + return "double"; + default: + throw "compiler error: no C++ base type name for base type " + t_base_type::t_base_name(tbase); + } +} + +/** + * Declares a field, which may include initialization as necessary. + * + * @param ttype The type + * @return Field declaration, i.e. int x = 0; + */ +string t_cpp_generator::declare_field(t_field* tfield, bool init, bool pointer, bool constant, bool reference) { + // TODO(mcslee): do we ever need to initialize the field? + string result = ""; + if (constant) { + result += "const "; + } + result += type_name(tfield->get_type()); + if (pointer) { + result += "*"; + } + if (reference) { + result += "&"; + } + result += " " + tfield->get_name(); + if (init) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + break; + case t_base_type::TYPE_STRING: + result += " = \"\""; + break; + case t_base_type::TYPE_BOOL: + result += " = false"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + result += " = 0"; + break; + case t_base_type::TYPE_DOUBLE: + result += " = (double)0"; + break; + default: + throw "compiler error: no C++ initializer for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + result += " = (" + type_name(type) + ")0"; + } + } + if (!reference) { + result += ";"; + } + return result; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_cpp_generator::function_signature(t_function* tfunction, + string prefix, + bool name_params) { + t_type* ttype = tfunction->get_returntype(); + t_struct* arglist = tfunction->get_arglist(); + + if (is_complex_type(ttype)) { + bool empty = arglist->get_members().size() == 0; + return + "void " + prefix + tfunction->get_name() + + "(" + type_name(ttype) + (name_params ? "& _return" : "& /* _return */") + + (empty ? "" : (", " + argument_list(arglist, name_params))) + ")"; + } else { + return + type_name(ttype) + " " + prefix + tfunction->get_name() + + "(" + argument_list(arglist, name_params) + ")"; + } +} + +/** + * Renders a field list + * + * @param tstruct The struct definition + * @return Comma sepearated list of all field names in that struct + */ +string t_cpp_generator::argument_list(t_struct* tstruct, bool name_params) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += type_name((*f_iter)->get_type(), false, true) + " " + + (name_params ? (*f_iter)->get_name() : "/* " + (*f_iter)->get_name() + " */"); + } + return result; +} + +/** + * Converts the parse type to a C++ enum string for the given type. + * + * @param type Thrift Type + * @return String of C++ code to definition of that type constant + */ +string t_cpp_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "apache::thrift::protocol::T_STRING"; + case t_base_type::TYPE_BOOL: + return "apache::thrift::protocol::T_BOOL"; + case t_base_type::TYPE_BYTE: + return "apache::thrift::protocol::T_BYTE"; + case t_base_type::TYPE_I16: + return "apache::thrift::protocol::T_I16"; + case t_base_type::TYPE_I32: + return "apache::thrift::protocol::T_I32"; + case t_base_type::TYPE_I64: + return "apache::thrift::protocol::T_I64"; + case t_base_type::TYPE_DOUBLE: + return "apache::thrift::protocol::T_DOUBLE"; + } + } else if (type->is_enum()) { + return "apache::thrift::protocol::T_I32"; + } else if (type->is_struct()) { + return "apache::thrift::protocol::T_STRUCT"; + } else if (type->is_xception()) { + return "apache::thrift::protocol::T_STRUCT"; + } else if (type->is_map()) { + return "apache::thrift::protocol::T_MAP"; + } else if (type->is_set()) { + return "apache::thrift::protocol::T_SET"; + } else if (type->is_list()) { + return "apache::thrift::protocol::T_LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +/** + * Returns the symbol name of the local reflection of a type. + */ +string t_cpp_generator::local_reflection_name(const char* prefix, t_type* ttype, bool external) { + ttype = get_true_type(ttype); + + // We have to use the program name as part of the identifier because + // if two thrift "programs" are compiled into one actual program + // you would get a symbol collison if they both defined list. + // trlo = Thrift Reflection LOcal. + string prog; + string name; + string nspace; + + // TODO(dreiss): Would it be better to pregenerate the base types + // and put them in Thrift.{h,cpp} ? + + if (ttype->is_base_type()) { + prog = program_->get_name(); + name = ttype->get_ascii_fingerprint(); + } else if (ttype->is_enum()) { + assert(ttype->get_program() != NULL); + prog = ttype->get_program()->get_name(); + name = ttype->get_ascii_fingerprint(); + } else if (ttype->is_container()) { + prog = program_->get_name(); + name = ttype->get_ascii_fingerprint(); + } else { + assert(ttype->is_struct() || ttype->is_xception()); + assert(ttype->get_program() != NULL); + prog = ttype->get_program()->get_name(); + name = ttype->get_ascii_fingerprint(); + } + + if (external && + ttype->get_program() != NULL && + ttype->get_program() != program_) { + nspace = namespace_prefix(ttype->get_program()->get_namespace("cpp")); + } + + return nspace + "trlo_" + prefix + "_" + prog + "_" + name; +} + +string t_cpp_generator::get_include_prefix(const t_program& program) const { + string include_prefix = program.get_include_prefix(); + if (!use_include_prefix_ || + (include_prefix.size() > 0 && include_prefix[0] == '/')) { + // if flag is turned off or this is absolute path, return empty prefix + return ""; + } + + string::size_type last_slash = string::npos; + if ((last_slash = include_prefix.rfind("/")) != string::npos) { + return include_prefix.substr(0, last_slash) + "/" + out_dir_base_ + "/"; + } + + return ""; +} + + +THRIFT_REGISTER_GENERATOR(cpp, "C++", +" dense: Generate type specifications for the dense protocol.\n" +" include_prefix: Use full include paths in generated files.\n" +); diff --git a/compiler/cpp/src/generate/t_csharp_generator.cc b/compiler/cpp/src/generate/t_csharp_generator.cc new file mode 100644 index 00000000..5a910eab --- /dev/null +++ b/compiler/cpp/src/generate/t_csharp_generator.cc @@ -0,0 +1,1700 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include + +#include "platform.h" +#include "t_oop_generator.h" +using namespace std; + + +class t_csharp_generator : public t_oop_generator +{ + public: + t_csharp_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-csharp"; + } + void init_generator(); + void close_generator(); + + void generate_consts(std::vector consts); + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + void generate_property(ofstream& out, t_field* tfield, bool isPublic); + bool print_const_value (std::ofstream& out, std::string name, t_type* type, t_const_value* value, bool in_static, bool defval=false, bool needtype=false); + std::string render_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + void print_const_constructor(std::ofstream& out, std::vector consts); + void print_const_def_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + + void generate_csharp_struct(t_struct* tstruct, bool is_exception); + void generate_csharp_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false, bool in_class=false, bool is_result=false); + void generate_csharp_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_csharp_struct_result_writer(std::ofstream& out, t_struct* tstruct); + void generate_csharp_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_csharp_struct_tostring(std::ofstream& out, t_struct* tstruct); + + void generate_function_helpers(t_function* tfunction); + void generate_service_interface (t_service* tservice); + void generate_service_helpers (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* function); + + void generate_deserialize_field (std::ofstream& out, t_field* tfield, std::string prefix=""); + void generate_deserialize_struct (std::ofstream& out, t_struct* tstruct, std::string prefix=""); + void generate_deserialize_container (std::ofstream& out, t_type* ttype, std::string prefix=""); + void generate_deserialize_set_element (std::ofstream& out, t_set* tset, std::string prefix=""); + void generate_deserialize_map_element (std::ofstream& out, t_map* tmap, std::string prefix=""); + void generate_deserialize_list_element (std::ofstream& out, t_list* list, std::string prefix=""); + void generate_serialize_field (std::ofstream& out, t_field* tfield, std::string prefix=""); + void generate_serialize_struct (std::ofstream& out, t_struct* tstruct, std::string prefix=""); + void generate_serialize_container (std::ofstream& out, t_type* ttype, std::string prefix=""); + void generate_serialize_map_element (std::ofstream& out, t_map* tmap, std::string iter, std::string map); + void generate_serialize_set_element (std::ofstream& out, t_set* tmap, std::string iter); + void generate_serialize_list_element (std::ofstream& out, t_list* tlist, std::string iter); + + void start_csharp_namespace (std::ofstream& out); + void end_csharp_namespace (std::ofstream& out); + + std::string csharp_type_usings(); + std::string csharp_thrift_usings(); + + std::string type_name(t_type* ttype, bool in_countainer=false, bool in_init=false); + std::string base_type_name(t_base_type* tbase, bool in_container=false); + std::string declare_field(t_field* tfield, bool init=false); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string prop_name(t_field* tfield); + + bool type_can_be_null(t_type* ttype) { + while (ttype->is_typedef()) { + ttype = ((t_typedef*)ttype)->get_type(); + } + + return ttype->is_container() || + ttype->is_struct() || + ttype->is_xception() || + ttype->is_string(); + } + + private: + std::string namespace_name_; + std::ofstream f_service_; + std::string namespace_dir_; +}; + + +void t_csharp_generator::init_generator() { + MKDIR(get_out_dir().c_str()); + namespace_name_ = program_->get_namespace("csharp"); + + string dir = namespace_name_; + string subdir = get_out_dir().c_str(); + string::size_type loc; + + while ((loc = dir.find(".")) != string::npos) { + subdir = subdir + "/" + dir.substr(0, loc); + MKDIR(subdir.c_str()); + dir = dir.substr(loc + 1); + } + if (dir.size() > 0) { + subdir = subdir + "/" + dir; + MKDIR(subdir.c_str()); + } + + namespace_dir_ = subdir; +} + +void t_csharp_generator::start_csharp_namespace(ofstream& out) { + if (!namespace_name_.empty()) { + out << + "namespace " << namespace_name_ << "\n"; + scope_up(out); + } +} + +void t_csharp_generator::end_csharp_namespace(ofstream& out) { + if (!namespace_name_.empty()) { + scope_down(out); + } +} + +string t_csharp_generator::csharp_type_usings() { + return string() + + "using System;\n" + + "using System.Collections;\n" + + "using System.Collections.Generic;\n" + + "using System.Text;\n" + + "using System.IO;\n" + + "using Thrift;\n" + + "using Thrift.Collections;\n"; +} + +string t_csharp_generator::csharp_thrift_usings() { + return string() + + "using Thrift.Protocol;\n" + + "using Thrift.Transport;\n"; +} + +void t_csharp_generator::close_generator() { } +void t_csharp_generator::generate_typedef(t_typedef* ttypedef) {} + +void t_csharp_generator::generate_enum(t_enum* tenum) { + string f_enum_name = namespace_dir_+"/" + (tenum->get_name())+".cs"; + ofstream f_enum; + f_enum.open(f_enum_name.c_str()); + + f_enum << + autogen_comment() << endl; + + start_csharp_namespace(f_enum); + + indent(f_enum) << + "public enum " << tenum->get_name() << "\n"; + scope_up(f_enum); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) + { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + indent(f_enum) << + (*c_iter)->get_name() << + " = " << value << "," << endl; + } + + scope_down(f_enum); + + end_csharp_namespace(f_enum); + + f_enum.close(); +} + +void t_csharp_generator::generate_consts(std::vector consts) { + if (consts.empty()){ + return; + } + string f_consts_name = namespace_dir_ + "/Constants.cs"; + ofstream f_consts; + f_consts.open(f_consts_name.c_str()); + + f_consts << + autogen_comment() << + csharp_type_usings() << endl; + + start_csharp_namespace(f_consts); + + indent(f_consts) << + "public class Constants" << endl; + scope_up(f_consts); + + vector::iterator c_iter; + bool need_static_constructor = false; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + if (print_const_value(f_consts, (*c_iter)->get_name(), (*c_iter)->get_type(), (*c_iter)->get_value(), false)) { + need_static_constructor = true; + } + } + + if (need_static_constructor) { + print_const_constructor(f_consts, consts); + } + + scope_down(f_consts); + end_csharp_namespace(f_consts); + f_consts.close(); +} + +void t_csharp_generator::print_const_def_value(std::ofstream& out, string name, t_type* type, t_const_value* value) +{ + if (type->is_struct() || type->is_xception()) { + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + string val = render_const_value(out, name, field_type, v_iter->second); + indent(out) << name << "." << v_iter->first->get_string() << " = " << val << ";" << endl; + } + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(out, name, ktype, v_iter->first); + string val = render_const_value(out, name, vtype, v_iter->second); + indent(out) << name << "[" << key << "]" << " = " << val << ";" << endl; + } + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string val = render_const_value(out, name, etype, *v_iter); + indent(out) << name << ".Add(" << val << ");" << endl; + } + } +} + +void t_csharp_generator::print_const_constructor(std::ofstream& out, std::vector consts) { + indent(out) << "static Constants()" << endl; + scope_up(out); + vector::iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + string name = (*c_iter)->get_name(); + t_type* type = (*c_iter)->get_type(); + t_const_value* value = (*c_iter)->get_value(); + + print_const_def_value(out, name, type, value); + } + scope_down(out); +} + + +//it seems like all that methods that call this are using in_static to be the opposite of what it would imply +bool t_csharp_generator::print_const_value(std::ofstream& out, string name, t_type* type, t_const_value* value, bool in_static, bool defval, bool needtype) { + indent(out); + bool need_static_construction = !in_static; + if (!defval || needtype) { + out << + (in_static ? "" : "public static ") << + type_name(type) << " "; + } + if (type->is_base_type()) { + string v2 = render_const_value(out, name, type, value); + out << name << " = " << v2 << ";" << endl; + need_static_construction = false; + } else if (type->is_enum()) { + out << name << " = (" << type_name(type, false, true) << ")" << value->get_integer() << ";" << endl; + need_static_construction = false; + } else if (type->is_struct() || type->is_xception()) { + out << name << " = new " << type_name(type) << "();" << endl; + } else if (type->is_map()) { + out << name << " = new " << type_name(type, true, true) << "();" << endl; + } else if (type->is_list() || type->is_set()) { + out << name << " = new " << type_name(type) << "();" << endl; + } + + if (defval && !type->is_base_type() && !type->is_enum()) { + print_const_def_value(out, name, type, value); + } + + return need_static_construction; +} + +std::string t_csharp_generator::render_const_value(ofstream& out, string name, t_type* type, t_const_value* value) { + std::ostringstream render; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + render << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + render << ((value->get_integer() > 0) ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + render << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + render << value->get_integer(); + } else { + render << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + tbase; + } + } else if (type->is_enum()) { + render << "(" << type->get_name() << ")" << value->get_integer(); + } else { + string t = tmp("tmp"); + print_const_value(out, t, type, value, true, true, true); + render << t; + } + + return render.str(); +} + +void t_csharp_generator::generate_struct(t_struct* tstruct) { + generate_csharp_struct(tstruct, false); +} + +void t_csharp_generator::generate_xception(t_struct* txception) { + generate_csharp_struct(txception, true); +} + +void t_csharp_generator::generate_csharp_struct(t_struct* tstruct, bool is_exception) { + string f_struct_name = namespace_dir_ + "/" + (tstruct->get_name()) + ".cs"; + ofstream f_struct; + + f_struct.open(f_struct_name.c_str()); + + f_struct << + autogen_comment() << + csharp_type_usings() << + csharp_thrift_usings(); + + generate_csharp_struct_definition(f_struct, tstruct, is_exception); + + f_struct.close(); +} + +void t_csharp_generator::generate_csharp_struct_definition(ofstream &out, t_struct* tstruct, bool is_exception, bool in_class, bool is_result) { + + if (!in_class) { + start_csharp_namespace(out); + } + + out << endl; + indent(out) << "[Serializable]" << endl; + bool is_final = (tstruct->annotations_.find("final") != tstruct->annotations_.end()); + + indent(out) << "public " << (is_final ? "sealed " : "") << "class " << tstruct->get_name() << " : "; + + if (is_exception) { + out << "Exception, "; + } + out << "TBase"; + + out << endl; + + scope_up(out); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + //make private members with public Properties + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + "private " << declare_field(*m_iter, false) << endl; + } + out << endl; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + generate_property(out, *m_iter, true); + } + + if (members.size() > 0) { + out << + endl << + indent() << "public Isset __isset;" << endl << + indent() << "[Serializable]" << endl << + indent() << "public struct Isset {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + "public bool " << (*m_iter)->get_name() << ";" << endl; + } + + indent_down(); + indent(out) << "}" << endl << endl; + } + + indent(out) << + "public " << tstruct->get_name() << "() {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = (*m_iter)->get_type(); + while (t->is_typedef()) { + t = ((t_typedef*)t)->get_type(); + } + if ((*m_iter)->get_value() != NULL) { + print_const_value(out, "this." + (*m_iter)->get_name(), t, (*m_iter)->get_value(), true, true); + } + } + + indent_down(); + indent(out) << "}" << endl << endl; + + generate_csharp_struct_reader(out, tstruct); + if (is_result) { + generate_csharp_struct_result_writer(out, tstruct); + } else { + generate_csharp_struct_writer(out, tstruct); + } + generate_csharp_struct_tostring(out, tstruct); + scope_down(out); + out << endl; + + if (!in_class) + { + end_csharp_namespace(out); + } +} + +void t_csharp_generator::generate_csharp_struct_reader(ofstream& out, t_struct* tstruct) { + indent(out) << + "public void Read (TProtocol iprot)" << endl; + scope_up(out); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + indent(out) << + "TField field;" << endl << + indent() << "iprot.ReadStructBegin();" << endl; + + indent(out) << + "while (true)" << endl; + scope_up(out); + + indent(out) << + "field = iprot.ReadFieldBegin();" << endl; + + indent(out) << + "if (field.Type == TType.Stop) { " << endl; + indent_up(); + indent(out) << + "break;" << endl; + indent_down(); + indent(out) << + "}" << endl; + + indent(out) << + "switch (field.ID)" << endl; + + scope_up(out); + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << (*f_iter)->get_key() << ":" << endl; + indent_up(); + indent(out) << + "if (field.Type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); + + generate_deserialize_field(out, *f_iter, "this."); + indent(out) << + "this.__isset." << (*f_iter)->get_name() << " = true;" << endl; + indent_down(); + out << + indent() << "} else { " << endl << + indent() << " TProtocolUtil.Skip(iprot, field.Type);" << endl << + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } + + indent(out) << + "default: " << endl; + indent_up(); + indent(out) << "TProtocolUtil.Skip(iprot, field.Type);" << endl; + indent(out) << "break;" << endl; + indent_down(); + + scope_down(out); + + indent(out) << + "iprot.ReadFieldEnd();" << endl; + + scope_down(out); + + indent(out) << + "iprot.ReadStructEnd();" << endl; + + indent_down(); + + indent(out) << "}" << endl << endl; + +} + +void t_csharp_generator::generate_csharp_struct_writer(ofstream& out, t_struct* tstruct) { + out << + indent() << "public void Write(TProtocol oprot) {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << + "TStruct struc = new TStruct(\"" << name << "\");" << endl; + indent(out) << + "oprot.WriteStructBegin(struc);" << endl; + + if (fields.size() > 0) { + indent(out) << "TField field = new TField();" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + bool null_allowed = type_can_be_null((*f_iter)->get_type()); + if (null_allowed) { + indent(out) << + "if (this." << (*f_iter)->get_name() << " != null && __isset." << (*f_iter)->get_name() << ") {" << endl; + indent_up(); + } + else + { + indent(out) << + "if (__isset." << (*f_iter)->get_name() << ") {" << endl; + indent_up(); + } + + indent(out) << + "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl; + indent(out) << + "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl; + indent(out) << + "field.ID = " << (*f_iter)->get_key() << ";" << endl; + indent(out) << + "oprot.WriteFieldBegin(field);" << endl; + + generate_serialize_field(out, *f_iter, "this."); + + indent(out) << + "oprot.WriteFieldEnd();" << endl; + + indent_down(); + indent(out) << "}" << endl; + } + } + + indent(out) << + "oprot.WriteFieldStop();" << endl; + indent(out) << + "oprot.WriteStructEnd();" << endl; + + indent_down(); + + indent(out) << + "}" << endl << endl; +} + +void t_csharp_generator::generate_csharp_struct_result_writer(ofstream& out, t_struct* tstruct) { + indent(out) << + "public void Write(TProtocol oprot) {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << + "TStruct struc = new TStruct(\"" << name << "\");" << endl; + indent(out) << + "oprot.WriteStructBegin(struc);" << endl; + + if (fields.size() > 0) { + indent(out) << "TField field = new TField();" << endl; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + out << + endl << indent() << "if "; + } else { + out << + " else if "; + } + + out << + "(this.__isset." << (*f_iter)->get_name() << ") {" << endl; + indent_up(); + + bool null_allowed = type_can_be_null((*f_iter)->get_type()); + if (null_allowed) { + indent(out) << + "if (this." << (*f_iter)->get_name() << " != null) {" << endl; + indent_up(); + } + + indent(out) << + "field.Name = \"" << (*f_iter)->get_name() << "\";" << endl; + indent(out) << + "field.Type = " << type_to_enum((*f_iter)->get_type()) << ";" << endl; + indent(out) << + "field.ID = " << (*f_iter)->get_key() << ";" << endl; + indent(out) << + "oprot.WriteFieldBegin(field);" << endl; + + generate_serialize_field(out, *f_iter, "this."); + + indent(out) << + "oprot.WriteFieldEnd();" << endl; + + if (null_allowed) { + indent_down(); + indent(out) << "}" << endl; + } + + indent_down(); + indent(out) << "}"; + } + } + + out << + endl << + indent() << "oprot.WriteFieldStop();" << endl << + indent() << "oprot.WriteStructEnd();" << endl; + + indent_down(); + + indent(out) << + "}" << endl << endl; +} + +void t_csharp_generator::generate_csharp_struct_tostring(ofstream& out, t_struct* tstruct) { + indent(out) << + "public override string ToString() {" << endl; + indent_up(); + + indent(out) << + "StringBuilder sb = new StringBuilder(\"" << tstruct->get_name() << "(\");" << endl; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + bool first = true; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + indent(out) << + "sb.Append(\"" << (*f_iter)->get_name() << ": \");" << endl; + } else { + indent(out) << + "sb.Append(\"," << (*f_iter)->get_name() << ": \");" << endl; + } + t_type* ttype = (*f_iter)->get_type(); + if (ttype->is_xception() || ttype->is_struct()) { + indent(out) << + "sb.Append(this." << (*f_iter)->get_name() << "== null ? \"\" : "<< "this." << (*f_iter)->get_name() << ".ToString());" << endl; + } else { + indent(out) << + "sb.Append(this." << (*f_iter)->get_name() << ");" << endl; + } + } + + indent(out) << + "sb.Append(\")\");" << endl; + indent(out) << + "return sb.ToString();" << endl; + + indent_down(); + indent(out) << "}" << endl << endl; +} + +void t_csharp_generator::generate_service(t_service* tservice) { + string f_service_name = namespace_dir_ + "/" + service_name_ + ".cs"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + autogen_comment() << + csharp_type_usings() << + csharp_thrift_usings(); + + start_csharp_namespace(f_service_); + + indent(f_service_) << + "public class " << service_name_ << " {" << endl; + indent_up(); + + generate_service_interface(tservice); + generate_service_client(tservice); + generate_service_server(tservice); + generate_service_helpers(tservice); + + indent_down(); + + indent(f_service_) << + "}" << endl; + end_csharp_namespace(f_service_); + f_service_.close(); +} + +void t_csharp_generator::generate_service_interface(t_service* tservice) { + string extends = ""; + string extends_iface = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_iface = " : " + extends + ".Iface"; + } + + indent(f_service_) << + "public interface Iface" << extends_iface << " {" << endl; + indent_up(); + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) + { + indent(f_service_) << + function_signature(*f_iter) << ";" << endl; + } + indent_down(); + f_service_ << + indent() << "}" << endl << endl; +} + +void t_csharp_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_csharp_struct_definition(f_service_, ts, false, true); + generate_function_helpers(*f_iter); + } +} + +void t_csharp_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_client = extends + ".Client, "; + } + + indent(f_service_) << + "public class Client : " << extends_client << "Iface {" << endl; + indent_up(); + indent(f_service_) << + "public Client(TProtocol prot) : this(prot, prot)" << endl; + scope_up(f_service_); + scope_down(f_service_); + f_service_ << endl; + + indent(f_service_) << + "public Client(TProtocol iprot, TProtocol oprot)"; + if (!extends.empty()) { + f_service_ << " : base(iprot, oprot)"; + } + f_service_ << endl; + + scope_up(f_service_); + if (extends.empty()) { + f_service_ << + indent() << "iprot_ = iprot;" << endl << + indent() << "oprot_ = oprot;" << endl; + } + scope_down(f_service_); + + f_service_ << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected TProtocol iprot_;" << endl << + indent() << "protected TProtocol oprot_;" << endl << + indent() << "protected int seqid_;" << endl << endl; + + f_service_ << indent() << "public TProtocol InputProtocol" << endl; + scope_up(f_service_); + indent(f_service_) << "get { return iprot_; }" << endl; + scope_down(f_service_); + + f_service_ << indent() << "public TProtocol OutputProtocol" << endl; + scope_up(f_service_); + indent(f_service_) << "get { return oprot_; }" << endl; + scope_down(f_service_); + f_service_ << endl << endl; + } + + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + + indent(f_service_) << + "public " << function_signature(*f_iter) << endl; + scope_up(f_service_); + indent(f_service_) << + "send_" << funname << "("; + + t_struct* arg_struct = (*f_iter)->get_arglist(); + + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << (*fld_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << "return "; + } + f_service_ << + "recv_" << funname << "();" << endl; + } + scope_down(f_service_); + f_service_ << endl; + + t_function send_function(g_type_void, + string("send_") + (*f_iter)->get_name(), + (*f_iter)->get_arglist()); + + string argsname = (*f_iter)->get_name() + "_args"; + + indent(f_service_) << + "public " << function_signature(&send_function) << endl; + scope_up(f_service_); + + f_service_ << + indent() << "oprot_.WriteMessageBegin(new TMessage(\"" << funname << "\", TMessageType.Call, seqid_));" << endl << + indent() << argsname << " args = new " << argsname << "();" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "args." << prop_name(*fld_iter) << " = " << (*fld_iter)->get_name() << ";" << endl; + } + + f_service_ << + indent() << "args.Write(oprot_);" << endl << + indent() << "oprot_.WriteMessageEnd();" << endl << + indent() << "oprot_.Transport.Flush();" << endl; + + scope_down(f_service_); + f_service_ << endl; + + if (!(*f_iter)->is_oneway()) { + string resultname = (*f_iter)->get_name() + "_result"; + + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs, + (*f_iter)->get_xceptions()); + indent(f_service_) << + "public " << function_signature(&recv_function) << endl; + scope_up(f_service_); + + f_service_ << + indent() << "TMessage msg = iprot_.ReadMessageBegin();" << endl << + indent() << "if (msg.Type == TMessageType.Exception) {" << endl; + indent_up(); + f_service_ << + indent() << "TApplicationException x = TApplicationException.Read(iprot_);" << endl << + indent() << "iprot_.ReadMessageEnd();" << endl << + indent() << "throw x;" << endl; + indent_down(); + f_service_ << + indent() << "}" << endl << + indent() << resultname << " result = new " << resultname << "();" << endl << + indent() << "result.Read(iprot_);" << endl << + indent() << "iprot_.ReadMessageEnd();" << endl; + + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "if (result.__isset.success) {" << endl << + indent() << " return result.Success;" << endl << + indent() << "}" << endl; + } + + t_struct *xs = (*f_iter)->get_xceptions(); + + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "if (result.__isset." << (*x_iter)->get_name() << ") {" << endl << + indent() << " throw result." << prop_name(*x_iter) << ";" << endl << + indent() << "}" << endl; + } + + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "return;" << endl; + } else { + f_service_ << + indent() << "throw new TApplicationException(TApplicationException.ExceptionType.MissingResult, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; + } + + scope_down(f_service_); + f_service_ << endl; + } + } + + indent_down(); + indent(f_service_) << + "}" << endl; +} + +void t_csharp_generator::generate_service_server(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_processor = extends + ".Processor, "; + } + + indent(f_service_) << + "public class Processor : " << extends_processor << "TProcessor {" << endl; + indent_up(); + + indent(f_service_) << + "public Processor(Iface iface)" ; + if (!extends.empty()) { + f_service_ << " : base(iface)"; + } + f_service_ << endl; + scope_up(f_service_); + f_service_ << + indent() << "iface_ = iface;" << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "processMap_[\"" << (*f_iter)->get_name() << "\"] = " << (*f_iter)->get_name() << "_Process;" << endl; + } + + scope_down(f_service_); + f_service_ << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected delegate void ProcessFunction(int seqid, TProtocol iprot, TProtocol oprot);" << endl; + } + + f_service_ << + indent() << "private Iface iface_;" << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected Dictionary processMap_ = new Dictionary();" << endl; + } + + f_service_ << endl; + + if (extends.empty()) { + indent(f_service_) << + "public bool Process(TProtocol iprot, TProtocol oprot)" << endl; + } + else + { + indent(f_service_) << + "public new bool Process(TProtocol iprot, TProtocol oprot)" << endl; + } + scope_up(f_service_); + + f_service_ << indent() << "try" << endl; + scope_up(f_service_); + + f_service_ << + indent() << "TMessage msg = iprot.ReadMessageBegin();" << endl; + + f_service_ << + indent() << "ProcessFunction fn;" << endl << + indent() << "processMap_.TryGetValue(msg.Name, out fn);" << endl << + indent() << "if (fn == null) {" << endl << + indent() << " TProtocolUtil.Skip(iprot, TType.Struct);" << endl << + indent() << " iprot.ReadMessageEnd();" << endl << + indent() << " TApplicationException x = new TApplicationException (TApplicationException.ExceptionType.UnknownMethod, \"Invalid method name: '\" + msg.Name + \"'\");" << endl << + indent() << " oprot.WriteMessageBegin(new TMessage(msg.Name, TMessageType.Exception, msg.SeqID));" << endl << + indent() << " x.Write(oprot);" << endl << + indent() << " oprot.WriteMessageEnd();" << endl << + indent() << " oprot.Transport.Flush();" << endl << + indent() << " return true;" << endl << + indent() << "}" << endl << + indent() << "fn(msg.SeqID, iprot, oprot);" << endl; + + scope_down(f_service_); + + f_service_ << + indent() << "catch (IOException)" << endl; + scope_up(f_service_); + f_service_ << + indent() << "return false;" << endl; + scope_down(f_service_); + + f_service_ << + indent() << "return true;" << endl; + + scope_down(f_service_); + f_service_ << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) + { + generate_process_function(tservice, *f_iter); + } + + indent_down(); + indent(f_service_) << + "}" << endl << endl; +} + +void t_csharp_generator::generate_function_helpers(t_function* tfunction) { + if (tfunction->is_oneway()) { + return; + } + + t_struct result(program_, tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct *xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + generate_csharp_struct_definition(f_service_, &result, false, true, true); +} + +void t_csharp_generator::generate_process_function(t_service* tservice, t_function* tfunction) { + indent(f_service_) << + "public void " << tfunction->get_name() << "_Process(int seqid, TProtocol iprot, TProtocol oprot)" << endl; + scope_up(f_service_); + + string argsname = tfunction->get_name() + "_args"; + string resultname = tfunction->get_name() + "_result"; + + f_service_ << + indent() << argsname << " args = new " << argsname << "();" << endl << + indent() << "args.Read(iprot);" << endl << + indent() << "iprot.ReadMessageEnd();" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + if (!tfunction->is_oneway()) { + f_service_ << + indent() << resultname << " result = new " << resultname << "();" << endl; + } + + if (xceptions.size() > 0) { + f_service_ << + indent() << "try {" << endl; + indent_up(); + } + + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "result.Success = "; + } + f_service_ << + "iface_." << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << prop_name(*f_iter); + } + f_service_ << ");" << endl; + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + f_service_ << indent() << "}"; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << " catch (" << type_name((*x_iter)->get_type(), false, false) << " " << (*x_iter)->get_name() << ") {" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << prop_name(*x_iter) << " = " << (*x_iter)->get_name() << ";" << endl; + indent_down(); + f_service_ << indent() << "}"; + } else { + f_service_ << "}"; + } + } + f_service_ << endl; + } + + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return;" << endl; + scope_down(f_service_); + + return; + } + + f_service_ << + indent() << "oprot.WriteMessageBegin(new TMessage(\"" << tfunction->get_name() << "\", TMessageType.Reply, seqid)); " << endl << + indent() << "result.Write(oprot);" << endl << + indent() << "oprot.WriteMessageEnd();" << endl << + indent() << "oprot.Transport.Flush();" << endl; + + scope_down(f_service_); + + f_service_ << endl; +} + +void t_csharp_generator::generate_deserialize_field(ofstream& out, t_field* tfield, string prefix) { + t_type* type = tfield->get_type(); + while(type->is_typedef()) { + type = ((t_typedef*)type)->get_type(); + } + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name(); + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, (t_struct*)type, name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + name << " = "; + + if (type->is_enum()) + { + out << "(" << type_name(type, false, true) << ")"; + } + + out << "iprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "ReadBinary();"; + } else { + out << "ReadString();"; + } + break; + case t_base_type::TYPE_BOOL: + out << "ReadBool();"; + break; + case t_base_type::TYPE_BYTE: + out << "ReadByte();"; + break; + case t_base_type::TYPE_I16: + out << "ReadI16();"; + break; + case t_base_type::TYPE_I32: + out << "ReadI32();"; + break; + case t_base_type::TYPE_I64: + out << "ReadI64();"; + break; + case t_base_type::TYPE_DOUBLE: + out << "ReadDouble();"; + break; + default: + throw "compiler error: no C# name for base type " + tbase; + } + } else if (type->is_enum()) { + out << "ReadI32();"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", tfield->get_name().c_str(), type_name(type).c_str()); + } +} + +void t_csharp_generator::generate_deserialize_struct(ofstream& out, t_struct* tstruct, string prefix) { + out << + indent() << prefix << " = new " << type_name(tstruct) << "();" << endl << + indent() << prefix << ".Read(iprot);" << endl; +} + +void t_csharp_generator::generate_deserialize_container(ofstream& out, t_type* ttype, string prefix) { + scope_up(out); + + string obj; + + if (ttype->is_map()) { + obj = tmp("_map"); + } else if (ttype->is_set()) { + obj = tmp("_set"); + } else if (ttype->is_list()) { + obj = tmp("_list"); + } + + indent(out) << + prefix << " = new " << type_name(ttype, false, true) << "();" <is_map()) { + out << + indent() << "TMap " << obj << " = iprot.ReadMapBegin();" << endl; + } else if (ttype->is_set()) { + out << + indent() << "TSet " << obj << " = iprot.ReadSetBegin();" << endl; + } else if (ttype->is_list()) { + out << + indent() << "TList " << obj << " = iprot.ReadListBegin();" << endl; + } + + string i = tmp("_i"); + indent(out) << + "for( int " << i << " = 0; " << i << " < " << obj << ".Count" << "; " << "++" << i << ")" << endl; + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + scope_down(out); + + if (ttype->is_map()) { + indent(out) << "iprot.ReadMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "iprot.ReadSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "iprot.ReadListEnd();" << endl; + } + + scope_down(out); +} + +void t_csharp_generator::generate_deserialize_map_element(ofstream& out, t_map* tmap, string prefix) { + string key = tmp("_key"); + string val = tmp("_val"); + + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + indent(out) << + declare_field(&fkey) << endl; + indent(out) << + declare_field(&fval) << endl; + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << + prefix << "[" << key << "] = " << val << ";" << endl; +} + +void t_csharp_generator::generate_deserialize_set_element(ofstream& out, t_set* tset, string prefix) { + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + indent(out) << + declare_field(&felem, true) << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".Add(" << elem << ");" << endl; +} + +void t_csharp_generator::generate_deserialize_list_element(ofstream& out, t_list* tlist, string prefix) { + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + + indent(out) << + declare_field(&felem, true) << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".Add(" << elem << ");" << endl; +} + +void t_csharp_generator::generate_serialize_field(ofstream& out, t_field* tfield, string prefix) { + t_type* type = tfield->get_type(); + while (type->is_typedef()) { + type = ((t_typedef*)type)->get_type(); + } + + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + prefix + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, (t_struct*)type, prefix + tfield->get_name()); + } else if (type->is_container()) { + generate_serialize_container(out, type, prefix + tfield->get_name()); + } else if (type->is_base_type() || type->is_enum()) { + string name = prefix + tfield->get_name(); + indent(out) << + "oprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + + switch(tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "WriteBinary("; + } else { + out << "WriteString("; + } + out << name << ");"; + break; + case t_base_type::TYPE_BOOL: + out << "WriteBool(" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "WriteByte(" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "WriteI16(" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "WriteI32(" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "WriteI64(" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "WriteDouble(" << name << ");"; + break; + default: + throw "compiler error: no C# name for base type " + tbase; + } + } else if (type->is_enum()) { + out << "WriteI32((int)" << name << ");"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE '%s%s' TYPE '%s'\n", + prefix.c_str(), + tfield->get_name().c_str(), + type_name(type).c_str()); + } +} + +void t_csharp_generator::generate_serialize_struct(ofstream& out, t_struct* tstruct, string prefix) { + out << + indent() << prefix << ".Write(oprot);" << endl; +} + +void t_csharp_generator::generate_serialize_container(ofstream& out, t_type* ttype, string prefix) { + scope_up(out); + + if (ttype->is_map()) { + indent(out) << + "oprot.WriteMapBegin(new TMap(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + prefix << ".Count));" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot.WriteSetBegin(new TSet(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + prefix << ".Count));" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot.WriteListBegin(new TList(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + prefix << ".Count));" << endl; + } + + string iter = tmp("_iter"); + if (ttype->is_map()) { + indent(out) << + "foreach (" << + type_name(((t_map*)ttype)->get_key_type()) << " " << iter << + " in " << + prefix << ".Keys)"; + } else if (ttype->is_set()) { + indent(out) << + "foreach (" << + type_name(((t_set*)ttype)->get_elem_type()) << " " << iter << + " in " << + prefix << ")"; + } else if (ttype->is_list()) { + indent(out) << + "foreach (" << + type_name(((t_list*)ttype)->get_elem_type()) << " " << iter << + " in " << + prefix << ")"; + } + + out << endl; + scope_up(out); + + if (ttype->is_map()) { + generate_serialize_map_element(out, (t_map*)ttype, iter, prefix); + } else if (ttype->is_set()) { + generate_serialize_set_element(out, (t_set*)ttype, iter); + } else if (ttype->is_list()) { + generate_serialize_list_element(out, (t_list*)ttype, iter); + } + + if (ttype->is_map()) { + indent(out) << "oprot.WriteMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "oprot.WriteSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "oprot.WriteListEnd();" << endl; + } + + scope_down(out); + scope_down(out); +} + +void t_csharp_generator::generate_serialize_map_element(ofstream& out, t_map* tmap, string iter, string map) { + t_field kfield(tmap->get_key_type(), iter); + generate_serialize_field(out, &kfield, ""); + t_field vfield(tmap->get_val_type(), map + "[" + iter + "]"); + generate_serialize_field(out, &vfield, ""); +} + +void t_csharp_generator::generate_serialize_set_element(ofstream& out, t_set* tset, string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +void t_csharp_generator::generate_serialize_list_element(ofstream& out, t_list* tlist, string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +void t_csharp_generator::generate_property(ofstream& out, t_field* tfield, bool isPublic) { + indent(out) << (isPublic ? "public " : "private ") << type_name(tfield->get_type()) + << " " << prop_name(tfield) << endl; + scope_up(out); + indent(out) << "get" << endl; + scope_up(out); + indent(out) << "return " << tfield->get_name() << ";" << endl; + scope_down(out); + indent(out) << "set" << endl; + scope_up(out); + indent(out) << "__isset." << tfield->get_name() << " = true;" << endl; + indent(out) << "this." << tfield->get_name() << " = value;" << endl; + scope_down(out); + scope_down(out); + out << endl; +} + +std::string t_csharp_generator::prop_name(t_field* tfield) { + string name (tfield->get_name()); + name[0] = toupper(name[0]); + return name; +} + +string t_csharp_generator::type_name(t_type* ttype, bool in_container, bool in_init) { + while (ttype->is_typedef()) { + ttype = ((t_typedef*)ttype)->get_type(); + } + + if (ttype->is_base_type()) { + return base_type_name((t_base_type*)ttype, in_container); + } else if (ttype->is_map()) { + t_map *tmap = (t_map*) ttype; + return "Dictionary<" + type_name(tmap->get_key_type(), true) + + ", " + type_name(tmap->get_val_type(), true) + ">"; + } else if (ttype->is_set()) { + t_set* tset = (t_set*) ttype; + return "THashSet<" + type_name(tset->get_elem_type(), true) + ">"; + } else if (ttype->is_list()) { + t_list* tlist = (t_list*) ttype; + return "List<" + type_name(tlist->get_elem_type(), true) + ">"; + } + + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + string ns = program->get_namespace("csharp"); + if (!ns.empty()) { + return ns + "." + ttype->get_name(); + } + } + + return ttype->get_name(); +} + +string t_csharp_generator::base_type_name(t_base_type* tbase, bool in_container) { + switch (tbase->get_base()) { + case t_base_type::TYPE_VOID: + return "void"; + case t_base_type::TYPE_STRING: + if (tbase->is_binary()) { + return "byte[]"; + } else { + return "string"; + } + case t_base_type::TYPE_BOOL: + return "bool"; + case t_base_type::TYPE_BYTE: + return "byte"; + case t_base_type::TYPE_I16: + return "short"; + case t_base_type::TYPE_I32: + return "int"; + case t_base_type::TYPE_I64: + return "long"; + case t_base_type::TYPE_DOUBLE: + return "double"; + default: + throw "compiler error: no C# name for base type " + tbase->get_base(); + } +} + +string t_csharp_generator::declare_field(t_field* tfield, bool init) { + string result = type_name(tfield->get_type()) + " " + tfield->get_name(); + if (init) { + t_type* ttype = tfield->get_type(); + while (ttype->is_typedef()) { + ttype = ((t_typedef*)ttype)->get_type(); + } + if (ttype->is_base_type() && tfield->get_value() != NULL) { + ofstream dummy; + result += " = " + render_const_value(dummy, tfield->get_name(), ttype, tfield->get_value()); + } else if (ttype->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + result += " = null"; + break; + case t_base_type::TYPE_BOOL: + result += " = false"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + result += " = 0"; + break; + case t_base_type::TYPE_DOUBLE: + result += " = (double)0"; + break; + } + } else if (ttype->is_enum()) { + result += " = (" + type_name(ttype, false, true) + ")0"; + } else if (ttype->is_container()) { + result += " = new " + type_name(ttype, false, true) + "()"; + } else { + result += " = new " + type_name(ttype, false, true) + "()"; + } + } + return result + ";"; +} + +string t_csharp_generator::function_signature(t_function* tfunction, string prefix) { + t_type* ttype = tfunction->get_returntype(); + return type_name(ttype) + " " + prefix + tfunction->get_name() + "(" + argument_list(tfunction->get_arglist()) + ")"; +} + +string t_csharp_generator::argument_list(t_struct* tstruct) { + string result = ""; + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += type_name((*f_iter)->get_type()) + " " + (*f_iter)->get_name(); + } + return result; +} + +string t_csharp_generator::type_to_enum(t_type* type) { + while (type->is_typedef()) { + type = ((t_typedef*)type)->get_type(); + } + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType.String"; + case t_base_type::TYPE_BOOL: + return "TType.Bool"; + case t_base_type::TYPE_BYTE: + return "TType.Byte"; + case t_base_type::TYPE_I16: + return "TType.I16"; + case t_base_type::TYPE_I32: + return "TType.I32"; + case t_base_type::TYPE_I64: + return "TType.I64"; + case t_base_type::TYPE_DOUBLE: + return "TType.Double"; + } + } else if (type->is_enum()) { + return "TType.I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType.Struct"; + } else if (type->is_map()) { + return "TType.Map"; + } else if (type->is_set()) { + return "TType.Set"; + } else if (type->is_list()) { + return "TType.List"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +THRIFT_REGISTER_GENERATOR(csharp, "C#", ""); diff --git a/compiler/cpp/src/generate/t_erl_generator.cc b/compiler/cpp/src/generate/t_erl_generator.cc new file mode 100644 index 00000000..0aff4f39 --- /dev/null +++ b/compiler/cpp/src/generate/t_erl_generator.cc @@ -0,0 +1,932 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "t_generator.h" +#include "platform.h" + +using namespace std; + + +/** + * Erlang code generator. + * + */ +class t_erl_generator : public t_generator { + public: + t_erl_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_generator(program) + { + program_name_[0] = tolower(program_name_[0]); + service_name_[0] = tolower(service_name_[0]); + out_dir_base_ = "gen-erl"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_erl_struct(t_struct* tstruct, bool is_exception); + void generate_erl_struct_definition(std::ostream& out, std::ostream& hrl_out, t_struct* tstruct, bool is_xception=false, bool is_result=false); + void generate_erl_struct_info(std::ostream& out, t_struct* tstruct); + void generate_erl_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_function_info (t_service* tservice, t_function* tfunction); + + /** + * Helper rendering functions + */ + + std::string erl_autogen_comment(); + std::string erl_imports(); + std::string render_includes(); + std::string declare_field(t_field* tfield); + std::string type_name(t_type* ttype); + + std::string function_signature(t_function* tfunction, std::string prefix=""); + + + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string generate_type_term(t_type* ttype, bool expand_structs); + std::string type_module(t_type* ttype); + + std::string capitalize(std::string in) { + in[0] = toupper(in[0]); + return in; + } + + std::string uncapitalize(std::string in) { + in[0] = tolower(in[0]); + return in; + } + + private: + + /** + * add function to export list + */ + + void export_function(t_function* tfunction, std::string prefix=""); + void export_string(std::string name, int num); + + void export_types_function(t_function* tfunction, std::string prefix=""); + void export_types_string(std::string name, int num); + + /** + * write out headers and footers for hrl files + */ + + void hrl_header(std::ostream& out, std::string name); + void hrl_footer(std::ostream& out, std::string name); + + /** + * stuff to spit out at the top of generated files + */ + + bool export_lines_first_; + std::ostringstream export_lines_; + + bool export_types_lines_first_; + std::ostringstream export_types_lines_; + + /** + * File streams + */ + + std::ostringstream f_types_; + std::ofstream f_types_file_; + std::ofstream f_types_hrl_file_; + + std::ofstream f_consts_; + std::ostringstream f_service_; + std::ofstream f_service_file_; + std::ofstream f_service_hrl_; + +}; + + +/** + * UI for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_erl_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // setup export lines + export_lines_first_ = true; + export_types_lines_first_ = true; + + // types files + string f_types_name = get_out_dir()+program_name_+"_types.erl"; + string f_types_hrl_name = get_out_dir()+program_name_+"_types.hrl"; + + f_types_file_.open(f_types_name.c_str()); + f_types_hrl_file_.open(f_types_hrl_name.c_str()); + + hrl_header(f_types_hrl_file_, program_name_ + "_types"); + + f_types_file_ << + erl_autogen_comment() << endl << + "-module(" << program_name_ << "_types)." << endl << + erl_imports() << endl; + + f_types_file_ << + "-include(\"" << program_name_ << "_types.hrl\")." << endl << + endl; + + f_types_hrl_file_ << render_includes() << endl; + + // consts file + string f_consts_name = get_out_dir()+program_name_+"_constants.hrl"; + f_consts_.open(f_consts_name.c_str()); + + f_consts_ << + erl_autogen_comment() << endl << + erl_imports() << endl << + "-include(\"" << program_name_ << "_types.hrl\")." << endl << + endl; +} + +/** + * Boilerplate at beginning and end of header files + */ +void t_erl_generator::hrl_header(ostream& out, string name) { + out << "-ifndef(_" << name << "_included)." << endl << + "-define(_" << name << "_included, yeah)." << endl; +} + +void t_erl_generator::hrl_footer(ostream& out, string name) { + out << "-endif." << endl; +} + +/** + * Renders all the imports necessary for including another Thrift program + */ +string t_erl_generator::render_includes() { + const vector& includes = program_->get_includes(); + string result = ""; + for (size_t i = 0; i < includes.size(); ++i) { + result += "-include(\"" + includes[i]->get_name() + "_types.hrl\").\n"; + } + if (includes.size() > 0) { + result += "\n"; + } + return result; +} + +/** + * Autogen'd comment + */ +string t_erl_generator::erl_autogen_comment() { + return + std::string("%%\n") + + "%% Autogenerated by Thrift\n" + + "%%\n" + + "%% DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "%%\n"; +} + +/** + * Prints standard thrift imports + */ +string t_erl_generator::erl_imports() { + return ""; +} + +/** + * Closes the type files + */ +void t_erl_generator::close_generator() { + // Close types file + export_types_string("struct_info", 1); + + f_types_file_ << "-export([" << export_types_lines_.str() << "])." << endl; + f_types_file_ << f_types_.str(); + f_types_file_ << "struct_info('i am a dummy struct') -> undefined." << endl; + + hrl_footer(f_types_hrl_file_, string("BOGUS")); + + f_types_file_.close(); + f_types_hrl_file_.close(); + f_consts_.close(); +} + +/** + * Generates a typedef. no op + * + * @param ttypedef The type definition + */ +void t_erl_generator::generate_typedef(t_typedef* ttypedef) { +} + +/** + * Generates code for an enumerated type. Done using a class to scope + * the values. + * + * @param tenum The enumeration + */ +void t_erl_generator::generate_enum(t_enum* tenum) { + vector constants = tenum->get_constants(); + vector::iterator c_iter; + + int value = -1; + + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + string name = capitalize((*c_iter)->get_name()); + + f_types_hrl_file_ << + indent() << "-define(" << program_name_ << "_" << name << ", " << value << ")."<< endl; + } + + f_types_hrl_file_ << endl; +} + +/** + * Generate a constant value + */ +void t_erl_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = capitalize(tconst->get_name()); + t_const_value* value = tconst->get_value(); + + f_consts_ << "-define(" << program_name_ << "_" << name << ", " << render_const_value(type, value) << ")." << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_erl_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + indent(out) << value->get_integer(); + + } else if (type->is_struct() || type->is_xception()) { + out << "#" << type->get_name() << "{"; + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + + bool first = true; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + + if (first) { + first = false; + } else { + out << ","; + } + out << v_iter->first->get_string(); + out << " = "; + out << render_const_value(field_type, v_iter->second); + } + indent_down(); + indent(out) << "}"; + + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + + bool first = true; + out << "dict:from_list(["; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (first) { + first=false; + } else { + out << ","; + } + out << "(" + << render_const_value(ktype, v_iter->first) << "," + << render_const_value(vtype, v_iter->second) << ")"; + } + out << "])"; + + } else if (type->is_set()) { + t_type* etype; + etype = ((t_set*)type)->get_elem_type(); + + bool first = true; + const vector& val = value->get_list(); + vector::const_iterator v_iter; + out << "sets:from_list(["; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (first) { + first=false; + } else { + out << ","; + } + out << "(" << render_const_value(etype, *v_iter) << ",true)"; + } + out << "])"; + + } else if (type->is_list()) { + t_type* etype; + etype = ((t_list*)type)->get_elem_type(); + out << "["; + + bool first = true; + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (first) { + first=false; + } else { + out << ","; + } + out << render_const_value(etype, *v_iter); + } + out << "]"; + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + return out.str(); +} + +/** + * Generates a struct + */ +void t_erl_generator::generate_struct(t_struct* tstruct) { + generate_erl_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_erl_generator::generate_xception(t_struct* txception) { + generate_erl_struct(txception, true); +} + +/** + * Generates a struct + */ +void t_erl_generator::generate_erl_struct(t_struct* tstruct, + bool is_exception) { + generate_erl_struct_definition(f_types_, f_types_hrl_file_, tstruct, is_exception); +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_erl_generator::generate_erl_struct_definition(ostream& out, + ostream& hrl_out, + t_struct* tstruct, + bool is_exception, + bool is_result) +{ + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + indent(out) << "%% struct " << type_name(tstruct) << endl; + + if (is_exception) { + } + + out << endl; + + if (members.size() > 0) { + indent(out) << "% -record(" << type_name(tstruct) << ", {"; + indent(hrl_out) << "-record(" << type_name(tstruct) << ", {"; + + bool first = true; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (first) { + first = false; + } else { + out << ", "; + hrl_out << ", "; + } + std::string name = uncapitalize((*m_iter)->get_name()); + out << name; + hrl_out << name; + } + out << "})." << endl; + hrl_out << "})." << endl; + } else { // no members; explicit comment + indent(out) << "% -record(" << type_name(tstruct) << ", {})." << endl; + indent(hrl_out) << "-record(" << type_name(tstruct) << ", {})." << endl; + } + + out << endl; + hrl_out << endl; + + + generate_erl_struct_info(out, tstruct); +} + +/** + * Generates the read method for a struct + */ +void t_erl_generator::generate_erl_struct_info(ostream& out, + t_struct* tstruct) { + string name = type_name(tstruct); + + indent(out) << "struct_info('" << name << "') ->" << endl; + indent_up(); + + out << indent() << generate_type_term(tstruct, true) << ";" << endl; + + indent_down(); + out << endl; +} + + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_erl_generator::generate_service(t_service* tservice) { + // somehow this point is reached before the constructor and it's not downcased yet + // ...awesome + service_name_[0] = tolower(service_name_[0]); + + string f_service_hrl_name = get_out_dir()+service_name_+"_thrift.hrl"; + string f_service_name = get_out_dir()+service_name_+"_thrift.erl"; + f_service_file_.open(f_service_name.c_str()); + f_service_hrl_.open(f_service_hrl_name.c_str()); + + // Reset service text aggregating stream streams + f_service_.str(""); + export_lines_.str(""); + export_lines_first_ = true; + + hrl_header(f_service_hrl_, service_name_); + + if (tservice->get_extends() != NULL) { + f_service_hrl_ << "-include(\"" << + uncapitalize(tservice->get_extends()->get_name()) << "_thrift.hrl\"). % inherit " << endl; + } + + f_service_hrl_ << + "-include(\"" << program_name_ << "_types.hrl\")." << endl << + endl; + + // Generate the three main parts of the service (well, two for now in PHP) + generate_service_helpers(tservice); // cpiro: New Erlang Order + + generate_service_interface(tservice); + + // indent_down(); + + f_service_file_ << + erl_autogen_comment() << endl << + "-module(" << service_name_ << "_thrift)." << endl << + "-behaviour(thrift_service)." << endl << endl << + erl_imports() << endl; + + f_service_file_ << "-include(\"" << uncapitalize(tservice->get_name()) << "_thrift.hrl\")." << endl << endl; + + f_service_file_ << "-export([" << export_lines_.str() << "])." << endl << endl; + + f_service_file_ << f_service_.str(); + + hrl_footer(f_service_hrl_, f_service_name); + + // Close service file + f_service_file_.close(); + f_service_hrl_.close(); +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_erl_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + // indent(f_service_) << + // "% HELPER FUNCTIONS AND STRUCTURES" << endl << endl; + + export_string("struct_info", 1); + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_erl_function_helpers(*f_iter); + } + f_service_ << "struct_info('i am a dummy struct') -> undefined." << endl; +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_erl_generator::generate_erl_function_helpers(t_function* tfunction) { +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_erl_generator::generate_service_interface(t_service* tservice) { + + export_string("function_info", 2); + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + f_service_ << "%%% interface" << endl; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "% " << function_signature(*f_iter) << endl; + + generate_function_info(tservice, *f_iter); + } + + // Inheritance - pass unknown functions to base class + if (tservice->get_extends() != NULL) { + indent(f_service_) << "function_info(Function, InfoType) ->" << endl; + indent_up(); + indent(f_service_) << uncapitalize(tservice->get_extends()->get_name()) + << "_thrift:function_info(Function, InfoType)." << endl; + indent_down(); + } else { + // Dummy function_info so we don't worry about the ;s + indent(f_service_) << "function_info(xxx, dummy) -> dummy." << endl; + } + + indent(f_service_) << endl; +} + + +/** + * Generates a function_info(FunctionName, params_type) and + * function_info(FunctionName, reply_type) + */ +void t_erl_generator::generate_function_info(t_service* tservice, + t_function* tfunction) { + + string name_atom = "'" + tfunction->get_name() + "'"; + + + + t_struct* xs = tfunction->get_xceptions(); + t_struct* arg_struct = tfunction->get_arglist(); + + // function_info(Function, params_type): + indent(f_service_) << + "function_info(" << name_atom << ", params_type) ->" << endl; + indent_up(); + + indent(f_service_) << generate_type_term(arg_struct, true) << ";" << endl; + + indent_down(); + + // function_info(Function, reply_type): + indent(f_service_) << + "function_info(" << name_atom << ", reply_type) ->" << endl; + indent_up(); + + if (!tfunction->get_returntype()->is_void()) + indent(f_service_) << + generate_type_term(tfunction->get_returntype(), false) << ";" << endl; + else if (tfunction->is_oneway()) + indent(f_service_) << "oneway_void;" << endl; + else + indent(f_service_) << "{struct, []}" << ";" << endl; + indent_down(); + + // function_info(Function, exceptions): + indent(f_service_) << + "function_info(" << name_atom << ", exceptions) ->" << endl; + indent_up(); + indent(f_service_) << generate_type_term(xs, true) << ";" << endl; + indent_down(); +} + + +/** + * Declares a field, which may include initialization as necessary. + * + * @param ttype The type + */ +string t_erl_generator::declare_field(t_field* tfield) { // TODO + string result = "@" + tfield->get_name(); + t_type* type = get_true_type(tfield->get_type()); + if (tfield->get_value() != NULL) { + result += " = " + render_const_value(type, tfield->get_value()); + } else { + result += " = nil"; + } + return result; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_erl_generator::function_signature(t_function* tfunction, + string prefix) { + return + prefix + tfunction->get_name() + + "(This" + capitalize(argument_list(tfunction->get_arglist())) + ")"; +} + +/** + * Add a function to the exports list + */ +void t_erl_generator::export_string(string name, int num) { + if (export_lines_first_) { + export_lines_first_ = false; + } else { + export_lines_ << ", "; + } + export_lines_ << name << "/" << num; +} + +void t_erl_generator::export_types_function(t_function* tfunction, + string prefix) { + + export_types_string(prefix + tfunction->get_name(), + 1 // This + + ((tfunction->get_arglist())->get_members()).size() + ); +} + +void t_erl_generator::export_types_string(string name, int num) { + if (export_types_lines_first_) { + export_types_lines_first_ = false; + } else { + export_types_lines_ << ", "; + } + export_types_lines_ << name << "/" << num; +} + +void t_erl_generator::export_function(t_function* tfunction, + string prefix) { + + export_string(prefix + tfunction->get_name(), + 1 // This + + ((tfunction->get_arglist())->get_members()).size() + ); +} + + +/** + * Renders a field list + */ +string t_erl_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + result += ", "; // initial comma to compensate for initial This + } else { + result += ", "; + } + result += capitalize((*f_iter)->get_name()); + } + return result; +} + +string t_erl_generator::type_name(t_type* ttype) { + string prefix = ""; + string name = ttype->get_name(); + + if (ttype->is_struct() || ttype->is_xception() || ttype->is_service()) { + name = uncapitalize(ttype->get_name()); + } + + return prefix + name; +} + +/** + * Converts the parse type to a Erlang "type" (macro for int constants) + */ +string t_erl_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "?tType_STRING"; + case t_base_type::TYPE_BOOL: + return "?tType_BOOL"; + case t_base_type::TYPE_BYTE: + return "?tType_BYTE"; + case t_base_type::TYPE_I16: + return "?tType_I16"; + case t_base_type::TYPE_I32: + return "?tType_I32"; + case t_base_type::TYPE_I64: + return "?tType_I64"; + case t_base_type::TYPE_DOUBLE: + return "?tType_DOUBLE"; + } + } else if (type->is_enum()) { + return "?tType_I32"; + } else if (type->is_struct() || type->is_xception()) { + return "?tType_STRUCT"; + } else if (type->is_map()) { + return "?tType_MAP"; + } else if (type->is_set()) { + return "?tType_SET"; + } else if (type->is_list()) { + return "?tType_LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +/** + * Generate an Erlang term which represents a thrift type + */ +std::string t_erl_generator::generate_type_term(t_type* type, + bool expand_structs) { + type = get_true_type(type); + + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "string"; + case t_base_type::TYPE_BOOL: + return "bool"; + case t_base_type::TYPE_BYTE: + return "byte"; + case t_base_type::TYPE_I16: + return "i16"; + case t_base_type::TYPE_I32: + return "i32"; + case t_base_type::TYPE_I64: + return "i64"; + case t_base_type::TYPE_DOUBLE: + return "double"; + } + } else if (type->is_enum()) { + return "i32"; + } else if (type->is_struct() || type->is_xception()) { + if (expand_structs) { + // Convert to format: {struct, [{Fid, TypeTerm}, {Fid, TypeTerm}...]} + std::stringstream ret; + + + ret << "{struct, ["; + + int first = true; + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + // Comma separate the tuples + if (!first) ret << "," << endl << indent(); + first = false; + + ret << "{" << (*f_iter)->get_key() << ", " << + generate_type_term((*f_iter)->get_type(), false) << "}"; + } + + ret << "]}" << endl; + + return ret.str(); + } else { + return "{struct, {'" + type_module(type) + "', '" + type_name(type) + "'}}"; + } + } else if (type->is_map()) { + // {map, KeyType, ValType} + t_type *key_type = ((t_map*)type)->get_key_type(); + t_type *val_type = ((t_map*)type)->get_val_type(); + + return "{map, " + generate_type_term(key_type, false) + ", " + + generate_type_term(val_type, false) + "}"; + + } else if (type->is_set()) { + t_type *elem_type = ((t_set*)type)->get_elem_type(); + + return "{set, " + generate_type_term(elem_type, false) + "}"; + + } else if (type->is_list()) { + t_type *elem_type = ((t_list*)type)->get_elem_type(); + + return "{list, " + generate_type_term(elem_type, false) + "}"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +std::string t_erl_generator::type_module(t_type* ttype) { + return uncapitalize(ttype->get_program()->get_name()) + "_types"; +} + +THRIFT_REGISTER_GENERATOR(erl, "Erlang", ""); diff --git a/compiler/cpp/src/generate/t_generator.cc b/compiler/cpp/src/generate/t_generator.cc new file mode 100644 index 00000000..38c053c5 --- /dev/null +++ b/compiler/cpp/src/generate/t_generator.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "t_generator.h" +using namespace std; + +/** + * Top level program generation function. Calls the generator subclass methods + * for preparing file streams etc. then iterates over all the parts of the + * program to perform the correct actions. + * + * @param program The thrift program to compile into C++ source + */ +void t_generator::generate_program() { + // Initialize the generator + init_generator(); + + // Generate enums + vector enums = program_->get_enums(); + vector::iterator en_iter; + for (en_iter = enums.begin(); en_iter != enums.end(); ++en_iter) { + generate_enum(*en_iter); + } + + // Generate typedefs + vector typedefs = program_->get_typedefs(); + vector::iterator td_iter; + for (td_iter = typedefs.begin(); td_iter != typedefs.end(); ++td_iter) { + generate_typedef(*td_iter); + } + + // Generate constants + vector consts = program_->get_consts(); + generate_consts(consts); + + // Generate structs and exceptions in declared order + vector objects = program_->get_objects(); + vector::iterator o_iter; + for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) { + if ((*o_iter)->is_xception()) { + generate_xception(*o_iter); + } else { + generate_struct(*o_iter); + } + } + + // Generate services + vector services = program_->get_services(); + vector::iterator sv_iter; + for (sv_iter = services.begin(); sv_iter != services.end(); ++sv_iter) { + service_name_ = get_service_name(*sv_iter); + generate_service(*sv_iter); + } + + // Close the generator + close_generator(); +} + +string t_generator::escape_string(const string &in) const { + string result = ""; + for (string::const_iterator it = in.begin(); it < in.end(); it++) { + std::map::const_iterator res = escape_.find(*it); + if (res != escape_.end()) { + result.append(res->second); + } else { + result.push_back(*it); + } + } + return result; +} + +void t_generator::generate_consts(vector consts) { + vector::iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + generate_const(*c_iter); + } +} + +void t_generator::generate_docstring_comment(ofstream& out, + const string& comment_start, + const string& line_prefix, + const string& contents, + const string& comment_end) { + if (comment_start != "") indent(out) << comment_start; + stringstream docs(contents, ios_base::in); + while (!docs.eof()) { + char line[1024]; + docs.getline(line, 1024); + if (strlen(line) > 0 || !docs.eof()) { // skip the empty last line + indent(out) << line_prefix << line << std::endl; + } + } + if (comment_end != "") indent(out) << comment_end; +} + + +void t_generator_registry::register_generator(t_generator_factory* factory) { + gen_map_t& the_map = get_generator_map(); + if (the_map.find(factory->get_short_name()) != the_map.end()) { + failure("Duplicate generators for language \"%s\"!\n", factory->get_short_name().c_str()); + } + the_map[factory->get_short_name()] = factory; +} + +t_generator* t_generator_registry::get_generator(t_program* program, + const string& options) { + string::size_type colon = options.find(':'); + string language = options.substr(0, colon); + + map parsed_options; + if (colon != string::npos) { + string::size_type pos = colon+1; + while (pos != string::npos && pos < options.size()) { + string::size_type next_pos = options.find(',', pos); + string option = options.substr(pos, next_pos-pos); + pos = ((next_pos == string::npos) ? next_pos : next_pos+1); + + string::size_type separator = option.find('='); + string key, value; + if (separator == string::npos) { + key = option; + value = ""; + } else { + key = option.substr(0, separator); + value = option.substr(separator+1); + } + + parsed_options[key] = value; + } + } + + gen_map_t& the_map = get_generator_map(); + gen_map_t::iterator iter = the_map.find(language); + + if (iter == the_map.end()) { + return NULL; + } + + return iter->second->get_generator(program, parsed_options, options); +} + +t_generator_registry::gen_map_t& t_generator_registry::get_generator_map() { + // http://www.parashift.com/c++-faq-lite/ctors.html#faq-10.12 + static gen_map_t* the_map = new gen_map_t(); + return *the_map; +} + +t_generator_factory::t_generator_factory( + const std::string& short_name, + const std::string& long_name, + const std::string& documentation) + : short_name_(short_name) + , long_name_(long_name) + , documentation_(documentation) +{ + t_generator_registry::register_generator(this); +} diff --git a/compiler/cpp/src/generate/t_generator.h b/compiler/cpp/src/generate/t_generator.h new file mode 100644 index 00000000..7514fb16 --- /dev/null +++ b/compiler/cpp/src/generate/t_generator.h @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_GENERATOR_H +#define T_GENERATOR_H + +#include +#include +#include +#include +#include "parse/t_program.h" +#include "globals.h" + +/** + * Base class for a thrift code generator. This class defines the basic + * routines for code generation and contains the top level method that + * dispatches code generation across various components. + * + */ +class t_generator { + public: + t_generator(t_program* program) { + tmp_ = 0; + indent_ = 0; + program_ = program; + program_name_ = get_program_name(program); + escape_['\n'] = "\\n"; + escape_['\r'] = "\\r"; + escape_['\t'] = "\\t"; + escape_['"'] = "\\\""; + escape_['\\'] = "\\\\"; + } + + virtual ~t_generator() {} + + /** + * Framework generator method that iterates over all the parts of a program + * and performs general actions. This is implemented by the base class and + * should not normally be overwritten in the subclasses. + */ + virtual void generate_program(); + + const t_program* get_program() const { return program_; } + + void generate_docstring_comment(std::ofstream& out, + const std::string& comment_start, + const std::string& line_prefix, + const std::string& contents, + const std::string& comment_end); + + /** + * Escape string to use one in generated sources. + */ + virtual std::string escape_string(const std::string &in) const; + + std::string get_escaped_string(t_const_value* constval) { + return escape_string(constval->get_string()); + } + + protected: + + /** + * Optional methods that may be imlemented by subclasses to take necessary + * steps at the beginning or end of code generation. + */ + + virtual void init_generator() {} + virtual void close_generator() {} + + virtual void generate_consts(std::vector consts); + + /** + * Pure virtual methods implemented by the generator subclasses. + */ + + virtual void generate_typedef (t_typedef* ttypedef) = 0; + virtual void generate_enum (t_enum* tenum) = 0; + virtual void generate_const (t_const* tconst) {} + virtual void generate_struct (t_struct* tstruct) = 0; + virtual void generate_service (t_service* tservice) = 0; + virtual void generate_xception (t_struct* txception) { + // By default exceptions are the same as structs + generate_struct(txception); + } + + /** + * Method to get the program name, may be overridden + */ + virtual std::string get_program_name(t_program* tprogram) { + return tprogram->get_name(); + } + + /** + * Method to get the service name, may be overridden + */ + virtual std::string get_service_name(t_service* tservice) { + return tservice->get_name(); + } + + /** + * Get the current output directory + */ + virtual std::string get_out_dir() const { + return program_->get_out_path() + out_dir_base_ + "/"; + } + + /** + * Creates a unique temporary variable name, which is just "name" with a + * number appended to it (i.e. name35) + */ + std::string tmp(std::string name) { + std::ostringstream out; + out << name << tmp_++; + return out.str(); + } + + /** + * Indentation level modifiers + */ + + void indent_up(){ + ++indent_; + } + + void indent_down() { + --indent_; + } + + /** + * Indentation print function + */ + std::string indent() { + std::string ind = ""; + int i; + for (i = 0; i < indent_; ++i) { + ind += " "; + } + return ind; + } + + /** + * Indentation utility wrapper + */ + std::ostream& indent(std::ostream &os) { + return os << indent(); + } + + /** + * Capitalization helpers + */ + std::string capitalize(std::string in) { + in[0] = toupper(in[0]); + return in; + } + std::string decapitalize(std::string in) { + in[0] = tolower(in[0]); + return in; + } + std::string lowercase(std::string in) { + for (size_t i = 0; i < in.size(); ++i) { + in[i] = tolower(in[i]); + } + return in; + } + std::string underscore(std::string in) { + in[0] = tolower(in[0]); + for (size_t i = 1; i < in.size(); ++i) { + if (isupper(in[i])) { + in[i] = tolower(in[i]); + in.insert(i, "_"); + } + } + return in; + } + + /** + * Get the true type behind a series of typedefs. + */ + static t_type* get_true_type(t_type* type) { + while (type->is_typedef()) { + type = ((t_typedef*)type)->get_type(); + } + return type; + } + + protected: + /** + * The program being generated + */ + t_program* program_; + + /** + * Quick accessor for formatted program name that is currently being + * generated. + */ + std::string program_name_; + + /** + * Quick accessor for formatted service name that is currently being + * generated. + */ + std::string service_name_; + + /** + * Output type-specifc directory name ("gen-*") + */ + std::string out_dir_base_; + + /** + * Map of characters to escape in string literals. + */ + std::map escape_; + + private: + /** + * Current code indentation level + */ + int indent_; + + /** + * Temporary variable counter, for making unique variable names + */ + int tmp_; +}; + + +/** + * A factory for producing generator classes of a particular language. + * + * This class is also responsible for: + * - Registering itself with the generator registry. + * - Providing documentation for the generators it produces. + */ +class t_generator_factory { + public: + t_generator_factory(const std::string& short_name, + const std::string& long_name, + const std::string& documentation); + + virtual ~t_generator_factory() {} + + virtual t_generator* get_generator( + // The program to generate. + t_program* program, + // Note: parsed_options will not exist beyond the call to get_generator. + const std::map& parsed_options, + // Note: option_string might not exist beyond the call to get_generator. + const std::string& option_string) + = 0; + + std::string get_short_name() { return short_name_; } + std::string get_long_name() { return long_name_; } + std::string get_documentation() { return documentation_; } + + private: + std::string short_name_; + std::string long_name_; + std::string documentation_; +}; + +template +class t_generator_factory_impl : public t_generator_factory { + public: + t_generator_factory_impl(const std::string& short_name, + const std::string& long_name, + const std::string& documentation) + : t_generator_factory(short_name, long_name, documentation) + {} + + virtual t_generator* get_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) { + return new generator(program, parsed_options, option_string); + } +}; + +class t_generator_registry { + public: + static void register_generator(t_generator_factory* factory); + + static t_generator* get_generator(t_program* program, + const std::string& options); + + typedef std::map gen_map_t; + static gen_map_t& get_generator_map(); + + private: + t_generator_registry(); + t_generator_registry(const t_generator_registry&); +}; + +#define THRIFT_REGISTER_GENERATOR(language, long_name, doc) \ + class t_##language##_generator_factory_impl \ + : public t_generator_factory_impl \ + { \ + public: \ + t_##language##_generator_factory_impl() \ + : t_generator_factory_impl( \ + #language, long_name, doc) \ + {} \ + }; \ + static t_##language##_generator_factory_impl _registerer; + +#endif diff --git a/compiler/cpp/src/generate/t_hs_generator.cc b/compiler/cpp/src/generate/t_hs_generator.cc new file mode 100644 index 00000000..c8fda774 --- /dev/null +++ b/compiler/cpp/src/generate/t_hs_generator.cc @@ -0,0 +1,1445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * Haskell code generator. + * + */ +class t_hs_generator : public t_oop_generator { + public: + t_hs_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-hs"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_hs_struct(t_struct* tstruct, bool is_exception); + void generate_hs_struct_definition(std::ofstream &out,t_struct* tstruct, bool is_xception=false,bool helper=false); + void generate_hs_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_hs_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_hs_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset); + + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + void generate_deserialize_type (std::ofstream &out, + t_type* type); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string name= ""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + /** + * Helper rendering functions + */ + + std::string hs_autogen_comment(); + std::string hs_imports(); + std::string type_name(t_type* ttype); + std::string function_type(t_function* tfunc, bool options = false, bool io = false, bool method = false); + std::string type_to_enum(t_type* ttype); + std::string render_hs_type(t_type* type, bool needs_parens = true); + + + private: + + /** + * File streams + */ + + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_service_; + std::ofstream f_iface_; + std::ofstream f_client_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_hs_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + + string pname = capitalize(program_name_); + string f_types_name = get_out_dir()+pname+"_Types.hs"; + f_types_.open(f_types_name.c_str()); + + string f_consts_name = get_out_dir()+pname+"_Consts.hs"; + f_consts_.open(f_consts_name.c_str()); + + // Print header + f_types_ << + hs_autogen_comment() << endl << + "module " << pname <<"_Types where" << endl << + hs_imports() << endl; + + f_consts_ << + hs_autogen_comment() << endl << + "module " << pname <<"_Consts where" << endl << + hs_imports() << endl << + "import " << pname<<"_Types"<< endl; + +} + + +/** + * Autogen'd comment + */ +string t_hs_generator::hs_autogen_comment() { + return + std::string("-----------------------------------------------------------------\n") + + "-- Autogenerated by Thrift --\n" + + "-- --\n" + + "-- DO NOT EDIT UNLESS YOU ARE SURE YOU KNOW WHAT YOU ARE DOING --\n" + + "-----------------------------------------------------------------\n"; +} + +/** + * Prints standard thrift imports + */ +string t_hs_generator::hs_imports() { + return "import Thrift\nimport Data.Typeable ( Typeable )\nimport Control.Exception\nimport qualified Data.Map as Map\nimport qualified Data.Set as Set\nimport Data.Int"; +} + +/** + * Closes the type files + */ +void t_hs_generator::close_generator() { + // Close types file + f_types_.close(); + f_consts_.close(); +} + +/** + * Generates a typedef. Ez. + * + * @param ttypedef The type definition + */ +void t_hs_generator::generate_typedef(t_typedef* ttypedef) { + f_types_ << + indent() << "type "<< capitalize(ttypedef->get_symbolic()) << " = " << render_hs_type(ttypedef->get_type(), false) << endl << endl; +} + +/** + * Generates code for an enumerated type. + * the values. + * + * @param tenum The enumeration + */ +void t_hs_generator::generate_enum(t_enum* tenum) { + indent(f_types_) << "data "<get_name())<<" = "; + indent_up(); + vector constants = tenum->get_constants(); + vector::iterator c_iter; + bool first = true; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + string name = capitalize((*c_iter)->get_name()); + if(first) + first=false; + else + f_types_ << "|"; + f_types_ << name; + } + indent(f_types_) << "deriving (Show,Eq, Typeable, Ord)" << endl; + indent_down(); + + int value = -1; + indent(f_types_) << "instance Enum " << capitalize(tenum->get_name()) << " where" << endl; + indent_up(); + indent(f_types_) << "fromEnum t = case t of" << endl; + indent_up(); + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + string name = capitalize((*c_iter)->get_name()); + + f_types_ << + indent() << name << " -> " << value << endl; + } + indent_down(); + + indent(f_types_) << "toEnum t = case t of" << endl; + indent_up(); + for(c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + string name = capitalize((*c_iter)->get_name()); + + f_types_ << + indent() << value << " -> " << name << endl; + } + indent(f_types_) << "_ -> throw ThriftException" << endl; + indent_down(); + indent_down(); +} + +/** + * Generate a constant value + */ +void t_hs_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = decapitalize(tconst->get_name()); + t_const_value* value = tconst->get_value(); + + indent(f_consts_) << name << " :: " << render_hs_type(type, false) << endl; + indent(f_consts_) << name << " = " << render_const_value(type, value) << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_hs_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "True" : "False"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + t_enum* tenum = (t_enum*)type; + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int val = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + val = (*c_iter)->get_value(); + } else { + ++val; + } + if(val == value->get_integer()){ + indent(out) << capitalize((*c_iter)->get_name()); + break; + } + } + } else if (type->is_struct() || type->is_xception()) { + string cname = type_name(type); + indent(out) << cname << "{"; + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + bool first = true; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + string fname = v_iter->first->get_string(); + if(first) + first=false; + else + out << ","; + out << "f_" << cname << "_" << fname << " = Just (" << render_const_value(field_type, v_iter->second) << ")"; + + } + indent(out) << "}"; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + out << "(Map.fromList ["; + bool first=true; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(ktype, v_iter->first); + string val = render_const_value(vtype, v_iter->second); + if(first) + first=false; + else + out << ","; + out << "(" << key << ","<< val << ")"; + } + out << "])"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + + if (type->is_list()) { + etype = ((t_list*) type)->get_elem_type(); + } else { + etype = ((t_set*) type)->get_elem_type(); + } + + const vector& val = value->get_list(); + vector::const_iterator v_iter; + bool first = true; + + if (type->is_set()) + out << "(Set.fromList "; + + out << "["; + + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if(first) + first=false; + else + out << ","; + out << render_const_value(etype, *v_iter); + } + + out << "]"; + if (type->is_set()) + out << ")"; + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + return out.str(); +} + +/** + * Generates a "struct" + */ +void t_hs_generator::generate_struct(t_struct* tstruct) { + generate_hs_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct, but also has an exception declaration. + * + * @param txception The struct definition + */ +void t_hs_generator::generate_xception(t_struct* txception) { + generate_hs_struct(txception, true); +} + +/** + * Generates a Haskell struct + */ +void t_hs_generator::generate_hs_struct(t_struct* tstruct, + bool is_exception) { + generate_hs_struct_definition(f_types_,tstruct, is_exception,false); +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_hs_generator::generate_hs_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception, + bool helper) { + string tname = type_name(tstruct); + string name = tstruct->get_name(); + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + indent(out) << "data "< 0) { + out << "{"; + bool first=true; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if(first) + first=false; + else + out << ","; + string mname = (*m_iter)->get_name(); + out << "f_" << tname << "_" << mname << " :: Maybe " << render_hs_type((*m_iter)->get_type()); + } + out << "}"; + } + + out << " deriving (Show,Eq,Ord,Typeable)" << endl; + if (is_exception) out << "instance Exception " << tname << endl; + generate_hs_struct_writer(out, tstruct); + + generate_hs_struct_reader(out, tstruct); + //f_struct_.close(); +} + + + +/** + * Generates the read method for a struct + */ +void t_hs_generator::generate_hs_struct_reader(ofstream& out, t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + string sname = type_name(tstruct); + string str = tmp("_str"); + string t = tmp("_t"); + string id = tmp("_id"); + + indent(out) << "read_" << sname << "_fields iprot rec = do" << endl; + indent_up(); // do + + // Read beginning field marker + indent(out) << "(_," << t <<","<get_key() << " -> "; + out << "if " << t <<" == " << type_to_enum((*f_iter)->get_type()) << " then do" << endl; + indent_up(); // if + indent(out) << "s <- "; + generate_deserialize_field(out, *f_iter,str); + out << endl; + indent(out) << "read_"<get_name()) <<"=Just s}" << endl; + out << + indent() << "else do" << endl; + indent_up(); + indent(out) << "skip iprot "<< t << endl; + indent(out) << "read_"< do" << endl; + indent_up(); + indent(out) << "skip iprot "<get_name()) << "=Nothing"; + } + out << "})" << endl; + indent(out) << "readStructEnd iprot" << endl; + indent(out) << "return rec" << endl; + indent_down(); +} + +void t_hs_generator::generate_hs_struct_writer(ofstream& out, + t_struct* tstruct) { + string name = type_name(tstruct); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + string str = tmp("_str"); + string f = tmp("_f"); + + indent(out) << + "write_"<get_name(); + indent(out) << + "case f_" << name << "_" << mname << " rec of {Nothing -> return (); Just _v -> do" << endl; + indent_up(); + indent(out) << "writeFieldBegin oprot (\""<< (*f_iter)->get_name()<<"\"," + <get_type())<<"," + <<(*f_iter)->get_key()<<")" << endl; + + // Write field contents + out << indent(); + generate_serialize_field(out, *f_iter, "_v"); + out << endl; + // Write field closer + indent(out) << "writeFieldEnd oprot}" << endl; + indent_down(); + } + + // Write the struct map + out << + indent() << "writeFieldStop oprot" << endl << + indent() << "writeStructEnd oprot" << endl; + + indent_down(); +} + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_hs_generator::generate_service(t_service* tservice) { + string f_service_name = get_out_dir()+capitalize(service_name_)+".hs"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + hs_autogen_comment() << endl << + "module " << capitalize(service_name_) << " where" << endl << + hs_imports() << endl; + + + if(tservice->get_extends()){ + f_service_ << + "import qualified " << capitalize(tservice->get_extends()->get_name()) << endl; + } + + + f_service_ << + "import " << capitalize(program_name_) << "_Types" << endl << + "import qualified " << capitalize(service_name_) << "_Iface as Iface" << endl; + + + // Generate the three main parts of the service + generate_service_helpers(tservice); + generate_service_interface(tservice); + generate_service_client(tservice); + generate_service_server(tservice); + + + // Close service file + f_service_.close(); +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_hs_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + indent(f_service_) << + "-- HELPER FUNCTIONS AND STRUCTURES --" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_hs_struct_definition(f_service_,ts, false); + generate_hs_function_helpers(*f_iter); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_hs_generator::generate_hs_function_helpers(t_function* tfunction) { + t_struct result(program_, decapitalize(tfunction->get_name()) + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + generate_hs_struct_definition(f_service_,&result, false); +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_hs_generator::generate_service_interface(t_service* tservice) { + string f_iface_name = get_out_dir()+capitalize(service_name_)+"_Iface.hs"; + f_iface_.open(f_iface_name.c_str()); + indent(f_iface_) << "module " << capitalize(service_name_) << "_Iface where" << endl; + + indent(f_iface_) << + hs_imports() << endl << + "import " << capitalize(program_name_) << "_Types" << endl << + endl; + + if (tservice->get_extends() != NULL) { + string extends = type_name(tservice->get_extends()); + indent(f_iface_) << "import " << extends <<"_Iface" << endl; + indent(f_iface_) << "class "<< extends << "_Iface a => " << capitalize(service_name_) << "_Iface a where" << endl; + } else { + f_iface_ << indent() << "class " << capitalize(service_name_) << "_Iface a where" << endl; + } + indent_up(); + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string ft = function_type(*f_iter,true,true,true); + f_iface_ << + indent() << decapitalize((*f_iter)->get_name()) << " :: a -> " << ft << endl; + } + indent_down(); + f_iface_.close(); + +} + +/** + * Generates a service client definition. Note that in Haskell, the client doesn't implement iface. This is because + * The client does not (and should not have to) deal with arguments being Nothing. + * + * @param tservice The service to generate a server for. + */ +void t_hs_generator::generate_service_client(t_service* tservice) { + string f_client_name = get_out_dir()+capitalize(service_name_)+"_Client.hs"; + f_client_.open(f_client_name.c_str()); + + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + + string extends = ""; + string exports=""; + bool first = true; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + if(first) + first=false; + else + exports+=","; + string funname = (*f_iter)->get_name(); + exports+=funname; + } + indent(f_client_) << "module " << capitalize(service_name_) << "_Client("<get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + indent(f_client_) << "import " << extends << "_Client" << endl; + } + indent(f_client_) << "import Data.IORef" << endl; + indent(f_client_) << hs_imports() << endl; + indent(f_client_) << "import " << capitalize(program_name_) << "_Types" << endl; + indent(f_client_) << "import " << capitalize(service_name_) << endl; + // DATS RITE A GLOBAL VAR + indent(f_client_) << "seqid = newIORef 0" << endl; + + + // Generate client method implementations + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + string fargs = ""; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + fargs+= " arg_" + decapitalize((*fld_iter)->get_name()); + } + + // Open function + indent(f_client_) << funname << " (ip,op)" << fargs << " = do" << endl; + indent_up(); + indent(f_client_) << "send_" << funname << " op" << fargs; + + f_client_ << endl; + + if (!(*f_iter)->is_oneway()) { + f_client_ << indent(); + f_client_ << + "recv_" << funname << " ip" << endl; + } + indent_down(); + + indent(f_client_) << + "send_" << funname << " op" << fargs << " = do" << endl; + indent_up(); + indent(f_client_) << "seq <- seqid" << endl; + indent(f_client_) << "seqn <- readIORef seq" << endl; + std::string argsname = capitalize((*f_iter)->get_name() + "_args"); + + // Serialize the request header + f_client_ << + indent() << "writeMessageBegin op (\"" << (*f_iter)->get_name() << "\", M_CALL, seqn)" << endl; + f_client_ << indent() << "write_" << argsname << " op ("<get_name() << "=Just arg_" << (*fld_iter)->get_name(); + } + f_client_ << "})" << endl; + + // Write to the stream + f_client_ << + indent() << "writeMessageEnd op" << endl << + indent() << "tFlush (getTransport op)" << endl; + + indent_down(); + + if (!(*f_iter)->is_oneway()) { + std::string resultname = capitalize((*f_iter)->get_name() + "_result"); + t_struct noargs(program_); + + std::string funname = string("recv_") + (*f_iter)->get_name(); + + t_function recv_function((*f_iter)->get_returntype(), + funname, + &noargs); + // Open function + f_client_ << + indent() << funname << " ip = do" << endl; + indent_up(); // fun + + // TODO(mcslee): Validate message reply here, seq ids etc. + + f_client_ << + indent() << "(fname, mtype, rseqid) <- readMessageBegin ip" << endl; + f_client_ << + indent() << "if mtype == M_EXCEPTION then do" << endl << + indent() << " x <- readAppExn ip" << endl << + indent() << " readMessageEnd ip" << endl; + f_client_ << + indent() << " throw x" << endl; + f_client_ << + indent() << " else return ()" << endl; + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + + f_client_ << + indent() << "res <- read_" << resultname << " ip" << endl; + f_client_ << + indent() << "readMessageEnd ip" << endl; + + // Careful, only return _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + f_client_ << + indent() << "case f_" << resultname << "_success res of" << endl; + indent_up(); // case + indent(f_client_) << "Just v -> return v" << endl; + indent(f_client_) << "Nothing -> do" << endl; + indent_up(); // none + } + + + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_client_ << + indent() << "case f_"<< resultname << "_" << (*x_iter)->get_name() << " res of" << endl; + indent_up(); //case + indent(f_client_) << "Nothing -> return ()" << endl; + indent(f_client_) << "Just _v -> throw _v" << endl; + indent_down(); //-case + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_client_) << + "return ()" << endl; + } else { + f_client_ << + indent() << "throw (AppExn AE_MISSING_RESULT \"" << (*f_iter)->get_name() << " failed: unknown result\")" << endl; + indent_down(); //-none + indent_down(); //-case + } + + // Close function + indent_down(); //-fun + } + } + f_client_.close(); + + +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_hs_generator::generate_service_server(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + + indent(f_service_) << "proc handler (iprot,oprot) (name,typ,seqid) = case name of" << endl; + indent_up(); + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string fname = (*f_iter)->get_name(); + indent(f_service_) << "\""< process_" << decapitalize(fname) << " (seqid,iprot,oprot,handler)" << endl; + } + indent(f_service_) << "_ -> "; + if(tservice->get_extends() != NULL){ + f_service_ << type_name(tservice->get_extends()) << ".proc handler (iprot,oprot) (name,typ,seqid)" << endl; + } else { + f_service_ << "do" << endl; + indent_up(); + indent(f_service_) << "skip iprot T_STRUCT" << endl; + indent(f_service_) << "readMessageEnd iprot" << endl; + indent(f_service_) << "writeMessageBegin oprot (name,M_EXCEPTION,seqid)" << endl; + indent(f_service_) << "writeAppExn oprot (AppExn AE_UNKNOWN_METHOD (\"Unknown function \" ++ name))" << endl; + indent(f_service_) << "writeMessageEnd oprot" << endl; + indent(f_service_) << "tFlush (getTransport oprot)" << endl; + indent_down(); + } + indent_down(); + + // Generate the server implementation + indent(f_service_) << + "process handler (iprot, oprot) = do" << endl; + indent_up(); + + f_service_ << + indent() << "(name, typ, seqid) <- readMessageBegin iprot" << endl; + f_service_ << indent() << "proc handler (iprot,oprot) (name,typ,seqid)" << endl; + indent(f_service_) << "return True" << endl; + indent_down(); + +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_hs_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + indent(f_service_) << + "process_" << tfunction->get_name() << " (seqid, iprot, oprot, handler) = do" << endl; + indent_up(); + + string argsname = capitalize(tfunction->get_name()) + "_args"; + string resultname = capitalize(tfunction->get_name()) + "_result"; + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + + f_service_ << + indent() << "args <- read_" << argsname << " iprot" << endl; + f_service_ << + indent() << "readMessageEnd iprot" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + int n = xceptions.size(); + if (!tfunction->is_oneway()){ + if(!tfunction->get_returntype()->is_void()){ + n++; + } + indent(f_service_) << "rs <- return (" << resultname; + + for(int i=0; i 0) { + for(unsigned int i=0;iis_oneway() && !tfunction->get_returntype()->is_void()){ + f_service_ << "res <- "; + } + f_service_ << "Iface." << tfunction->get_name() << " handler"; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + f_service_ << " (f_" << argsname << "_" << (*f_iter)->get_name() << " args)"; + } + + + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()){ + f_service_ << endl; + indent(f_service_) << "return rs{f_"<is_oneway()){ + f_service_ << endl; + indent(f_service_) << "return rs"; + } + f_service_ << ")" << endl; + indent_down(); + + if (xceptions.size() > 0 && !tfunction->is_oneway()) { + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + indent(f_service_) << "(\\e -> " <is_oneway()){ + f_service_ << + indent() << "return rs{f_"<get_name() << " =Just e}"; + } else { + indent(f_service_) << "return ()"; + } + f_service_ << "))" << endl; + indent_down(); + indent_down(); + } + } + + + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return ()" << endl; + indent_down(); + return; + } + + f_service_ << + indent() << "writeMessageBegin oprot (\"" << tfunction->get_name() << "\", M_REPLY, seqid);" << endl << + indent() << "write_"<get_type(); + generate_deserialize_type(out,type); +} + + +/** + * Deserializes a field of any type. + */ +void t_hs_generator::generate_deserialize_type(ofstream &out, + t_type* type){ + type = get_true_type(type); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE"; + } + + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type); + } else if (type->is_container()) { + generate_deserialize_container(out, type); + } else if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct"; + break; + case t_base_type::TYPE_STRING: + out << "readString"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte"; + break; + case t_base_type::TYPE_I16: + out << "readI16"; + break; + case t_base_type::TYPE_I32: + out << "readI32"; + break; + case t_base_type::TYPE_I64: + out << "readI64"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + out << " iprot"; + } else if (type->is_enum()) { + string ename = capitalize(type->get_name()); + out << "(do {i <- readI32 iprot; return (toEnum i :: " << ename << ")})"; + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE TYPE '%s'\n", + type->get_name().c_str()); + } +} + + +/** + * Generates an unserializer for a struct, calling read() + */ +void t_hs_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct) { + string name = capitalize(tstruct->get_name()); + out << "(read_" << name << " iprot)"; + +} + +/** + * Serialize a container by writing out the header followed by + * data and then a footer. + */ +void t_hs_generator::generate_deserialize_container(ofstream &out, + t_type* ttype) { + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + string con = tmp("_con"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + // Declare variables, read header + if (ttype->is_map()) { + out << "(let {f 0 = return []; f n = do {k <- "; + generate_deserialize_type(out,((t_map*)ttype)->get_key_type()); + out << "; v <- "; + generate_deserialize_type(out,((t_map*)ttype)->get_val_type()); + out << ";r <- f (n-1); return $ (k,v):r}} in do {("<is_set()) { + out << "(let {f 0 = return []; f n = do {v <- "; + generate_deserialize_type(out,((t_map*)ttype)->get_key_type()); + out << ";r <- f (n-1); return $ v:r}} in do {("<is_list()) { + out << "(let {f 0 = return []; f n = do {v <- "; + generate_deserialize_type(out,((t_map*)ttype)->get_key_type()); + out << ";r <- f (n-1); return $ v:r}} in do {("<get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + tfield->get_name(); + } + + if(name.length() == 0){ + name = decapitalize(tfield->get_name()); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + name); + } else if (type->is_base_type() || type->is_enum()) { + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString oprot " << name; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool oprot " << name; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte oprot " << name; + break; + case t_base_type::TYPE_I16: + out << "writeI16 oprot " << name; + break; + case t_base_type::TYPE_I32: + out << "writeI32 oprot " << name; + break; + case t_base_type::TYPE_I64: + out << "writeI64 oprot " << name; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble oprot " << name; + break; + default: + throw "compiler error: no hs name for base type " + t_base_type::t_base_name(tbase); + } + + } else if (type->is_enum()) { + string ename = capitalize(type->get_name()); + out << "writeI32 oprot (fromEnum "<< name << ")"; + } + + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), + type->get_name().c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_hs_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + out << "write_" << type_name(tstruct) << " oprot " << prefix; +} + +void t_hs_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + if (ttype->is_map()) { + string k = tmp("_kiter"); + string v = tmp("_viter"); + out << "(let {f [] = return (); f (("<get_key_type())<<","<< type_to_enum(((t_map*)ttype)->get_val_type())<<",Map.size " << prefix << "); f (Map.toList " << prefix << ");writeMapEnd oprot})"; + } else if (ttype->is_set()) { + string v = tmp("_viter"); + out << "(let {f [] = return (); f ("<get_elem_type())<<",Set.size " << prefix << "); f (Set.toList " << prefix << ");writeSetEnd oprot})"; + } else if (ttype->is_list()) { + string v = tmp("_viter"); + out << "(let {f [] = return (); f ("<get_elem_type())<<",length " << prefix << "); f " << prefix << ";writeListEnd oprot})"; + } + +} + +/** + * Serializes the members of a map. + * + */ +void t_hs_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + out << "do {"; + generate_serialize_field(out, &kfield); + out << ";"; + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield); + out << "}"; +} + +/** + * Serializes the members of a set. + */ +void t_hs_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + +/** + * Serializes the members of a list. + */ +void t_hs_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + + +string t_hs_generator::function_type(t_function* tfunc, bool options, bool io, bool method){ + string result=""; + + const vector& fields = tfunc->get_arglist()->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if(options) result += "Maybe "; + result += render_hs_type((*f_iter)->get_type(), options); + result += " -> "; + } + if(fields.empty() && !method){ + result += "() -> "; + } + if(io) result += "IO "; + result += render_hs_type(tfunc->get_returntype(), io); + return result; +} + + +string t_hs_generator::type_name(t_type* ttype) { + string prefix = ""; + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + if (!ttype->is_service()) { + prefix = capitalize(program->get_name()) + "_Types."; + } + } + + string name = ttype->get_name(); + if(ttype->is_service()){ + name = capitalize(name); + } else { + name = capitalize(name); + } + return prefix + name; +} + +/** + * Converts the parse type to a Protocol.t_type enum + */ +string t_hs_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + return "T_VOID"; + case t_base_type::TYPE_STRING: + return "T_STRING"; + case t_base_type::TYPE_BOOL: + return "T_BOOL"; + case t_base_type::TYPE_BYTE: + return "T_BYTE"; + case t_base_type::TYPE_I16: + return "T_I16"; + case t_base_type::TYPE_I32: + return "T_I32"; + case t_base_type::TYPE_I64: + return "T_I64"; + case t_base_type::TYPE_DOUBLE: + return "T_DOUBLE"; + } + } else if (type->is_enum()) { + return "T_I32"; + } else if (type->is_struct() || type->is_xception()) { + return "T_STRUCT"; + } else if (type->is_map()) { + return "T_MAP"; + } else if (type->is_set()) { + return "T_SET"; + } else if (type->is_list()) { + return "T_LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +/** + * Converts the parse type to an haskell type + */ +string t_hs_generator::render_hs_type(t_type* type, bool needs_parens) { + type = get_true_type(type); + string type_repr; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + return "()"; + case t_base_type::TYPE_STRING: + return "String"; + case t_base_type::TYPE_BOOL: + return "Bool"; + case t_base_type::TYPE_BYTE: + return "Int"; + case t_base_type::TYPE_I16: + return "Int"; + case t_base_type::TYPE_I32: + return "Int"; + case t_base_type::TYPE_I64: + return "Int64"; + case t_base_type::TYPE_DOUBLE: + return "Double"; + } + } else if (type->is_enum()) { + return capitalize(((t_enum*)type)->get_name()); + } else if (type->is_struct() || type->is_xception()) { + return type_name((t_struct*)type); + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + + type_repr = "Map.Map " + render_hs_type(ktype, true) + " " + render_hs_type(vtype, true); + } else if (type->is_set()) { + t_type* etype = ((t_set*)type)->get_elem_type(); + + type_repr = "Set.Set " + render_hs_type(etype, true) ; + } else if (type->is_list()) { + t_type* etype = ((t_list*)type)->get_elem_type(); + return "[" + render_hs_type(etype, false) + "]"; + } else { + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); + } + + return needs_parens ? "(" + type_repr + ")" : type_repr; +} + + +THRIFT_REGISTER_GENERATOR(hs, "Haskell", ""); diff --git a/compiler/cpp/src/generate/t_html_generator.cc b/compiler/cpp/src/generate/t_html_generator.cc new file mode 100644 index 00000000..ad1c4cbc --- /dev/null +++ b/compiler/cpp/src/generate/t_html_generator.cc @@ -0,0 +1,637 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "t_generator.h" +#include "platform.h" +using namespace std; + + +/** + * HTML code generator + * + * mostly copy/pasting/tweaking from mcslee's work. + */ +class t_html_generator : public t_generator { + public: + t_html_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_generator(program) + { + out_dir_base_ = "gen-html"; + escape_.clear(); + escape_['&'] = "&"; + escape_['<'] = "<"; + escape_['>'] = ">"; + escape_['"'] = """; + escape_['\''] = "'"; + } + + void generate_program(); + void generate_program_toc(); + void generate_program_toc_row(t_program* tprog); + void generate_program_toc_rows(t_program* tprog, + std::vector& finished); + void generate_index(); + void generate_css(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_service (t_service* tservice); + void generate_xception(t_struct* txception); + + void print_doc (t_doc* tdoc); + int print_type (t_type* ttype); + void print_const_value(t_const_value* tvalue); + + std::ofstream f_out_; +}; + +/** + * Emits the Table of Contents links at the top of the module's page + */ +void t_html_generator::generate_program_toc() { + f_out_ << "" + << "" << endl; + generate_program_toc_row(program_); + f_out_ << "
ModuleServicesData typesConstants
" << endl; +} + + +/** + * Recurses through from the provided program and generates a ToC row + * for each discovered program exactly once by maintaining the list of + * completed rows in 'finished' + */ +void t_html_generator::generate_program_toc_rows(t_program* tprog, + std::vector& finished) { + for (vector::iterator iter = finished.begin(); + iter != finished.end(); iter++) { + if (tprog->get_path() == (*iter)->get_path()) { + return; + } + } + finished.push_back(tprog); + generate_program_toc_row(tprog); + vector includes = tprog->get_includes(); + for (vector::iterator iter = includes.begin(); + iter != includes.end(); iter++) { + generate_program_toc_rows(*iter, finished); + } +} + +/** + * Emits the Table of Contents links at the top of the module's page + */ +void t_html_generator::generate_program_toc_row(t_program* tprog) { + string fname = tprog->get_name() + ".html"; + f_out_ << "" << endl << "" << tprog->get_name() << ""; + if (!tprog->get_services().empty()) { + vector services = tprog->get_services(); + vector::iterator sv_iter; + for (sv_iter = services.begin(); sv_iter != services.end(); ++sv_iter) { + string name = get_service_name(*sv_iter); + f_out_ << "" << name + << "
" << endl; + f_out_ << "
    " << endl; + map fn_html; + vector functions = (*sv_iter)->get_functions(); + vector::iterator fn_iter; + for (fn_iter = functions.begin(); fn_iter != functions.end(); ++fn_iter) { + string fn_name = (*fn_iter)->get_name(); + string html = "
  • " + fn_name + "
  • "; + fn_html.insert(pair(fn_name, html)); + } + for (map::iterator html_iter = fn_html.begin(); + html_iter != fn_html.end(); html_iter++) { + f_out_ << html_iter->second << endl; + } + f_out_ << "
" << endl; + } + } + f_out_ << "" << endl << ""; + map data_types; + if (!tprog->get_enums().empty()) { + vector enums = tprog->get_enums(); + vector::iterator en_iter; + for (en_iter = enums.begin(); en_iter != enums.end(); ++en_iter) { + string name = (*en_iter)->get_name(); + // f_out_ << "" << name + // << "
" << endl; + string html = "" + name + + ""; + data_types.insert(pair(name, html)); + } + } + if (!tprog->get_typedefs().empty()) { + vector typedefs = tprog->get_typedefs(); + vector::iterator td_iter; + for (td_iter = typedefs.begin(); td_iter != typedefs.end(); ++td_iter) { + string name = (*td_iter)->get_symbolic(); + // f_out_ << "" << name + // << "
" << endl; + string html = "" + name + + ""; + data_types.insert(pair(name, html)); + } + } + if (!tprog->get_objects().empty()) { + vector objects = tprog->get_objects(); + vector::iterator o_iter; + for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) { + string name = (*o_iter)->get_name(); + //f_out_ << "" << name + //<< "
" << endl; + string html = "" + name + + ""; + data_types.insert(pair(name, html)); + } + } + for (map::iterator dt_iter = data_types.begin(); + dt_iter != data_types.end(); dt_iter++) { + f_out_ << dt_iter->second << "
" << endl; + } + f_out_ << "" << endl << ""; + if (!tprog->get_consts().empty()) { + map const_html; + vector consts = tprog->get_consts(); + vector::iterator con_iter; + for (con_iter = consts.begin(); con_iter != consts.end(); ++con_iter) { + string name = (*con_iter)->get_name(); + string html ="" + name + ""; + const_html.insert(pair(name, html)); + } + for (map::iterator con_iter = const_html.begin(); + con_iter != const_html.end(); con_iter++) { + f_out_ << con_iter->second << "
" << endl; + } + } + f_out_ << "
" << endl << ""; +} + +/** + * Prepares for file generation by opening up the necessary file output + * stream. + */ +void t_html_generator::generate_program() { + // Make output directory + MKDIR(get_out_dir().c_str()); + string fname = get_out_dir() + program_->get_name() + ".html"; + f_out_.open(fname.c_str()); + f_out_ << "" << endl; + f_out_ << "" + << endl; + f_out_ << "Thrift module: " << program_->get_name() + << "" << endl << "

Thrift module: " + << program_->get_name() << "

" << endl; + + print_doc(program_); + + generate_program_toc(); + + if (!program_->get_consts().empty()) { + f_out_ << "

Constants

" << endl; + vector consts = program_->get_consts(); + f_out_ << ""; + f_out_ << "" << endl; + generate_consts(consts); + f_out_ << "
ConstantTypeValue
"; + } + + if (!program_->get_enums().empty()) { + f_out_ << "

Enumerations

" << endl; + // Generate enums + vector enums = program_->get_enums(); + vector::iterator en_iter; + for (en_iter = enums.begin(); en_iter != enums.end(); ++en_iter) { + generate_enum(*en_iter); + } + } + + if (!program_->get_typedefs().empty()) { + f_out_ << "

Type declarations

" << endl; + // Generate typedefs + vector typedefs = program_->get_typedefs(); + vector::iterator td_iter; + for (td_iter = typedefs.begin(); td_iter != typedefs.end(); ++td_iter) { + generate_typedef(*td_iter); + } + } + + if (!program_->get_objects().empty()) { + f_out_ << "

Data structures

" << endl; + // Generate structs and exceptions in declared order + vector objects = program_->get_objects(); + vector::iterator o_iter; + for (o_iter = objects.begin(); o_iter != objects.end(); ++o_iter) { + if ((*o_iter)->is_xception()) { + generate_xception(*o_iter); + } else { + generate_struct(*o_iter); + } + } + } + + if (!program_->get_services().empty()) { + f_out_ << "

Services

" << endl; + // Generate services + vector services = program_->get_services(); + vector::iterator sv_iter; + for (sv_iter = services.begin(); sv_iter != services.end(); ++sv_iter) { + service_name_ = get_service_name(*sv_iter); + generate_service(*sv_iter); + } + } + + f_out_ << "" << endl; + f_out_.close(); + + generate_index(); + generate_css(); +} + +/** + * Emits the index.html file for the recursive set of Thrift programs + */ +void t_html_generator::generate_index() { + string index_fname = get_out_dir() + "index.html"; + f_out_.open(index_fname.c_str()); + f_out_ << "" << endl; + f_out_ << "" + << endl; + f_out_ << "All Thrift declarations" + << endl << "

All Thrift declarations

" << endl; + f_out_ << "" + << "" << endl; + vector programs; + generate_program_toc_rows(program_, programs); + f_out_ << "
ModuleServicesData typesConstants
" << endl; + f_out_ << "" << endl; + f_out_.close(); +} + +void t_html_generator::generate_css() { + string css_fname = get_out_dir() + "style.css"; + f_out_.open(css_fname.c_str()); + f_out_ << "/* Auto-generated CSS for generated Thrift docs */" << endl; + f_out_ << + "body { font-family: Tahoma, sans-serif; }" << endl; + f_out_ << + "pre { background-color: #dddddd; padding: 6px; }" << endl; + f_out_ << + "h3,h4 { padding-top: 0px; margin-top: 0px; }" << endl; + f_out_ << + "div.definition { border: 1px solid gray; margin: 10px; padding: 10px; }" << endl; + f_out_ << + "div.extends { margin: -0.5em 0 1em 5em }" << endl; + f_out_ << + "table { border: 1px solid grey; border-collapse: collapse; }" << endl; + f_out_ << + "td { border: 1px solid grey; padding: 1px 6px; vertical-align: top; }" << endl; + f_out_ << + "th { border: 1px solid black; background-color: #bbbbbb;" << endl << + " text-align: left; padding: 1px 6px; }" << endl; + f_out_.close(); +} + +/** + * If the provided documentable object has documentation attached, this + * will emit it to the output stream in HTML format. + */ +void t_html_generator::print_doc(t_doc* tdoc) { + if (tdoc->has_doc()) { + string doc = tdoc->get_doc(); + size_t index; + while ((index = doc.find_first_of("\r\n")) != string::npos) { + if (index == 0) { + f_out_ << "

" << endl; + } else { + f_out_ << doc.substr(0, index) << endl; + } + if (index + 1 < doc.size() && doc.at(index) != doc.at(index + 1) && + (doc.at(index + 1) == '\r' || doc.at(index + 1) == '\n')) { + index++; + } + doc = doc.substr(index + 1); + } + f_out_ << doc << "
"; + } +} + +/** + * Prints out the provided type in HTML + */ +int t_html_generator::print_type(t_type* ttype) { + int len = 0; + f_out_ << ""; + if (ttype->is_container()) { + if (ttype->is_list()) { + f_out_ << "list<"; + len = 6 + print_type(((t_list*)ttype)->get_elem_type()); + f_out_ << ">"; + } else if (ttype->is_set()) { + f_out_ << "set<"; + len = 5 + print_type(((t_set*)ttype)->get_elem_type()); + f_out_ << ">"; + } else if (ttype->is_map()) { + f_out_ << "map<"; + len = 5 + print_type(((t_map*)ttype)->get_key_type()); + f_out_ << ", "; + len += print_type(((t_map*)ttype)->get_val_type()); + f_out_ << ">"; + } + } else if (ttype->is_base_type()) { + f_out_ << ttype->get_name(); + len = ttype->get_name().size(); + } else { + string prog_name = ttype->get_program()->get_name(); + string type_name = ttype->get_name(); + f_out_ << "is_typedef()) { + f_out_ << "Typedef_"; + } else if (ttype->is_struct() || ttype->is_xception()) { + f_out_ << "Struct_"; + } else if (ttype->is_enum()) { + f_out_ << "Enum_"; + } else if (ttype->is_service()) { + f_out_ << "Svc_"; + } + f_out_ << type_name << "\">"; + len = type_name.size(); + if (ttype->get_program() != program_) { + f_out_ << prog_name << "."; + len += prog_name.size() + 1; + } + f_out_ << type_name << ""; + } + f_out_ << ""; + return len; +} + +/** + * Prints out an HTML representation of the provided constant value + */ +void t_html_generator::print_const_value(t_const_value* tvalue) { + bool first = true; + switch (tvalue->get_type()) { + case t_const_value::CV_INTEGER: + f_out_ << tvalue->get_integer(); + break; + case t_const_value::CV_DOUBLE: + f_out_ << tvalue->get_double(); + break; + case t_const_value::CV_STRING: + f_out_ << '"' << get_escaped_string(tvalue) << '"'; + break; + case t_const_value::CV_MAP: + { + f_out_ << "{ "; + map map_elems = tvalue->get_map(); + map::iterator map_iter; + for (map_iter = map_elems.begin(); map_iter != map_elems.end(); + map_iter++) { + if (!first) { + f_out_ << ", "; + } + first = false; + print_const_value(map_iter->first); + f_out_ << " = "; + print_const_value(map_iter->second); + } + f_out_ << " }"; + } + break; + case t_const_value::CV_LIST: + { + f_out_ << "{ "; + vector list_elems = tvalue->get_list();; + vector::iterator list_iter; + for (list_iter = list_elems.begin(); list_iter != list_elems.end(); + list_iter++) { + if (!first) { + f_out_ << ", "; + } + first = false; + print_const_value(*list_iter); + } + f_out_ << " }"; + } + break; + default: + f_out_ << "UNKNOWN"; + break; + } +} + +/** + * Generates a typedef. + * + * @param ttypedef The type definition + */ +void t_html_generator::generate_typedef(t_typedef* ttypedef) { + string name = ttypedef->get_name(); + f_out_ << "

"; + f_out_ << "

Typedef: " << name + << "

" << endl; + f_out_ << "

Base type: "; + print_type(ttypedef->get_type()); + f_out_ << "

" << endl; + print_doc(ttypedef); + f_out_ << "
" << endl; +} + +/** + * Generates code for an enumerated type. + * + * @param tenum The enumeration + */ +void t_html_generator::generate_enum(t_enum* tenum) { + string name = tenum->get_name(); + f_out_ << "
"; + f_out_ << "

Enumeration: " << name + << "

" << endl; + print_doc(tenum); + vector values = tenum->get_constants(); + vector::iterator val_iter; + f_out_ << "
" << endl; + for (val_iter = values.begin(); val_iter != values.end(); ++val_iter) { + f_out_ << "" << endl; + } + f_out_ << "
"; + f_out_ << (*val_iter)->get_name(); + f_out_ << ""; + f_out_ << (*val_iter)->get_value(); + f_out_ << "
" << endl; +} + +/** + * Generates a constant value + */ +void t_html_generator::generate_const(t_const* tconst) { + string name = tconst->get_name(); + f_out_ << "" << name + << ""; + print_type(tconst->get_type()); + f_out_ << ""; + print_const_value(tconst->get_value()); + f_out_ << ""; + if (tconst->has_doc()) { + f_out_ << "
"; + print_doc(tconst); + f_out_ << "
"; + } +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_html_generator::generate_struct(t_struct* tstruct) { + string name = tstruct->get_name(); + f_out_ << "
"; + f_out_ << "

"; + if (tstruct->is_xception()) { + f_out_ << "Exception: "; + } else { + f_out_ << "Struct: "; + } + f_out_ << name << "

" << endl; + vector members = tstruct->get_members(); + vector::iterator mem_iter = members.begin(); + f_out_ << ""; + f_out_ << "" + << endl; + for ( ; mem_iter != members.end(); mem_iter++) { + f_out_ << "" << endl; + } + f_out_ << "
FieldTypeRequiredDefault value
" << (*mem_iter)->get_name() << ""; + print_type((*mem_iter)->get_type()); + f_out_ << ""; + if ((*mem_iter)->get_req() != t_field::T_OPTIONAL) { + f_out_ << "yes"; + } else { + f_out_ << "no"; + } + f_out_ << ""; + t_const_value* default_val = (*mem_iter)->get_value(); + if (default_val != NULL) { + print_const_value(default_val); + } + f_out_ << "

"; + print_doc(tstruct); + f_out_ << "
"; +} + +/** + * Exceptions are special structs + * + * @param tstruct The struct definition + */ +void t_html_generator::generate_xception(t_struct* txception) { + generate_struct(txception); +} + +/** + * Generates the HTML block for a Thrift service. + * + * @param tservice The service definition + */ +void t_html_generator::generate_service(t_service* tservice) { + f_out_ << "

Service: " + << service_name_ << "

" << endl; + + if (tservice->get_extends()) { + f_out_ << "
extends "; + print_type(tservice->get_extends()); + f_out_ << "
\n"; + } + print_doc(tservice); + vector functions = tservice->get_functions(); + vector::iterator fn_iter = functions.begin(); + for ( ; fn_iter != functions.end(); fn_iter++) { + string fn_name = (*fn_iter)->get_name(); + f_out_ << "
"; + f_out_ << "

Function: " << service_name_ << "." << fn_name + << "

" << endl; + f_out_ << "
";
+    int offset = print_type((*fn_iter)->get_returntype());
+    bool first = true;
+    f_out_ << " " << fn_name << "(";
+    offset += fn_name.size() + 2;
+    vector args = (*fn_iter)->get_arglist()->get_members();
+    vector::iterator arg_iter = args.begin();
+    if (arg_iter != args.end()) {
+      for ( ; arg_iter != args.end(); arg_iter++) {
+	if (!first) {
+	  f_out_ << "," << endl;
+	  for (int i = 0; i < offset; ++i) {
+	    f_out_ << " ";
+	  }
+	}
+	first = false;
+	print_type((*arg_iter)->get_type());
+	f_out_ << " " << (*arg_iter)->get_name();
+	if ((*arg_iter)->get_value() != NULL) {
+	  f_out_ << " = ";
+	  print_const_value((*arg_iter)->get_value());
+	}
+      }
+    }
+    f_out_ << ")" << endl;
+    first = true;
+    vector excepts = (*fn_iter)->get_xceptions()->get_members();
+    vector::iterator ex_iter = excepts.begin();
+    if (ex_iter != excepts.end()) {
+      f_out_ << "    throws ";
+      for ( ; ex_iter != excepts.end(); ex_iter++) {
+	if (!first) {
+	  f_out_ << ", ";
+	}
+	first = false;
+	print_type((*ex_iter)->get_type());
+      }
+      f_out_ << endl;
+    }
+    f_out_ << "
"; + print_doc(*fn_iter); + f_out_ << "
"; + } +} + +THRIFT_REGISTER_GENERATOR(html, "HTML", ""); diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc new file mode 100644 index 00000000..3ec816fd --- /dev/null +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -0,0 +1,3008 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "platform.h" +#include "t_oop_generator.h" +using namespace std; + + +/** + * Java code generator. + * + */ +class t_java_generator : public t_oop_generator { + public: + t_java_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + std::map::const_iterator iter; + + iter = parsed_options.find("beans"); + bean_style_ = (iter != parsed_options.end()); + + iter = parsed_options.find("nocamel"); + nocamel_style_ = (iter != parsed_options.end()); + + iter = parsed_options.find("hashcode"); + gen_hash_code_ = (iter != parsed_options.end()); + + out_dir_base_ = (bean_style_ ? "gen-javabean" : "gen-java"); + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + void generate_consts(std::vector consts); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_struct (t_struct* tstruct); + void generate_xception(t_struct* txception); + void generate_service (t_service* tservice); + + void print_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value, bool in_static, bool defval=false); + std::string render_const_value(std::ofstream& out, std::string name, t_type* type, t_const_value* value); + + /** + * Service-level generation functions + */ + + void generate_java_struct(t_struct* tstruct, bool is_exception); + + void generate_java_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false, bool in_class=false, bool is_result=false); + void generate_java_struct_equality(std::ofstream& out, t_struct* tstruct); + void generate_java_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_java_validator(std::ofstream& out, t_struct* tstruct); + void generate_java_struct_result_writer(std::ofstream& out, t_struct* tstruct); + void generate_java_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_java_struct_tostring(std::ofstream& out, t_struct* tstruct); + void generate_java_meta_data_map(std::ofstream& out, t_struct* tstruct); + void generate_field_value_meta_data(std::ofstream& out, t_type* type); + std::string get_java_type_string(t_type* type); + void generate_reflection_setters(std::ostringstream& out, t_type* type, std::string field_name, std::string cap_name); + void generate_reflection_getters(std::ostringstream& out, t_type* type, std::string field_name, std::string cap_name); + void generate_generic_field_getters_setters(std::ofstream& out, t_struct* tstruct); + void generate_generic_isset_method(std::ofstream& out, t_struct* tstruct); + void generate_java_bean_boilerplate(std::ofstream& out, t_struct* tstruct); + + void generate_function_helpers(t_function* tfunction); + std::string get_cap_name(std::string name); + std::string generate_isset_check(t_field* field); + std::string generate_isset_check(std::string field); + void generate_isset_set(ofstream& out, t_field* field); + + void generate_service_interface (t_service* tservice); + void generate_service_helpers (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream& out, + t_field* tfield, + std::string prefix=""); + + void generate_deserialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream& out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream& out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream& out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream& out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream& out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream& out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream& out, + t_map* tmap, + std::string iter, + std::string map); + + void generate_serialize_set_element (std::ofstream& out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream& out, + t_list* tlist, + std::string iter); + + void generate_java_doc (std::ofstream& out, + t_doc* tdoc); + + void generate_java_doc (std::ofstream& out, + t_function* tdoc); + + void generate_deep_copy_container(std::ofstream& out, std::string source_name_p1, std::string source_name_p2, std::string result_name, t_type* type); + void generate_deep_copy_non_container(std::ofstream& out, std::string source_name, std::string dest_name, t_type* type); + + /** + * Helper rendering functions + */ + + std::string java_package(); + std::string java_type_imports(); + std::string java_thrift_imports(); + std::string type_name(t_type* ttype, bool in_container=false, bool in_init=false); + std::string base_type_name(t_base_type* tbase, bool in_container=false); + std::string declare_field(t_field* tfield, bool init=false); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string get_enum_class_name(t_type* type); + + bool type_can_be_null(t_type* ttype) { + ttype = get_true_type(ttype); + + return + ttype->is_container() || + ttype->is_struct() || + ttype->is_xception() || + ttype->is_string(); + } + + std::string constant_name(std::string name); + + private: + + /** + * File streams + */ + + std::string package_name_; + std::ofstream f_service_; + std::string package_dir_; + + bool bean_style_; + bool nocamel_style_; + bool gen_hash_code_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_java_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + package_name_ = program_->get_namespace("java"); + + string dir = package_name_; + string subdir = get_out_dir(); + string::size_type loc; + while ((loc = dir.find(".")) != string::npos) { + subdir = subdir + "/" + dir.substr(0, loc); + MKDIR(subdir.c_str()); + dir = dir.substr(loc+1); + } + if (dir.size() > 0) { + subdir = subdir + "/" + dir; + MKDIR(subdir.c_str()); + } + + package_dir_ = subdir; +} + +/** + * Packages the generated file + * + * @return String of the package, i.e. "package org.apache.thriftdemo;" + */ +string t_java_generator::java_package() { + if (!package_name_.empty()) { + return string("package ") + package_name_ + ";\n\n"; + } + return ""; +} + +/** + * Prints standard java imports + * + * @return List of imports for Java types that are used in here + */ +string t_java_generator::java_type_imports() { + string hash_builder; + if (gen_hash_code_) { + hash_builder = "import org.apache.commons.lang.builder.HashCodeBuilder;\n"; + } + + return + string() + + hash_builder + + "import java.util.List;\n" + + "import java.util.ArrayList;\n" + + "import java.util.Map;\n" + + "import java.util.HashMap;\n" + + "import java.util.Set;\n" + + "import java.util.HashSet;\n" + + "import java.util.Collections;\n\n"; +} + +/** + * Prints standard java imports + * + * @return List of imports necessary for thrift + */ +string t_java_generator::java_thrift_imports() { + return + string() + + "import org.apache.thrift.*;\n" + + "import org.apache.thrift.meta_data.*;\n" + + "import org.apache.thrift.protocol.*;\n\n"; +} + +/** + * Nothing in Java + */ +void t_java_generator::close_generator() {} + +/** + * Generates a typedef. This is not done in Java, since it does + * not support arbitrary name replacements, and it'd be a wacky waste + * of overhead to make wrapper classes. + * + * @param ttypedef The type definition + */ +void t_java_generator::generate_typedef(t_typedef* ttypedef) {} + +/** + * Enums are a class with a set of static constants. + * + * @param tenum The enumeration + */ +void t_java_generator::generate_enum(t_enum* tenum) { + // Make output file + string f_enum_name = package_dir_+"/"+(tenum->get_name())+".java"; + ofstream f_enum; + f_enum.open(f_enum_name.c_str()); + + // Comment and package it + f_enum << + autogen_comment() << + java_package() << endl; + + // Add java imports + f_enum << string() + + "import java.util.Set;\n" + + "import java.util.HashSet;\n" + + "import java.util.Collections;\n" + + "import org.apache.thrift.IntRangeSet;\n" + + "import java.util.Map;\n" + + "import java.util.HashMap;\n" << endl; + + f_enum << + "public class " << tenum->get_name() << " "; + scope_up(f_enum); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + indent(f_enum) << + "public static final int " << (*c_iter)->get_name() << + " = " << value << ";" << endl; + } + + // Create a static Set with all valid values for this enum + f_enum << endl; + indent(f_enum) << "public static final IntRangeSet VALID_VALUES = new IntRangeSet("; + indent_up(); + bool first = true; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + // populate set + if ((*c_iter)->has_value()) { + f_enum << (first ? "" : ", ") << (*c_iter)->get_name(); + first = false; + } + } + indent_down(); + f_enum << ");" << endl; + + indent(f_enum) << "public static final Map VALUES_TO_NAMES = new HashMap() {{" << endl; + + indent_up(); + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + indent(f_enum) << "put(" << (*c_iter)->get_name() << ", \"" << (*c_iter)->get_name() <<"\");" << endl; + } + indent_down(); + + + indent(f_enum) << "}};" << endl; + + scope_down(f_enum); + + f_enum.close(); +} + +/** + * Generates a class that holds all the constants. + */ +void t_java_generator::generate_consts(std::vector consts) { + if (consts.empty()) { + return; + } + + string f_consts_name = package_dir_+"/Constants.java"; + ofstream f_consts; + f_consts.open(f_consts_name.c_str()); + + // Print header + f_consts << + autogen_comment() << + java_package() << + java_type_imports(); + + f_consts << + "public class Constants {" << endl << + endl; + indent_up(); + vector::iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + print_const_value(f_consts, + (*c_iter)->get_name(), + (*c_iter)->get_type(), + (*c_iter)->get_value(), + false); + } + indent_down(); + indent(f_consts) << + "}" << endl; + f_consts.close(); +} + + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +void t_java_generator::print_const_value(std::ofstream& out, string name, t_type* type, t_const_value* value, bool in_static, bool defval) { + type = get_true_type(type); + + indent(out); + if (!defval) { + out << + (in_static ? "" : "public static final ") << + type_name(type) << " "; + } + if (type->is_base_type()) { + string v2 = render_const_value(out, name, type, value); + out << name << " = " << v2 << ";" << endl << endl; + } else if (type->is_enum()) { + out << name << " = " << value->get_integer() << ";" << endl << endl; + } else if (type->is_struct() || type->is_xception()) { + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + out << name << " = new " << type_name(type, false, true) << "();" << endl; + if (!in_static) { + indent(out) << "static {" << endl; + indent_up(); + } + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + string val = render_const_value(out, name, field_type, v_iter->second); + indent(out) << name << "."; + std::string cap_name = get_cap_name(v_iter->first->get_string()); + out << "set" << cap_name << "(" << val << ");" << endl; + } + if (!in_static) { + indent_down(); + indent(out) << "}" << endl; + } + out << endl; + } else if (type->is_map()) { + out << name << " = new " << type_name(type, false, true) << "();" << endl; + if (!in_static) { + indent(out) << "static {" << endl; + indent_up(); + } + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(out, name, ktype, v_iter->first); + string val = render_const_value(out, name, vtype, v_iter->second); + indent(out) << name << ".put(" << key << ", " << val << ");" << endl; + } + if (!in_static) { + indent_down(); + indent(out) << "}" << endl; + } + out << endl; + } else if (type->is_list() || type->is_set()) { + out << name << " = new " << type_name(type, false, true) << "();" << endl; + if (!in_static) { + indent(out) << "static {" << endl; + indent_up(); + } + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string val = render_const_value(out, name, etype, *v_iter); + indent(out) << name << ".add(" << val << ");" << endl; + } + if (!in_static) { + indent_down(); + indent(out) << "}" << endl; + } + out << endl; + } else { + throw "compiler error: no const of type " + type->get_name(); + } +} + +string t_java_generator::render_const_value(ofstream& out, string name, t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream render; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + render << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + render << ((value->get_integer() > 0) ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + render << "(byte)" << value->get_integer(); + break; + case t_base_type::TYPE_I16: + render << "(short)" << value->get_integer(); + break; + case t_base_type::TYPE_I32: + render << value->get_integer(); + break; + case t_base_type::TYPE_I64: + render << value->get_integer() << "L"; + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + render << "(double)" << value->get_integer(); + } else { + render << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + render << value->get_integer(); + } else { + string t = tmp("tmp"); + print_const_value(out, t, type, value, true); + render << t; + } + + return render.str(); +} + +/** + * Generates a struct definition for a thrift data type. This is a class + * with data members, read(), write(), and an inner Isset class. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_struct(t_struct* tstruct) { + generate_java_struct(tstruct, false); +} + +/** + * Exceptions are structs, but they inherit from Exception + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_xception(t_struct* txception) { + generate_java_struct(txception, true); +} + + +/** + * Java struct definition. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct(t_struct* tstruct, + bool is_exception) { + // Make output file + string f_struct_name = package_dir_+"/"+(tstruct->get_name())+".java"; + ofstream f_struct; + f_struct.open(f_struct_name.c_str()); + + f_struct << + autogen_comment() << + java_package() << + java_type_imports() << + java_thrift_imports(); + + generate_java_struct_definition(f_struct, + tstruct, + is_exception); + f_struct.close(); +} + +/** + * Java struct definition. This has various parameters, as it could be + * generated standalone or inside another class as a helper. If it + * is a helper than it is a static class. + * + * @param tstruct The struct definition + * @param is_exception Is this an exception? + * @param in_class If inside a class, needs to be static class + * @param is_result If this is a result it needs a different writer + */ +void t_java_generator::generate_java_struct_definition(ofstream &out, + t_struct* tstruct, + bool is_exception, + bool in_class, + bool is_result) { + generate_java_doc(out, tstruct); + + bool is_final = (tstruct->annotations_.find("final") != tstruct->annotations_.end()); + + indent(out) << + "public " << (is_final ? "final " : "") << + (in_class ? "static " : "") << "class " << tstruct->get_name() << " "; + + if (is_exception) { + out << "extends Exception "; + } + out << "implements TBase, java.io.Serializable, Cloneable "; + + scope_up(out); + + indent(out) << + "private static final TStruct STRUCT_DESC = new TStruct(\"" << tstruct->get_name() << "\");" << endl; + + // Members are public for -java, private for -javabean + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << + "private static final TField " << constant_name((*m_iter)->get_name()) << + "_FIELD_DESC = new TField(\"" << (*m_iter)->get_name() << "\", " << + type_to_enum((*m_iter)->get_type()) << ", " << + "(short)" << (*m_iter)->get_key() << ");" << endl; + } + + out << endl; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (bean_style_) { + indent(out) << "private "; + } else { + generate_java_doc(out, *m_iter); + indent(out) << "public "; + } + out << declare_field(*m_iter, false) << endl; + + indent(out) << "public static final int " << upcase_string((*m_iter)->get_name()) << " = " << (*m_iter)->get_key() << ";" << endl; + } + + // Inner Isset class + if (members.size() > 0) { + out << + endl << + indent() << "private final Isset __isset = new Isset();" << endl << + indent() << "private static final class Isset implements java.io.Serializable {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (!type_can_be_null((*m_iter)->get_type())){ + indent(out) << + "public boolean " << (*m_iter)->get_name() << " = false;" << endl; + } + } + indent_down(); + out << + indent() << "}" << endl << + endl; + } + + generate_java_meta_data_map(out, tstruct); + + // Static initializer to populate global class to struct metadata map + indent(out) << "static {" << endl; + indent_up(); + indent(out) << "FieldMetaData.addStructMetaDataMap(" << type_name(tstruct) << ".class, metaDataMap);" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + // Default constructor + indent(out) << + "public " << tstruct->get_name() << "() {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if ((*m_iter)->get_value() != NULL) { + print_const_value(out, "this." + (*m_iter)->get_name(), t, (*m_iter)->get_value(), true, true); + } + } + indent_down(); + indent(out) << "}" << endl << endl; + + + if (!members.empty()) { + // Full constructor for all fields + indent(out) << + "public " << tstruct->get_name() << "(" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ) { + indent(out) << type_name((*m_iter)->get_type()) << " " << + (*m_iter)->get_name(); + ++m_iter; + if (m_iter != members.end()) { + out << "," << endl; + } + } + out << ")" << endl; + indent_down(); + indent(out) << "{" << endl; + indent_up(); + indent(out) << "this();" << endl; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + indent(out) << "this." << (*m_iter)->get_name() << " = " << + (*m_iter)->get_name() << ";" << endl; + generate_isset_set(out, (*m_iter)); + } + indent_down(); + indent(out) << "}" << endl << endl; + } + + // copy constructor + indent(out) << "/**" << endl; + indent(out) << " * Performs a deep copy on other." << endl; + indent(out) << " */" << endl; + indent(out) << "public " << tstruct->get_name() << "(" << tstruct->get_name() << " other) {" << endl; + indent_up(); + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_field* field = (*m_iter); + std::string field_name = field->get_name(); + t_type* type = field->get_type(); + bool can_be_null = type_can_be_null(type); + + if (!can_be_null) { + indent(out) << "__isset." << field_name << " = other.__isset." << field_name << ";" << endl; + } + + if (can_be_null) { + indent(out) << "if (other." << generate_isset_check(field) << ") {" << endl; + indent_up(); + } + + if (type->is_container()) { + generate_deep_copy_container(out, "other", field_name, "__this__" + field_name, type); + indent(out) << "this." << field_name << " = __this__" << field_name << ";" << endl; + } else { + indent(out) << "this." << field_name << " = "; + generate_deep_copy_non_container(out, "other." + field_name, field_name, type); + out << ";" << endl; + } + + if (can_be_null) { + indent_down(); + indent(out) << "}" << endl; + } + } + + indent_down(); + indent(out) << "}" << endl << endl; + + // clone method, so that you can deep copy an object when you don't know its class. + indent(out) << "@Override" << endl; + indent(out) << "public " << tstruct->get_name() << " clone() {" << endl; + indent(out) << " return new " << tstruct->get_name() << "(this);" << endl; + indent(out) << "}" << endl << endl; + + generate_java_bean_boilerplate(out, tstruct); + generate_generic_field_getters_setters(out, tstruct); + generate_generic_isset_method(out, tstruct); + + generate_java_struct_equality(out, tstruct); + + generate_java_struct_reader(out, tstruct); + if (is_result) { + generate_java_struct_result_writer(out, tstruct); + } else { + generate_java_struct_writer(out, tstruct); + } + generate_java_struct_tostring(out, tstruct); + generate_java_validator(out, tstruct); + scope_down(out); + out << endl; +} + +/** + * Generates equals methods and a hashCode method for a structure. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct_equality(ofstream& out, + t_struct* tstruct) { + out << indent() << "@Override" << endl << + indent() << "public boolean equals(Object that) {" << endl; + indent_up(); + out << + indent() << "if (that == null)" << endl << + indent() << " return false;" << endl << + indent() << "if (that instanceof " << tstruct->get_name() << ")" << endl << + indent() << " return this.equals((" << tstruct->get_name() << ")that);" << endl << + indent() << "return false;" << endl; + scope_down(out); + out << endl; + + out << + indent() << "public boolean equals(" << tstruct->get_name() << " that) {" << endl; + indent_up(); + out << + indent() << "if (that == null)" << endl << + indent() << " return false;" << endl; + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << endl; + + t_type* t = get_true_type((*m_iter)->get_type()); + // Most existing Thrift code does not use isset or optional/required, + // so we treat "default" fields as required. + bool is_optional = (*m_iter)->get_req() == t_field::T_OPTIONAL; + bool can_be_null = type_can_be_null(t); + string name = (*m_iter)->get_name(); + + string this_present = "true"; + string that_present = "true"; + string unequal; + + if (is_optional || can_be_null) { + this_present += " && this." + generate_isset_check(*m_iter); + that_present += " && that." + generate_isset_check(*m_iter); + } + + out << + indent() << "boolean this_present_" << name << " = " + << this_present << ";" << endl << + indent() << "boolean that_present_" << name << " = " + << that_present << ";" << endl << + indent() << "if (" << "this_present_" << name + << " || that_present_" << name << ") {" << endl; + indent_up(); + out << + indent() << "if (!(" << "this_present_" << name + << " && that_present_" << name << "))" << endl << + indent() << " return false;" << endl; + + if (t->is_base_type() && ((t_base_type*)t)->is_binary()) { + unequal = "!java.util.Arrays.equals(this." + name + ", that." + name + ")"; + } else if (can_be_null) { + unequal = "!this." + name + ".equals(that." + name + ")"; + } else { + unequal = "this." + name + " != that." + name; + } + + out << + indent() << "if (" << unequal << ")" << endl << + indent() << " return false;" << endl; + + scope_down(out); + } + out << endl; + indent(out) << "return true;" << endl; + scope_down(out); + out << endl; + + out << indent() << "@Override" << endl << + indent() << "public int hashCode() {" << endl; + indent_up(); + if (gen_hash_code_) { + indent(out) << "HashCodeBuilder builder = new HashCodeBuilder();" << endl; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << endl; + + t_type* t = get_true_type((*m_iter)->get_type()); + bool is_optional = (*m_iter)->get_req() == t_field::T_OPTIONAL; + bool can_be_null = type_can_be_null(t); + string name = (*m_iter)->get_name(); + + string present = "true"; + + if (is_optional || can_be_null) { + present += " && (" + generate_isset_check(*m_iter) + ")"; + } + + out << + indent() << "boolean present_" << name << " = " + << present << ";" << endl << + indent() << "builder.append(present_" << name << ");" << endl << + indent() << "if (present_" << name << ")" << endl << + indent() << " builder.append(" << name << ");" << endl; + } + + out << endl; + indent(out) << "return builder.toHashCode();" << endl; + } else { + indent(out) << "return 0;" << endl; + } + indent_down(); + indent(out) << "}" << endl << endl; +} + +/** + * Generates a function to read all the fields of the struct. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct_reader(ofstream& out, + t_struct* tstruct) { + out << + indent() << "public void read(TProtocol iprot) throws TException {" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // Declare stack tmp variables and read struct header + out << + indent() << "TField field;" << endl << + indent() << "iprot.readStructBegin();" << endl; + + // Loop over reading in fields + indent(out) << + "while (true)" << endl; + scope_up(out); + + // Read beginning field marker + indent(out) << + "field = iprot.readFieldBegin();" << endl; + + // Check for field STOP marker and break + indent(out) << + "if (field.type == TType.STOP) { " << endl; + indent_up(); + indent(out) << + "break;" << endl; + indent_down(); + indent(out) << + "}" << endl; + + // Switch statement on the field we are reading + indent(out) << + "switch (field.id)" << endl; + + scope_up(out); + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << upcase_string((*f_iter)->get_name()) << ":" << endl; + indent_up(); + indent(out) << + "if (field.type == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); + + generate_deserialize_field(out, *f_iter, "this."); + generate_isset_set(out, *f_iter); + indent_down(); + out << + indent() << "} else { " << endl << + indent() << " TProtocolUtil.skip(iprot, field.type);" << endl << + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } + + // In the default case we skip the field + out << + indent() << "default:" << endl << + indent() << " TProtocolUtil.skip(iprot, field.type);" << endl << + indent() << " break;" << endl; + + scope_down(out); + + // Read field end marker + indent(out) << + "iprot.readFieldEnd();" << endl; + + scope_down(out); + + out << + indent() << "iprot.readStructEnd();" << endl << endl; + + // in non-beans style, check for required fields of primitive type + // (which can be checked here but not in the general validate method) + if (!bean_style_){ + out << endl << indent() << "// check for required fields of primitive type, which can't be checked in the validate method" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_req() == t_field::T_REQUIRED && !type_can_be_null((*f_iter)->get_type())) { + out << + indent() << "if (!__isset." << (*f_iter)->get_name() << ") {" << endl << + indent() << " throw new TProtocolException(\"Required field '" << (*f_iter)->get_name() << "' was not found in serialized data! Struct: \" + toString());" << endl << + indent() << "}" << endl; + } + } + } + + // performs various checks (e.g. check that all required fields are set) + indent(out) << "validate();" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +// generates java method to perform various checks +// (e.g. check that all required fields are set) +void t_java_generator::generate_java_validator(ofstream& out, + t_struct* tstruct){ + indent(out) << "public void validate() throws TException {" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + out << indent() << "// check for required fields" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_req() == t_field::T_REQUIRED) { + if (bean_style_) { + out << + indent() << "if (!" << generate_isset_check(*f_iter) << ") {" << endl << + indent() << " throw new TProtocolException(\"Required field '" << (*f_iter)->get_name() << "' is unset! Struct:\" + toString());" << endl << + indent() << "}" << endl << endl; + } else{ + if (type_can_be_null((*f_iter)->get_type())) { + indent(out) << "if (" << (*f_iter)->get_name() << " == null) {" << endl; + indent(out) << " throw new TProtocolException(\"Required field '" << (*f_iter)->get_name() << "' was not present! Struct: \" + toString());" << endl; + indent(out) << "}" << endl; + } else { + indent(out) << "// alas, we cannot check '" << (*f_iter)->get_name() << "' because it's a primitive and you chose the non-beans generator." << endl; + } + } + } + } + + // check that fields of type enum have valid values + out << indent() << "// check that fields of type enum have valid values" << endl; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = (*f_iter); + t_type* type = field->get_type(); + // if field is an enum, check that its value is valid + if (type->is_enum()){ + indent(out) << "if (" << generate_isset_check(field) << " && !" << get_enum_class_name(type) << ".VALID_VALUES.contains(" << field->get_name() << ")){" << endl; + indent_up(); + indent(out) << "throw new TProtocolException(\"The field '" << field->get_name() << "' has been assigned the invalid value \" + " << field->get_name() << ");" << endl; + indent_down(); + indent(out) << "}" << endl; + } + } + + indent_down(); + indent(out) << "}" << endl << endl; +} + +/** + * Generates a function to write all the fields of the struct + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct_writer(ofstream& out, + t_struct* tstruct) { + out << + indent() << "public void write(TProtocol oprot) throws TException {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + // performs various checks (e.g. check that all required fields are set) + indent(out) << "validate();" << endl << endl; + + indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + bool null_allowed = type_can_be_null((*f_iter)->get_type()); + if (null_allowed) { + out << + indent() << "if (this." << (*f_iter)->get_name() << " != null) {" << endl; + indent_up(); + } + bool optional = bean_style_ && (*f_iter)->get_req() == t_field::T_OPTIONAL; + if (optional) { + indent(out) << "if (" << generate_isset_check((*f_iter)) << ") {" << endl; + indent_up(); + } + + indent(out) << "oprot.writeFieldBegin(" << constant_name((*f_iter)->get_name()) << "_FIELD_DESC);" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "this."); + + // Write field closer + indent(out) << + "oprot.writeFieldEnd();" << endl; + + if (optional) { + indent_down(); + indent(out) << "}" << endl; + } + if (null_allowed) { + indent_down(); + indent(out) << "}" << endl; + } + } + // Write the struct map + out << + indent() << "oprot.writeFieldStop();" << endl << + indent() << "oprot.writeStructEnd();" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates a function to write all the fields of the struct, + * which is a function result. These fields are only written + * if they are set in the Isset array, and only one of them + * can be set at a time. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct_result_writer(ofstream& out, + t_struct* tstruct) { + out << + indent() << "public void write(TProtocol oprot) throws TException {" << endl; + indent_up(); + + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << "oprot.writeStructBegin(STRUCT_DESC);" << endl; + + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + out << + endl << + indent() << "if "; + } else { + out << " else if "; + } + + out << "(this." << generate_isset_check(*f_iter) << ") {" << endl; + + indent_up(); + + indent(out) << "oprot.writeFieldBegin(" << constant_name((*f_iter)->get_name()) << "_FIELD_DESC);" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "this."); + + // Write field closer + indent(out) << + "oprot.writeFieldEnd();" << endl; + + indent_down(); + indent(out) << "}"; + } + // Write the struct map + out << + endl << + indent() << "oprot.writeFieldStop();" << endl << + indent() << "oprot.writeStructEnd();" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +void t_java_generator::generate_reflection_getters(ostringstream& out, t_type* type, string field_name, string cap_name) { + indent(out) << "case " << upcase_string(field_name) << ":" << endl; + indent_up(); + + if (type->is_base_type() && !type->is_string()) { + t_base_type* base_type = (t_base_type*)type; + + indent(out) << "return new " << type_name(type, true, false) << "(" << (base_type->is_bool() ? "is" : "get") << cap_name << "());" << endl << endl; + } else { + indent(out) << "return get" << cap_name << "();" << endl << endl; + } + + indent_down(); +} + +void t_java_generator::generate_reflection_setters(ostringstream& out, t_type* type, string field_name, string cap_name) { + indent(out) << "case " << upcase_string(field_name) << ":" << endl; + indent_up(); + indent(out) << "if (value == null) {" << endl; + indent(out) << " unset" << get_cap_name(field_name) << "();" << endl; + indent(out) << "} else {" << endl; + indent(out) << " set" << cap_name << "((" << type_name(type, true, false) << ")value);" << endl; + indent(out) << "}" << endl; + indent(out) << "break;" << endl << endl; + + indent_down(); +} + +void t_java_generator::generate_generic_field_getters_setters(std::ofstream& out, t_struct* tstruct) { + + std::ostringstream getter_stream; + std::ostringstream setter_stream; + + // build up the bodies of both the getter and setter at once + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = *f_iter; + t_type* type = get_true_type(field->get_type()); + std::string field_name = field->get_name(); + std::string cap_name = get_cap_name(field_name); + + indent_up(); + generate_reflection_setters(setter_stream, type, field_name, cap_name); + generate_reflection_getters(getter_stream, type, field_name, cap_name); + indent_down(); + } + + + // create the setter + indent(out) << "public void setFieldValue(int fieldID, Object value) {" << endl; + indent_up(); + + indent(out) << "switch (fieldID) {" << endl; + + out << setter_stream.str(); + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; + + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "}" << endl << endl; + + // create the getter + indent(out) << "public Object getFieldValue(int fieldID) {" << endl; + indent_up(); + + indent(out) << "switch (fieldID) {" << endl; + + out << getter_stream.str(); + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; + + indent(out) << "}" << endl; + + indent_down(); + + indent(out) << "}" << endl << endl; +} + +// Creates a generic isSet method that takes the field number as argument +void t_java_generator::generate_generic_isset_method(std::ofstream& out, t_struct* tstruct){ + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // create the isSet method + indent(out) << "// Returns true if field corresponding to fieldID is set (has been asigned a value) and false otherwise" << endl; + indent(out) << "public boolean isSet(int fieldID) {" << endl; + indent_up(); + indent(out) << "switch (fieldID) {" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = *f_iter; + indent(out) << "case " << upcase_string(field->get_name()) << ":" << endl; + indent_up(); + indent(out) << "return " << generate_isset_check(field) << ";" << endl; + indent_down(); + } + + indent(out) << "default:" << endl; + indent(out) << " throw new IllegalArgumentException(\"Field \" + fieldID + \" doesn't exist!\");" << endl; + + indent(out) << "}" << endl; + + indent_down(); + indent(out) << "}" << endl << endl; +} + +/** + * Generates a set of Java Bean boilerplate functions (setters, getters, etc.) + * for the given struct. + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_bean_boilerplate(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = *f_iter; + t_type* type = get_true_type(field->get_type()); + std::string field_name = field->get_name(); + std::string cap_name = get_cap_name(field_name); + + if (type->is_container()) { + // Method to return the size of the collection + indent(out) << "public int get" << cap_name; + out << get_cap_name("size() {") << endl; + + indent_up(); + indent(out) << "return (this." << field_name << " == null) ? 0 : " << + "this." << field_name << ".size();" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + } + + if (type->is_set() || type->is_list()) { + + t_type* element_type; + if (type->is_set()) { + element_type = ((t_set*)type)->get_elem_type(); + } else { + element_type = ((t_list*)type)->get_elem_type(); + } + + // Iterator getter for sets and lists + indent(out) << "public java.util.Iterator<" << + type_name(element_type, true, false) << "> get" << cap_name; + out << get_cap_name("iterator() {") << endl; + + indent_up(); + indent(out) << "return (this." << field_name << " == null) ? null : " << + "this." << field_name << ".iterator();" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + // Add to set or list, create if the set/list is null + indent(out); + out << "public void add" << get_cap_name("to"); + out << cap_name << "(" << type_name(element_type) << " elem) {" << endl; + + indent_up(); + indent(out) << "if (this." << field_name << " == null) {" << endl; + indent_up(); + indent(out) << "this." << field_name << " = new " << type_name(type, false, true) << + "();" << endl; + indent_down(); + indent(out) << "}" << endl; + indent(out) << "this." << field_name << ".add(elem);" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + } else if (type->is_map()) { + // Put to map + t_type* key_type = ((t_map*)type)->get_key_type(); + t_type* val_type = ((t_map*)type)->get_val_type(); + + indent(out); + out << "public void put" << get_cap_name("to"); + out << cap_name << "(" << type_name(key_type) << " key, " + << type_name(val_type) << " val) {" << endl; + + indent_up(); + indent(out) << "if (this." << field_name << " == null) {" << endl; + indent_up(); + indent(out) << "this." << field_name << " = new " << + type_name(type, false, true) << "();" << endl; + indent_down(); + indent(out) << "}" << endl; + indent(out) << "this." << field_name << ".put(key, val);" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + } + + // Simple getter + generate_java_doc(out, field); + indent(out) << "public " << type_name(type); + if (type->is_base_type() && + ((t_base_type*)type)->get_base() == t_base_type::TYPE_BOOL) { + out << " is"; + } else { + out << " get"; + } + out << cap_name << "() {" << endl; + indent_up(); + indent(out) << "return this." << field_name << ";" << endl; + indent_down(); + indent(out) << "}" << endl << endl; + + // Simple setter + generate_java_doc(out, field); + indent(out) << "public void set" << cap_name << "(" << type_name(type) << + " " << field_name << ") {" << endl; + indent_up(); + indent(out) << "this." << field_name << " = " << field_name << ";" << + endl; + generate_isset_set(out, field); + + indent_down(); + indent(out) << "}" << endl << endl; + + // Unsetter + indent(out) << "public void unset" << cap_name << "() {" << endl; + indent_up(); + if (type_can_be_null(type)) { + indent(out) << "this." << field_name << " = null;" << endl; + } else { + indent(out) << "this.__isset." << field_name << " = false;" << endl; + } + indent_down(); + indent(out) << "}" << endl << endl; + + // isSet method + indent(out) << "// Returns true if field " << field_name << " is set (has been asigned a value) and false otherwise" << endl; + indent(out) << "public boolean is" << get_cap_name("set") << cap_name << "() {" << endl; + indent_up(); + if (type_can_be_null(type)) { + indent(out) << "return this." << field_name << " != null;" << endl; + } else { + indent(out) << "return this.__isset." << field_name << ";" << endl; + } + indent_down(); + indent(out) << "}" << endl << endl; + + if(!bean_style_) { + indent(out) << "public void set" << cap_name << get_cap_name("isSet") << "(boolean value) {" << endl; + indent_up(); + if (type_can_be_null(type)) { + indent(out) << "if (!value) {" << endl; + indent(out) << " this." << field_name << " = null;" << endl; + indent(out) << "}" << endl; + } else { + indent(out) << "this.__isset." << field_name << " = value;" << endl; + } + indent_down(); + indent(out) << "}" << endl << endl; + } + } +} + +/** + * Generates a toString() method for the given struct + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_struct_tostring(ofstream& out, + t_struct* tstruct) { + out << indent() << "@Override" << endl << + indent() << "public String toString() {" << endl; + indent_up(); + + out << + indent() << "StringBuilder sb = new StringBuilder(\"" << tstruct->get_name() << "(\");" << endl; + out << indent() << "boolean first = true;" << endl << endl; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + bool could_be_unset = (*f_iter)->get_req() == t_field::T_OPTIONAL; + if(could_be_unset) { + indent(out) << "if (" << generate_isset_check(*f_iter) << ") {" << endl; + indent_up(); + } + + t_field* field = (*f_iter); + + if (!first) { + indent(out) << "if (!first) sb.append(\", \");" << endl; + } + indent(out) << "sb.append(\"" << (*f_iter)->get_name() << ":\");" << endl; + bool can_be_null = type_can_be_null(field->get_type()); + if (can_be_null) { + indent(out) << "if (this." << (*f_iter)->get_name() << " == null) {" << endl; + indent(out) << " sb.append(\"null\");" << endl; + indent(out) << "} else {" << endl; + indent_up(); + } + + if (field->get_type()->is_base_type() && ((t_base_type*)(field->get_type()))->is_binary()) { + indent(out) << " int __" << field->get_name() << "_size = Math.min(this." << field->get_name() << ".length, 128);" << endl; + indent(out) << " for (int i = 0; i < __" << field->get_name() << "_size; i++) {" << endl; + indent(out) << " if (i != 0) sb.append(\" \");" << endl; + indent(out) << " sb.append(Integer.toHexString(this." << field->get_name() << "[i]).length() > 1 ? Integer.toHexString(this." << field->get_name() << "[i]).substring(Integer.toHexString(this." << field->get_name() << "[i]).length() - 2).toUpperCase() : \"0\" + Integer.toHexString(this." << field->get_name() << "[i]).toUpperCase());" <get_name() << ".length > 128) sb.append(\" ...\");" << endl; + } else if(field->get_type()->is_enum()) { + indent(out) << "String " << field->get_name() << "_name = " << get_enum_class_name(field->get_type()) << ".VALUES_TO_NAMES.get(this." << (*f_iter)->get_name() << ");"<< endl; + indent(out) << "if (" << field->get_name() << "_name != null) {" << endl; + indent(out) << " sb.append(" << field->get_name() << "_name);" << endl; + indent(out) << " sb.append(\" (\");" << endl; + indent(out) << "}" << endl; + indent(out) << "sb.append(this." << field->get_name() << ");" << endl; + indent(out) << "if (" << field->get_name() << "_name != null) {" << endl; + indent(out) << " sb.append(\")\");" << endl; + indent(out) << "}" << endl; + } else { + indent(out) << "sb.append(this." << (*f_iter)->get_name() << ");" << endl; + } + + if (can_be_null) { + indent_down(); + indent(out) << "}" << endl; + } + indent(out) << "first = false;" << endl; + + if(could_be_unset) { + indent_down(); + indent(out) << "}" << endl; + } + first = false; + } + out << + indent() << "sb.append(\")\");" << endl << + indent() << "return sb.toString();" << endl; + + indent_down(); + indent(out) << "}" << endl << + endl; +} + +/** + * Generates a static map with meta data to store information such as fieldID to + * fieldName mapping + * + * @param tstruct The struct definition + */ +void t_java_generator::generate_java_meta_data_map(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + // Static Map with fieldID -> FieldMetaData mappings + indent(out) << "public static final Map metaDataMap = Collections.unmodifiableMap(new HashMap() {{" << endl; + + // Populate map + indent_up(); + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = *f_iter; + std::string field_name = field->get_name(); + indent(out) << "put(" << upcase_string(field_name) << ", new FieldMetaData(\"" << field_name << "\", "; + + // Set field requirement type (required, optional, etc.) + if (field->get_req() == t_field::T_REQUIRED) { + out << "TFieldRequirementType.REQUIRED, "; + } else if (field->get_req() == t_field::T_OPTIONAL) { + out << "TFieldRequirementType.OPTIONAL, "; + } else { + out << "TFieldRequirementType.DEFAULT, "; + } + + // Create value meta data + generate_field_value_meta_data(out, field->get_type()); + out << "));" << endl; + } + indent_down(); + indent(out) << "}});" << endl << endl; +} + +/** + * Returns a string with the java representation of the given thrift type + * (e.g. for the type struct it returns "TType.STRUCT") + */ +std::string t_java_generator::get_java_type_string(t_type* type) { + if (type->is_list()){ + return "TType.LIST"; + } else if (type->is_map()) { + return "TType.MAP"; + } else if (type->is_set()) { + return "TType.SET"; + } else if (type->is_struct() || type->is_xception()) { + return "TType.STRUCT"; + } else if (type->is_enum()) { + return "TType.I32"; + } else if (type->is_typedef()) { + return get_java_type_string(((t_typedef*)type)->get_type()); + } else if (type->is_base_type()) { + switch (((t_base_type*)type)->get_base()) { + case t_base_type::TYPE_VOID : return "TType.VOID"; break; + case t_base_type::TYPE_STRING : return "TType.STRING"; break; + case t_base_type::TYPE_BOOL : return "TType.BOOL"; break; + case t_base_type::TYPE_BYTE : return "TType.BYTE"; break; + case t_base_type::TYPE_I16 : return "TType.I16"; break; + case t_base_type::TYPE_I32 : return "TType.I32"; break; + case t_base_type::TYPE_I64 : return "TType.I64"; break; + case t_base_type::TYPE_DOUBLE : return "TType.DOUBLE"; break; + default : throw std::runtime_error("Unknown thrift type \"" + type->get_name() + "\" passed to t_java_generator::get_java_type_string!"); break; // This should never happen! + } + } else { + throw std::runtime_error("Unknown thrift type \"" + type->get_name() + "\" passed to t_java_generator::get_java_type_string!"); // This should never happen! + } +} + +void t_java_generator::generate_field_value_meta_data(std::ofstream& out, t_type* type){ + out << endl; + indent_up(); + indent_up(); + if (type->is_struct()){ + indent(out) << "new StructMetaData(TType.STRUCT, " << type_name(type) << ".class"; + } else if (type->is_container()){ + if (type->is_list()){ + indent(out) << "new ListMetaData(TType.LIST, "; + t_type* elem_type = ((t_list*)type)->get_elem_type(); + generate_field_value_meta_data(out, elem_type); + } else if (type->is_set()){ + indent(out) << "new SetMetaData(TType.SET, "; + t_type* elem_type = ((t_list*)type)->get_elem_type(); + generate_field_value_meta_data(out, elem_type); + } else{ // map + indent(out) << "new MapMetaData(TType.MAP, "; + t_type* key_type = ((t_map*)type)->get_key_type(); + t_type* val_type = ((t_map*)type)->get_val_type(); + generate_field_value_meta_data(out, key_type); + out << ", "; + generate_field_value_meta_data(out, val_type); + } + } else { + indent(out) << "new FieldValueMetaData(" << get_java_type_string(type); + } + out << ")"; + indent_down(); + indent_down(); +} + + +/** + * Generates a thrift service. In C++, this comprises an entirely separate + * header and source file. The header file defines the methods and includes + * the data types defined in the main header file, and the implementation + * file contains implementations of the basic printer and default interfaces. + * + * @param tservice The service definition + */ +void t_java_generator::generate_service(t_service* tservice) { + // Make output file + string f_service_name = package_dir_+"/"+service_name_+".java"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + autogen_comment() << + java_package() << + java_type_imports() << + java_thrift_imports(); + + f_service_ << + "public class " << service_name_ << " {" << endl << + endl; + indent_up(); + + // Generate the three main parts of the service + generate_service_interface(tservice); + generate_service_client(tservice); + generate_service_server(tservice); + generate_service_helpers(tservice); + + indent_down(); + f_service_ << + "}" << endl; + f_service_.close(); +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_java_generator::generate_service_interface(t_service* tservice) { + string extends = ""; + string extends_iface = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_iface = " extends " + extends + ".Iface"; + } + + generate_java_doc(f_service_, tservice); + f_service_ << indent() << "public interface Iface" << extends_iface << + " {" << endl << endl; + indent_up(); + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_java_doc(f_service_, *f_iter); + indent(f_service_) << "public " << function_signature(*f_iter) << ";" << + endl << endl; + } + indent_down(); + f_service_ << + indent() << "}" << endl << + endl; +} + +/** + * Generates structs for all the service args and return types + * + * @param tservice The service + */ +void t_java_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_java_struct_definition(f_service_, ts, false, true); + generate_function_helpers(*f_iter); + } +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_java_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_client = " extends " + extends + ".Client"; + } + + indent(f_service_) << + "public static class Client" << extends_client << " implements Iface {" << endl; + indent_up(); + + indent(f_service_) << + "public Client(TProtocol prot)" << endl; + scope_up(f_service_); + indent(f_service_) << + "this(prot, prot);" << endl; + scope_down(f_service_); + f_service_ << endl; + + indent(f_service_) << + "public Client(TProtocol iprot, TProtocol oprot)" << endl; + scope_up(f_service_); + if (extends.empty()) { + f_service_ << + indent() << "iprot_ = iprot;" << endl << + indent() << "oprot_ = oprot;" << endl; + } else { + f_service_ << + indent() << "super(iprot, oprot);" << endl; + } + scope_down(f_service_); + f_service_ << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected TProtocol iprot_;" << endl << + indent() << "protected TProtocol oprot_;" << endl << + endl << + indent() << "protected int seqid_;" << endl << + endl; + + indent(f_service_) << + "public TProtocol getInputProtocol()" << endl; + scope_up(f_service_); + indent(f_service_) << + "return this.iprot_;" << endl; + scope_down(f_service_); + f_service_ << endl; + + indent(f_service_) << + "public TProtocol getOutputProtocol()" << endl; + scope_up(f_service_); + indent(f_service_) << + "return this.oprot_;" << endl; + scope_down(f_service_); + f_service_ << endl; + + } + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + + // Open function + indent(f_service_) << + "public " << function_signature(*f_iter) << endl; + scope_up(f_service_); + indent(f_service_) << + "send_" << funname << "("; + + // Get the struct of function call params + t_struct* arg_struct = (*f_iter)->get_arglist(); + + // Declare the function arguments + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << (*fld_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << "return "; + } + f_service_ << + "recv_" << funname << "();" << endl; + } + scope_down(f_service_); + f_service_ << endl; + + t_function send_function(g_type_void, + string("send_") + (*f_iter)->get_name(), + (*f_iter)->get_arglist()); + + string argsname = (*f_iter)->get_name() + "_args"; + + // Open function + indent(f_service_) << + "public " << function_signature(&send_function) << endl; + scope_up(f_service_); + + // Serialize the request + f_service_ << + indent() << "oprot_.writeMessageBegin(new TMessage(\"" << funname << "\", TMessageType.CALL, seqid_));" << endl << + indent() << argsname << " args = new " << argsname << "();" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "args." << (*fld_iter)->get_name() << " = " << (*fld_iter)->get_name() << ";" << endl; + } + + f_service_ << + indent() << "args.write(oprot_);" << endl << + indent() << "oprot_.writeMessageEnd();" << endl << + indent() << "oprot_.getTransport().flush();" << endl; + + scope_down(f_service_); + f_service_ << endl; + + if (!(*f_iter)->is_oneway()) { + string resultname = (*f_iter)->get_name() + "_result"; + + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs, + (*f_iter)->get_xceptions()); + // Open function + indent(f_service_) << + "public " << function_signature(&recv_function) << endl; + scope_up(f_service_); + + // TODO(mcslee): Message validation here, was the seqid etc ok? + + f_service_ << + indent() << "TMessage msg = iprot_.readMessageBegin();" << endl << + indent() << "if (msg.type == TMessageType.EXCEPTION) {" << endl << + indent() << " TApplicationException x = TApplicationException.read(iprot_);" << endl << + indent() << " iprot_.readMessageEnd();" << endl << + indent() << " throw x;" << endl << + indent() << "}" << endl << + indent() << resultname << " result = new " << resultname << "();" << endl << + indent() << "result.read(iprot_);" << endl << + indent() << "iprot_.readMessageEnd();" << endl; + + // Careful, only return _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "if (result." << generate_isset_check("success") << ") {" << endl << + indent() << " return result.success;" << endl << + indent() << "}" << endl; + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "if (result." << (*x_iter)->get_name() << " != null) {" << endl << + indent() << " throw result." << (*x_iter)->get_name() << ";" << endl << + indent() << "}" << endl; + } + + // If you get here it's an exception, unless a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "return;" << endl; + } else { + f_service_ << + indent() << "throw new TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; + } + + // Close function + scope_down(f_service_); + f_service_ << endl; + } + } + + indent_down(); + indent(f_service_) << + "}" << endl; +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_java_generator::generate_service_server(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + // Extends stuff + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_processor = " extends " + extends + ".Processor"; + } + + // Generate the header portion + indent(f_service_) << + "public static class Processor" << extends_processor << " implements TProcessor {" << endl; + indent_up(); + + indent(f_service_) << + "public Processor(Iface iface)" << endl; + scope_up(f_service_); + if (!extends.empty()) { + f_service_ << + indent() << "super(iface);" << endl; + } + f_service_ << + indent() << "iface_ = iface;" << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "processMap_.put(\"" << (*f_iter)->get_name() << "\", new " << (*f_iter)->get_name() << "());" << endl; + } + + scope_down(f_service_); + f_service_ << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected static interface ProcessFunction {" << endl << + indent() << " public void process(int seqid, TProtocol iprot, TProtocol oprot) throws TException;" << endl << + indent() << "}" << endl << + endl; + } + + f_service_ << + indent() << "private Iface iface_;" << endl; + + if (extends.empty()) { + f_service_ << + indent() << "protected final HashMap processMap_ = new HashMap();" << endl; + } + + f_service_ << endl; + + // Generate the server implementation + indent(f_service_) << + "public boolean process(TProtocol iprot, TProtocol oprot) throws TException" << endl; + scope_up(f_service_); + + f_service_ << + indent() << "TMessage msg = iprot.readMessageBegin();" << endl; + + // TODO(mcslee): validate message, was the seqid etc. legit? + + f_service_ << + indent() << "ProcessFunction fn = processMap_.get(msg.name);" << endl << + indent() << "if (fn == null) {" << endl << + indent() << " TProtocolUtil.skip(iprot, TType.STRUCT);" << endl << + indent() << " iprot.readMessageEnd();" << endl << + indent() << " TApplicationException x = new TApplicationException(TApplicationException.UNKNOWN_METHOD, \"Invalid method name: '\"+msg.name+\"'\");" << endl << + indent() << " oprot.writeMessageBegin(new TMessage(msg.name, TMessageType.EXCEPTION, msg.seqid));" << endl << + indent() << " x.write(oprot);" << endl << + indent() << " oprot.writeMessageEnd();" << endl << + indent() << " oprot.getTransport().flush();" << endl << + indent() << " return true;" << endl << + indent() << "}" << endl << + indent() << "fn.process(msg.seqid, iprot, oprot);" << endl; + + f_service_ << + indent() << "return true;" << endl; + + scope_down(f_service_); + f_service_ << endl; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + indent_down(); + indent(f_service_) << + "}" << endl << + endl; +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_java_generator::generate_function_helpers(t_function* tfunction) { + if (tfunction->is_oneway()) { + return; + } + + t_struct result(program_, tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + generate_java_struct_definition(f_service_, &result, false, true, true); +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_java_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open class + indent(f_service_) << + "private class " << tfunction->get_name() << " implements ProcessFunction {" << endl; + indent_up(); + + // Open function + indent(f_service_) << + "public void process(int seqid, TProtocol iprot, TProtocol oprot) throws TException" << endl; + scope_up(f_service_); + + string argsname = tfunction->get_name() + "_args"; + string resultname = tfunction->get_name() + "_result"; + + f_service_ << + indent() << argsname << " args = new " << argsname << "();" << endl << + indent() << "args.read(iprot);" << endl << + indent() << "iprot.readMessageEnd();" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << resultname << " result = new " << resultname << "();" << endl; + } + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "try {" << endl; + indent_up(); + } + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "result.success = "; + } + f_service_ << + "iface_." << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << (*f_iter)->get_name(); + } + f_service_ << ");" << endl; + + // Set isset on success field + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void() && !type_can_be_null(tfunction->get_returntype())) { + f_service_ << + indent() << "result.__isset.success = true;" << endl; + } + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + f_service_ << indent() << "}"; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << " catch (" << type_name((*x_iter)->get_type(), false, false) << " " << (*x_iter)->get_name() << ") {" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << ";" << endl; + indent_down(); + f_service_ << indent() << "}"; + } else { + f_service_ << "}"; + } + } + f_service_ << endl; + } + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return;" << endl; + scope_down(f_service_); + + // Close class + indent_down(); + f_service_ << + indent() << "}" << endl << + endl; + return; + } + + f_service_ << + indent() << "oprot.writeMessageBegin(new TMessage(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid));" << endl << + indent() << "result.write(oprot);" << endl << + indent() << "oprot.writeMessageEnd();" << endl << + indent() << "oprot.getTransport().flush();" << endl; + + // Close function + scope_down(f_service_); + f_service_ << endl; + + // Close class + indent_down(); + f_service_ << + indent() << "}" << endl << + endl; +} + +/** + * Deserializes a field of any type. + * + * @param tfield The field + * @param prefix The variable name or container for this field + */ +void t_java_generator::generate_deserialize_field(ofstream& out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name(); + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + + indent(out) << + name << " = iprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "readBinary();"; + } else { + out << "readString();"; + } + break; + case t_base_type::TYPE_BOOL: + out << "readBool();"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte();"; + break; + case t_base_type::TYPE_I16: + out << "readI16();"; + break; + case t_base_type::TYPE_I32: + out << "readI32();"; + break; + case t_base_type::TYPE_I64: + out << "readI64();"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble();"; + break; + default: + throw "compiler error: no Java name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32();"; + } + out << + endl; + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type_name(type).c_str()); + } +} + +/** + * Generates an unserializer for a struct, invokes read() + */ +void t_java_generator::generate_deserialize_struct(ofstream& out, + t_struct* tstruct, + string prefix) { + out << + indent() << prefix << " = new " << type_name(tstruct) << "();" << endl << + indent() << prefix << ".read(iprot);" << endl; +} + +/** + * Deserializes a container by reading its size and then iterating + */ +void t_java_generator::generate_deserialize_container(ofstream& out, + t_type* ttype, + string prefix) { + scope_up(out); + + string obj; + + if (ttype->is_map()) { + obj = tmp("_map"); + } else if (ttype->is_set()) { + obj = tmp("_set"); + } else if (ttype->is_list()) { + obj = tmp("_list"); + } + + // Declare variables, read header + if (ttype->is_map()) { + indent(out) << "TMap " << obj << " = iprot.readMapBegin();" << endl; + } else if (ttype->is_set()) { + indent(out) << "TSet " << obj << " = iprot.readSetBegin();" << endl; + } else if (ttype->is_list()) { + indent(out) << "TList " << obj << " = iprot.readListBegin();" << endl; + } + + indent(out) + << prefix << " = new " << type_name(ttype, false, true) + // size the collection correctly + << "(" + << (ttype->is_list() ? "" : "2*" ) + << obj << ".size" + << ");" << endl; + + // For loop iterates over elements + string i = tmp("_i"); + indent(out) << + "for (int " << i << " = 0; " << + i << " < " << obj << ".size" << "; " << + "++" << i << ")" << endl; + + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + scope_down(out); + + // Read container end + if (ttype->is_map()) { + indent(out) << "iprot.readMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "iprot.readSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "iprot.readListEnd();" << endl; + } + + scope_down(out); +} + + +/** + * Generates code to deserialize a map + */ +void t_java_generator::generate_deserialize_map_element(ofstream& out, + t_map* tmap, + string prefix) { + string key = tmp("_key"); + string val = tmp("_val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + indent(out) << + declare_field(&fkey) << endl; + indent(out) << + declare_field(&fval) << endl; + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << + prefix << ".put(" << key << ", " << val << ");" << endl; +} + +/** + * Deserializes a set element + */ +void t_java_generator::generate_deserialize_set_element(ofstream& out, + t_set* tset, + string prefix) { + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + indent(out) << + declare_field(&felem) << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".add(" << elem << ");" << endl; +} + +/** + * Deserializes a list element + */ +void t_java_generator::generate_deserialize_list_element(ofstream& out, + t_list* tlist, + string prefix) { + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + + indent(out) << + declare_field(&felem) << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".add(" << elem << ");" << endl; +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param prefix Name to prepend to field name + */ +void t_java_generator::generate_serialize_field(ofstream& out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + prefix + tfield->get_name()); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + prefix + tfield->get_name()); + } else if (type->is_base_type() || type->is_enum()) { + + string name = prefix + tfield->get_name(); + indent(out) << + "oprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + if (((t_base_type*)type)->is_binary()) { + out << "writeBinary(" << name << ");"; + } else { + out << "writeString(" << name << ");"; + } + break; + case t_base_type::TYPE_BOOL: + out << "writeBool(" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte(" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "writeI16(" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "writeI32(" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "writeI64(" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble(" << name << ");"; + break; + default: + throw "compiler error: no Java name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32(" << name << ");"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s%s' TYPE '%s'\n", + prefix.c_str(), + tfield->get_name().c_str(), + type_name(type).c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_java_generator::generate_serialize_struct(ofstream& out, + t_struct* tstruct, + string prefix) { + out << + indent() << prefix << ".write(oprot);" << endl; +} + +/** + * Serializes a container by writing its size then the elements. + * + * @param ttype The type of container + * @param prefix String prefix for fields + */ +void t_java_generator::generate_serialize_container(ofstream& out, + t_type* ttype, + string prefix) { + scope_up(out); + + if (ttype->is_map()) { + indent(out) << + "oprot.writeMapBegin(new TMap(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + prefix << ".size()));" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot.writeSetBegin(new TSet(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + prefix << ".size()));" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot.writeListBegin(new TList(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + prefix << ".size()));" << endl; + } + + string iter = tmp("_iter"); + if (ttype->is_map()) { + indent(out) << + "for (Map.Entry<" << + type_name(((t_map*)ttype)->get_key_type(), true, false) << ", " << + type_name(((t_map*)ttype)->get_val_type(), true, false) << "> " << iter << + " : " << + prefix << ".entrySet())"; + } else if (ttype->is_set()) { + indent(out) << + "for (" << + type_name(((t_set*)ttype)->get_elem_type()) << " " << iter << + " : " << + prefix << ")"; + } else if (ttype->is_list()) { + indent(out) << + "for (" << + type_name(((t_list*)ttype)->get_elem_type()) << " " << iter << + " : " << + prefix << ")"; + } + + scope_up(out); + + if (ttype->is_map()) { + generate_serialize_map_element(out, (t_map*)ttype, iter, prefix); + } else if (ttype->is_set()) { + generate_serialize_set_element(out, (t_set*)ttype, iter); + } else if (ttype->is_list()) { + generate_serialize_list_element(out, (t_list*)ttype, iter); + } + + scope_down(out); + + if (ttype->is_map()) { + indent(out) << + "oprot.writeMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot.writeSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot.writeListEnd();" << endl; + } + + scope_down(out); +} + +/** + * Serializes the members of a map. + */ +void t_java_generator::generate_serialize_map_element(ofstream& out, + t_map* tmap, + string iter, + string map) { + t_field kfield(tmap->get_key_type(), iter + ".getKey()"); + generate_serialize_field(out, &kfield, ""); + t_field vfield(tmap->get_val_type(), iter + ".getValue()"); + generate_serialize_field(out, &vfield, ""); +} + +/** + * Serializes the members of a set. + */ +void t_java_generator::generate_serialize_set_element(ofstream& out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Serializes the members of a list. + */ +void t_java_generator::generate_serialize_list_element(ofstream& out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Returns a Java type name + * + * @param ttype The type + * @param container Is the type going inside a container? + * @return Java type name, i.e. HashMap + */ +string t_java_generator::type_name(t_type* ttype, bool in_container, bool in_init) { + // In Java typedefs are just resolved to their real type + ttype = get_true_type(ttype); + string prefix; + + if (ttype->is_base_type()) { + return base_type_name((t_base_type*)ttype, in_container); + } else if (ttype->is_enum()) { + return (in_container ? "Integer" : "int"); + } else if (ttype->is_map()) { + t_map* tmap = (t_map*) ttype; + if (in_init) { + prefix = "HashMap"; + } else { + prefix = "Map"; + } + return prefix + "<" + + type_name(tmap->get_key_type(), true) + "," + + type_name(tmap->get_val_type(), true) + ">"; + } else if (ttype->is_set()) { + t_set* tset = (t_set*) ttype; + if (in_init) { + prefix = "HashSet<"; + } else { + prefix = "Set<"; + } + return prefix + type_name(tset->get_elem_type(), true) + ">"; + } else if (ttype->is_list()) { + t_list* tlist = (t_list*) ttype; + if (in_init) { + prefix = "ArrayList<"; + } else { + prefix = "List<"; + } + return prefix + type_name(tlist->get_elem_type(), true) + ">"; + } + + // Check for namespacing + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + string package = program->get_namespace("java"); + if (!package.empty()) { + return package + "." + ttype->get_name(); + } + } + + return ttype->get_name(); +} + +/** + * Returns the C++ type that corresponds to the thrift type. + * + * @param tbase The base type + * @param container Is it going in a Java container? + */ +string t_java_generator::base_type_name(t_base_type* type, + bool in_container) { + t_base_type::t_base tbase = type->get_base(); + + switch (tbase) { + case t_base_type::TYPE_VOID: + return "void"; + case t_base_type::TYPE_STRING: + if (type->is_binary()) { + return "byte[]"; + } else { + return "String"; + } + case t_base_type::TYPE_BOOL: + return (in_container ? "Boolean" : "boolean"); + case t_base_type::TYPE_BYTE: + return (in_container ? "Byte" : "byte"); + case t_base_type::TYPE_I16: + return (in_container ? "Short" : "short"); + case t_base_type::TYPE_I32: + return (in_container ? "Integer" : "int"); + case t_base_type::TYPE_I64: + return (in_container ? "Long" : "long"); + case t_base_type::TYPE_DOUBLE: + return (in_container ? "Double" : "double"); + default: + throw "compiler error: no C++ name for base type " + t_base_type::t_base_name(tbase); + } +} + +/** + * Declares a field, which may include initialization as necessary. + * + * @param ttype The type + */ +string t_java_generator::declare_field(t_field* tfield, bool init) { + // TODO(mcslee): do we ever need to initialize the field? + string result = type_name(tfield->get_type()) + " " + tfield->get_name(); + if (init) { + t_type* ttype = get_true_type(tfield->get_type()); + if (ttype->is_base_type() && tfield->get_value() != NULL) { + ofstream dummy; + result += " = " + render_const_value(dummy, tfield->get_name(), ttype, tfield->get_value()); + } else if (ttype->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + result += " = null"; + break; + case t_base_type::TYPE_BOOL: + result += " = false"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + result += " = 0"; + break; + case t_base_type::TYPE_DOUBLE: + result += " = (double)0"; + break; + } + + } else if (ttype->is_enum()) { + result += " = 0"; + } else if (ttype->is_container()) { + result += " = new " + type_name(ttype, false, true) + "()"; + } else { + result += " = new " + type_name(ttype, false, true) + "()";; + } + } + return result + ";"; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_java_generator::function_signature(t_function* tfunction, + string prefix) { + t_type* ttype = tfunction->get_returntype(); + std::string result = + type_name(ttype) + " " + prefix + tfunction->get_name() + "(" + argument_list(tfunction->get_arglist()) + ") throws "; + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + result += type_name((*x_iter)->get_type(), false, false) + ", "; + } + result += "TException"; + return result; +} + +/** + * Renders a comma separated field list, with type names + */ +string t_java_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += type_name((*f_iter)->get_type()) + " " + (*f_iter)->get_name(); + } + return result; +} + +/** + * Converts the parse type to a C++ enum string for the given type. + */ +string t_java_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType.STRING"; + case t_base_type::TYPE_BOOL: + return "TType.BOOL"; + case t_base_type::TYPE_BYTE: + return "TType.BYTE"; + case t_base_type::TYPE_I16: + return "TType.I16"; + case t_base_type::TYPE_I32: + return "TType.I32"; + case t_base_type::TYPE_I64: + return "TType.I64"; + case t_base_type::TYPE_DOUBLE: + return "TType.DOUBLE"; + } + } else if (type->is_enum()) { + return "TType.I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType.STRUCT"; + } else if (type->is_map()) { + return "TType.MAP"; + } else if (type->is_set()) { + return "TType.SET"; + } else if (type->is_list()) { + return "TType.LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +/** + * Applies the correct style to a string based on the value of nocamel_style_ + */ +std::string t_java_generator::get_cap_name(std::string name){ + if (nocamel_style_) { + return "_" + name; + } else { + name[0] = toupper(name[0]); + return name; + } +} + +string t_java_generator::constant_name(string name) { + string constant_name; + + bool is_first = true; + bool was_previous_char_upper = false; + for (string::iterator iter = name.begin(); iter != name.end(); ++iter) { + string::value_type character = (*iter); + + bool is_upper = isupper(character); + + if (is_upper && !is_first && !was_previous_char_upper) { + constant_name += '_'; + } + constant_name += toupper(character); + + is_first = false; + was_previous_char_upper = is_upper; + } + + return constant_name; +} + +/** + * Emits a JavaDoc comment if the provided object has a doc in Thrift + */ +void t_java_generator::generate_java_doc(ofstream &out, + t_doc* tdoc) { + if (tdoc->has_doc()) { + generate_docstring_comment(out, + "/**\n", + " * ", tdoc->get_doc(), + " */\n"); + } +} + +/** + * Emits a JavaDoc comment if the provided function object has a doc in Thrift + */ +void t_java_generator::generate_java_doc(ofstream &out, + t_function* tfunction) { + if (tfunction->has_doc()) { + stringstream ss; + ss << tfunction->get_doc(); + const vector& fields = tfunction->get_arglist()->get_members(); + vector::const_iterator p_iter; + for (p_iter = fields.begin(); p_iter != fields.end(); ++p_iter) { + t_field* p = *p_iter; + ss << "\n@param " << p->get_name(); + if (p->has_doc()) { + ss << " " << p->get_doc(); + } + } + generate_docstring_comment(out, + "/**\n", + " * ", ss.str(), + " */\n"); + } +} + +void t_java_generator::generate_deep_copy_container(ofstream &out, std::string source_name_p1, std::string source_name_p2, + std::string result_name, t_type* type) { + + t_container* container = (t_container*)type; + std::string source_name; + if (source_name_p2 == "") + source_name = source_name_p1; + else + source_name = source_name_p1 + "." + source_name_p2; + + indent(out) << type_name(type, true, false) << " " << result_name << " = new " << type_name(container, false, true) << "();" << endl; + + std::string iterator_element_name = source_name_p1 + "_element"; + std::string result_element_name = result_name + "_copy"; + + if(container->is_map()) { + t_type* key_type = ((t_map*)container)->get_key_type(); + t_type* val_type = ((t_map*)container)->get_val_type(); + + indent(out) << + "for (Map.Entry<" << type_name(key_type, true, false) << ", " << type_name(val_type, true, false) << "> " << iterator_element_name << " : " << source_name << ".entrySet()) {" << endl; + indent_up(); + + out << endl; + + indent(out) << type_name(key_type, true, false) << " " << iterator_element_name << "_key = " << iterator_element_name << ".getKey();" << endl; + indent(out) << type_name(val_type, true, false) << " " << iterator_element_name << "_value = " << iterator_element_name << ".getValue();" << endl; + + out << endl; + + if (key_type->is_container()) { + generate_deep_copy_container(out, iterator_element_name + "_key", "", result_element_name + "_key", key_type); + } else { + indent(out) << type_name(key_type, true, false) << " " << result_element_name << "_key = "; + generate_deep_copy_non_container(out, iterator_element_name + "_key", result_element_name + "_key", key_type); + out << ";" << endl; + } + + out << endl; + + if (val_type->is_container()) { + generate_deep_copy_container(out, iterator_element_name + "_value", "", result_element_name + "_value", val_type); + } else { + indent(out) << type_name(val_type, true, false) << " " << result_element_name << "_value = "; + generate_deep_copy_non_container(out, iterator_element_name + "_value", result_element_name + "_value", val_type); + out << ";" << endl; + } + + out << endl; + + indent(out) << result_name << ".put(" << result_element_name << "_key, " << result_element_name << "_value);" << endl; + + indent_down(); + indent(out) << "}" << endl; + + } else { + t_type* elem_type; + + if (container->is_set()) { + elem_type = ((t_set*)container)->get_elem_type(); + } else { + elem_type = ((t_list*)container)->get_elem_type(); + } + + indent(out) + << "for (" << type_name(elem_type, true, false) << " " << iterator_element_name << " : " << source_name << ") {" << endl; + + indent_up(); + + if (elem_type->is_container()) { + // recursive deep copy + generate_deep_copy_container(out, iterator_element_name, "", result_element_name, elem_type); + indent(out) << result_name << ".add(" << result_element_name << ");" << endl; + } else { + // iterative copy + if(((t_base_type*)elem_type)->is_binary()){ + indent(out) << "byte[] temp_binary_element = "; + generate_deep_copy_non_container(out, iterator_element_name, "temp_binary_element", elem_type); + out << ";" << endl; + indent(out) << result_name << ".add(temp_binary_element);" << endl; + } + else{ + indent(out) << result_name << ".add("; + generate_deep_copy_non_container(out, iterator_element_name, result_name, elem_type); + out << ");" << endl; + } + } + + indent_down(); + + indent(out) << "}" << endl; + + } +} + +void t_java_generator::generate_deep_copy_non_container(ofstream& out, std::string source_name, std::string dest_name, t_type* type) { + if (type->is_base_type() || type->is_enum() || type->is_typedef()) { + // binary fields need to be copied with System.arraycopy + if (((t_base_type*)type)->is_binary()){ + out << "new byte[" << source_name << ".length];" << endl; + indent(out) << "System.arraycopy(" << source_name << ", 0, " << dest_name << ", 0, " << source_name << ".length)"; + } + // everything else can be copied directly + else + out << source_name; + } else { + out << "new " << type_name(type, true, true) << "(" << source_name << ")"; + } +} + +std::string t_java_generator::generate_isset_check(t_field* field) { + return generate_isset_check(field->get_name()); +} + +std::string t_java_generator::generate_isset_check(std::string field_name) { + return "is" + get_cap_name("set") + get_cap_name(field_name) + "()"; +} + +void t_java_generator::generate_isset_set(ofstream& out, t_field* field) { + if (!type_can_be_null(field->get_type())) { + indent(out) << "this.__isset." << field->get_name() << " = true;" << endl; + } +} + +std::string t_java_generator::get_enum_class_name(t_type* type) { + string package = ""; + t_program* program = type->get_program(); + if (program != NULL && program != program_) { + package = program->get_namespace("java") + "."; + } + return package + type->get_name(); +} + +THRIFT_REGISTER_GENERATOR(java, "Java", +" beans: Generate bean-style output files.\n" +" nocamel: Do not use CamelCase field accessors with beans.\n" +" hashcode: Generate quality hashCode methods.\n" +); diff --git a/compiler/cpp/src/generate/t_ocaml_generator.cc b/compiler/cpp/src/generate/t_ocaml_generator.cc new file mode 100644 index 00000000..0405a04f --- /dev/null +++ b/compiler/cpp/src/generate/t_ocaml_generator.cc @@ -0,0 +1,1673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * OCaml code generator. + * + */ +class t_ocaml_generator : public t_oop_generator { + public: + t_ocaml_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-ocaml"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + void generate_program (); + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_ocaml_struct(t_struct* tstruct, bool is_exception); + void generate_ocaml_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void generate_ocaml_struct_sig(std::ofstream& out, t_struct* tstruct, bool is_exception); + void generate_ocaml_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_ocaml_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_ocaml_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset); + + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + void generate_deserialize_type (std::ofstream &out, + t_type* type); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string name= ""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + /** + * Helper rendering functions + */ + + std::string ocaml_autogen_comment(); + std::string ocaml_imports(); + std::string type_name(t_type* ttype); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string function_type(t_function* tfunc, bool method=false, bool options = false); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string render_ocaml_type(t_type* type); + + + private: + + /** + * File streams + */ + + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_service_; + + std::ofstream f_types_i_; + std::ofstream f_service_i_; + +}; + + +/* + * This is necessary because we want typedefs to appear later, + * after all the types have been declared. + */ +void t_ocaml_generator::generate_program() { + // Initialize the generator + init_generator(); + + // Generate enums + vector enums = program_->get_enums(); + vector::iterator en_iter; + for (en_iter = enums.begin(); en_iter != enums.end(); ++en_iter) { + generate_enum(*en_iter); + } + + // Generate structs + vector structs = program_->get_structs(); + vector::iterator st_iter; + for (st_iter = structs.begin(); st_iter != structs.end(); ++st_iter) { + generate_struct(*st_iter); + } + + // Generate xceptions + vector xceptions = program_->get_xceptions(); + vector::iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + generate_xception(*x_iter); + } + + // Generate typedefs + vector typedefs = program_->get_typedefs(); + vector::iterator td_iter; + for (td_iter = typedefs.begin(); td_iter != typedefs.end(); ++td_iter) { + generate_typedef(*td_iter); + } + + // Generate services + vector services = program_->get_services(); + vector::iterator sv_iter; + for (sv_iter = services.begin(); sv_iter != services.end(); ++sv_iter) { + service_name_ = get_service_name(*sv_iter); + generate_service(*sv_iter); + } + + // Generate constants + vector consts = program_->get_consts(); + generate_consts(consts); + + // Close the generator + close_generator(); +} + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_ocaml_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + string f_types_name = get_out_dir()+program_name_+"_types.ml"; + f_types_.open(f_types_name.c_str()); + string f_types_i_name = get_out_dir()+program_name_+"_types.mli"; + f_types_i_.open(f_types_i_name.c_str()); + + string f_consts_name = get_out_dir()+program_name_+"_consts.ml"; + f_consts_.open(f_consts_name.c_str()); + + // Print header + f_types_ << + ocaml_autogen_comment() << endl << + ocaml_imports() << endl; + f_types_i_ << + ocaml_autogen_comment() << endl << + ocaml_imports() << endl; + f_consts_ << + ocaml_autogen_comment() << endl << + ocaml_imports() << endl << + "open " << capitalize(program_name_)<<"_types"<< endl; +} + + +/** + * Autogen'd comment + */ +string t_ocaml_generator::ocaml_autogen_comment() { + return + std::string("(*\n") + + " Autogenerated by Thrift\n" + + "\n" + + " DO NOT EDIT UNLESS YOU ARE SURE YOU KNOW WHAT YOU ARE DOING\n" + + "*)\n"; +} + +/** + * Prints standard thrift imports + */ +string t_ocaml_generator::ocaml_imports() { + return "open Thrift"; +} + +/** + * Closes the type files + */ +void t_ocaml_generator::close_generator() { + // Close types file + f_types_.close(); +} + +/** + * Generates a typedef. Ez. + * + * @param ttypedef The type definition + */ +void t_ocaml_generator::generate_typedef(t_typedef* ttypedef) { + f_types_ << + indent() << "type "<< decapitalize(ttypedef->get_symbolic()) << " = " << render_ocaml_type(ttypedef->get_type()) << endl << endl; + f_types_i_ << + indent() << "type "<< decapitalize(ttypedef->get_symbolic()) << " = " << render_ocaml_type(ttypedef->get_type()) << endl << endl; +} + +/** + * Generates code for an enumerated type. + * the values. + * + * @param tenum The enumeration + */ +void t_ocaml_generator::generate_enum(t_enum* tenum) { + indent(f_types_) << "module " << capitalize(tenum->get_name()) << " = " << endl << "struct" << endl; + indent(f_types_i_) << "module " << capitalize(tenum->get_name()) << " : " << endl << "sig" << endl; + indent_up(); + indent(f_types_) << "type t = " << endl; + indent(f_types_i_) << "type t = " << endl; + indent_up(); + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + string name = capitalize((*c_iter)->get_name()); + indent(f_types_) << "| " << name << endl; + indent(f_types_i_) << "| " << name << endl; + } + indent_down(); + + indent(f_types_) << "let to_i = function" << endl; + indent(f_types_i_) << "val to_i : t -> int" << endl; + indent_up(); + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + string name = capitalize((*c_iter)->get_name()); + + f_types_ << + indent() << "| " << name << " -> " << value << endl; + } + indent_down(); + + indent(f_types_) << "let of_i = function" << endl; + indent(f_types_i_) << "val of_i : int -> t" << endl; + indent_up(); + for(c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + string name = capitalize((*c_iter)->get_name()); + + f_types_ << + indent() << "| " << value << " -> " << name << endl; + } + indent(f_types_) << "| _ -> raise Thrift_error" << endl; + indent_down(); + indent_down(); + indent(f_types_) << "end" << endl; + indent(f_types_i_) << "end" << endl; +} + +/** + * Generate a constant value + */ +void t_ocaml_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = decapitalize(tconst->get_name()); + t_const_value* value = tconst->get_value(); + + indent(f_consts_) << "let " << name << " = " << render_const_value(type, value) << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_ocaml_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + out << value->get_integer(); + break; + case t_base_type::TYPE_I64: + out << value->get_integer() << "L"; + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + t_enum* tenum = (t_enum*)type; + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int val = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + val = (*c_iter)->get_value(); + } else { + ++val; + } + if(val == value->get_integer()){ + indent(out) << capitalize(tenum->get_name()) << "." << capitalize((*c_iter)->get_name()); + break; + } + } + } else if (type->is_struct() || type->is_xception()) { + string cname = type_name(type); + string ct = tmp("_c"); + out << endl; + indent_up(); + indent(out) << "(let " << ct << " = new " << cname << " in" << endl; + indent_up(); + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + string fname = v_iter->first->get_string(); + out << indent(); + out << ct <<"#set_" << fname << " "; + out << render_const_value(field_type, v_iter->second); + out << ";" << endl; + } + indent(out) << ct << ")"; + indent_down(); + indent_down(); + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + string hm = tmp("_hm"); + out << endl; + indent_up(); + indent(out) << "(let " << hm << " = Hashtbl.create " << val.size() << " in" << endl; + indent_up(); + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string key = render_const_value(ktype, v_iter->first); + string val = render_const_value(vtype, v_iter->second); + indent(out) << "Hashtbl.add " << hm << " " << key << " " << val << ";" << endl; + } + indent(out) << hm << ")"; + indent_down(); + indent_down(); + } else if (type->is_list()) { + t_type* etype; + etype = ((t_list*)type)->get_elem_type(); + out << "[" << endl; + indent_up(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(etype, *v_iter); + out << ";" << endl; + } + indent_down(); + indent(out) << "]"; + } else if (type->is_set()) { + t_type* etype = ((t_set*)type)->get_elem_type(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + string hm = tmp("_hm"); + indent(out) << "(let " << hm << " = Hashtbl.create " << val.size() << " in" << endl; + indent_up(); + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + string val = render_const_value(etype, *v_iter); + indent(out) << "Hashtbl.add " << hm << " " << val << " true;" << endl; + } + indent(out) << hm << ")" << endl; + indent_down(); + out << endl; + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + return out.str(); +} + +/** + * Generates a "struct" + */ +void t_ocaml_generator::generate_struct(t_struct* tstruct) { + generate_ocaml_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct, but also has an exception declaration. + * + * @param txception The struct definition + */ +void t_ocaml_generator::generate_xception(t_struct* txception) { + generate_ocaml_struct(txception, true); +} + +/** + * Generates an OCaml struct + */ +void t_ocaml_generator::generate_ocaml_struct(t_struct* tstruct, + bool is_exception) { + generate_ocaml_struct_definition(f_types_, tstruct, is_exception); + generate_ocaml_struct_sig(f_types_i_,tstruct,is_exception); +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_ocaml_generator::generate_ocaml_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + string tname = type_name(tstruct); + indent(out) << "class " << tname << " =" << endl; + indent(out) << "object (self)" << endl; + + indent_up(); + + string x = tmp("_x"); + if (members.size() > 0) { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + string mname = decapitalize((*m_iter)->get_name()); + indent(out) << "val mutable _" << mname << " : " << render_ocaml_type((*m_iter)->get_type()) << " option = None" << endl; + indent(out) << "method get_" << mname << " = _" << mname << endl; + indent(out) << "method grab_" << mname << " = match _"<raise (Field_empty \""< " << x << endl; + indent(out) << "method set_" << mname << " " << x << " = _" << mname << " <- Some " << x << endl; + } + } + generate_ocaml_struct_writer(out, tstruct); + indent_down(); + indent(out) << "end" << endl; + + if(is_exception){ + indent(out) << "exception " << capitalize(tname) <<" of " << tname << endl; + } + + generate_ocaml_struct_reader(out, tstruct); +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_ocaml_generator::generate_ocaml_struct_sig(ofstream& out, + t_struct* tstruct, + bool is_exception) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + string tname = type_name(tstruct); + indent(out) << "class " << tname << " :" << endl; + indent(out) << "object" << endl; + + indent_up(); + + string x = tmp("_x"); + if (members.size() > 0) { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + string mname = decapitalize((*m_iter)->get_name()); + string type = render_ocaml_type((*m_iter)->get_type()); + indent(out) << "method get_" << mname << " : " << type << " option" << endl; + indent(out) << "method grab_" << mname << " : " << type << endl; + indent(out) << "method set_" << mname << " : " << type << " -> unit" << endl; + } + } + indent(out) << "method write : Protocol.t -> unit" << endl; + indent_down(); + indent(out) << "end" << endl; + + if(is_exception){ + indent(out) << "exception " << capitalize(tname) <<" of " << tname << endl; + } + + indent(out) << "val read_" << tname << " : Protocol.t -> " << tname << endl; +} + +/** + * Generates the read method for a struct + */ +void t_ocaml_generator::generate_ocaml_struct_reader(ofstream& out, t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + string sname = type_name(tstruct); + string str = tmp("_str"); + string t = tmp("_t"); + string id = tmp("_id"); + indent(out) << + "let rec read_" << sname << " (iprot : Protocol.t) =" << endl; + indent_up(); + indent(out) << "let " << str << " = new " << sname << " in" << endl; + indent_up(); + indent(out) << + "ignore(iprot#readStructBegin);" << endl; + + // Loop over reading in fields + indent(out) << + "(try while true do" << endl; + indent_up(); + indent_up(); + + // Read beginning field marker + indent(out) << + "let (_," << t <<","<get_key() << " -> ("; + out << "if " << t <<" = " << type_to_enum((*f_iter)->get_type()) << " then" << endl; + indent_up(); + indent_up(); + generate_deserialize_field(out, *f_iter,str); + indent_down(); + out << + indent() << "else" << endl << + indent() << " iprot#skip "<< t << ")" << endl; + indent_down(); + } + + // In the default case we skip the field + out << + indent() << "| _ -> " << "iprot#skip "< ());" << endl; + + indent(out) << + "iprot#readStructEnd;" << endl; + + indent(out) << str << endl << endl; + indent_down(); + indent_down(); +} + +void t_ocaml_generator::generate_ocaml_struct_writer(ofstream& out, + t_struct* tstruct) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + string str = tmp("_str"); + string f = tmp("_f"); + + indent(out) << + "method write (oprot : Protocol.t) =" << endl; + indent_up(); + indent(out) << + "oprot#writeStructBegin \""<get_name()); + indent(out) << + "(match " << mname << " with None -> () | Some _v -> " << endl; + indent_up(); + indent(out) << "oprot#writeFieldBegin(\""<< (*f_iter)->get_name()<<"\"," + <get_type())<<"," + <<(*f_iter)->get_key()<<");" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "_v"); + + // Write field closer + indent(out) << "oprot#writeFieldEnd" << endl; + + indent_down(); + indent(out) << ");" << endl; + } + + // Write the struct map + out << + indent() << "oprot#writeFieldStop;" << endl << + indent() << "oprot#writeStructEnd" << endl; + + indent_down(); +} + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_ocaml_generator::generate_service(t_service* tservice) { + string f_service_name = get_out_dir()+capitalize(service_name_)+".ml"; + f_service_.open(f_service_name.c_str()); + string f_service_i_name = get_out_dir()+capitalize(service_name_)+".mli"; + f_service_i_.open(f_service_i_name.c_str()); + + f_service_ << + ocaml_autogen_comment() << endl << + ocaml_imports() << endl; + f_service_i_ << + ocaml_autogen_comment() << endl << + ocaml_imports() << endl; + + /* if (tservice->get_extends() != NULL) { + f_service_ << + "open " << capitalize(tservice->get_extends()->get_name()) << endl; + f_service_i_ << + "open " << capitalize(tservice->get_extends()->get_name()) << endl; + } + */ + f_service_ << + "open " << capitalize(program_name_) << "_types" << endl << + endl; + + f_service_i_ << + "open " << capitalize(program_name_) << "_types" << endl << + endl; + + // Generate the three main parts of the service + generate_service_helpers(tservice); + generate_service_interface(tservice); + generate_service_client(tservice); + generate_service_server(tservice); + + + // Close service file + f_service_.close(); + f_service_i_.close(); +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_ocaml_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + indent(f_service_) << + "(* HELPER FUNCTIONS AND STRUCTURES *)" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_ocaml_struct_definition(f_service_, ts, false); + generate_ocaml_function_helpers(*f_iter); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_ocaml_generator::generate_ocaml_function_helpers(t_function* tfunction) { + t_struct result(program_, decapitalize(tfunction->get_name()) + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + generate_ocaml_struct_definition(f_service_, &result, false); +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_ocaml_generator::generate_service_interface(t_service* tservice) { + f_service_ << + indent() << "class virtual iface =" << endl << "object (self)" << endl; + f_service_i_ << + indent() << "class virtual iface :" << endl << "object" << endl; + + indent_up(); + + if (tservice->get_extends() != NULL) { + string extends = type_name(tservice->get_extends()); + indent(f_service_) << "inherit " << extends << ".iface" << endl; + indent(f_service_i_) << "inherit " << extends << ".iface" << endl; + } + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string ft = function_type(*f_iter,true,true); + f_service_ << + indent() << "method virtual " << decapitalize((*f_iter)->get_name()) << " : " << ft << endl; + f_service_i_ << + indent() << "method virtual " << decapitalize((*f_iter)->get_name()) << " : " << ft << endl; + } + indent_down(); + indent(f_service_) << "end" << endl << endl; + indent(f_service_i_) << "end" << endl << endl; +} + +/** + * Generates a service client definition. Note that in OCaml, the client doesn't implement iface. This is because + * The client does not (and should not have to) deal with arguments being None. + * + * @param tservice The service to generate a server for. + */ +void t_ocaml_generator::generate_service_client(t_service* tservice) { + string extends = ""; + indent(f_service_) << + "class client (iprot : Protocol.t) (oprot : Protocol.t) =" << endl << "object (self)" << endl; + indent(f_service_i_) << + "class client : Protocol.t -> Protocol.t -> " << endl << "object" << endl; + indent_up(); + + + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + indent(f_service_) << "inherit " << extends << ".client iprot oprot as super" << endl; + indent(f_service_i_) << "inherit " << extends << ".client" << endl; + } + indent(f_service_) << "val mutable seqid = 0" << endl; + + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + // Open function + indent(f_service_) << + "method " << function_signature(*f_iter) << " = " << endl; + indent(f_service_i_) << + "method " << decapitalize((*f_iter)->get_name()) << " : " << function_type(*f_iter,true,false) << endl; + indent_up(); + indent(f_service_) << + "self#send_" << funname; + + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << " " << decapitalize((*fld_iter)->get_name()); + } + f_service_ << ";" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + f_service_ << + "self#recv_" << funname << endl; + } + indent_down(); + + indent(f_service_) << + "method private send_" << function_signature(*f_iter) << " = " << endl; + indent_up(); + + std::string argsname = decapitalize((*f_iter)->get_name() + "_args"); + + // Serialize the request header + f_service_ << + indent() << "oprot#writeMessageBegin (\"" << (*f_iter)->get_name() << "\", Protocol.CALL, seqid);" << endl; + + f_service_ << + indent() << "let args = new " << argsname << " in" << endl; + indent_up(); + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "args#set_" << (*fld_iter)->get_name() << " " << (*fld_iter)->get_name() << ";" << endl; + } + + // Write to the stream + f_service_ << + indent() << "args#write oprot;" << endl << + indent() << "oprot#writeMessageEnd;" << endl << + indent() << "oprot#getTransport#flush" << endl; + + indent_down(); + indent_down(); + + if (!(*f_iter)->is_oneway()) { + std::string resultname = decapitalize((*f_iter)->get_name() + "_result"); + t_struct noargs(program_); + + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + // Open function + f_service_ << + indent() << "method private " << function_signature(&recv_function) << " =" << endl; + indent_up(); + + // TODO(mcslee): Validate message reply here, seq ids etc. + + f_service_ << + indent() << "let (fname, mtype, rseqid) = iprot#readMessageBegin in" << endl; + indent_up(); + f_service_ << + indent() << "(if mtype = Protocol.EXCEPTION then" << endl << + indent() << " let x = Application_Exn.read iprot in" << endl; + indent_up(); + f_service_ << + indent() << " (iprot#readMessageEnd;" << + indent() << " raise (Application_Exn.E x))" << endl; + indent_down(); + f_service_ << + indent() << "else ());" << endl; + string res = "_"; + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + + if (!(*f_iter)->get_returntype()->is_void() || xceptions.size() > 0) { + res = "result"; + } + f_service_ << + indent() << "let "<get_returntype()->is_void()) { + f_service_ << + indent() << "match result#get_success with Some v -> v | None -> (" << endl; + indent_up(); + } + + + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "(match result#get_" << (*x_iter)->get_name() << " with None -> () | Some _v ->" << endl; + indent(f_service_) << " raise (" << capitalize(type_name((*x_iter)->get_type())) << " _v));" << endl; + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "()" << endl; + } else { + f_service_ << + indent() << "raise (Application_Exn.E (Application_Exn.create Application_Exn.MISSING_RESULT \"" << (*f_iter)->get_name() << " failed: unknown result\")))" << endl; + indent_down(); + } + + // Close function + indent_down(); + indent_down(); + indent_down(); + } + } + + indent_down(); + indent(f_service_) << "end" << endl << endl; + indent(f_service_i_) << "end" << endl << endl; +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_ocaml_generator::generate_service_server(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + + // Generate the header portion + indent(f_service_) << + "class processor (handler : iface) =" << endl << indent() << "object (self)" << endl; + indent(f_service_i_) << + "class processor : iface ->" << endl << indent() << "object" << endl; + indent_up(); + + f_service_ << + indent() << "inherit Processor.t" << endl << + endl; + f_service_i_ << + indent() << "inherit Processor.t" << endl << + endl; + string extends = ""; + + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + indent(f_service_) << "inherit " + extends + ".processor (handler :> " + extends + ".iface)" << endl; + indent(f_service_i_) << "inherit " + extends + ".processor" << endl; + } + + if (extends.empty()) { + indent(f_service_) << "val processMap = Hashtbl.create " << functions.size() << endl; + } + indent(f_service_i_) << "val processMap : (string, int * Protocol.t * Protocol.t -> unit) Hashtbl.t" << endl; + + // Generate the server implementation + indent(f_service_) << + "method process iprot oprot =" << endl; + indent(f_service_i_) << + "method process : Protocol.t -> Protocol.t -> bool" << endl; + indent_up(); + + f_service_ << + indent() << "let (name, typ, seqid) = iprot#readMessageBegin in" << endl; + indent_up(); + // TODO(mcslee): validate message + + // HOT: dictionary function lookup + f_service_ << + indent() << "if Hashtbl.mem processMap name then" << endl << + indent() << " (Hashtbl.find processMap name) (seqid, iprot, oprot)" << endl << + indent() << "else (" << endl << + indent() << " iprot#skip(Protocol.T_STRUCT);" << endl << + indent() << " iprot#readMessageEnd;" << endl << + indent() << " let x = Application_Exn.create Application_Exn.UNKNOWN_METHOD (\"Unknown function \"^name) in" << endl << + indent() << " oprot#writeMessageBegin(name, Protocol.EXCEPTION, seqid);" << endl << + indent() << " x#write oprot;" << endl << + indent() << " oprot#writeMessageEnd;" << endl << + indent() << " oprot#getTransport#flush" << endl << + indent() << ");" << endl; + + // Read end of args field, the T_STOP, and the struct close + f_service_ << + indent() << "true" << endl; + indent_down(); + indent_down(); + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + indent(f_service_) << "initializer" << endl; + indent_up(); + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "Hashtbl.add processMap \"" << (*f_iter)->get_name() << "\" self#process_" << (*f_iter)->get_name() << ";" << endl; + } + indent_down(); + + indent_down(); + indent(f_service_) << "end" << endl << endl; + indent(f_service_i_) << "end" << endl << endl; +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_ocaml_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + indent(f_service_) << + "method private process_" << tfunction->get_name() << + " (seqid, iprot, oprot) =" << endl; + indent_up(); + + string argsname = decapitalize(tfunction->get_name()) + "_args"; + string resultname = decapitalize(tfunction->get_name()) + "_result"; + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + string args = "args"; + if(fields.size() == 0){ + args="_"; + } + + f_service_ << + indent() << "let "<get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << "let result = new " << resultname << " in" << endl; + indent_up(); + } + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "(try" << endl; + indent_up(); + } + + + + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "result#set_success "; + } + f_service_ << + "(handler#" << tfunction->get_name(); + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + f_service_ << " args#get_" << (*f_iter)->get_name(); + } + f_service_ << ");" << endl; + + + if (xceptions.size() > 0) { + indent_down(); + indent(f_service_) << "with" <get_type())) << " " << (*x_iter)->get_name() << " -> " << endl; + indent_up(); + indent_up(); + if(!tfunction->is_oneway()){ + f_service_ << + indent() << "result#set_" << (*x_iter)->get_name() << " " << (*x_iter)->get_name() << endl; + } else { + indent(f_service_) << "()"; + } + indent_down(); + indent_down(); + } + indent_down(); + f_service_ << indent() << ");" << endl; + } + + + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "()" << endl; + indent_down(); + indent_down(); + return; + } + + f_service_ << + indent() << "oprot#writeMessageBegin (\"" << tfunction->get_name() << "\", Protocol.REPLY, seqid);" << endl << + indent() << "result#write oprot;" << endl << + indent() << "oprot#writeMessageEnd;" << endl << + indent() << "oprot#getTransport#flush" << endl; + + // Close function + indent_down(); + indent_down(); + indent_down(); +} + +/** + * Deserializes a field of any type. + */ +void t_ocaml_generator::generate_deserialize_field(ofstream &out, + t_field* tfield, + string prefix){ + t_type* type = tfield->get_type(); + + + string name = decapitalize(tfield->get_name()); + indent(out) << prefix << "#set_"<is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE"; + } + + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type); + } else if (type->is_container()) { + generate_deserialize_container(out, type); + } else if (type->is_base_type()) { + out << "iprot#"; + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct"; + break; + case t_base_type::TYPE_STRING: + out << "readString"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte"; + break; + case t_base_type::TYPE_I16: + out << "readI16"; + break; + case t_base_type::TYPE_I32: + out << "readI32"; + break; + case t_base_type::TYPE_I64: + out << "readI64"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + string ename = capitalize(type->get_name()); + out << "(" <get_name().c_str()); + } +} + + +/** + * Generates an unserializer for a struct, calling read() + */ +void t_ocaml_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct) { + string name = decapitalize(tstruct->get_name()); + out << "(read_" << name << " iprot)"; + +} + +/** + * Serialize a container by writing out the header followed by + * data and then a footer. + */ +void t_ocaml_generator::generate_deserialize_container(ofstream &out, + t_type* ttype) { + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + string con = tmp("_con"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + out << endl; + indent_up(); + // Declare variables, read header + if (ttype->is_map()) { + indent(out) << "(let ("<is_set()) { + indent(out) << "(let ("<get_elem_type()); + out << " true" << endl; + indent_down(); + indent(out) << "done; iprot#readSetEnd; "<is_list()) { + indent(out) << "(let ("< "; + generate_deserialize_type(out,((t_list*)ttype)->get_elem_type()); + out << "))) in" << endl; + indent_up(); + indent(out) << "iprot#readListEnd; "<get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + tfield->get_name(); + } + + if(name.length() == 0){ + name = decapitalize(tfield->get_name()); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + name); + } else if (type->is_base_type() || type->is_enum()) { + + + indent(out) << + "oprot#"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString(" << name << ")"; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool(" << name << ")"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte(" << name << ")"; + break; + case t_base_type::TYPE_I16: + out << "writeI16(" << name << ")"; + break; + case t_base_type::TYPE_I32: + out << "writeI32(" << name << ")"; + break; + case t_base_type::TYPE_I64: + out << "writeI64(" << name << ")"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble(" << name << ")"; + break; + default: + throw "compiler error: no ocaml name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + string ename = capitalize(type->get_name()); + out << "writeI32("<get_name().c_str(), + type->get_name().c_str()); + } + out << ";" << endl; +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_ocaml_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) << prefix << "#write(oprot)"; +} + +void t_ocaml_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + if (ttype->is_map()) { + indent(out) << "oprot#writeMapBegin("<< type_to_enum(((t_map*)ttype)->get_key_type()) << ","; + out << type_to_enum(((t_map*)ttype)->get_val_type()) << ","; + out << "Hashtbl.length " << prefix << ");" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot#writeSetBegin(" << type_to_enum(((t_set*)ttype)->get_elem_type()) << ","; + out << "Hashtbl.length " << prefix << ");" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot#writeListBegin(" << type_to_enum(((t_list*)ttype)->get_elem_type()) << ","; + out << "List.length " << prefix << ");" << endl; + } + + if (ttype->is_map()) { + string kiter = tmp("_kiter"); + string viter = tmp("_viter"); + indent(out) << "Hashtbl.iter (fun "< fun " << viter << " -> " << endl; + indent_up(); + generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); + indent_down(); + indent(out) << ") " << prefix << ";" << endl; + } else if (ttype->is_set()) { + string iter = tmp("_iter"); + indent(out) << "Hashtbl.iter (fun "< fun _ -> "; + indent_up(); + generate_serialize_set_element(out, (t_set*)ttype, iter); + indent_down(); + indent(out) << ") " << prefix << ";" << endl; + } else if (ttype->is_list()) { + string iter = tmp("_iter"); + indent(out) << "List.iter (fun "< "; + indent_up(); + generate_serialize_list_element(out, (t_list*)ttype, iter); + indent_down(); + indent(out) << ") " << prefix << ";" << endl; + } + + if (ttype->is_map()) { + indent(out) << + "oprot#writeMapEnd"; + } else if (ttype->is_set()) { + indent(out) << + "oprot#writeSetEnd"; + } else if (ttype->is_list()) { + indent(out) << + "oprot#writeListEnd"; + } +} + +/** + * Serializes the members of a map. + * + */ +void t_ocaml_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + generate_serialize_field(out, &kfield); + + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield); +} + +/** + * Serializes the members of a set. + */ +void t_ocaml_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + +/** + * Serializes the members of a list. + */ +void t_ocaml_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + + + +/** + * Renders a function signature of the form 'name args' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_ocaml_generator::function_signature(t_function* tfunction, + string prefix) { + return + prefix + decapitalize(tfunction->get_name()) + + " " + argument_list(tfunction->get_arglist()); +} + +string t_ocaml_generator::function_type(t_function* tfunc, bool method, bool options){ + string result=""; + + const vector& fields = tfunc->get_arglist()->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result += render_ocaml_type((*f_iter)->get_type()); + if(options) + result += " option"; + result += " -> "; + } + if(fields.empty() && !method){ + result += "unit -> "; + } + result += render_ocaml_type(tfunc->get_returntype()); + return result; +} + +/** + * Renders a field list + */ +string t_ocaml_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += " "; + } + result += (*f_iter)->get_name(); + } + return result; +} + +string t_ocaml_generator::type_name(t_type* ttype) { + string prefix = ""; + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + if (!ttype->is_service()) { + prefix = capitalize(program->get_name()) + "_types."; + } + } + + string name = ttype->get_name(); + if(ttype->is_service()){ + name = capitalize(name); + } else { + name = decapitalize(name); + } + return prefix + name; +} + +/** + * Converts the parse type to a Protocol.t_type enum + */ +string t_ocaml_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + return "Protocol.T_VOID"; + case t_base_type::TYPE_STRING: + return "Protocol.T_STRING"; + case t_base_type::TYPE_BOOL: + return "Protocol.T_BOOL"; + case t_base_type::TYPE_BYTE: + return "Protocol.T_BYTE"; + case t_base_type::TYPE_I16: + return "Protocol.T_I16"; + case t_base_type::TYPE_I32: + return "Protocol.T_I32"; + case t_base_type::TYPE_I64: + return "Protocol.T_I64"; + case t_base_type::TYPE_DOUBLE: + return "Protocol.T_DOUBLE"; + } + } else if (type->is_enum()) { + return "Protocol.T_I32"; + } else if (type->is_struct() || type->is_xception()) { + return "Protocol.T_STRUCT"; + } else if (type->is_map()) { + return "Protocol.T_MAP"; + } else if (type->is_set()) { + return "Protocol.T_SET"; + } else if (type->is_list()) { + return "Protocol.T_LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +/** + * Converts the parse type to an ocaml type + */ +string t_ocaml_generator::render_ocaml_type(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + return "unit"; + case t_base_type::TYPE_STRING: + return "string"; + case t_base_type::TYPE_BOOL: + return "bool"; + case t_base_type::TYPE_BYTE: + return "int"; + case t_base_type::TYPE_I16: + return "int"; + case t_base_type::TYPE_I32: + return "int"; + case t_base_type::TYPE_I64: + return "Int64.t"; + case t_base_type::TYPE_DOUBLE: + return "float"; + } + } else if (type->is_enum()) { + return capitalize(((t_enum*)type)->get_name())+".t"; + } else if (type->is_struct() || type->is_xception()) { + return type_name((t_struct*)type); + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + return "("+render_ocaml_type(ktype)+","+render_ocaml_type(vtype)+") Hashtbl.t"; + } else if (type->is_set()) { + t_type* etype = ((t_set*)type)->get_elem_type(); + return "("+render_ocaml_type(etype)+",bool) Hashtbl.t"; + } else if (type->is_list()) { + t_type* etype = ((t_list*)type)->get_elem_type(); + return render_ocaml_type(etype)+" list"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +THRIFT_REGISTER_GENERATOR(ocaml, "OCaml", ""); diff --git a/compiler/cpp/src/generate/t_oop_generator.h b/compiler/cpp/src/generate/t_oop_generator.h new file mode 100644 index 00000000..bf757862 --- /dev/null +++ b/compiler/cpp/src/generate/t_oop_generator.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_OOP_GENERATOR_H +#define T_OOP_GENERATOR_H + +#include +#include + +#include "globals.h" +#include "t_generator.h" + +#include + +/** + * Class with utility methods shared across common object oriented languages. + * Specifically, most of this stuff is for C++/Java. + * + */ +class t_oop_generator : public t_generator { + public: + t_oop_generator(t_program* program) : + t_generator(program) {} + + /** + * Scoping, using curly braces! + */ + + void scope_up(std::ostream& out) { + indent(out) << "{" << std::endl; + indent_up(); + } + + void scope_down(std::ostream& out) { + indent_down(); + indent(out) << "}" << std::endl; + } + + std::string upcase_string(std::string original) { + std::transform(original.begin(), original.end(), original.begin(), (int(*)(int)) toupper); + return original; + } + + /** + * Generates a comment about this code being autogenerated, using C++ style + * comments, which are also fair game in Java / PHP, yay! + * + * @return C-style comment mentioning that this file is autogenerated. + */ + virtual std::string autogen_comment() { + return + std::string("/**\n") + + " * Autogenerated by Thrift\n" + + " *\n" + + " * DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + " */\n"; + } +}; + +#endif + diff --git a/compiler/cpp/src/generate/t_perl_generator.cc b/compiler/cpp/src/generate/t_perl_generator.cc new file mode 100644 index 00000000..ae204fd9 --- /dev/null +++ b/compiler/cpp/src/generate/t_perl_generator.cc @@ -0,0 +1,1815 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * PERL code generator. + * + */ +class t_perl_generator : public t_oop_generator { + public: + t_perl_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-perl"; + escape_['$'] = "\\$"; + escape_['@'] = "\\@"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Structs! + */ + + void generate_perl_struct(t_struct* tstruct, bool is_exception); + void generate_perl_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void generate_perl_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_perl_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_perl_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_rest (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_processor (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix="", + bool inclass=false); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream &out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + /** + * Helper rendering functions + */ + + std::string perl_includes(); + std::string declare_field(t_field* tfield, bool init=false, bool obj=false); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + + std::string autogen_comment() { + return + std::string("#\n") + + "# Autogenerated by Thrift\n" + + "#\n" + + "# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "#\n"; + } + + void perl_namespace_dirs(t_program* p, std::list& dirs) { + std::string ns = p->get_namespace("perl"); + std::string::size_type loc; + + if (ns.size() > 0) { + while ((loc = ns.find(".")) != std::string::npos) { + dirs.push_back(ns.substr(0, loc)); + ns = ns.substr(loc+1); + } + } + + if (ns.size() > 0) { + dirs.push_back(ns); + } + } + + std::string perl_namespace(t_program* p) { + std::string ns = p->get_namespace("perl"); + std::string result = ""; + std::string::size_type loc; + + if (ns.size() > 0) { + while ((loc = ns.find(".")) != std::string::npos) { + result += ns.substr(0, loc); + result += "::"; + ns = ns.substr(loc+1); + } + + if (ns.size() > 0) { + result += ns + "::"; + } + } + + return result; + } + + std::string get_namespace_out_dir() { + std::string outdir = get_out_dir(); + std::list dirs; + perl_namespace_dirs(program_, dirs); + std::list::iterator it; + for (it = dirs.begin(); it != dirs.end(); it++) { + outdir += *it + "/"; + } + return outdir; + } + + private: + + /** + * File streams + */ + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_helpers_; + std::ofstream f_service_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_perl_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + string outdir = get_out_dir(); + std::list dirs; + perl_namespace_dirs(program_, dirs); + std::list::iterator it; + for (it = dirs.begin(); it != dirs.end(); it++) { + outdir += *it + "/"; + MKDIR(outdir.c_str()); + } + + // Make output file + string f_types_name = outdir+"Types.pm"; + f_types_.open(f_types_name.c_str()); + string f_consts_name = outdir+"Constants.pm"; + f_consts_.open(f_consts_name.c_str()); + + // Print header + f_types_ << + autogen_comment() << + perl_includes(); + + // Print header + f_consts_ << + autogen_comment() << + "package "<< perl_namespace(program_) <<"Constants;"<get_name()<<";"< constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + f_types_ << "use constant "<<(*c_iter)->get_name() << " => " << value << ";" << endl; + } +} + +/** + * Generate a constant value + */ +void t_perl_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + f_consts_ << "use constant " << name << " => "; + f_consts_ << render_const_value(type, value); + f_consts_ << ";" << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_perl_generator::render_const_value(t_type* type, t_const_value* value) { + std::ostringstream out; + + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "1" : "0"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << "new " << perl_namespace(type->get_program()) << type->get_name() << "({" << endl; + indent_up(); + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + out << render_const_value(g_type_string, v_iter->first); + out << " => "; + out << render_const_value(field_type, v_iter->second); + out << ","; + out << endl; + } + + out << "})"; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + out << "{" << endl; + + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << render_const_value(ktype, v_iter->first); + out << " => "; + out << render_const_value(vtype, v_iter->second); + out << "," << endl; + } + + out << "}"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + out << "[" << endl; + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + + out << render_const_value(etype, *v_iter); + if (type->is_set()) { + out << " => 1"; + } + out << "," << endl; + } + out << "]"; + } + return out.str(); +} + +/** + * Make a struct + */ +void t_perl_generator::generate_struct(t_struct* tstruct) { + generate_perl_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_perl_generator::generate_xception(t_struct* txception) { + generate_perl_struct(txception, true); +} + +/** + * Structs can be normal or exceptions. + */ +void t_perl_generator::generate_perl_struct(t_struct* tstruct, + bool is_exception) { + generate_perl_struct_definition(f_types_, tstruct, is_exception); +} + +/** + * Generates a struct definition for a thrift data type. This is nothing in PERL + * where the objects are all just associative arrays (unless of course we + * decide to start using objects for them...) + * + * @param tstruct The struct definition + */ +void t_perl_generator::generate_perl_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + out << + "package " << perl_namespace(tstruct->get_program()) << tstruct->get_name() <<";\n"; + if (is_exception) { + out << "use base('Thrift::TException');\n"; + } + + //Create simple acessor methods + out << "use Class::Accessor;\n"; + out << "use base('Class::Accessor');\n"; + + if (members.size() > 0) { + out << perl_namespace(tstruct->get_program()) << tstruct->get_name() <<"->mk_accessors( qw( "; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if (!t->is_xception()) { + out << (*m_iter)->get_name() << " "; + } + } + + out << ") );\n"; + } + + + // new() + out << "sub new {\n"; + indent_up(); + out << "my $classname = shift;\n"; + out << "my $self = {};\n"; + out << "my $vals = shift || {};\n"; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + string dval = "undef"; + t_type* t = get_true_type((*m_iter)->get_type()); + if ((*m_iter)->get_value() != NULL && !(t->is_struct() || t->is_xception())) { + dval = render_const_value((*m_iter)->get_type(), (*m_iter)->get_value()); + } + out << + "$self->{" << (*m_iter)->get_name() << "} = " << dval << ";" << endl; + } + + // Generate constructor from array + if (members.size() > 0) { + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if ((*m_iter)->get_value() != NULL && (t->is_struct() || t->is_xception())) { + indent(out) << "$self->{" << (*m_iter)->get_name() << "} = " << render_const_value(t, (*m_iter)->get_value()) << ";" << endl; + } + } + + out << indent() << "if (UNIVERSAL::isa($vals,'HASH')) {" << endl; + indent_up(); + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << + indent() << "if (defined $vals->{" << (*m_iter)->get_name() << "}) {" << endl << + indent() << " $self->{" << (*m_iter)->get_name() << "} = $vals->{" << (*m_iter)->get_name() << "};" << endl << + indent() << "}" << endl; + } + indent_down(); + out << + indent() << "}" << endl; + + } + + out << "return bless($self,$classname);\n"; + indent_down(); + out << "}\n\n"; + + out << + "sub getName {" << endl << + indent() << " return '" << tstruct->get_name() << "';" << endl << + indent() << "}" << endl << + endl; + + generate_perl_struct_reader(out, tstruct); + generate_perl_struct_writer(out, tstruct); + +} + +/** + * Generates the read() method for a struct + */ +void t_perl_generator::generate_perl_struct_reader(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + out << "sub read {" <readStructBegin(\\$fname);" << endl; + + + // Loop over reading in fields + indent(out) << "while (1) " << endl; + + scope_up(out); + + indent(out) << "$xfer += $input->readFieldBegin(\\$fname, \\$ftype, \\$fid);" << endl; + + // Check for field STOP marker and break + indent(out) << "if ($ftype == TType::STOP) {" << endl; + indent_up(); + indent(out) << "last;" << endl; + indent_down(); + indent(out) << "}" << endl; + + // Switch statement on the field we are reading + indent(out) << "SWITCH: for($fid)" << endl; + + scope_up(out); + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + + indent(out) << "/^" << (*f_iter)->get_key() << "$/ && do{"; + indent(out) << "if ($ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + + indent_up(); + generate_deserialize_field(out, *f_iter, "self->"); + indent_down(); + + indent(out) << "} else {" << endl; + + indent(out) << " $xfer += $input->skip($ftype);" << endl; + + out << + indent() << "}" << endl << + indent() << "last; };" << endl; + + } + // In the default case we skip the field + + indent(out) << " $xfer += $input->skip($ftype);" << endl; + + scope_down(out); + + indent(out) << "$xfer += $input->readFieldEnd();" << endl; + + scope_down(out); + + indent(out) << "$xfer += $input->readStructEnd();" << endl; + + indent(out) << "return $xfer;" << endl; + + indent_down(); + out << indent() << "}" << endl << endl; +} + +/** + * Generates the write() method for a struct + */ +void t_perl_generator::generate_perl_struct_writer(ofstream& out, + t_struct* tstruct) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + out << "sub write {" << endl; + + indent_up(); + indent(out) << "my $self = shift;"<writeStructBegin('" << name << "');" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + out << indent() << "if (defined $self->{" << (*f_iter)->get_name() << "}) {" << endl; + indent_up(); + + indent(out) << + "$xfer += $output->writeFieldBegin(" << + "'" << (*f_iter)->get_name() << "', " << + type_to_enum((*f_iter)->get_type()) << ", " << + (*f_iter)->get_key() << ");" << endl; + + + // Write field contents + generate_serialize_field(out, *f_iter, "self->"); + + indent(out) << + "$xfer += $output->writeFieldEnd();" << endl; + + indent_down(); + indent(out) << "}" << endl; + } + + + out << + indent() << "$xfer += $output->writeFieldStop();" << endl << + indent() << "$xfer += $output->writeStructEnd();" << endl; + + out <get_program()) << "Types;" << endl; + + t_service* extends_s = tservice->get_extends(); + if (extends_s != NULL) { + f_service_ << + "use " << perl_namespace(extends_s->get_program()) << extends_s->get_name() << ";" << endl; + } + + f_service_ << + endl; + + // Generate the three main parts of the service (well, two for now in PERL) + generate_service_helpers(tservice); + generate_service_interface(tservice); + generate_service_rest(tservice); + generate_service_client(tservice); + generate_service_processor(tservice); + + // Close service file + f_service_ << "1;" << endl; + f_service_.close(); +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_perl_generator::generate_service_processor(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + t_service* extends_s = tservice->get_extends(); + if (extends_s != NULL) { + extends = perl_namespace(extends_s->get_program()) + extends_s->get_name(); + extends_processor = "use base('" + extends + "Processor');"; + } + + indent_up(); + + // Generate the header portion + f_service_ << + "package " << perl_namespace(program_) << service_name_ << "Processor;" << endl << extends_processor << endl; + + + if (extends.empty()) { + f_service_ << "sub new {" << endl; + + indent_up(); + + f_service_ << + indent() << "my $classname = shift;"<< endl << + indent() << "my $handler = shift;"<< endl << + indent() << "my $self = {};" << endl; + + f_service_ << + indent() << "$self->{handler} = $handler;" << endl; + + f_service_ << + indent() << "return bless($self,$classname);"<readMessageBegin(\\$fname, \\$mtype, \\$rseqid);" << endl; + + // HOT: check for method implementation + f_service_ << + indent() << "my $methodname = 'process_'.$fname;" << endl << + indent() << "if (!method_exists($self, $methodname)) {" << endl; + + f_service_ << + indent() << " $input->skip(TType::STRUCT);" << endl << + indent() << " $input->readMessageEnd();" << endl << + indent() << " my $x = new TApplicationException('Function '.$fname.' not implemented.', TApplicationException::UNKNOWN_METHOD);" << endl << + indent() << " $output->writeMessageBegin($fname, TMessageType::EXCEPTION, $rseqid);" << endl << + indent() << " $x->write($output);" << endl << + indent() << " $output->writeMessageEnd();" << endl << + indent() << " $output->getTransport()->flush();" << endl << + indent() << " return;" << endl; + + f_service_ << + indent() << "}" << endl << + indent() << "$self->$methodname($rseqid, $input, $output);" << endl << + indent() << "return 1;" << endl; + + indent_down(); + + f_service_ << + indent() << "}" << endl <get_name() << "{"<get_program()) + service_name_ + "_" + tfunction->get_name() + "_args"; + string resultname = perl_namespace(tservice->get_program()) + service_name_ + "_" + tfunction->get_name() + "_result"; + + f_service_ << + indent() << "my $args = new " << argsname << "();" << endl << + indent() << "$args->read($input);" << endl; + + f_service_ << + indent() << "$input->readMessageEnd();" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << "my $result = new " << resultname << "();" << endl; + } + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "eval {" << endl; + indent_up(); + } + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "$result->{success} = "; + } + f_service_ << + "$self->{handler}->" << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "$args->" << (*f_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "}; if( UNIVERSAL::isa($@,'"<<(*x_iter)->get_type()->get_name()<<"') ){ "<is_oneway()) { + indent_up(); + f_service_ << + indent() << "$result->{" << (*x_iter)->get_name() << "} = $@;" << endl; + indent_down(); + f_service_ << indent(); + } + } + indent_down(); + f_service_ << "}" << endl; + } + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return;" << endl; + indent_down(); + f_service_ << + indent() << "}" << endl; + return; + } + indent_up(); + // Serialize the request header + f_service_ << + indent() << "$output->writeMessageBegin('" << tfunction->get_name() << "', TMessageType::REPLY, $seqid);" << endl << + indent() << "$result->write($output);" << endl << + indent() << "$output->getTransport()->flush();" << endl; + indent_down(); + + // Close function + indent_down(); + f_service_ << + indent() << "}" << endl; +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_perl_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + f_service_ << + "# HELPER FUNCTIONS AND STRUCTURES" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + string name = ts->get_name(); + ts->set_name(service_name_ + "_" + name); + generate_perl_struct_definition(f_service_, ts, false); + generate_perl_function_helpers(*f_iter); + ts->set_name(name); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_perl_generator::generate_perl_function_helpers(t_function* tfunction) { + t_struct result(program_, service_name_ + "_" + tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + generate_perl_struct_definition(f_service_, &result, false); +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_perl_generator::generate_service_interface(t_service* tservice) { + string extends_if = ""; + t_service* extends_s = tservice->get_extends(); + if (extends_s != NULL) { + extends_if = "use base('" + perl_namespace(extends_s->get_program()) + extends_s->get_name() + "If');"; + } + + f_service_ << + "package " << perl_namespace(program_) << service_name_ << "If;"< functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + "sub " << function_signature(*f_iter) <get_extends(); + if (extends_s != NULL) { + extends = extends_s->get_name(); + extends_if = "use base('" + perl_namespace(extends_s->get_program()) + extends_s->get_name() + "Rest');"; + } + f_service_ << + "package " << perl_namespace(program_) << service_name_ << "Rest;"< $impl };"< functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + "sub " << (*f_iter)->get_name() << + "{" <& args = (*f_iter)->get_arglist()->get_members(); + vector::const_iterator a_iter; + for (a_iter = args.begin(); a_iter != args.end(); ++a_iter) { + t_type* atype = get_true_type((*a_iter)->get_type()); + string req = "$request->{'" + (*a_iter)->get_name() + "'}"; + f_service_ << + indent() << "my $" << (*a_iter)->get_name() << " = (" << req << ") ? " << req << " : undef;" << endl; + if (atype->is_string() && + ((t_base_type*)atype)->is_string_list()) { + f_service_ << + indent() << "my @" << (*a_iter)->get_name() << " = split(/,/, $" << (*a_iter)->get_name() << ");" << endl << + indent() << "$"<<(*a_iter)->get_name() <<" = \\@"<<(*a_iter)->get_name()<{impl}->" << (*f_iter)->get_name() << "(" << argument_list((*f_iter)->get_arglist()) << ");" << endl; + indent_down(); + indent(f_service_) << "}" << endl <get_extends(); + if (extends_s != NULL) { + extends = perl_namespace(extends_s->get_program()) + extends_s->get_name(); + extends_client = "use base('" + extends + "Client');"; + } + + f_service_ << + "package " << perl_namespace(program_) << service_name_ << "Client;"<SUPER::new($input, $output);" << endl; + } else { + f_service_ << + indent() << " $self->{input} = $input;" << endl << + indent() << " $self->{output} = defined $output ? $output : $input;" << endl << + indent() << " $self->{seqid} = 0;" << endl; + } + + f_service_ << + indent() << "return bless($self,$classname);"< functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + // Open function + f_service_ << "sub " << function_signature(*f_iter) << endl; + + indent_up(); + + indent(f_service_) << indent() << + "$self->send_" << funname << "("; + + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "$" << (*fld_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << "return "; + } + f_service_ << + "$self->recv_" << funname << "();" << endl; + } + + indent_down(); + + f_service_ << "}" << endl << endl; + + f_service_ << + "sub send_" << function_signature(*f_iter) << endl; + + indent_up(); + + std::string argsname = perl_namespace(tservice->get_program()) + service_name_ + "_" + (*f_iter)->get_name() + "_args"; + + // Serialize the request header + f_service_ << + indent() << "$self->{output}->writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType::CALL, $self->{seqid});" << endl; + + f_service_ << + indent() << "my $args = new " << argsname << "();" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "$args->{" << (*fld_iter)->get_name() << "} = $" << (*fld_iter)->get_name() << ";" << endl; + } + + // Write to the stream + f_service_ << + indent() << "$args->write($self->{output});" << endl << + indent() << "$self->{output}->writeMessageEnd();" << endl << + indent() << "$self->{output}->getTransport()->flush();" << endl; + + + indent_down(); + + f_service_ << "}" << endl; + + + if (!(*f_iter)->is_oneway()) { + std::string resultname = perl_namespace(tservice->get_program()) + service_name_ + "_" + (*f_iter)->get_name() + "_result"; + t_struct noargs(program_); + + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + // Open function + f_service_ << + endl << + "sub " << function_signature(&recv_function) << endl; + + indent_up(); + + f_service_ << + indent() << "my $rseqid = 0;" << endl << + indent() << "my $fname;" << endl << + indent() << "my $mtype = 0;" << endl << + endl; + + f_service_ << + indent() << "$self->{input}->readMessageBegin(\\$fname, \\$mtype, \\$rseqid);" << endl << + indent() << "if ($mtype == TMessageType::EXCEPTION) {" << endl << + indent() << " my $x = new TApplicationException();" << endl << + indent() << " $x->read($self->{input});" << endl << + indent() << " $self->{input}->readMessageEnd();" << endl << + indent() << " die $x;" << endl << + indent() << "}" << endl; + + + f_service_ << + indent() << "my $result = new " << resultname << "();" << endl << + indent() << "$result->read($self->{input});" << endl; + + + f_service_ << + indent() << "$self->{input}->readMessageEnd();" << endl << + endl; + + + // Careful, only return result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "if (defined $result->{success} ) {" << endl << + indent() << " return $result->{success};" << endl << + indent() << "}" << endl; + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "if (defined $result->{" << (*x_iter)->get_name() << "}) {" << endl << + indent() << " die $result->{" << (*x_iter)->get_name() << "};" << endl << + indent() << "}" << endl; + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "return;" << endl; + } else { + f_service_ << + indent() << "die \"" << (*f_iter)->get_name() << " failed: unknown result\";" << endl; + } + + // Close function + indent_down(); + f_service_ << "}"<get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = tfield->get_name(); + + //Hack for when prefix is defined (always a hash ref) + if (!prefix.empty()) { + name = prefix + "{" + tfield->get_name() + "}"; + } + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + "$xfer += $input->"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + out << "readString(\\$" << name << ");"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool(\\$" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte(\\$" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "readI16(\\$" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "readI32(\\$" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "readI64(\\$" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble(\\$" << name << ");"; + break; + default: + throw "compiler error: no PERL name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32(\\$" << name << ");"; + } + out << endl; + + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type->get_name().c_str()); + } +} + +/** + * Generates an unserializer for a variable. This makes two key assumptions, + * first that there is a const char* variable named data that points to the + * buffer for deserialization, and that there is a variable protocol which + * is a reference to a TProtocol serialization object. + */ +void t_perl_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + out << + indent() << "$" << prefix << " = new " << perl_namespace(tstruct->get_program()) << tstruct->get_name() << "();" << endl << + indent() << "$xfer += $" << prefix << "->read($input);" << endl; +} + +void t_perl_generator::generate_deserialize_container(ofstream &out, + t_type* ttype, + string prefix) { + scope_up(out); + + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + out << + indent() << "my $" << size << " = 0;" << endl; + + // Declare variables, read header + if (ttype->is_map()) { + out << + indent() << "$" << prefix << " = {};" << endl << + indent() << "my $" << ktype << " = 0;" << endl << + indent() << "my $" << vtype << " = 0;" << endl; + + out << + indent() << "$xfer += $input->readMapBegin(" << + "\\$" << ktype << ", \\$" << vtype << ", \\$" << size << ");" << endl; + + } else if (ttype->is_set()) { + + out << + indent() << "$" << prefix << " = {};" << endl << + indent() << "my $" << etype << " = 0;" << endl << + indent() << "$xfer += $input->readSetBegin(" << + "\\$" << etype << ", \\$" << size << ");" << endl; + + } else if (ttype->is_list()) { + + out << + indent() << "$" << prefix << " = [];" << endl << + indent() << "my $" << etype << " = 0;" << endl << + indent() << "$xfer += $input->readListBegin(" << + "\\$" << etype << ", \\$" << size << ");" << endl; + + } + + // For loop iterates over elements + string i = tmp("_i"); + indent(out) << + "for (my $" << + i << " = 0; $" << i << " < $" << size << "; ++$" << i << ")" << endl; + + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + scope_down(out); + + + // Read container end + if (ttype->is_map()) { + indent(out) << "$xfer += $input->readMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "$xfer += $input->readSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "$xfer += $input->readListEnd();" << endl; + } + + scope_down(out); +} + + +/** + * Generates code to deserialize a map + */ +void t_perl_generator::generate_deserialize_map_element(ofstream &out, + t_map* tmap, + string prefix) { + string key = tmp("key"); + string val = tmp("val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + indent(out) << + declare_field(&fkey, true, true) << endl; + indent(out) << + declare_field(&fval, true, true) << endl; + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << + "$" << prefix << "->{$" << key << "} = $" << val << ";" << endl; +} + +void t_perl_generator::generate_deserialize_set_element(ofstream &out, + t_set* tset, + string prefix) { + string elem = tmp("elem"); + t_field felem(tset->get_elem_type(), elem); + + indent(out) << + "my $" << elem << " = undef;" << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + "$" << prefix << "->{$" << elem << "} = 1;" << endl; +} + +void t_perl_generator::generate_deserialize_list_element(ofstream &out, + t_list* tlist, + string prefix) { + string elem = tmp("elem"); + t_field felem(tlist->get_elem_type(), elem); + + indent(out) << + "my $" << elem << " = undef;" << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + "push(@{$" << prefix << "},$" << elem << ");" << endl; +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param prefix Name to prepend to field name + */ +void t_perl_generator::generate_serialize_field(ofstream &out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + prefix + "{"+tfield->get_name()+"}" ); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + prefix + "{" + tfield->get_name()+"}"); + } else if (type->is_base_type() || type->is_enum()) { + + string name = tfield->get_name(); + + //Hack for when prefix is defined (always a hash ref) + if(!prefix.empty()) + name = prefix + "{" + tfield->get_name() + "}"; + + indent(out) << + "$xfer += $output->"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString($" << name << ");"; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool($" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte($" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "writeI16($" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "writeI32($" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "writeI64($" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble($" << name << ");"; + break; + default: + throw "compiler error: no PERL name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32($" << name << ");"; + } + out << endl; + + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s%s' TYPE '%s'\n", + prefix.c_str(), + tfield->get_name().c_str(), + type->get_name().c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_perl_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) << + "$xfer += $" << prefix << "->write($output);" << endl; +} + +/** + * Writes out a container + */ +void t_perl_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + scope_up(out); + + if (ttype->is_map()) { + indent(out) << + "$output->writeMapBegin(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + "scalar(keys %{$" << prefix << "}));" << endl; + } else if (ttype->is_set()) { + indent(out) << + "$output->writeSetBegin(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + "scalar(@{$" << prefix << "}));" << endl; + + } else if (ttype->is_list()) { + + indent(out) << + "$output->writeListBegin(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + "scalar(@{$" << prefix << "}));" << endl; + + } + + scope_up(out); + + if (ttype->is_map()) { + string kiter = tmp("kiter"); + string viter = tmp("viter"); + indent(out) << + "while( my ($"<is_set()) { + string iter = tmp("iter"); + indent(out) << + "foreach my $"<is_list()) { + string iter = tmp("iter"); + indent(out) << + "foreach my $"<is_map()) { + indent(out) << + "$output->writeMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << + "$output->writeSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << + "$output->writeListEnd();" << endl; + } + + scope_down(out); +} + +/** + * Serializes the members of a map. + * + */ +void t_perl_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + generate_serialize_field(out, &kfield); + + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield); +} + +/** + * Serializes the members of a set. + */ +void t_perl_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + +/** + * Serializes the members of a list. + */ +void t_perl_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield); +} + +/** + * Declares a field, which may include initialization as necessary. + * + * @param ttype The type + */ +string t_perl_generator::declare_field(t_field* tfield, bool init, bool obj) { + string result = "my $" + tfield->get_name(); + if (init) { + t_type* type = get_true_type(tfield->get_type()); + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + break; + case t_base_type::TYPE_STRING: + result += " = ''"; + break; + case t_base_type::TYPE_BOOL: + result += " = 0"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + result += " = 0"; + break; + case t_base_type::TYPE_DOUBLE: + result += " = 0.0"; + break; + default: + throw "compiler error: no PERL initializer for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + result += " = 0"; + } else if (type->is_container()) { + result += " = []"; + } else if (type->is_struct() || type->is_xception()) { + if (obj) { + result += " = new " + perl_namespace(type->get_program()) + type->get_name() + "()"; + } else { + result += " = undef"; + } + } + } + return result + ";"; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_perl_generator::function_signature(t_function* tfunction, + string prefix) { + + string str; + + str = prefix + tfunction->get_name() + "{\n"; + str += " my $self = shift;\n"; + + //Need to create perl function arg inputs + const vector &fields = tfunction->get_arglist()->get_members(); + vector::const_iterator f_iter; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + str += " my $" + (*f_iter)->get_name() + " = shift;\n"; + } + + return str; +} + +/** + * Renders a field list + */ +string t_perl_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += "$" + (*f_iter)->get_name(); + } + return result; +} + +/** + * Converts the parse type to a C++ enum string for the given type. + */ +string t_perl_generator ::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType::STRING"; + case t_base_type::TYPE_BOOL: + return "TType::BOOL"; + case t_base_type::TYPE_BYTE: + return "TType::BYTE"; + case t_base_type::TYPE_I16: + return "TType::I16"; + case t_base_type::TYPE_I32: + return "TType::I32"; + case t_base_type::TYPE_I64: + return "TType::I64"; + case t_base_type::TYPE_DOUBLE: + return "TType::DOUBLE"; + } + } else if (type->is_enum()) { + return "TType::I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType::STRUCT"; + } else if (type->is_map()) { + return "TType::MAP"; + } else if (type->is_set()) { + return "TType::SET"; + } else if (type->is_list()) { + return "TType::LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +THRIFT_REGISTER_GENERATOR(perl, "Perl", ""); diff --git a/compiler/cpp/src/generate/t_php_generator.cc b/compiler/cpp/src/generate/t_php_generator.cc new file mode 100644 index 00000000..436a6325 --- /dev/null +++ b/compiler/cpp/src/generate/t_php_generator.cc @@ -0,0 +1,2281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * PHP code generator. + * + */ +class t_php_generator : public t_oop_generator { + public: + t_php_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + std::map::const_iterator iter; + + iter = parsed_options.find("inlined"); + binary_inline_ = (iter != parsed_options.end()); + + iter = parsed_options.find("rest"); + rest_ = (iter != parsed_options.end()); + + iter = parsed_options.find("server"); + phps_ = (iter != parsed_options.end()); + + iter = parsed_options.find("autoload"); + autoload_ = (iter != parsed_options.end()); + + iter = parsed_options.find("oop"); + oop_ = (iter != parsed_options.end()); + + if (oop_ && binary_inline_) { + throw "oop and inlined are mutually exclusive."; + } + + out_dir_base_ = (binary_inline_ ? "gen-phpi" : "gen-php"); + escape_['$'] = "\\$"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Structs! + */ + + void generate_php_struct(t_struct* tstruct, bool is_exception); + void generate_php_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void _generate_php_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false); + void generate_php_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_php_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_php_function_helpers(t_function* tfunction); + + void generate_php_type_spec(std::ofstream &out, t_type* t); + void generate_php_struct_spec(std::ofstream &out, t_struct* tstruct); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_rest (t_service* tservice); + void generate_service_client (t_service* tservice); + void _generate_service_client (std::ofstream &out, t_service* tservice); + void generate_service_processor (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix="", + bool inclass=false); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream &out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + /** + * Helper rendering functions + */ + + std::string php_includes(); + std::string declare_field(t_field* tfield, bool init=false, bool obj=false); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_cast(t_type* ttype); + std::string type_to_enum(t_type* ttype); + + std::string php_namespace(t_program* p) { + std::string ns = p->get_namespace("php"); + return ns.size() ? (ns + "_") : ""; + } + + private: + + /** + * File streams + */ + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_helpers_; + std::ofstream f_service_; + + /** + * Generate protocol-independent template? Or Binary inline code? + */ + bool binary_inline_; + + /** + * Generate a REST handler class + */ + bool rest_; + + /** + * Generate stubs for a PHP server + */ + bool phps_; + + /** + * Generate PHP code that uses autoload + */ + bool autoload_; + + /** + * Whether to use OOP base class TBase + */ + bool oop_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_php_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + string f_types_name = get_out_dir()+program_name_+"_types.php"; + f_types_.open(f_types_name.c_str()); + + // Print header + f_types_ << + "& includes = program_->get_includes(); + for (size_t i = 0; i < includes.size(); ++i) { + string package = includes[i]->get_name(); + f_types_ << + "include_once $GLOBALS['THRIFT_ROOT'].'/packages/" << package << "/" << package << "_types.php';" << endl; + } + f_types_ << endl; + + // Print header + if (!program_->get_consts().empty()) { + string f_consts_name = get_out_dir()+program_name_+"_constants.php"; + f_consts_.open(f_consts_name.c_str()); + f_consts_ << + "" << endl; + f_types_.close(); + + if (!program_->get_consts().empty()) { + f_consts_ << "?>" << endl; + f_consts_.close(); + } +} + +/** + * Generates a typedef. This is not done in PHP, types are all implicit. + * + * @param ttypedef The type definition + */ +void t_php_generator::generate_typedef(t_typedef* ttypedef) {} + +/** + * Generates code for an enumerated type. Since define is expensive to lookup + * in PHP, we use a global array for this. + * + * @param tenum The enumeration + */ +void t_php_generator::generate_enum(t_enum* tenum) { + f_types_ << + "$GLOBALS['" << php_namespace(tenum->get_program()) << "E_" << tenum->get_name() << "'] = array(" << endl; + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + f_types_ << + " '" << (*c_iter)->get_name() << "' => " << value << "," << endl; + } + + f_types_ << + ");" << endl << endl; + + + // We're also doing it this way to see how it performs. It's more legible + // code but you can't do things like an 'extract' on it, which is a bit of + // a downer. + f_types_ << + "final class " << php_namespace(tenum->get_program()) << tenum->get_name() << " {" << endl; + indent_up(); + + value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + indent(f_types_) << + "const " << (*c_iter)->get_name() << " = " << value << ";" << endl; + } + + indent(f_types_) << + "static public $__names = array(" << endl; + value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + indent(f_types_) << + " " << value << " => '" << (*c_iter)->get_name() << "'," << endl; + } + indent(f_types_) << + ");" << endl; + + indent_down(); + f_types_ << "}" << endl << endl; +} + +/** + * Generate a constant value + */ +void t_php_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + f_consts_ << "$GLOBALS['" << program_name_ << "_CONSTANTS']['" << name << "'] = "; + f_consts_ << render_const_value(type, value); + f_consts_ << ";" << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_php_generator::render_const_value(t_type* type, t_const_value* value) { + std::ostringstream out; + type = get_true_type(type); + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + indent(out) << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << "new " << php_namespace(type->get_program()) << type->get_name() << "(array(" << endl; + indent_up(); + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + out << indent(); + out << render_const_value(g_type_string, v_iter->first); + out << " => "; + out << render_const_value(field_type, v_iter->second); + out << endl; + } + indent_down(); + indent(out) << "))"; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + out << "array(" << endl; + indent_up(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(ktype, v_iter->first); + out << " => "; + out << render_const_value(vtype, v_iter->second); + out << "," << endl; + } + indent_down(); + indent(out) << ")"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + out << "array(" << endl; + indent_up(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(etype, *v_iter); + if (type->is_set()) { + out << " => true"; + } + out << "," << endl; + } + indent_down(); + indent(out) << ")"; + } + return out.str(); +} + +/** + * Make a struct + */ +void t_php_generator::generate_struct(t_struct* tstruct) { + generate_php_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_php_generator::generate_xception(t_struct* txception) { + generate_php_struct(txception, true); +} + +/** + * Structs can be normal or exceptions. + */ +void t_php_generator::generate_php_struct(t_struct* tstruct, + bool is_exception) { + generate_php_struct_definition(f_types_, tstruct, is_exception); +} + +void t_php_generator::generate_php_type_spec(ofstream& out, + t_type* t) { + t = get_true_type(t); + indent(out) << "'type' => " << type_to_enum(t) << "," << endl; + + if (t->is_base_type() || t->is_enum()) { + // Noop, type is all we need + } else if (t->is_struct() || t->is_xception()) { + indent(out) << "'class' => '" << php_namespace(t->get_program()) << t->get_name() <<"'," << endl; + } else if (t->is_map()) { + t_type* ktype = get_true_type(((t_map*)t)->get_key_type()); + t_type* vtype = get_true_type(((t_map*)t)->get_val_type()); + indent(out) << "'ktype' => " << type_to_enum(ktype) << "," << endl; + indent(out) << "'vtype' => " << type_to_enum(vtype) << "," << endl; + indent(out) << "'key' => array(" << endl; + indent_up(); + generate_php_type_spec(out, ktype); + indent_down(); + indent(out) << ")," << endl; + indent(out) << "'val' => array(" << endl; + indent_up(); + generate_php_type_spec(out, vtype); + indent(out) << ")," << endl; + indent_down(); + } else if (t->is_list() || t->is_set()) { + t_type* etype; + if (t->is_list()) { + etype = get_true_type(((t_list*)t)->get_elem_type()); + } else { + etype = get_true_type(((t_set*)t)->get_elem_type()); + } + indent(out) << "'etype' => " << type_to_enum(etype) <<"," << endl; + indent(out) << "'elem' => array(" << endl; + indent_up(); + generate_php_type_spec(out, etype); + indent(out) << ")," << endl; + indent_down(); + } else { + throw "compiler error: no type for php struct spec field"; + } + +} + +/** + * Generates the struct specification structure, which fully qualifies enough + * type information to generalize serialization routines. + */ +void t_php_generator::generate_php_struct_spec(ofstream& out, + t_struct* tstruct) { + indent(out) << "if (!isset(self::$_TSPEC)) {" << endl; + indent_up(); + + indent(out) << "self::$_TSPEC = array(" << endl; + indent_up(); + + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + indent(out) << (*m_iter)->get_key() << " => array(" << endl; + indent_up(); + out << + indent() << "'var' => '" << (*m_iter)->get_name() << "'," << endl; + generate_php_type_spec(out, t); + indent(out) << ")," << endl; + indent_down(); + } + + indent_down(); + indent(out) << " );" << endl; + indent_down(); + indent(out) << "}" << endl; +} + + +void t_php_generator::generate_php_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception) { + if (autoload_) { + // Make output file + ofstream autoload_out; + string f_struct = program_name_+"."+(tstruct->get_name())+".php"; + string f_struct_name = get_out_dir()+f_struct; + autoload_out.open(f_struct_name.c_str()); + autoload_out << "" << endl; + autoload_out.close(); + + f_types_ << + "$GLOBALS['THRIFT_AUTOLOAD']['" << lowercase(php_namespace(tstruct->get_program()) + tstruct->get_name()) << "'] = '" << program_name_ << "/" << f_struct << "';" << endl; + + } else { + _generate_php_struct_definition(out, tstruct, is_exception); + } +} + +/** + * Generates a struct definition for a thrift data type. This is nothing in PHP + * where the objects are all just associative arrays (unless of course we + * decide to start using objects for them...) + * + * @param tstruct The struct definition + */ +void t_php_generator::_generate_php_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + out << + "class " << php_namespace(tstruct->get_program()) << tstruct->get_name(); + if (is_exception) { + out << " extends TException"; + } else if (oop_) { + out << " extends TBase"; + } + out << + " {" << endl; + indent_up(); + + indent(out) << "static $_TSPEC;" << endl << endl; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + string dval = "null"; + t_type* t = get_true_type((*m_iter)->get_type()); + if ((*m_iter)->get_value() != NULL && !(t->is_struct() || t->is_xception())) { + dval = render_const_value((*m_iter)->get_type(), (*m_iter)->get_value()); + } + indent(out) << + "public $" << (*m_iter)->get_name() << " = " << dval << ";" << endl; + } + + out << endl; + + // Generate constructor from array + string param = (members.size() > 0) ? "$vals=null" : ""; + out << + indent() << "public function __construct(" << param << ") {" << endl; + indent_up(); + + generate_php_struct_spec(out, tstruct); + + if (members.size() > 0) { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + t_type* t = get_true_type((*m_iter)->get_type()); + if ((*m_iter)->get_value() != NULL && (t->is_struct() || t->is_xception())) { + indent(out) << "$this->" << (*m_iter)->get_name() << " = " << render_const_value(t, (*m_iter)->get_value()) << ";" << endl; + } + } + out << + indent() << "if (is_array($vals)) {" << endl; + indent_up(); + if (oop_) { + out << indent() << "parent::__construct(self::$_TSPEC, $vals);" << endl; + } else { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << + indent() << "if (isset($vals['" << (*m_iter)->get_name() << "'])) {" << endl << + indent() << " $this->" << (*m_iter)->get_name() << " = $vals['" << (*m_iter)->get_name() << "'];" << endl << + indent() << "}" << endl; + } + } + indent_down(); + out << + indent() << "}" << endl; + } + scope_down(out); + out << endl; + + out << + indent() << "public function getName() {" << endl << + indent() << " return '" << tstruct->get_name() << "';" << endl << + indent() << "}" << endl << + endl; + + generate_php_struct_reader(out, tstruct); + generate_php_struct_writer(out, tstruct); + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates the read() method for a struct + */ +void t_php_generator::generate_php_struct_reader(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + indent(out) << + "public function read($input)" << endl; + scope_up(out); + + if (oop_) { + indent(out) << "return $this->_read('" << tstruct->get_name() << "', self::$_TSPEC, $input);" << endl; + scope_down(out); + return; + } + + out << + indent() << "$xfer = 0;" << endl << + indent() << "$fname = null;" << endl << + indent() << "$ftype = 0;" << endl << + indent() << "$fid = 0;" << endl; + + // Declare stack tmp variables + if (!binary_inline_) { + indent(out) << + "$xfer += $input->readStructBegin($fname);" << endl; + } + + // Loop over reading in fields + indent(out) << + "while (true)" << endl; + + scope_up(out); + + // Read beginning field marker + if (binary_inline_) { + t_field fftype(g_type_byte, "ftype"); + t_field ffid(g_type_i16, "fid"); + generate_deserialize_field(out, &fftype); + out << + indent() << "if ($ftype == TType::STOP) {" << endl << + indent() << " break;" << endl << + indent() << "}" << endl; + generate_deserialize_field(out, &ffid); + } else { + indent(out) << + "$xfer += $input->readFieldBegin($fname, $ftype, $fid);" << endl; + // Check for field STOP marker and break + indent(out) << + "if ($ftype == TType::STOP) {" << endl; + indent_up(); + indent(out) << + "break;" << endl; + indent_down(); + indent(out) << + "}" << endl; + } + + // Switch statement on the field we are reading + indent(out) << + "switch ($fid)" << endl; + + scope_up(out); + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + indent(out) << + "case " << (*f_iter)->get_key() << ":" << endl; + indent_up(); + indent(out) << "if ($ftype == " << type_to_enum((*f_iter)->get_type()) << ") {" << endl; + indent_up(); + generate_deserialize_field(out, *f_iter, "this->"); + indent_down(); + out << + indent() << "} else {" << endl; + if (binary_inline_) { + indent(out) << " $xfer += TProtocol::skipBinary($input, $ftype);" << endl; + } else { + indent(out) << " $xfer += $input->skip($ftype);" << endl; + } + out << + indent() << "}" << endl << + indent() << "break;" << endl; + indent_down(); + } + + // In the default case we skip the field + indent(out) << "default:" << endl; + if (binary_inline_) { + indent(out) << " $xfer += TProtocol::skipBinary($input, $ftype);" << endl; + } else { + indent(out) << " $xfer += $input->skip($ftype);" << endl; + } + indent(out) << " break;" << endl; + + scope_down(out); + + if (!binary_inline_) { + // Read field end marker + indent(out) << + "$xfer += $input->readFieldEnd();" << endl; + } + + scope_down(out); + + if (!binary_inline_) { + indent(out) << + "$xfer += $input->readStructEnd();" << endl; + } + + indent(out) << + "return $xfer;" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates the write() method for a struct + */ +void t_php_generator::generate_php_struct_writer(ofstream& out, + t_struct* tstruct) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + if (binary_inline_) { + indent(out) << + "public function write(&$output) {" << endl; + } else { + indent(out) << + "public function write($output) {" << endl; + } + indent_up(); + + if (oop_) { + indent(out) << "return $this->_write('" << tstruct->get_name() << "', self::$_TSPEC, $output);" << endl; + scope_down(out); + return; + } + + indent(out) << + "$xfer = 0;" << endl; + + if (!binary_inline_) { + indent(out) << + "$xfer += $output->writeStructBegin('" << name << "');" << endl; + } + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + out << + indent() << "if ($this->" << (*f_iter)->get_name() << " !== null) {" << endl; + indent_up(); + + t_type* type = get_true_type((*f_iter)->get_type()); + string expect; + if (type->is_container()) { + expect = "array"; + } else if (type->is_struct()) { + expect = "object"; + } + if (!expect.empty()) { + out << + indent() << "if (!is_" << expect << "($this->" << (*f_iter)->get_name() << ")) {" << endl; + indent_up(); + out << + indent() << "throw new TProtocolException('Bad type in structure.', TProtocolException::INVALID_DATA);" << endl; + scope_down(out); + } + + // Write field header + if (binary_inline_) { + out << + indent() << "$output .= pack('c', " << type_to_enum((*f_iter)->get_type()) << ");" << endl << + indent() << "$output .= pack('n', " << (*f_iter)->get_key() << ");" << endl; + } else { + indent(out) << + "$xfer += $output->writeFieldBegin(" << + "'" << (*f_iter)->get_name() << "', " << + type_to_enum((*f_iter)->get_type()) << ", " << + (*f_iter)->get_key() << ");" << endl; + } + + // Write field contents + generate_serialize_field(out, *f_iter, "this->"); + + // Write field closer + if (!binary_inline_) { + indent(out) << + "$xfer += $output->writeFieldEnd();" << endl; + } + + indent_down(); + indent(out) << + "}" << endl; + } + + if (binary_inline_) { + out << + indent() << "$output .= pack('c', TType::STOP);" << endl; + } else { + out << + indent() << "$xfer += $output->writeFieldStop();" << endl << + indent() << "$xfer += $output->writeStructEnd();" << endl; + } + + out << + indent() << "return $xfer;" << endl; + + indent_down(); + out << + indent() << "}" << endl << + endl; +} + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_php_generator::generate_service(t_service* tservice) { + string f_service_name = get_out_dir()+service_name_+".php"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + "get_extends() != NULL) { + f_service_ << + "include_once $GLOBALS['THRIFT_ROOT'].'/packages/" << tservice->get_extends()->get_program()->get_name() << "/" << tservice->get_extends()->get_name() << ".php';" << endl; + } + + f_service_ << + endl; + + // Generate the three main parts of the service (well, two for now in PHP) + generate_service_interface(tservice); + if (rest_) { + generate_service_rest(tservice); + } + generate_service_client(tservice); + generate_service_helpers(tservice); + if (phps_) { + generate_service_processor(tservice); + } + + // Close service file + f_service_ << "?>" << endl; + f_service_.close(); +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_php_generator::generate_service_processor(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = tservice->get_extends()->get_name(); + extends_processor = " extends " + extends + "Processor"; + } + + // Generate the header portion + f_service_ << + "class " << service_name_ << "Processor" << extends_processor << " {" << endl; + indent_up(); + + if (extends.empty()) { + f_service_ << + indent() << "protected $handler_ = null;" << endl; + } + + f_service_ << + indent() << "public function __construct($handler) {" << endl; + if (extends.empty()) { + f_service_ << + indent() << " $this->handler_ = $handler;" << endl; + } else { + f_service_ << + indent() << " parent::__construct($handler);" << endl; + } + f_service_ << + indent() << "}" << endl << + endl; + + // Generate the server implementation + indent(f_service_) << + "public function process($input, $output) {" << endl; + indent_up(); + + f_service_ << + indent() << "$rseqid = 0;" << endl << + indent() << "$fname = null;" << endl << + indent() << "$mtype = 0;" << endl << + endl; + + if (binary_inline_) { + t_field ffname(g_type_string, "fname"); + t_field fmtype(g_type_byte, "mtype"); + t_field fseqid(g_type_i32, "rseqid"); + generate_deserialize_field(f_service_, &ffname, "", true); + generate_deserialize_field(f_service_, &fmtype, "", true); + generate_deserialize_field(f_service_, &fseqid, "", true); + } else { + f_service_ << + indent() << "$input->readMessageBegin($fname, $mtype, $rseqid);" << endl; + } + + // HOT: check for method implementation + f_service_ << + indent() << "$methodname = 'process_'.$fname;" << endl << + indent() << "if (!method_exists($this, $methodname)) {" << endl; + if (binary_inline_) { + f_service_ << + indent() << " throw new Exception('Function '.$fname.' not implemented.');" << endl; + } else { + f_service_ << + indent() << " $input->skip(TType::STRUCT);" << endl << + indent() << " $input->readMessageEnd();" << endl << + indent() << " $x = new TApplicationException('Function '.$fname.' not implemented.', TApplicationException::UNKNOWN_METHOD);" << endl << + indent() << " $output->writeMessageBegin($fname, TMessageType::EXCEPTION, $rseqid);" << endl << + indent() << " $x->write($output);" << endl << + indent() << " $output->writeMessageEnd();" << endl << + indent() << " $output->getTransport()->flush();" << endl << + indent() << " return;" << endl; + } + f_service_ << + indent() << "}" << endl << + indent() << "$this->$methodname($rseqid, $input, $output);" << endl << + indent() << "return true;" << endl; + indent_down(); + f_service_ << + indent() << "}" << endl << + endl; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + indent_down(); + f_service_ << "}" << endl; +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_php_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + indent(f_service_) << + "protected function process_" << tfunction->get_name() << + "($seqid, $input, $output) {" << endl; + indent_up(); + + string argsname = php_namespace(tservice->get_program()) + service_name_ + "_" + tfunction->get_name() + "_args"; + string resultname = php_namespace(tservice->get_program()) + service_name_ + "_" + tfunction->get_name() + "_result"; + + f_service_ << + indent() << "$args = new " << argsname << "();" << endl << + indent() << "$args->read($input);" << endl; + if (!binary_inline_) { + f_service_ << + indent() << "$input->readMessageEnd();" << endl; + } + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << "$result = new " << resultname << "();" << endl; + } + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "try {" << endl; + indent_up(); + } + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "$result->success = "; + } + f_service_ << + "$this->handler_->" << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "$args->" << (*f_iter)->get_name(); + } + f_service_ << ");" << endl; + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "} catch (" << php_namespace((*x_iter)->get_type()->get_program()) << (*x_iter)->get_type()->get_name() << " $" << (*x_iter)->get_name() << ") {" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "$result->" << (*x_iter)->get_name() << " = $" << (*x_iter)->get_name() << ";" << endl; + indent_down(); + f_service_ << indent(); + } + } + f_service_ << "}" << endl; + } + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return;" << endl; + indent_down(); + f_service_ << + indent() << "}" << endl; + return; + } + + // Serialize the request header + if (binary_inline_) { + f_service_ << + indent() << "$buff = pack('N', (0x80010000 | TMessageType::REPLY)); " << endl << + indent() << "$buff .= pack('N', strlen('" << tfunction->get_name() << "'));" << endl << + indent() << "$buff .= '" << tfunction->get_name() << "';" << endl << + indent() << "$buff .= pack('N', $seqid);" << endl << + indent() << "$result->write($buff);" << endl << + indent() << "$output->write($buff);" << endl << + indent() << "$output->flush();" << endl; + } else { + f_service_ << + indent() << "$output->writeMessageBegin('" << tfunction->get_name() << "', TMessageType::REPLY, $seqid);" << endl << + indent() << "$result->write($output);" << endl << + indent() << "$output->getTransport()->flush();" << endl; + } + + // Close function + indent_down(); + f_service_ << + indent() << "}" << endl; +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_php_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + f_service_ << + "// HELPER FUNCTIONS AND STRUCTURES" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + string name = ts->get_name(); + ts->set_name(service_name_ + "_" + name); + generate_php_struct_definition(f_service_, ts, false); + generate_php_function_helpers(*f_iter); + ts->set_name(name); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_php_generator::generate_php_function_helpers(t_function* tfunction) { + if (!tfunction->is_oneway()) { + t_struct result(program_, service_name_ + "_" + tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + + generate_php_struct_definition(f_service_, &result, false); + } +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_php_generator::generate_service_interface(t_service* tservice) { + string extends = ""; + string extends_if = ""; + if (tservice->get_extends() != NULL) { + extends = " extends " + tservice->get_extends()->get_name(); + extends_if = " extends " + tservice->get_extends()->get_name() + "If"; + } + f_service_ << + "interface " << service_name_ << "If" << extends_if << " {" << endl; + indent_up(); + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + indent(f_service_) << + "public function " << function_signature(*f_iter) << ";" << endl; + } + indent_down(); + f_service_ << + "}" << endl << endl; +} + +/** + * Generates a REST interface + */ +void t_php_generator::generate_service_rest(t_service* tservice) { + string extends = ""; + string extends_if = ""; + if (tservice->get_extends() != NULL) { + extends = " extends " + tservice->get_extends()->get_name(); + extends_if = " extends " + tservice->get_extends()->get_name() + "Rest"; + } + f_service_ << + "class " << service_name_ << "Rest" << extends_if << " {" << endl; + indent_up(); + + if (extends.empty()) { + f_service_ << + indent() << "protected $impl_;" << endl << + endl; + } + + f_service_ << + indent() << "public function __construct($impl) {" << endl << + indent() << " $this->impl_ = $impl;" << endl << + indent() << "}" << endl << + endl; + + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + indent(f_service_) << + "public function " << (*f_iter)->get_name() << "($request) {" << endl; + indent_up(); + const vector& args = (*f_iter)->get_arglist()->get_members(); + vector::const_iterator a_iter; + for (a_iter = args.begin(); a_iter != args.end(); ++a_iter) { + t_type* atype = get_true_type((*a_iter)->get_type()); + string cast = type_to_cast(atype); + string req = "$request['" + (*a_iter)->get_name() + "']"; + if (atype->is_bool()) { + f_service_ << + indent() << "$" << (*a_iter)->get_name() << " = " << cast << "(!empty(" << req << ") && (" << req << " !== 'false'));" << endl; + } else { + f_service_ << + indent() << "$" << (*a_iter)->get_name() << " = isset(" << req << ") ? " << cast << req << " : null;" << endl; + } + if (atype->is_string() && + ((t_base_type*)atype)->is_string_list()) { + f_service_ << + indent() << "$" << (*a_iter)->get_name() << " = explode(',', $" << (*a_iter)->get_name() << ");" << endl; + } else if (atype->is_map() || atype->is_list()) { + f_service_ << + indent() << "$" << (*a_iter)->get_name() << " = json_decode($" << (*a_iter)->get_name() << ", true);" << endl; + } else if (atype->is_set()) { + f_service_ << + indent() << "$" << (*a_iter)->get_name() << " = array_fill_keys(json_decode($" << (*a_iter)->get_name() << ", true), 1);" << endl; + } else if (atype->is_struct() || atype->is_xception()) { + f_service_ << + indent() << "if ($" << (*a_iter)->get_name() << " !== null) {" << endl << + indent() << " $" << (*a_iter)->get_name() << " = new " << php_namespace(atype->get_program()) << atype->get_name() << "(json_decode($" << (*a_iter)->get_name() << ", true));" << endl << + indent() << "}" << endl; + } + } + f_service_ << + indent() << "return $this->impl_->" << (*f_iter)->get_name() << "(" << argument_list((*f_iter)->get_arglist()) << ");" << endl; + indent_down(); + indent(f_service_) << + "}" << endl << + endl; + } + indent_down(); + f_service_ << + "}" << endl << endl; +} + +void t_php_generator::generate_service_client(t_service* tservice) { + if (autoload_) { + // Make output file + ofstream autoload_out; + string f_struct = program_name_+"."+(tservice->get_name())+".client.php"; + string f_struct_name = get_out_dir()+f_struct; + autoload_out.open(f_struct_name.c_str()); + autoload_out << "" << endl; + autoload_out.close(); + + f_service_ << + "$GLOBALS['THRIFT_AUTOLOAD']['" << lowercase(service_name_ + "Client") << "'] = '" << program_name_ << "/" << f_struct << "';" << endl; + + } else { + _generate_service_client(f_service_, tservice); + } +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_php_generator::_generate_service_client(ofstream& out, t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = tservice->get_extends()->get_name(); + extends_client = " extends " + extends + "Client"; + } + + out << + "class " << service_name_ << "Client" << extends_client << " implements " << service_name_ << "If {" << endl; + indent_up(); + + // Private members + if (extends.empty()) { + out << + indent() << "protected $input_ = null;" << endl << + indent() << "protected $output_ = null;" << endl << + endl; + out << + indent() << "protected $seqid_ = 0;" << endl << + endl; + } + + // Constructor function + out << + indent() << "public function __construct($input, $output=null) {" << endl; + if (!extends.empty()) { + out << + indent() << " parent::__construct($input, $output);" << endl; + } else { + out << + indent() << " $this->input_ = $input;" << endl << + indent() << " $this->output_ = $output ? $output : $input;" << endl; + } + out << + indent() << "}" << endl << endl; + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + // Open function + indent(out) << + "public function " << function_signature(*f_iter) << endl; + scope_up(out); + indent(out) << + "$this->send_" << funname << "("; + + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + out << ", "; + } + out << "$" << (*fld_iter)->get_name(); + } + out << ");" << endl; + + if (!(*f_iter)->is_oneway()) { + out << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + out << "return "; + } + out << + "$this->recv_" << funname << "();" << endl; + } + scope_down(out); + out << endl; + + indent(out) << + "public function send_" << function_signature(*f_iter) << endl; + scope_up(out); + + std::string argsname = php_namespace(tservice->get_program()) + service_name_ + "_" + (*f_iter)->get_name() + "_args"; + + out << + indent() << "$args = new " << argsname << "();" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + out << + indent() << "$args->" << (*fld_iter)->get_name() << " = $" << (*fld_iter)->get_name() << ";" << endl; + } + + out << + indent() << "$bin_accel = ($this->output_ instanceof TProtocol::$TBINARYPROTOCOLACCELERATED) && function_exists('thrift_protocol_write_binary');" << endl; + + out << + indent() << "if ($bin_accel)" << endl; + scope_up(out); + + out << + indent() << "thrift_protocol_write_binary($this->output_, '" << (*f_iter)->get_name() << "', TMessageType::CALL, $args, $this->seqid_, $this->output_->isStrictWrite());" << endl; + + scope_down(out); + out << + indent() << "else" << endl; + scope_up(out); + + // Serialize the request header + if (binary_inline_) { + out << + indent() << "$buff = pack('N', (0x80010000 | TMessageType::CALL));" << endl << + indent() << "$buff .= pack('N', strlen('" << funname << "'));" << endl << + indent() << "$buff .= '" << funname << "';" << endl << + indent() << "$buff .= pack('N', $this->seqid_);" << endl; + } else { + out << + indent() << "$this->output_->writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType::CALL, $this->seqid_);" << endl; + } + + // Write to the stream + if (binary_inline_) { + out << + indent() << "$args->write($buff);" << endl << + indent() << "$this->output_->write($buff);" << endl << + indent() << "$this->output_->flush();" << endl; + } else { + out << + indent() << "$args->write($this->output_);" << endl << + indent() << "$this->output_->writeMessageEnd();" << endl << + indent() << "$this->output_->getTransport()->flush();" << endl; + } + + scope_down(out); + + scope_down(out); + + + if (!(*f_iter)->is_oneway()) { + std::string resultname = php_namespace(tservice->get_program()) + service_name_ + "_" + (*f_iter)->get_name() + "_result"; + t_struct noargs(program_); + + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + // Open function + out << + endl << + indent() << "public function " << function_signature(&recv_function) << endl; + scope_up(out); + + out << + indent() << "$bin_accel = ($this->input_ instanceof TProtocol::$TBINARYPROTOCOLACCELERATED)" + << " && function_exists('thrift_protocol_read_binary');" << endl; + + out << + indent() << "if ($bin_accel) $result = thrift_protocol_read_binary($this->input_, '" << resultname << "', $this->input_->isStrictRead());" << endl; + out << + indent() << "else" << endl; + scope_up(out); + + out << + indent() << "$rseqid = 0;" << endl << + indent() << "$fname = null;" << endl << + indent() << "$mtype = 0;" << endl << + endl; + + if (binary_inline_) { + t_field ffname(g_type_string, "fname"); + t_field fseqid(g_type_i32, "rseqid"); + out << + indent() << "$ver = unpack('N', $this->input_->readAll(4));" << endl << + indent() << "$ver = $ver[1];" << endl << + indent() << "$mtype = $ver & 0xff;" << endl << + indent() << "$ver = $ver & 0xffff0000;" << endl << + indent() << "if ($ver != 0x80010000) throw new TProtocolException('Bad version identifier: '.$ver, TProtocolException::BAD_VERSION);" << endl; + generate_deserialize_field(out, &ffname, "", true); + generate_deserialize_field(out, &fseqid, "", true); + } else { + out << + indent() << "$this->input_->readMessageBegin($fname, $mtype, $rseqid);" << endl << + indent() << "if ($mtype == TMessageType::EXCEPTION) {" << endl << + indent() << " $x = new TApplicationException();" << endl << + indent() << " $x->read($this->input_);" << endl << + indent() << " $this->input_->readMessageEnd();" << endl << + indent() << " throw $x;" << endl << + indent() << "}" << endl; + } + + out << + indent() << "$result = new " << resultname << "();" << endl << + indent() << "$result->read($this->input_);" << endl; + + if (!binary_inline_) { + out << + indent() << "$this->input_->readMessageEnd();" << endl; + } + + scope_down(out); + + // Careful, only return result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + out << + indent() << "if ($result->success !== null) {" << endl << + indent() << " return $result->success;" << endl << + indent() << "}" << endl; + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + out << + indent() << "if ($result->" << (*x_iter)->get_name() << " !== null) {" << endl << + indent() << " throw $result->" << (*x_iter)->get_name() << ";" << endl << + indent() << "}" << endl; + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(out) << + "return;" << endl; + } else { + out << + indent() << "throw new Exception(\"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; + } + + // Close function + scope_down(out); + out << endl; + + } + } + + indent_down(); + out << + "}" << endl << endl; +} + +/** + * Deserializes a field of any type. + */ +void t_php_generator::generate_deserialize_field(ofstream &out, + t_field* tfield, + string prefix, + bool inclass) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name(); + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type, + name); + } else { + + if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + + if (binary_inline_) { + std::string itrans = (inclass ? "$this->input_" : "$input"); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + out << + indent() << "$len = unpack('N', " << itrans << "->readAll(4));" << endl << + indent() << "$len = $len[1];" << endl << + indent() << "if ($len > 0x7fffffff) {" << endl << + indent() << " $len = 0 - (($len - 1) ^ 0xffffffff);" << endl << + indent() << "}" << endl << + indent() << "$" << name << " = " << itrans << "->readAll($len);" << endl; + break; + case t_base_type::TYPE_BOOL: + out << + indent() << "$" << name << " = unpack('c', " << itrans << "->readAll(1));" << endl << + indent() << "$" << name << " = (bool)$" << name << "[1];" << endl; + break; + case t_base_type::TYPE_BYTE: + out << + indent() << "$" << name << " = unpack('c', " << itrans << "->readAll(1));" << endl << + indent() << "$" << name << " = $" << name << "[1];" << endl; + break; + case t_base_type::TYPE_I16: + out << + indent() << "$val = unpack('n', " << itrans << "->readAll(2));" << endl << + indent() << "$val = $val[1];" << endl << + indent() << "if ($val > 0x7fff) {" << endl << + indent() << " $val = 0 - (($val - 1) ^ 0xffff);" << endl << + indent() << "}" << endl << + indent() << "$" << name << " = $val;" << endl; + break; + case t_base_type::TYPE_I32: + out << + indent() << "$val = unpack('N', " << itrans << "->readAll(4));" << endl << + indent() << "$val = $val[1];" << endl << + indent() << "if ($val > 0x7fffffff) {" << endl << + indent() << " $val = 0 - (($val - 1) ^ 0xffffffff);" << endl << + indent() << "}" << endl << + indent() << "$" << name << " = $val;" << endl; + break; + case t_base_type::TYPE_I64: + out << + indent() << "$arr = unpack('N2', " << itrans << "->readAll(8));" << endl << + indent() << "if ($arr[1] & 0x80000000) {" << endl << + indent() << " $arr[1] = $arr[1] ^ 0xFFFFFFFF;" << endl << + indent() << " $arr[2] = $arr[2] ^ 0xFFFFFFFF;" << endl << + indent() << " $" << name << " = 0 - $arr[1]*4294967296 - $arr[2] - 1;" << endl << + indent() << "} else {" << endl << + indent() << " $" << name << " = $arr[1]*4294967296 + $arr[2];" << endl << + indent() << "}" << endl; + break; + case t_base_type::TYPE_DOUBLE: + out << + indent() << "$arr = unpack('d', strrev(" << itrans << "->readAll(8)));" << endl << + indent() << "$" << name << " = $arr[1];" << endl; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase) + tfield->get_name(); + } + } else if (type->is_enum()) { + out << + indent() << "$val = unpack('N', " << itrans << "->readAll(4));" << endl << + indent() << "$val = $val[1];" << endl << + indent() << "if ($val > 0x7fffffff) {" << endl << + indent() << " $val = 0 - (($val - 1) ^ 0xffffffff);" << endl << + indent() << "}" << endl << + indent() << "$" << name << " = $val;" << endl; + } + } else { + + indent(out) << + "$xfer += $input->"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + out << "readString($" << name << ");"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool($" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte($" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "readI16($" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "readI32($" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "readI64($" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble($" << name << ");"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32($" << name << ");"; + } + out << endl; + } + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type->get_name().c_str()); + } + } +} + +/** + * Generates an unserializer for a variable. This makes two key assumptions, + * first that there is a const char* variable named data that points to the + * buffer for deserialization, and that there is a variable protocol which + * is a reference to a TProtocol serialization object. + */ +void t_php_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + out << + indent() << "$" << prefix << " = new " << php_namespace(tstruct->get_program()) << tstruct->get_name() << "();" << endl << + indent() << "$xfer += $" << prefix << "->read($input);" << endl; +} + +void t_php_generator::generate_deserialize_container(ofstream &out, + t_type* ttype, + string prefix) { + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + out << + indent() << "$" << prefix << " = array();" << endl << + indent() << "$" << size << " = 0;" << endl; + + // Declare variables, read header + if (ttype->is_map()) { + out << + indent() << "$" << ktype << " = 0;" << endl << + indent() << "$" << vtype << " = 0;" << endl; + if (binary_inline_) { + generate_deserialize_field(out, &fktype); + generate_deserialize_field(out, &fvtype); + generate_deserialize_field(out, &fsize); + } else { + out << + indent() << "$xfer += $input->readMapBegin(" << + "$" << ktype << ", $" << vtype << ", $" << size << ");" << endl; + } + } else if (ttype->is_set()) { + if (binary_inline_) { + generate_deserialize_field(out, &fetype); + generate_deserialize_field(out, &fsize); + } else { + out << + indent() << "$" << etype << " = 0;" << endl << + indent() << "$xfer += $input->readSetBegin(" << + "$" << etype << ", $" << size << ");" << endl; + } + } else if (ttype->is_list()) { + if (binary_inline_) { + generate_deserialize_field(out, &fetype); + generate_deserialize_field(out, &fsize); + } else { + out << + indent() << "$" << etype << " = 0;" << endl << + indent() << "$xfer += $input->readListBegin(" << + "$" << etype << ", $" << size << ");" << endl; + } + } + + // For loop iterates over elements + string i = tmp("_i"); + indent(out) << + "for ($" << + i << " = 0; $" << i << " < $" << size << "; ++$" << i << ")" << endl; + + scope_up(out); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + scope_down(out); + + if (!binary_inline_) { + // Read container end + if (ttype->is_map()) { + indent(out) << "$xfer += $input->readMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << "$xfer += $input->readSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << "$xfer += $input->readListEnd();" << endl; + } + } +} + + +/** + * Generates code to deserialize a map + */ +void t_php_generator::generate_deserialize_map_element(ofstream &out, + t_map* tmap, + string prefix) { + string key = tmp("key"); + string val = tmp("val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + indent(out) << + declare_field(&fkey, true, true) << endl; + indent(out) << + declare_field(&fval, true, true) << endl; + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << + "$" << prefix << "[$" << key << "] = $" << val << ";" << endl; +} + +void t_php_generator::generate_deserialize_set_element(ofstream &out, + t_set* tset, + string prefix) { + string elem = tmp("elem"); + t_field felem(tset->get_elem_type(), elem); + + indent(out) << + "$" << elem << " = null;" << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + "$" << prefix << "[$" << elem << "] = true;" << endl; +} + +void t_php_generator::generate_deserialize_list_element(ofstream &out, + t_list* tlist, + string prefix) { + string elem = tmp("elem"); + t_field felem(tlist->get_elem_type(), elem); + + indent(out) << + "$" << elem << " = null;" << endl; + + generate_deserialize_field(out, &felem); + + indent(out) << + "$" << prefix << " []= $" << elem << ";" << endl; +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param prefix Name to prepend to field name + */ +void t_php_generator::generate_serialize_field(ofstream &out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + prefix + tfield->get_name()); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + prefix + tfield->get_name()); + } else if (type->is_base_type() || type->is_enum()) { + + string name = prefix + tfield->get_name(); + + if (binary_inline_) { + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << + indent() << "$output .= pack('N', strlen($" << name << "));" << endl << + indent() << "$output .= $" << name << ";" << endl; + break; + case t_base_type::TYPE_BOOL: + out << + indent() << "$output .= pack('c', $" << name << " ? 1 : 0);" << endl; + break; + case t_base_type::TYPE_BYTE: + out << + indent() << "$output .= pack('c', $" << name << ");" << endl; + break; + case t_base_type::TYPE_I16: + out << + indent() << "$output .= pack('n', $" << name << ");" << endl; + break; + case t_base_type::TYPE_I32: + out << + indent() << "$output .= pack('N', $" << name << ");" << endl; + break; + case t_base_type::TYPE_I64: + out << + indent() << "$output .= pack('N2', $" << name << " >> 32, $" << name << " & 0xFFFFFFFF);" << endl; + break; + case t_base_type::TYPE_DOUBLE: + out << + indent() << "$output .= strrev(pack('d', $" << name << "));" << endl; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << + indent() << "$output .= pack('N', $" << name << ");" << endl; + } + } else { + + indent(out) << + "$xfer += $output->"; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString($" << name << ");"; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool($" << name << ");"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte($" << name << ");"; + break; + case t_base_type::TYPE_I16: + out << "writeI16($" << name << ");"; + break; + case t_base_type::TYPE_I32: + out << "writeI32($" << name << ");"; + break; + case t_base_type::TYPE_I64: + out << "writeI64($" << name << ");"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble($" << name << ");"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32($" << name << ");"; + } + out << endl; + } + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s%s' TYPE '%s'\n", + prefix.c_str(), + tfield->get_name().c_str(), + type->get_name().c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_php_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) << + "$xfer += $" << prefix << "->write($output);" << endl; +} + +/** + * Writes out a container + */ +void t_php_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + scope_up(out); + + if (ttype->is_map()) { + if (binary_inline_) { + out << + indent() << "$output .= pack('c', " << type_to_enum(((t_map*)ttype)->get_key_type()) << ");" << endl << + indent() << "$output .= pack('c', " << type_to_enum(((t_map*)ttype)->get_val_type()) << ");" << endl << + indent() << "$output .= strrev(pack('l', count($" << prefix << ")));" << endl; + } else { + indent(out) << + "$output->writeMapBegin(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + "count($" << prefix << "));" << endl; + } + } else if (ttype->is_set()) { + if (binary_inline_) { + out << + indent() << "$output .= pack('c', " << type_to_enum(((t_set*)ttype)->get_elem_type()) << ");" << endl << + indent() << "$output .= strrev(pack('l', count($" << prefix << ")));" << endl; + + } else { + indent(out) << + "$output->writeSetBegin(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + "count($" << prefix << "));" << endl; + } + } else if (ttype->is_list()) { + if (binary_inline_) { + out << + indent() << "$output .= pack('c', " << type_to_enum(((t_list*)ttype)->get_elem_type()) << ");" << endl << + indent() << "$output .= strrev(pack('l', count($" << prefix << ")));" << endl; + + } else { + indent(out) << + "$output->writeListBegin(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + "count($" << prefix << "));" << endl; + } + } + + scope_up(out); + + if (ttype->is_map()) { + string kiter = tmp("kiter"); + string viter = tmp("viter"); + indent(out) << + "foreach ($" << prefix << " as " << + "$" << kiter << " => $" << viter << ")" << endl; + scope_up(out); + generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); + scope_down(out); + } else if (ttype->is_set()) { + string iter = tmp("iter"); + indent(out) << + "foreach ($" << prefix << " as $" << iter << " => $true)" << endl; + scope_up(out); + generate_serialize_set_element(out, (t_set*)ttype, iter); + scope_down(out); + } else if (ttype->is_list()) { + string iter = tmp("iter"); + indent(out) << + "foreach ($" << prefix << " as $" << iter << ")" << endl; + scope_up(out); + generate_serialize_list_element(out, (t_list*)ttype, iter); + scope_down(out); + } + + scope_down(out); + + if (!binary_inline_) { + if (ttype->is_map()) { + indent(out) << + "$output->writeMapEnd();" << endl; + } else if (ttype->is_set()) { + indent(out) << + "$output->writeSetEnd();" << endl; + } else if (ttype->is_list()) { + indent(out) << + "$output->writeListEnd();" << endl; + } + } + + scope_down(out); +} + +/** + * Serializes the members of a map. + * + */ +void t_php_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + generate_serialize_field(out, &kfield, ""); + + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield, ""); +} + +/** + * Serializes the members of a set. + */ +void t_php_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Serializes the members of a list. + */ +void t_php_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Declares a field, which may include initialization as necessary. + * + * @param ttype The type + */ +string t_php_generator::declare_field(t_field* tfield, bool init, bool obj) { + string result = "$" + tfield->get_name(); + if (init) { + t_type* type = get_true_type(tfield->get_type()); + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + break; + case t_base_type::TYPE_STRING: + result += " = ''"; + break; + case t_base_type::TYPE_BOOL: + result += " = false"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + result += " = 0"; + break; + case t_base_type::TYPE_DOUBLE: + result += " = 0.0"; + break; + default: + throw "compiler error: no PHP initializer for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + result += " = 0"; + } else if (type->is_container()) { + result += " = array()"; + } else if (type->is_struct() || type->is_xception()) { + if (obj) { + result += " = new " + php_namespace(type->get_program()) + type->get_name() + "()"; + } else { + result += " = null"; + } + } + } + return result + ";"; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_php_generator::function_signature(t_function* tfunction, + string prefix) { + return + prefix + tfunction->get_name() + + "(" + argument_list(tfunction->get_arglist()) + ")"; +} + +/** + * Renders a field list + */ +string t_php_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += "$" + (*f_iter)->get_name(); + } + return result; +} + +/** + * Gets a typecast string for a particular type. + */ +string t_php_generator::type_to_cast(t_type* type) { + if (type->is_base_type()) { + t_base_type* btype = (t_base_type*)type; + switch (btype->get_base()) { + case t_base_type::TYPE_BOOL: + return "(bool)"; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + return "(int)"; + case t_base_type::TYPE_DOUBLE: + return "(double)"; + case t_base_type::TYPE_STRING: + return "(string)"; + default: + return ""; + } + } else if (type->is_enum()) { + return "(int)"; + } + return ""; +} + +/** + * Converts the parse type to a C++ enum string for the given type. + */ +string t_php_generator ::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType::STRING"; + case t_base_type::TYPE_BOOL: + return "TType::BOOL"; + case t_base_type::TYPE_BYTE: + return "TType::BYTE"; + case t_base_type::TYPE_I16: + return "TType::I16"; + case t_base_type::TYPE_I32: + return "TType::I32"; + case t_base_type::TYPE_I64: + return "TType::I64"; + case t_base_type::TYPE_DOUBLE: + return "TType::DOUBLE"; + } + } else if (type->is_enum()) { + return "TType::I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType::STRUCT"; + } else if (type->is_map()) { + return "TType::MAP"; + } else if (type->is_set()) { + return "TType::SET"; + } else if (type->is_list()) { + return "TType::LST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +THRIFT_REGISTER_GENERATOR(php, "PHP", +" inlined: Generate PHP inlined files\n" +" server: Generate PHP server stubs\n" +" autoload: Generate PHP with autoload\n" +" oop: Generate PHP with object oriented subclasses\n" +" rest: Generate PHP REST processors\n" +); diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc new file mode 100644 index 00000000..343c982b --- /dev/null +++ b/compiler/cpp/src/generate/t_py_generator.cc @@ -0,0 +1,2310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "t_generator.h" +#include "platform.h" +using namespace std; + + +/** + * Python code generator. + * + */ +class t_py_generator : public t_generator { + public: + t_py_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_generator(program) + { + std::map::const_iterator iter; + + iter = parsed_options.find("new_style"); + gen_newstyle_ = (iter != parsed_options.end()); + + iter = parsed_options.find("twisted"); + gen_twisted_ = (iter != parsed_options.end()); + + out_dir_base_ = "gen-py"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_py_struct(t_struct* tstruct, bool is_exception); + void generate_py_struct_definition(std::ofstream& out, t_struct* tstruct, bool is_xception=false, bool is_result=false); + void generate_py_struct_reader(std::ofstream& out, t_struct* tstruct); + void generate_py_struct_writer(std::ofstream& out, t_struct* tstruct); + void generate_py_function_helpers(t_function* tfunction); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_remote (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix="", + bool inclass=false); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream &out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + void generate_python_docstring (std::ofstream& out, + t_struct* tstruct); + + void generate_python_docstring (std::ofstream& out, + t_function* tfunction); + + void generate_python_docstring (std::ofstream& out, + t_doc* tdoc, + t_struct* tstruct, + const char* subheader); + + void generate_python_docstring (std::ofstream& out, + t_doc* tdoc); + + /** + * Helper rendering functions + */ + + std::string py_autogen_comment(); + std::string py_imports(); + std::string render_includes(); + std::string render_fastbinary_includes(); + std::string declare_argument(t_field* tfield); + std::string render_field_default_value(t_field* tfield); + std::string type_name(t_type* ttype); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string function_signature_if(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + std::string type_to_spec_args(t_type* ttype); + + static std::string get_real_py_module(const t_program* program) { + std::string real_module = program->get_namespace("py"); + if (real_module.empty()) { + return program->get_name(); + } + return real_module; + } + + private: + + /** + * True iff we should generate new-style classes. + */ + bool gen_newstyle_; + + /** + * True iff we should generate Twisted-friendly RPC services. + */ + bool gen_twisted_; + + /** + * File streams + */ + + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_service_; + + std::string package_dir_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_py_generator::init_generator() { + // Make output directory + string module = get_real_py_module(program_); + package_dir_ = get_out_dir(); + while (true) { + // TODO: Do better error checking here. + MKDIR(package_dir_.c_str()); + std::ofstream init_py((package_dir_+"/__init__.py").c_str()); + init_py.close(); + if (module.empty()) { + break; + } + string::size_type pos = module.find('.'); + if (pos == string::npos) { + package_dir_ += "/"; + package_dir_ += module; + module.clear(); + } else { + package_dir_ += "/"; + package_dir_ += module.substr(0, pos); + module.erase(0, pos+1); + } + } + + // Make output file + string f_types_name = package_dir_+"/"+"ttypes.py"; + f_types_.open(f_types_name.c_str()); + + string f_consts_name = package_dir_+"/"+"constants.py"; + f_consts_.open(f_consts_name.c_str()); + + string f_init_name = package_dir_+"/__init__.py"; + ofstream f_init; + f_init.open(f_init_name.c_str()); + f_init << + "__all__ = ['ttypes', 'constants'"; + vector services = program_->get_services(); + vector::iterator sv_iter; + for (sv_iter = services.begin(); sv_iter != services.end(); ++sv_iter) { + f_init << ", '" << (*sv_iter)->get_name() << "'"; + } + f_init << "]" << endl; + f_init.close(); + + // Print header + f_types_ << + py_autogen_comment() << endl << + py_imports() << endl << + render_includes() << endl << + render_fastbinary_includes() << + endl << endl; + + f_consts_ << + py_autogen_comment() << endl << + py_imports() << endl << + "from ttypes import *" << endl << + endl; +} + +/** + * Renders all the imports necessary for including another Thrift program + */ +string t_py_generator::render_includes() { + const vector& includes = program_->get_includes(); + string result = ""; + for (size_t i = 0; i < includes.size(); ++i) { + result += "import " + get_real_py_module(includes[i]) + ".ttypes\n"; + } + if (includes.size() > 0) { + result += "\n"; + } + return result; +} + +/** + * Renders all the imports necessary to use the accelerated TBinaryProtocol + */ +string t_py_generator::render_fastbinary_includes() { + return + "from thrift.transport import TTransport\n" + "from thrift.protocol import TBinaryProtocol\n" + "try:\n" + " from thrift.protocol import fastbinary\n" + "except:\n" + " fastbinary = None\n"; +} + +/** + * Autogen'd comment + */ +string t_py_generator::py_autogen_comment() { + return + std::string("#\n") + + "# Autogenerated by Thrift\n" + + "#\n" + + "# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "#\n"; +} + +/** + * Prints standard thrift imports + */ +string t_py_generator::py_imports() { + return + string("from thrift.Thrift import *"); +} + +/** + * Closes the type files + */ +void t_py_generator::close_generator() { + // Close types file + f_types_.close(); + f_consts_.close(); +} + +/** + * Generates a typedef. This is not done in Python, types are all implicit. + * + * @param ttypedef The type definition + */ +void t_py_generator::generate_typedef(t_typedef* ttypedef) {} + +/** + * Generates code for an enumerated type. Done using a class to scope + * the values. + * + * @param tenum The enumeration + */ +void t_py_generator::generate_enum(t_enum* tenum) { + f_types_ << + "class " << tenum->get_name() << + (gen_newstyle_ ? "(object)" : "") << + ":" << endl; + indent_up(); + generate_python_docstring(f_types_, tenum); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + f_types_ << + indent() << (*c_iter)->get_name() << " = " << value << endl; + } + + indent_down(); + f_types_ << endl; +} + +/** + * Generate a constant value + */ +void t_py_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + indent(f_consts_) << name << " = " << render_const_value(type, value); + f_consts_ << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_py_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "True" : "False"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + indent(out) << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << type->get_name() << "(**{" << endl; + indent_up(); + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + out << indent(); + out << render_const_value(g_type_string, v_iter->first); + out << " : "; + out << render_const_value(field_type, v_iter->second); + out << "," << endl; + } + indent_down(); + indent(out) << "})"; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + out << "{" << endl; + indent_up(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(ktype, v_iter->first); + out << " : "; + out << render_const_value(vtype, v_iter->second); + out << "," << endl; + } + indent_down(); + indent(out) << "}"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + if (type->is_set()) { + out << "set("; + } + out << "[" << endl; + indent_up(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(etype, *v_iter); + out << "," << endl; + } + indent_down(); + indent(out) << "]"; + if (type->is_set()) { + out << ")"; + } + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + + return out.str(); +} + +/** + * Generates a python struct + */ +void t_py_generator::generate_struct(t_struct* tstruct) { + generate_py_struct(tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_py_generator::generate_xception(t_struct* txception) { + generate_py_struct(txception, true); +} + +/** + * Generates a python struct + */ +void t_py_generator::generate_py_struct(t_struct* tstruct, + bool is_exception) { + generate_py_struct_definition(f_types_, tstruct, is_exception); +} + +/** + * Generates a struct definition for a thrift data type. + * + * @param tstruct The struct definition + */ +void t_py_generator::generate_py_struct_definition(ofstream& out, + t_struct* tstruct, + bool is_exception, + bool is_result) { + + const vector& members = tstruct->get_members(); + const vector& sorted_members = tstruct->get_sorted_members(); + vector::const_iterator m_iter; + + out << + "class " << tstruct->get_name(); + if (is_exception) { + out << "(Exception)"; + } else if (gen_newstyle_) { + out << "(object)"; + } + out << + ":" << endl; + indent_up(); + generate_python_docstring(out, tstruct); + + out << endl; + + /* + Here we generate the structure specification for the fastbinary codec. + These specifications have the following structure: + thrift_spec -> tuple of item_spec + item_spec -> None | (tag, type_enum, name, spec_args, default) + tag -> integer + type_enum -> TType.I32 | TType.STRING | TType.STRUCT | ... + name -> string_literal + default -> None # Handled by __init__ + spec_args -> None # For simple types + | (type_enum, spec_args) # Value type for list/set + | (type_enum, spec_args, type_enum, spec_args) + # Key and value for map + | (class_name, spec_args_ptr) # For struct/exception + class_name -> identifier # Basically a pointer to the class + spec_args_ptr -> expression # just class_name.spec_args + + TODO(dreiss): Consider making this work for structs with negative tags. + */ + + // TODO(dreiss): Look into generating an empty tuple instead of None + // for structures with no members. + // TODO(dreiss): Test encoding of structs where some inner structs + // don't have thrift_spec. + if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) { + indent(out) << "thrift_spec = (" << endl; + indent_up(); + + int sorted_keys_pos = 0; + for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) { + + for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) { + indent(out) << "None, # " << sorted_keys_pos << endl; + } + + indent(out) << "(" << (*m_iter)->get_key() << ", " + << type_to_enum((*m_iter)->get_type()) << ", " + << "'" << (*m_iter)->get_name() << "'" << ", " + << type_to_spec_args((*m_iter)->get_type()) << ", " + << render_field_default_value(*m_iter) << ", " + << ")," + << " # " << sorted_keys_pos + << endl; + + sorted_keys_pos ++; + } + + indent_down(); + indent(out) << ")" << endl << endl; + } else { + indent(out) << "thrift_spec = None" << endl; + } + + + if (members.size() > 0) { + out << + indent() << "def __init__(self,"; + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + // This fills in default values, as opposed to nulls + out << " " << declare_argument(*m_iter) << ","; + } + + out << "):" << endl; + + indent_up(); + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + // Initialize fields + t_type* type = (*m_iter)->get_type(); + if (!type->is_base_type() && !type->is_enum() && (*m_iter)->get_value() != NULL) { + indent(out) << + "if " << (*m_iter)->get_name() << " is " << "self.thrift_spec[" << + (*m_iter)->get_key() << "][4]:" << endl; + indent(out) << " " << (*m_iter)->get_name() << " = " << + render_field_default_value(*m_iter) << endl; + } + indent(out) << + "self." << (*m_iter)->get_name() << " = " << (*m_iter)->get_name() << endl; + } + + indent_down(); + + out << endl; + } + + generate_py_struct_reader(out, tstruct); + generate_py_struct_writer(out, tstruct); + + // For exceptions only, generate a __str__ method. This is + // because when raised exceptions are printed to the console, __repr__ + // isn't used. See python bug #5882 + if (is_exception) { + out << + indent() << "def __str__(self):" << endl << + indent() << " return repr(self)" << endl << + endl; + } + + // Printing utilities so that on the command line thrift + // structs look pretty like dictionaries + out << + indent() << "def __repr__(self):" << endl << + indent() << " L = ['%s=%r' % (key, value)" << endl << + indent() << " for key, value in self.__dict__.iteritems()]" << endl << + indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl << + endl; + + // Equality and inequality methods that compare by value + out << + indent() << "def __eq__(self, other):" << endl; + indent_up(); + out << + indent() << "return isinstance(other, self.__class__) and " + "self.__dict__ == other.__dict__" << endl; + indent_down(); + out << endl; + + out << + indent() << "def __ne__(self, other):" << endl; + indent_up(); + out << + indent() << "return not (self == other)" << endl; + indent_down(); + out << endl; + + indent_down(); +} + +/** + * Generates the read method for a struct + */ +void t_py_generator::generate_py_struct_reader(ofstream& out, + t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + indent(out) << + "def read(self, iprot):" << endl; + indent_up(); + + indent(out) << + "if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated " + "and isinstance(iprot.trans, TTransport.CReadableTransport) " + "and self.thrift_spec is not None " + "and fastbinary is not None:" << endl; + indent_up(); + + indent(out) << + "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" << endl; + indent(out) << + "return" << endl; + indent_down(); + + indent(out) << + "iprot.readStructBegin()" << endl; + + // Loop over reading in fields + indent(out) << + "while True:" << endl; + indent_up(); + + // Read beginning field marker + indent(out) << + "(fname, ftype, fid) = iprot.readFieldBegin()" << endl; + + // Check for field STOP marker and break + indent(out) << + "if ftype == TType.STOP:" << endl; + indent_up(); + indent(out) << + "break" << endl; + indent_down(); + + // Switch statement on the field we are reading + bool first = true; + + // Generate deserialization code for known cases + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + out << + indent() << "if "; + } else { + out << + indent() << "elif "; + } + out << "fid == " << (*f_iter)->get_key() << ":" << endl; + indent_up(); + indent(out) << "if ftype == " << type_to_enum((*f_iter)->get_type()) << ":" << endl; + indent_up(); + generate_deserialize_field(out, *f_iter, "self."); + indent_down(); + out << + indent() << "else:" << endl << + indent() << " iprot.skip(ftype)" << endl; + indent_down(); + } + + // In the default case we skip the field + out << + indent() << "else:" << endl << + indent() << " iprot.skip(ftype)" << endl; + + // Read field end marker + indent(out) << + "iprot.readFieldEnd()" << endl; + + indent_down(); + + indent(out) << + "iprot.readStructEnd()" << endl; + + indent_down(); + out << endl; +} + +void t_py_generator::generate_py_struct_writer(ofstream& out, + t_struct* tstruct) { + string name = tstruct->get_name(); + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator f_iter; + + indent(out) << + "def write(self, oprot):" << endl; + indent_up(); + + indent(out) << + "if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated " + "and self.thrift_spec is not None " + "and fastbinary is not None:" << endl; + indent_up(); + + indent(out) << + "oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))" << endl; + indent(out) << + "return" << endl; + indent_down(); + + indent(out) << + "oprot.writeStructBegin('" << name << "')" << endl; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + // Write field header + indent(out) << + "if self." << (*f_iter)->get_name() << " != None:" << endl; + indent_up(); + indent(out) << + "oprot.writeFieldBegin(" << + "'" << (*f_iter)->get_name() << "', " << + type_to_enum((*f_iter)->get_type()) << ", " << + (*f_iter)->get_key() << ")" << endl; + + // Write field contents + generate_serialize_field(out, *f_iter, "self."); + + // Write field closer + indent(out) << + "oprot.writeFieldEnd()" << endl; + + indent_down(); + } + + // Write the struct map + out << + indent() << "oprot.writeFieldStop()" << endl << + indent() << "oprot.writeStructEnd()" << endl; + + indent_down(); + out << + endl; +} + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_py_generator::generate_service(t_service* tservice) { + string f_service_name = package_dir_+"/"+service_name_+".py"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + py_autogen_comment() << endl << + py_imports() << endl; + + if (tservice->get_extends() != NULL) { + f_service_ << + "import " << get_real_py_module(tservice->get_extends()->get_program()) << + "." << tservice->get_extends()->get_name() << endl; + } + + f_service_ << + "from ttypes import *" << endl << + "from thrift.Thrift import TProcessor" << endl << + render_fastbinary_includes() << endl; + + if (gen_twisted_) { + f_service_ << + "from zope.interface import Interface, implements" << endl << + "from twisted.internet import defer" << endl << + "from thrift.transport import TTwisted" << endl; + } + + f_service_ << endl; + + // Generate the three main parts of the service (well, two for now in PHP) + generate_service_interface(tservice); + generate_service_client(tservice); + generate_service_server(tservice); + generate_service_helpers(tservice); + generate_service_remote(tservice); + + // Close service file + f_service_ << endl; + f_service_.close(); +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_py_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + f_service_ << + "# HELPER FUNCTIONS AND STRUCTURES" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_py_struct_definition(f_service_, ts, false); + generate_py_function_helpers(*f_iter); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_py_generator::generate_py_function_helpers(t_function* tfunction) { + if (!tfunction->is_oneway()) { + t_struct result(program_, tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + generate_py_struct_definition(f_service_, &result, false, true); + } +} + +/** + * Generates a service interface definition. + * + * @param tservice The service to generate a header definition for + */ +void t_py_generator::generate_service_interface(t_service* tservice) { + string extends = ""; + string extends_if = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_if = "(" + extends + ".Iface)"; + } else { + if (gen_twisted_) { + extends_if = "(Interface)"; + } + } + + f_service_ << + "class Iface" << extends_if << ":" << endl; + indent_up(); + generate_python_docstring(f_service_, tservice); + vector functions = tservice->get_functions(); + if (functions.empty()) { + f_service_ << + indent() << "pass" << endl; + } else { + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "def " << function_signature_if(*f_iter) << ":" << endl; + indent_up(); + generate_python_docstring(f_service_, (*f_iter)); + f_service_ << + indent() << "pass" << endl << endl; + indent_down(); + } + } + + indent_down(); + f_service_ << + endl; +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_py_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + if (gen_twisted_) { + extends_client = "(" + extends + ".Client)"; + } else { + extends_client = extends + ".Client, "; + } + } else { + if (gen_twisted_ && gen_newstyle_) { + extends_client = "(object)"; + } + } + + if (gen_twisted_) { + f_service_ << + "class Client" << extends_client << ":" << endl << + " implements(Iface)" << endl << endl; + } else { + f_service_ << + "class Client(" << extends_client << "Iface):" << endl; + } + indent_up(); + generate_python_docstring(f_service_, tservice); + + // Constructor function + if (gen_twisted_) { + f_service_ << + indent() << "def __init__(self, transport, oprot_factory):" << endl; + } else { + f_service_ << + indent() << "def __init__(self, iprot, oprot=None):" << endl; + } + if (extends.empty()) { + if (gen_twisted_) { + f_service_ << + indent() << " self._transport = transport" << endl << + indent() << " self._oprot_factory = oprot_factory" << endl << + indent() << " self._seqid = 0" << endl << + indent() << " self._reqs = {}" << endl << + endl; + } else { + f_service_ << + indent() << " self._iprot = self._oprot = iprot" << endl << + indent() << " if oprot != None:" << endl << + indent() << " self._oprot = oprot" << endl << + indent() << " self._seqid = 0" << endl << + endl; + } + } else { + if (gen_twisted_) { + f_service_ << + indent() << " " << extends << ".Client.__init__(self, transport, oprot_factory)" << endl << + endl; + } else { + f_service_ << + indent() << " " << extends << ".Client.__init__(self, iprot, oprot)" << endl << + endl; + } + } + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + // Open function + indent(f_service_) << + "def " << function_signature(*f_iter) << ":" << endl; + indent_up(); + generate_python_docstring(f_service_, (*f_iter)); + if (gen_twisted_) { + indent(f_service_) << "self._seqid += 1" << endl; + if (!(*f_iter)->is_oneway()) { + indent(f_service_) << + "d = self._reqs[self._seqid] = defer.Deferred()" << endl; + } + } + + indent(f_service_) << + "self.send_" << funname << "("; + + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << (*fld_iter)->get_name(); + } + f_service_ << ")" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (gen_twisted_) { + f_service_ << "return d" << endl; + } else { + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << "return "; + } + f_service_ << + "self.recv_" << funname << "()" << endl; + } + } else { + if (gen_twisted_) { + f_service_ << + indent() << "return defer.succeed(None)" << endl; + } + } + indent_down(); + f_service_ << endl; + + indent(f_service_) << + "def send_" << function_signature(*f_iter) << ":" << endl; + + indent_up(); + + std::string argsname = (*f_iter)->get_name() + "_args"; + + // Serialize the request header + if (gen_twisted_) { + f_service_ << + indent() << "oprot = self._oprot_factory.getProtocol(self._transport)" << endl << + indent() << + "oprot.writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType.CALL, self._seqid)" + << endl; + } else { + f_service_ << + indent() << "self._oprot.writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType.CALL, self._seqid)" << endl; + } + + f_service_ << + indent() << "args = " << argsname << "()" << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << + indent() << "args." << (*fld_iter)->get_name() << " = " << (*fld_iter)->get_name() << endl; + } + + // Write to the stream + if (gen_twisted_) { + f_service_ << + indent() << "args.write(oprot)" << endl << + indent() << "oprot.writeMessageEnd()" << endl << + indent() << "oprot.trans.flush()" << endl; + } else { + f_service_ << + indent() << "args.write(self._oprot)" << endl << + indent() << "self._oprot.writeMessageEnd()" << endl << + indent() << "self._oprot.trans.flush()" << endl; + } + + indent_down(); + + if (!(*f_iter)->is_oneway()) { + std::string resultname = (*f_iter)->get_name() + "_result"; + // Open function + f_service_ << + endl; + if (gen_twisted_) { + f_service_ << + indent() << "def recv_" << (*f_iter)->get_name() << + "(self, iprot, mtype, rseqid):" << endl; + } else { + t_struct noargs(program_); + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + f_service_ << + indent() << "def " << function_signature(&recv_function) << ":" << endl; + } + indent_up(); + + // TODO(mcslee): Validate message reply here, seq ids etc. + + if (gen_twisted_) { + f_service_ << + indent() << "d = self._reqs.pop(rseqid)" << endl; + } else { + f_service_ << + indent() << "(fname, mtype, rseqid) = self._iprot.readMessageBegin()" << endl; + } + + f_service_ << + indent() << "if mtype == TMessageType.EXCEPTION:" << endl << + indent() << " x = TApplicationException()" << endl; + + if (gen_twisted_) { + f_service_ << + indent() << " x.read(iprot)" << endl << + indent() << " iprot.readMessageEnd()" << endl << + indent() << " return d.errback(x)" << endl << + indent() << "result = " << resultname << "()" << endl << + indent() << "result.read(iprot)" << endl << + indent() << "iprot.readMessageEnd()" << endl; + } else { + f_service_ << + indent() << " x.read(self._iprot)" << endl << + indent() << " self._iprot.readMessageEnd()" << endl << + indent() << " raise x" << endl << + indent() << "result = " << resultname << "()" << endl << + indent() << "result.read(self._iprot)" << endl << + indent() << "self._iprot.readMessageEnd()" << endl; + } + + // Careful, only return _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "if result.success != None:" << endl; + if (gen_twisted_) { + f_service_ << + indent() << " return d.callback(result.success)" << endl; + } else { + f_service_ << + indent() << " return result.success" << endl; + } + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "if result." << (*x_iter)->get_name() << " != None:" << endl; + if (gen_twisted_) { + f_service_ << + indent() << " return d.errback(result." << (*x_iter)->get_name() << ")" << endl; + + } else { + f_service_ << + indent() << " raise result." << (*x_iter)->get_name() << "" << endl; + } + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + if (gen_twisted_) { + indent(f_service_) << + "return d.callback(None)" << endl; + } else { + indent(f_service_) << + "return" << endl; + } + } else { + if (gen_twisted_) { + f_service_ << + indent() << "return d.errback(TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\"))" << endl; + } else { + f_service_ << + indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl; + } + } + + // Close function + indent_down(); + f_service_ << endl; + } + } + + indent_down(); + f_service_ << + endl; +} + +/** + * Generates a command line tool for making remote requests + * + * @param tservice The service to generate a remote for. + */ +void t_py_generator::generate_service_remote(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string f_remote_name = package_dir_+"/"+service_name_+"-remote"; + ofstream f_remote; + f_remote.open(f_remote_name.c_str()); + + f_remote << + "#!/usr/bin/env python" << endl << + py_autogen_comment() << endl << + "import sys" << endl << + "import pprint" << endl << + "from urlparse import urlparse" << endl << + "from thrift.transport import TTransport" << endl << + "from thrift.transport import TSocket" << endl << + "from thrift.transport import THttpClient" << endl << + "from thrift.protocol import TBinaryProtocol" << endl << + endl; + + f_remote << + "import " << service_name_ << endl << + "from ttypes import *" << endl << + endl; + + f_remote << + "if len(sys.argv) <= 1 or sys.argv[1] == '--help':" << endl << + " print ''" << endl << + " print 'Usage: ' + sys.argv[0] + ' [-h host:port] [-u url] [-f[ramed]] function [arg1 [arg2...]]'" << endl << + " print ''" << endl << + " print 'Functions:'" << endl; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_remote << + " print ' " << (*f_iter)->get_returntype()->get_name() << " " << (*f_iter)->get_name() << "("; + t_struct* arg_struct = (*f_iter)->get_arglist(); + const std::vector& args = arg_struct->get_members(); + vector::const_iterator a_iter; + int num_args = args.size(); + bool first = true; + for (int i = 0; i < num_args; ++i) { + if (first) { + first = false; + } else { + f_remote << ", "; + } + f_remote << + args[i]->get_type()->get_name() << " " << args[i]->get_name(); + } + f_remote << ")'" << endl; + } + f_remote << + " print ''" << endl << + " sys.exit(0)" << endl << + endl; + + f_remote << + "pp = pprint.PrettyPrinter(indent = 2)" << endl << + "host = 'localhost'" << endl << + "port = 9090" << endl << + "uri = ''" << endl << + "framed = False" << endl << + "http = False" << endl << + "argi = 1" << endl << + endl << + "if sys.argv[argi] == '-h':" << endl << + " parts = sys.argv[argi+1].split(':') " << endl << + " host = parts[0]" << endl << + " port = int(parts[1])" << endl << + " argi += 2" << endl << + endl << + "if sys.argv[argi] == '-u':" << endl << + " url = urlparse(sys.argv[argi+1])" << endl << + " parts = url[1].split(':') " << endl << + " host = parts[0]" << endl << + " if len(parts) > 1:" << endl << + " port = int(parts[1])" << endl << + " else:" << endl << + " port = 80" << endl << + " uri = url[2]" << endl << + " http = True" << endl << + " argi += 2" << endl << + endl << + "if sys.argv[argi] == '-f' or sys.argv[argi] == '-framed':" << endl << + " framed = True" << endl << + " argi += 1" << endl << + endl << + "cmd = sys.argv[argi]" << endl << + "args = sys.argv[argi+1:]" << endl << + endl << + "if http:" << endl << + " transport = THttpClient.THttpClient(host, port, uri)" << endl << + "else:" << endl << + " socket = TSocket.TSocket(host, port)" << endl << + " if framed:" << endl << + " transport = TTransport.TFramedTransport(socket)" << endl << + " else:" << endl << + " transport = TTransport.TBufferedTransport(socket)" << endl << + "protocol = TBinaryProtocol.TBinaryProtocol(transport)" << endl << + "client = " << service_name_ << ".Client(protocol)" << endl << + "transport.open()" << endl << + endl; + + // Generate the dispatch methods + bool first = true; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_remote << "el"; + } + + t_struct* arg_struct = (*f_iter)->get_arglist(); + const std::vector& args = arg_struct->get_members(); + vector::const_iterator a_iter; + int num_args = args.size(); + + f_remote << + "if cmd == '" << (*f_iter)->get_name() << "':" << endl << + " if len(args) != " << num_args << ":" << endl << + " print '" << (*f_iter)->get_name() << " requires " << num_args << " args'" << endl << + " sys.exit(1)" << endl << + " pp.pprint(client." << (*f_iter)->get_name() << "("; + for (int i = 0; i < num_args; ++i) { + if (args[i]->get_type()->is_string()) { + f_remote << "args[" << i << "],"; + } else { + f_remote << "eval(args[" << i << "]),"; + } + } + f_remote << "))" << endl; + + f_remote << endl; + } + + f_remote << "transport.close()" << endl; + + // Close service file + f_remote.close(); + + // Make file executable, love that bitwise OR action + chmod(f_remote_name.c_str(), + S_IRUSR + | S_IWUSR + | S_IXUSR +#ifndef MINGW + | S_IRGRP + | S_IXGRP + | S_IROTH + | S_IXOTH +#endif + ); +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_py_generator::generate_service_server(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_processor = extends + ".Processor, "; + } + + // Generate the header portion + if (gen_twisted_) { + f_service_ << + "class Processor(" << extends_processor << "TProcessor):" << endl << + " implements(Iface)" << endl << endl; + } else { + f_service_ << + "class Processor(" << extends_processor << "Iface, TProcessor):" << endl; + } + + indent_up(); + + indent(f_service_) << + "def __init__(self, handler):" << endl; + indent_up(); + if (extends.empty()) { + if (gen_twisted_) { + f_service_ << + indent() << "self._handler = Iface(handler)" << endl; + } else { + f_service_ << + indent() << "self._handler = handler" << endl; + } + + f_service_ << + indent() << "self._processMap = {}" << endl; + } else { + if (gen_twisted_) { + f_service_ << + indent() << extends << ".Processor.__init__(self, Iface(handler))" << endl; + } else { + f_service_ << + indent() << extends << ".Processor.__init__(self, handler)" << endl; + } + } + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + f_service_ << + indent() << "self._processMap[\"" << (*f_iter)->get_name() << "\"] = Processor.process_" << (*f_iter)->get_name() << endl; + } + indent_down(); + f_service_ << endl; + + // Generate the server implementation + indent(f_service_) << + "def process(self, iprot, oprot):" << endl; + indent_up(); + + f_service_ << + indent() << "(name, type, seqid) = iprot.readMessageBegin()" << endl; + + // TODO(mcslee): validate message + + // HOT: dictionary function lookup + f_service_ << + indent() << "if name not in self._processMap:" << endl << + indent() << " iprot.skip(TType.STRUCT)" << endl << + indent() << " iprot.readMessageEnd()" << endl << + indent() << " x = TApplicationException(TApplicationException.UNKNOWN_METHOD, 'Unknown function %s' % (name))" << endl << + indent() << " oprot.writeMessageBegin(name, TMessageType.EXCEPTION, seqid)" << endl << + indent() << " x.write(oprot)" << endl << + indent() << " oprot.writeMessageEnd()" << endl << + indent() << " oprot.trans.flush()" << endl; + + if (gen_twisted_) { + f_service_ << + indent() << " return defer.succeed(None)" << endl; + } else { + f_service_ << + indent() << " return" << endl; + } + + f_service_ << + indent() << "else:" << endl; + + if (gen_twisted_) { + f_service_ << + indent() << " return self._processMap[name](self, seqid, iprot, oprot)" << endl; + } else { + f_service_ << + indent() << " self._processMap[name](self, seqid, iprot, oprot)" << endl; + + // Read end of args field, the T_STOP, and the struct close + f_service_ << + indent() << "return True" << endl; + } + + indent_down(); + f_service_ << endl; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + indent_down(); + f_service_ << endl; +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_py_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + indent(f_service_) << + "def process_" << tfunction->get_name() << + "(self, seqid, iprot, oprot):" << endl; + indent_up(); + + string argsname = tfunction->get_name() + "_args"; + string resultname = tfunction->get_name() + "_result"; + + f_service_ << + indent() << "args = " << argsname << "()" << endl << + indent() << "args.read(iprot)" << endl << + indent() << "iprot.readMessageEnd()" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << "result = " << resultname << "()" << endl; + } + + if (gen_twisted_) { + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << + indent() << "d = defer.maybeDeferred(self._handler." << + tfunction->get_name() << ", "; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << (*f_iter)->get_name(); + } + f_service_ << ")" << endl; + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return d" << endl; + indent_down(); + f_service_ << endl; + return; + } + + f_service_ << + indent() << + "d.addCallback(self.write_results_success_" << + tfunction->get_name() << ", result, seqid, oprot)" << endl; + + if (xceptions.size() > 0) { + f_service_ << + indent() << + "d.addErrback(self.write_results_exception_" << + tfunction->get_name() << ", result, seqid, oprot)" << endl; + } + + f_service_ << + indent() << "return d" << endl; + + indent_down(); + f_service_ << endl; + + indent(f_service_) << + "def write_results_success_" << tfunction->get_name() << + "(self, success, result, seqid, oprot):" << endl; + indent_up(); + f_service_ << + indent() << "result.success = success" << endl << + indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << + "\", TMessageType.REPLY, seqid)" << endl << + indent() << "result.write(oprot)" << endl << + indent() << "oprot.writeMessageEnd()" << endl << + indent() << "oprot.trans.flush()" << endl; + indent_down(); + f_service_ << endl; + + // Try block for a function with exceptions + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent(f_service_) << + "def write_results_exception_" << tfunction->get_name() << + "(self, error, result, seqid, oprot):" << endl; + indent_up(); + f_service_ << + indent() << "try:" << endl; + + // Kinda absurd + f_service_ << + indent() << " error.raiseException()" << endl; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "except " << type_name((*x_iter)->get_type()) << ", " << (*x_iter)->get_name() << ":" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl; + indent_down(); + } else { + f_service_ << + indent() << "pass" << endl; + } + } + f_service_ << + indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << + "\", TMessageType.REPLY, seqid)" << endl << + indent() << "result.write(oprot)" << endl << + indent() << "oprot.writeMessageEnd()" << endl << + indent() << "oprot.trans.flush()" << endl; + indent_down(); + f_service_ << endl; + } + } else { + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "try:" << endl; + indent_up(); + } + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "result.success = "; + } + f_service_ << + "self._handler." << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << (*f_iter)->get_name(); + } + f_service_ << ")" << endl; + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "except " << type_name((*x_iter)->get_type()) << ", " << (*x_iter)->get_name() << ":" << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl; + indent_down(); + } else { + f_service_ << + indent() << "pass" << endl; + } + } + } + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return" << endl; + indent_down(); + f_service_ << endl; + return; + } + + f_service_ << + indent() << "oprot.writeMessageBegin(\"" << tfunction->get_name() << "\", TMessageType.REPLY, seqid)" << endl << + indent() << "result.write(oprot)" << endl << + indent() << "oprot.writeMessageEnd()" << endl << + indent() << "oprot.trans.flush()" << endl; + + // Close function + indent_down(); + f_service_ << endl; + } +} + +/** + * Deserializes a field of any type. + */ +void t_py_generator::generate_deserialize_field(ofstream &out, + t_field* tfield, + string prefix, + bool inclass) { + t_type* type = get_true_type(tfield->get_type()); + + if (type->is_void()) { + throw "CANNOT GENERATE DESERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + string name = prefix + tfield->get_name(); + + if (type->is_struct() || type->is_xception()) { + generate_deserialize_struct(out, + (t_struct*)type, + name); + } else if (type->is_container()) { + generate_deserialize_container(out, type, name); + } else if (type->is_base_type() || type->is_enum()) { + indent(out) << + name << " = iprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "compiler error: cannot serialize void field in a struct: " + + name; + break; + case t_base_type::TYPE_STRING: + out << "readString();"; + break; + case t_base_type::TYPE_BOOL: + out << "readBool();"; + break; + case t_base_type::TYPE_BYTE: + out << "readByte();"; + break; + case t_base_type::TYPE_I16: + out << "readI16();"; + break; + case t_base_type::TYPE_I32: + out << "readI32();"; + break; + case t_base_type::TYPE_I64: + out << "readI64();"; + break; + case t_base_type::TYPE_DOUBLE: + out << "readDouble();"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "readI32();"; + } + out << endl; + + } else { + printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n", + tfield->get_name().c_str(), type->get_name().c_str()); + } +} + +/** + * Generates an unserializer for a struct, calling read() + */ +void t_py_generator::generate_deserialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + out << + indent() << prefix << " = " << type_name(tstruct) << "()" << endl << + indent() << prefix << ".read(iprot)" << endl; +} + +/** + * Serialize a container by writing out the header followed by + * data and then a footer. + */ +void t_py_generator::generate_deserialize_container(ofstream &out, + t_type* ttype, + string prefix) { + string size = tmp("_size"); + string ktype = tmp("_ktype"); + string vtype = tmp("_vtype"); + string etype = tmp("_etype"); + + t_field fsize(g_type_i32, size); + t_field fktype(g_type_byte, ktype); + t_field fvtype(g_type_byte, vtype); + t_field fetype(g_type_byte, etype); + + // Declare variables, read header + if (ttype->is_map()) { + out << + indent() << prefix << " = {}" << endl << + indent() << "(" << ktype << ", " << vtype << ", " << size << " ) = iprot.readMapBegin() " << endl; + } else if (ttype->is_set()) { + out << + indent() << prefix << " = set()" << endl << + indent() << "(" << etype << ", " << size << ") = iprot.readSetBegin()" << endl; + } else if (ttype->is_list()) { + out << + indent() << prefix << " = []" << endl << + indent() << "(" << etype << ", " << size << ") = iprot.readListBegin()" << endl; + } + + // For loop iterates over elements + string i = tmp("_i"); + indent(out) << + "for " << i << " in xrange(" << size << "):" << endl; + + indent_up(); + + if (ttype->is_map()) { + generate_deserialize_map_element(out, (t_map*)ttype, prefix); + } else if (ttype->is_set()) { + generate_deserialize_set_element(out, (t_set*)ttype, prefix); + } else if (ttype->is_list()) { + generate_deserialize_list_element(out, (t_list*)ttype, prefix); + } + + indent_down(); + + // Read container end + if (ttype->is_map()) { + indent(out) << "iprot.readMapEnd()" << endl; + } else if (ttype->is_set()) { + indent(out) << "iprot.readSetEnd()" << endl; + } else if (ttype->is_list()) { + indent(out) << "iprot.readListEnd()" << endl; + } +} + + +/** + * Generates code to deserialize a map + */ +void t_py_generator::generate_deserialize_map_element(ofstream &out, + t_map* tmap, + string prefix) { + string key = tmp("_key"); + string val = tmp("_val"); + t_field fkey(tmap->get_key_type(), key); + t_field fval(tmap->get_val_type(), val); + + generate_deserialize_field(out, &fkey); + generate_deserialize_field(out, &fval); + + indent(out) << + prefix << "[" << key << "] = " << val << endl; +} + +/** + * Write a set element + */ +void t_py_generator::generate_deserialize_set_element(ofstream &out, + t_set* tset, + string prefix) { + string elem = tmp("_elem"); + t_field felem(tset->get_elem_type(), elem); + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".add(" << elem << ")" << endl; +} + +/** + * Write a list element + */ +void t_py_generator::generate_deserialize_list_element(ofstream &out, + t_list* tlist, + string prefix) { + string elem = tmp("_elem"); + t_field felem(tlist->get_elem_type(), elem); + + generate_deserialize_field(out, &felem); + + indent(out) << + prefix << ".append(" << elem << ")" << endl; +} + + +/** + * Serializes a field of any type. + * + * @param tfield The field to serialize + * @param prefix Name to prepend to field name + */ +void t_py_generator::generate_serialize_field(ofstream &out, + t_field* tfield, + string prefix) { + t_type* type = get_true_type(tfield->get_type()); + + // Do nothing for void types + if (type->is_void()) { + throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " + + prefix + tfield->get_name(); + } + + if (type->is_struct() || type->is_xception()) { + generate_serialize_struct(out, + (t_struct*)type, + prefix + tfield->get_name()); + } else if (type->is_container()) { + generate_serialize_container(out, + type, + prefix + tfield->get_name()); + } else if (type->is_base_type() || type->is_enum()) { + + string name = prefix + tfield->get_name(); + + indent(out) << + "oprot."; + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw + "compiler error: cannot serialize void field in a struct: " + name; + break; + case t_base_type::TYPE_STRING: + out << "writeString(" << name << ")"; + break; + case t_base_type::TYPE_BOOL: + out << "writeBool(" << name << ")"; + break; + case t_base_type::TYPE_BYTE: + out << "writeByte(" << name << ")"; + break; + case t_base_type::TYPE_I16: + out << "writeI16(" << name << ")"; + break; + case t_base_type::TYPE_I32: + out << "writeI32(" << name << ")"; + break; + case t_base_type::TYPE_I64: + out << "writeI64(" << name << ")"; + break; + case t_base_type::TYPE_DOUBLE: + out << "writeDouble(" << name << ")"; + break; + default: + throw "compiler error: no PHP name for base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + out << "writeI32(" << name << ")"; + } + out << endl; + } else { + printf("DO NOT KNOW HOW TO SERIALIZE FIELD '%s%s' TYPE '%s'\n", + prefix.c_str(), + tfield->get_name().c_str(), + type->get_name().c_str()); + } +} + +/** + * Serializes all the members of a struct. + * + * @param tstruct The struct to serialize + * @param prefix String prefix to attach to all fields + */ +void t_py_generator::generate_serialize_struct(ofstream &out, + t_struct* tstruct, + string prefix) { + indent(out) << + prefix << ".write(oprot)" << endl; +} + +void t_py_generator::generate_serialize_container(ofstream &out, + t_type* ttype, + string prefix) { + if (ttype->is_map()) { + indent(out) << + "oprot.writeMapBegin(" << + type_to_enum(((t_map*)ttype)->get_key_type()) << ", " << + type_to_enum(((t_map*)ttype)->get_val_type()) << ", " << + "len(" << prefix << "))" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot.writeSetBegin(" << + type_to_enum(((t_set*)ttype)->get_elem_type()) << ", " << + "len(" << prefix << "))" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot.writeListBegin(" << + type_to_enum(((t_list*)ttype)->get_elem_type()) << ", " << + "len(" << prefix << "))" << endl; + } + + if (ttype->is_map()) { + string kiter = tmp("kiter"); + string viter = tmp("viter"); + indent(out) << + "for " << kiter << "," << viter << " in " << prefix << ".items():" << endl; + indent_up(); + generate_serialize_map_element(out, (t_map*)ttype, kiter, viter); + indent_down(); + } else if (ttype->is_set()) { + string iter = tmp("iter"); + indent(out) << + "for " << iter << " in " << prefix << ":" << endl; + indent_up(); + generate_serialize_set_element(out, (t_set*)ttype, iter); + indent_down(); + } else if (ttype->is_list()) { + string iter = tmp("iter"); + indent(out) << + "for " << iter << " in " << prefix << ":" << endl; + indent_up(); + generate_serialize_list_element(out, (t_list*)ttype, iter); + indent_down(); + } + + if (ttype->is_map()) { + indent(out) << + "oprot.writeMapEnd()" << endl; + } else if (ttype->is_set()) { + indent(out) << + "oprot.writeSetEnd()" << endl; + } else if (ttype->is_list()) { + indent(out) << + "oprot.writeListEnd()" << endl; + } +} + +/** + * Serializes the members of a map. + * + */ +void t_py_generator::generate_serialize_map_element(ofstream &out, + t_map* tmap, + string kiter, + string viter) { + t_field kfield(tmap->get_key_type(), kiter); + generate_serialize_field(out, &kfield, ""); + + t_field vfield(tmap->get_val_type(), viter); + generate_serialize_field(out, &vfield, ""); +} + +/** + * Serializes the members of a set. + */ +void t_py_generator::generate_serialize_set_element(ofstream &out, + t_set* tset, + string iter) { + t_field efield(tset->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Serializes the members of a list. + */ +void t_py_generator::generate_serialize_list_element(ofstream &out, + t_list* tlist, + string iter) { + t_field efield(tlist->get_elem_type(), iter); + generate_serialize_field(out, &efield, ""); +} + +/** + * Generates the docstring for a given struct. + */ +void t_py_generator::generate_python_docstring(ofstream& out, + t_struct* tstruct) { + generate_python_docstring(out, tstruct, tstruct, "Attributes"); +} + +/** + * Generates the docstring for a given function. + */ +void t_py_generator::generate_python_docstring(ofstream& out, + t_function* tfunction) { + generate_python_docstring(out, tfunction, tfunction->get_arglist(), "Parameters"); +} + +/** + * Generates the docstring for a struct or function. + */ +void t_py_generator::generate_python_docstring(ofstream& out, + t_doc* tdoc, + t_struct* tstruct, + const char* subheader) { + bool has_doc = false; + stringstream ss; + if (tdoc->has_doc()) { + has_doc = true; + ss << tdoc->get_doc(); + } + + const vector& fields = tstruct->get_members(); + if (fields.size() > 0) { + if (has_doc) { + ss << endl; + } + has_doc = true; + ss << subheader << ":\n"; + vector::const_iterator p_iter; + for (p_iter = fields.begin(); p_iter != fields.end(); ++p_iter) { + t_field* p = *p_iter; + ss << " - " << p->get_name(); + if (p->has_doc()) { + ss << ": " << p->get_doc(); + } else { + ss << endl; + } + } + } + + if (has_doc) { + generate_docstring_comment(out, + "\"\"\"\n", + "", ss.str(), + "\"\"\"\n"); + } +} + +/** + * Generates the docstring for a generic object. + */ +void t_py_generator::generate_python_docstring(ofstream& out, + t_doc* tdoc) { + if (tdoc->has_doc()) { + generate_docstring_comment(out, + "\"\"\"\n", + "", tdoc->get_doc(), + "\"\"\"\n"); + } +} + +/** + * Declares an argument, which may include initialization as necessary. + * + * @param tfield The field + */ +string t_py_generator::declare_argument(t_field* tfield) { + std::ostringstream result; + result << tfield->get_name() << "="; + if (tfield->get_value() != NULL) { + result << "thrift_spec[" << + tfield->get_key() << "][4]"; + } else { + result << "None"; + } + return result.str(); +} + +/** + * Renders a field default value, returns None otherwise. + * + * @param tfield The field + */ +string t_py_generator::render_field_default_value(t_field* tfield) { + t_type* type = get_true_type(tfield->get_type()); + if (tfield->get_value() != NULL) { + return render_const_value(type, tfield->get_value()); + } else { + return "None"; + } +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_py_generator::function_signature(t_function* tfunction, + string prefix) { + // TODO(mcslee): Nitpicky, no ',' if argument_list is empty + return + prefix + tfunction->get_name() + + "(self, " + argument_list(tfunction->get_arglist()) + ")"; +} + +/** + * Renders an interface function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_py_generator::function_signature_if(t_function* tfunction, + string prefix) { + // TODO(mcslee): Nitpicky, no ',' if argument_list is empty + string signature = prefix + tfunction->get_name() + "("; + if (!gen_twisted_) { + signature += "self, "; + } + signature += argument_list(tfunction->get_arglist()) + ")"; + return signature; +} + + +/** + * Renders a field list + */ +string t_py_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += (*f_iter)->get_name(); + } + return result; +} + +string t_py_generator::type_name(t_type* ttype) { + t_program* program = ttype->get_program(); + if (ttype->is_service()) { + return get_real_py_module(program) + "." + ttype->get_name(); + } + if (program != NULL && program != program_) { + return get_real_py_module(program) + ".ttypes." + ttype->get_name(); + } + return ttype->get_name(); +} + +/** + * Converts the parse type to a Python tyoe + */ +string t_py_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType.STRING"; + case t_base_type::TYPE_BOOL: + return "TType.BOOL"; + case t_base_type::TYPE_BYTE: + return "TType.BYTE"; + case t_base_type::TYPE_I16: + return "TType.I16"; + case t_base_type::TYPE_I32: + return "TType.I32"; + case t_base_type::TYPE_I64: + return "TType.I64"; + case t_base_type::TYPE_DOUBLE: + return "TType.DOUBLE"; + } + } else if (type->is_enum()) { + return "TType.I32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType.STRUCT"; + } else if (type->is_map()) { + return "TType.MAP"; + } else if (type->is_set()) { + return "TType.SET"; + } else if (type->is_list()) { + return "TType.LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + +/** See the comment inside generate_py_struct_definition for what this is. */ +string t_py_generator::type_to_spec_args(t_type* ttype) { + while (ttype->is_typedef()) { + ttype = ((t_typedef*)ttype)->get_type(); + } + + if (ttype->is_base_type() || ttype->is_enum()) { + return "None"; + } else if (ttype->is_struct() || ttype->is_xception()) { + return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)"; + } else if (ttype->is_map()) { + return "(" + + type_to_enum(((t_map*)ttype)->get_key_type()) + "," + + type_to_spec_args(((t_map*)ttype)->get_key_type()) + "," + + type_to_enum(((t_map*)ttype)->get_val_type()) + "," + + type_to_spec_args(((t_map*)ttype)->get_val_type()) + + ")"; + + } else if (ttype->is_set()) { + return "(" + + type_to_enum(((t_set*)ttype)->get_elem_type()) + "," + + type_to_spec_args(((t_set*)ttype)->get_elem_type()) + + ")"; + + } else if (ttype->is_list()) { + return "(" + + type_to_enum(((t_list*)ttype)->get_elem_type()) + "," + + type_to_spec_args(((t_list*)ttype)->get_elem_type()) + + ")"; + } + + throw "INVALID TYPE IN type_to_spec_args: " + ttype->get_name(); +} + + +THRIFT_REGISTER_GENERATOR(py, "Python", +" new_style: Generate new-style classes.\n" \ +" twisted: Generate Twisted-friendly RPC services.\n" +); diff --git a/compiler/cpp/src/generate/t_rb_generator.cc b/compiler/cpp/src/generate/t_rb_generator.cc new file mode 100644 index 00000000..708cd42a --- /dev/null +++ b/compiler/cpp/src/generate/t_rb_generator.cc @@ -0,0 +1,1097 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include "t_oop_generator.h" +#include "platform.h" +using namespace std; + + +/** + * Ruby code generator. + * + */ +class t_rb_generator : public t_oop_generator { + public: + t_rb_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-rb"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_rb_struct(std::ofstream& out, t_struct* tstruct, bool is_exception); + void generate_rb_struct_required_validator(std::ofstream& out, t_struct* tstruct); + void generate_rb_function_helpers(t_function* tfunction); + void generate_rb_simple_constructor(std::ofstream& out, t_struct* tstruct); + void generate_rb_simple_exception_constructor(std::ofstream& out, t_struct* tstruct); + void generate_field_constants (std::ofstream& out, t_struct* tstruct); + void generate_accessors (std::ofstream& out, t_struct* tstruct); + void generate_field_defns (std::ofstream& out, t_struct* tstruct); + void generate_field_data (std::ofstream& out, t_type* field_type, const std::string& field_name, t_const_value* field_value, bool optional); + + /** + * Service-level generation functions + */ + + void generate_service_helpers (t_service* tservice); + void generate_service_interface (t_service* tservice); + void generate_service_client (t_service* tservice); + void generate_service_server (t_service* tservice); + void generate_process_function (t_service* tservice, t_function* tfunction); + + /** + * Serialization constructs + */ + + void generate_deserialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix="", + bool inclass=false); + + void generate_deserialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_deserialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_deserialize_set_element (std::ofstream &out, + t_set* tset, + std::string prefix=""); + + void generate_deserialize_map_element (std::ofstream &out, + t_map* tmap, + std::string prefix=""); + + void generate_deserialize_list_element (std::ofstream &out, + t_list* tlist, + std::string prefix=""); + + void generate_serialize_field (std::ofstream &out, + t_field* tfield, + std::string prefix=""); + + void generate_serialize_struct (std::ofstream &out, + t_struct* tstruct, + std::string prefix=""); + + void generate_serialize_container (std::ofstream &out, + t_type* ttype, + std::string prefix=""); + + void generate_serialize_map_element (std::ofstream &out, + t_map* tmap, + std::string kiter, + std::string viter); + + void generate_serialize_set_element (std::ofstream &out, + t_set* tmap, + std::string iter); + + void generate_serialize_list_element (std::ofstream &out, + t_list* tlist, + std::string iter); + + void generate_rdoc (std::ofstream& out, + t_doc* tdoc); + + /** + * Helper rendering functions + */ + + std::string rb_autogen_comment(); + std::string render_includes(); + std::string declare_field(t_field* tfield); + std::string type_name(t_type* ttype); + std::string full_type_name(t_type* ttype); + std::string function_signature(t_function* tfunction, std::string prefix=""); + std::string argument_list(t_struct* tstruct); + std::string type_to_enum(t_type* ttype); + + + + std::vector ruby_modules(t_program* p) { + std::string ns = p->get_namespace("rb"); + boost::tokenizer<> tok(ns); + std::vector modules; + + for(boost::tokenizer<>::iterator beg=tok.begin(); beg != tok.end(); ++beg) { + modules.push_back(capitalize(*beg)); + } + + return modules; + } + + void begin_namespace(std::ofstream&, std::vector); + void end_namespace(std::ofstream&, std::vector); + + private: + + /** + * File streams + */ + + std::ofstream f_types_; + std::ofstream f_consts_; + std::ofstream f_service_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_rb_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + string f_types_name = get_out_dir()+underscore(program_name_)+"_types.rb"; + f_types_.open(f_types_name.c_str()); + + string f_consts_name = get_out_dir()+underscore(program_name_)+"_constants.rb"; + f_consts_.open(f_consts_name.c_str()); + + // Print header + f_types_ << + rb_autogen_comment() << endl << + render_includes() << endl; + begin_namespace(f_types_, ruby_modules(program_)); + + f_consts_ << + rb_autogen_comment() << endl << + "require File.dirname(__FILE__) + '/" << underscore(program_name_) << "_types'" << endl << + endl; + begin_namespace(f_consts_, ruby_modules(program_)); + +} + +/** + * Renders all the imports necessary for including another Thrift program + */ +string t_rb_generator::render_includes() { + const vector& includes = program_->get_includes(); + string result = ""; + for (size_t i = 0; i < includes.size(); ++i) { + result += "require '" + underscore(includes[i]->get_name()) + "_types'\n"; + } + if (includes.size() > 0) { + result += "\n"; + } + return result; +} + +/** + * Autogen'd comment + */ +string t_rb_generator::rb_autogen_comment() { + return + std::string("#\n") + + "# Autogenerated by Thrift\n" + + "#\n" + + "# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "#\n"; +} + +/** + * Closes the type files + */ +void t_rb_generator::close_generator() { + // Close types file + end_namespace(f_types_, ruby_modules(program_)); + end_namespace(f_consts_, ruby_modules(program_)); + f_types_.close(); + f_consts_.close(); +} + +/** + * Generates a typedef. This is not done in Ruby, types are all implicit. + * + * @param ttypedef The type definition + */ +void t_rb_generator::generate_typedef(t_typedef* ttypedef) {} + +/** + * Generates code for an enumerated type. Done using a class to scope + * the values. + * + * @param tenum The enumeration + */ +void t_rb_generator::generate_enum(t_enum* tenum) { + indent(f_types_) << + "module " << capitalize(tenum->get_name()) << endl; + indent_up(); + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + // Ruby class constants have to be capitalized... omg i am so on the fence + // about languages strictly enforcing capitalization why can't we just all + // agree and play nice. + string name = capitalize((*c_iter)->get_name()); + + f_types_ << + indent() << name << " = " << value << endl; + } + + // Create a set with valid values for this enum + indent(f_types_) << "VALID_VALUES = Set.new(["; + bool first = true; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + // Populate the set + first ? first = false: f_types_ << ", "; + f_types_ << capitalize((*c_iter)->get_name()); + } + f_types_ << "]).freeze" << endl; + + indent_down(); + indent(f_types_) << + "end" << endl << endl; +} + +/** + * Generate a constant value + */ +void t_rb_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + name[0] = toupper(name[0]); + + indent(f_consts_) << name << " = " << render_const_value(type, value); + f_consts_ << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_rb_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << "%q\"" << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + indent(out) << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << type->get_name() << ".new({" << endl; + indent_up(); + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + out << indent(); + out << render_const_value(g_type_string, v_iter->first); + out << " => "; + out << render_const_value(field_type, v_iter->second); + out << "," << endl; + } + indent_down(); + indent(out) << "})"; + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + out << "{" << endl; + indent_up(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(ktype, v_iter->first); + out << " => "; + out << render_const_value(vtype, v_iter->second); + out << "," << endl; + } + indent_down(); + indent(out) << "}"; + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + if (type->is_set()) { + out << "Set.new(["; + } else { + out << "[" << endl; + } + indent_up(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent(); + out << render_const_value(etype, *v_iter); + out << "," << endl; + } + indent_down(); + if (type->is_set()) { + indent(out) << "])"; + } else { + indent(out) << "]"; + } + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + return out.str(); +} + +/** + * Generates a ruby struct + */ +void t_rb_generator::generate_struct(t_struct* tstruct) { + generate_rb_struct(f_types_, tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_rb_generator::generate_xception(t_struct* txception) { + generate_rb_struct(f_types_, txception, true); +} + +/** + * Generates a ruby struct + */ +void t_rb_generator::generate_rb_struct(std::ofstream& out, t_struct* tstruct, bool is_exception = false) { + generate_rdoc(out, tstruct); + indent(out) << "class " << type_name(tstruct); + if (is_exception) { + out << " < ::Thrift::Exception"; + } + out << endl; + + indent_up(); + indent(out) << "include ::Thrift::Struct" << endl; + + if (is_exception) { + generate_rb_simple_exception_constructor(out, tstruct); + } + + generate_field_constants(out, tstruct); + generate_accessors(out, tstruct); + generate_field_defns(out, tstruct); + generate_rb_struct_required_validator(out, tstruct); + + indent_down(); + indent(out) << "end" << endl << endl; +} + +void t_rb_generator::generate_rb_simple_exception_constructor(std::ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + + if (members.size() == 1) { + vector::const_iterator m_iter = members.begin(); + + if ((*m_iter)->get_type()->is_string()) { + string name = (*m_iter)->get_name(); + + indent(out) << "def initialize(message=nil)" << endl; + indent_up(); + indent(out) << "super()" << endl; + indent(out) << "self." << name << " = message" << endl; + indent_down(); + indent(out) << "end" << endl << endl; + + if (name != "message") { + indent(out) << "def message; " << name << " end" << endl << endl; + } + } + } +} + +void t_rb_generator::generate_field_constants(std::ofstream& out, t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + std::string field_name = (*f_iter)->get_name(); + std::string cap_field_name = upcase_string(field_name); + + indent(out) << cap_field_name << " = " << (*f_iter)->get_key() << endl; + } + out << endl; +} + +void t_rb_generator::generate_accessors(std::ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + if (members.size() > 0) { + indent(out) << "::Thrift::Struct.field_accessor self"; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + out << ", :" << (*m_iter)->get_name(); + } + out << endl; + } +} + +void t_rb_generator::generate_field_defns(std::ofstream& out, t_struct* tstruct) { + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + indent(out) << "FIELDS = {" << endl; + indent_up(); + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (f_iter != fields.begin()) { + out << "," << endl; + } + + // generate the field docstrings within the FIELDS constant. no real better place... + generate_rdoc(out, *f_iter); + + indent(out) << + upcase_string((*f_iter)->get_name()) << " => "; + + generate_field_data(out, (*f_iter)->get_type(), (*f_iter)->get_name(), (*f_iter)->get_value(), + (*f_iter)->get_req() == t_field::T_OPTIONAL); + } + indent_down(); + out << endl; + indent(out) << "}" << endl << endl; + + indent(out) << "def struct_fields; FIELDS; end" << endl << endl; + +} + +void t_rb_generator::generate_field_data(std::ofstream& out, t_type* field_type, + const std::string& field_name = "", t_const_value* field_value = NULL, bool optional = false) { + field_type = get_true_type(field_type); + + // Begin this field's defn + out << "{:type => " << type_to_enum(field_type); + + if (!field_name.empty()) { + out << ", :name => '" << field_name << "'"; + } + + if (field_value != NULL) { + out << ", :default => " << render_const_value(field_type, field_value); + } + + if (!field_type->is_base_type()) { + if (field_type->is_struct() || field_type->is_xception()) { + out << ", :class => " << full_type_name((t_struct*)field_type); + } else if (field_type->is_list()) { + out << ", :element => "; + generate_field_data(out, ((t_list*)field_type)->get_elem_type()); + } else if (field_type->is_map()) { + out << ", :key => "; + generate_field_data(out, ((t_map*)field_type)->get_key_type()); + out << ", :value => "; + generate_field_data(out, ((t_map*)field_type)->get_val_type()); + } else if (field_type->is_set()) { + out << ", :element => "; + generate_field_data(out, ((t_set*)field_type)->get_elem_type()); + } + } + + if(optional) { + out << ", :optional => true"; + } + + if (field_type->is_enum()) { + out << ", :enum_class => " << full_type_name(field_type); + } + + // End of this field's defn + out << "}"; +} + +void t_rb_generator::begin_namespace(std::ofstream& out, vector modules) { + for (vector::iterator m_iter = modules.begin(); m_iter != modules.end(); ++m_iter) { + indent(out) << "module " << *m_iter << endl; + indent_up(); + } +} + +void t_rb_generator::end_namespace(std::ofstream& out, vector modules) { + for (vector::reverse_iterator m_iter = modules.rbegin(); m_iter != modules.rend(); ++m_iter) { + indent_down(); + indent(out) << "end" << endl; + } +} + + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_rb_generator::generate_service(t_service* tservice) { + string f_service_name = get_out_dir()+underscore(service_name_)+".rb"; + f_service_.open(f_service_name.c_str()); + + f_service_ << + rb_autogen_comment() << endl << + "require 'thrift'" << endl; + + if (tservice->get_extends() != NULL) { + f_service_ << + "require '" << underscore(tservice->get_extends()->get_name()) << "'" << endl; + } + + f_service_ << + "require File.dirname(__FILE__) + '/" << underscore(program_name_) << "_types'" << endl << + endl; + + begin_namespace(f_service_, ruby_modules(tservice->get_program())); + + indent(f_service_) << "module " << capitalize(tservice->get_name()) << endl; + indent_up(); + + // Generate the three main parts of the service (well, two for now in PHP) + generate_service_client(tservice); + generate_service_server(tservice); + generate_service_helpers(tservice); + + indent_down(); + indent(f_service_) << "end" << endl << + endl; + + end_namespace(f_service_, ruby_modules(tservice->get_program())); + + // Close service file + f_service_.close(); +} + +/** + * Generates helper functions for a service. + * + * @param tservice The service to generate a header definition for + */ +void t_rb_generator::generate_service_helpers(t_service* tservice) { + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + indent(f_service_) << + "# HELPER FUNCTIONS AND STRUCTURES" << endl << endl; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* ts = (*f_iter)->get_arglist(); + generate_rb_struct(f_service_, ts); + generate_rb_function_helpers(*f_iter); + } +} + +/** + * Generates a struct and helpers for a function. + * + * @param tfunction The function + */ +void t_rb_generator::generate_rb_function_helpers(t_function* tfunction) { + t_struct result(program_, tfunction->get_name() + "_result"); + t_field success(tfunction->get_returntype(), "success", 0); + if (!tfunction->get_returntype()->is_void()) { + result.append(&success); + } + + t_struct* xs = tfunction->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + result.append(*f_iter); + } + generate_rb_struct(f_service_, &result); +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_rb_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = ""; + if (tservice->get_extends() != NULL) { + extends = full_type_name(tservice->get_extends()); + extends_client = " < " + extends + "::Client "; + } + + indent(f_service_) << + "class Client" << extends_client << endl; + indent_up(); + + indent(f_service_) << + "include ::Thrift::Client" << endl << endl; + + // Generate client method implementations + vector functions = tservice->get_functions(); + vector::const_iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + t_struct* arg_struct = (*f_iter)->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + string funname = (*f_iter)->get_name(); + + // Open function + indent(f_service_) << + "def " << function_signature(*f_iter) << endl; + indent_up(); + indent(f_service_) << + "send_" << funname << "("; + + bool first = true; + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << (*fld_iter)->get_name(); + } + f_service_ << ")" << endl; + + if (!(*f_iter)->is_oneway()) { + f_service_ << indent(); + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << "return "; + } + f_service_ << + "recv_" << funname << "()" << endl; + } + indent_down(); + indent(f_service_) << "end" << endl; + f_service_ << endl; + + indent(f_service_) << + "def send_" << function_signature(*f_iter) << endl; + indent_up(); + + std::string argsname = capitalize((*f_iter)->get_name() + "_args"); + + indent(f_service_) << "send_message('" << funname << "', " << argsname; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + f_service_ << ", :" << (*fld_iter)->get_name() << " => " << (*fld_iter)->get_name(); + } + + f_service_ << ")" << endl; + + indent_down(); + indent(f_service_) << "end" << endl; + + if (!(*f_iter)->is_oneway()) { + std::string resultname = capitalize((*f_iter)->get_name() + "_result"); + t_struct noargs(program_); + + t_function recv_function((*f_iter)->get_returntype(), + string("recv_") + (*f_iter)->get_name(), + &noargs); + // Open function + f_service_ << + endl << + indent() << "def " << function_signature(&recv_function) << endl; + indent_up(); + + // TODO(mcslee): Validate message reply here, seq ids etc. + + f_service_ << + indent() << "result = receive_message(" << resultname << ")" << endl; + + // Careful, only return _result if not a void function + if (!(*f_iter)->get_returntype()->is_void()) { + f_service_ << + indent() << "return result.success unless result.success.nil?" << endl; + } + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + indent(f_service_) << + "raise result." << (*x_iter)->get_name() << + " unless result." << (*x_iter)->get_name() << ".nil?" << endl; + } + + // Careful, only return _result if not a void function + if ((*f_iter)->get_returntype()->is_void()) { + indent(f_service_) << + "return" << endl; + } else { + f_service_ << + indent() << "raise ::Thrift::ApplicationException.new(::Thrift::ApplicationException::MISSING_RESULT, '" << (*f_iter)->get_name() << " failed: unknown result')" << endl; + } + + // Close function + indent_down(); + indent(f_service_) << "end" << endl << endl; + } + } + + indent_down(); + indent(f_service_) << "end" << endl << endl; +} + +/** + * Generates a service server definition. + * + * @param tservice The service to generate a server for. + */ +void t_rb_generator::generate_service_server(t_service* tservice) { + // Generate the dispatch methods + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + string extends = ""; + string extends_processor = ""; + if (tservice->get_extends() != NULL) { + extends = full_type_name(tservice->get_extends()); + extends_processor = " < " + extends + "::Processor "; + } + + // Generate the header portion + indent(f_service_) << + "class Processor" << extends_processor << endl; + indent_up(); + + f_service_ << + indent() << "include ::Thrift::Processor" << endl << + endl; + + // Generate the process subfunctions + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + generate_process_function(tservice, *f_iter); + } + + indent_down(); + indent(f_service_) << "end" << endl << endl; +} + +/** + * Generates a process function definition. + * + * @param tfunction The function to write a dispatcher for + */ +void t_rb_generator::generate_process_function(t_service* tservice, + t_function* tfunction) { + // Open function + indent(f_service_) << + "def process_" << tfunction->get_name() << + "(seqid, iprot, oprot)" << endl; + indent_up(); + + string argsname = capitalize(tfunction->get_name()) + "_args"; + string resultname = capitalize(tfunction->get_name()) + "_result"; + + f_service_ << + indent() << "args = read_args(iprot, " << argsname << ")" << endl; + + t_struct* xs = tfunction->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + + // Declare result for non oneway function + if (!tfunction->is_oneway()) { + f_service_ << + indent() << "result = " << resultname << ".new()" << endl; + } + + // Try block for a function with exceptions + if (xceptions.size() > 0) { + f_service_ << + indent() << "begin" << endl; + indent_up(); + } + + // Generate the function call + t_struct* arg_struct = tfunction->get_arglist(); + const std::vector& fields = arg_struct->get_members(); + vector::const_iterator f_iter; + + f_service_ << indent(); + if (!tfunction->is_oneway() && !tfunction->get_returntype()->is_void()) { + f_service_ << "result.success = "; + } + f_service_ << + "@handler." << tfunction->get_name() << "("; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + f_service_ << ", "; + } + f_service_ << "args." << (*f_iter)->get_name(); + } + f_service_ << ")" << endl; + + if (!tfunction->is_oneway() && xceptions.size() > 0) { + indent_down(); + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + f_service_ << + indent() << "rescue " << full_type_name((*x_iter)->get_type()) << " => " << (*x_iter)->get_name() << endl; + if (!tfunction->is_oneway()) { + indent_up(); + f_service_ << + indent() << "result." << (*x_iter)->get_name() << " = " << (*x_iter)->get_name() << endl; + indent_down(); + } + } + indent(f_service_) << "end" << endl; + } + + // Shortcut out here for oneway functions + if (tfunction->is_oneway()) { + f_service_ << + indent() << "return" << endl; + indent_down(); + indent(f_service_) << "end" << endl << endl; + return; + } + + f_service_ << + indent() << "write_result(result, oprot, '" << tfunction->get_name() << "', seqid)" << endl; + + // Close function + indent_down(); + indent(f_service_) << "end" << endl << endl; +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_rb_generator::function_signature(t_function* tfunction, + string prefix) { + // TODO(mcslee): Nitpicky, no ',' if argument_list is empty + return + prefix + tfunction->get_name() + + "(" + argument_list(tfunction->get_arglist()) + ")"; +} + +/** + * Renders a field list + */ +string t_rb_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += ", "; + } + result += (*f_iter)->get_name(); + } + return result; +} + +string t_rb_generator::type_name(t_type* ttype) { + string prefix = ""; + + string name = ttype->get_name(); + if (ttype->is_struct() || ttype->is_xception() || ttype->is_enum()) { + name = capitalize(ttype->get_name()); + } + + return prefix + name; +} + +string t_rb_generator::full_type_name(t_type* ttype) { + string prefix = ""; + vector modules = ruby_modules(ttype->get_program()); + for (vector::iterator m_iter = modules.begin(); + m_iter != modules.end(); ++m_iter) { + prefix += *m_iter + "::"; + } + return prefix + type_name(ttype); +} + +/** + * Converts the parse type to a Ruby tyoe + */ +string t_rb_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "::Thrift::Types::STRING"; + case t_base_type::TYPE_BOOL: + return "::Thrift::Types::BOOL"; + case t_base_type::TYPE_BYTE: + return "::Thrift::Types::BYTE"; + case t_base_type::TYPE_I16: + return "::Thrift::Types::I16"; + case t_base_type::TYPE_I32: + return "::Thrift::Types::I32"; + case t_base_type::TYPE_I64: + return "::Thrift::Types::I64"; + case t_base_type::TYPE_DOUBLE: + return "::Thrift::Types::DOUBLE"; + } + } else if (type->is_enum()) { + return "::Thrift::Types::I32"; + } else if (type->is_struct() || type->is_xception()) { + return "::Thrift::Types::STRUCT"; + } else if (type->is_map()) { + return "::Thrift::Types::MAP"; + } else if (type->is_set()) { + return "::Thrift::Types::SET"; + } else if (type->is_list()) { + return "::Thrift::Types::LIST"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +void t_rb_generator::generate_rdoc(std::ofstream& out, t_doc* tdoc) { + if (tdoc->has_doc()) { + generate_docstring_comment(out, + "", "# ", tdoc->get_doc(), ""); + } +} + +void t_rb_generator::generate_rb_struct_required_validator(std::ofstream& out, + t_struct* tstruct) { + indent(out) << "def validate" << endl; + indent_up(); + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = (*f_iter); + if (field->get_req() == t_field::T_REQUIRED) { + indent(out) << "raise ::Thrift::ProtocolException.new(::Thrift::ProtocolException::UNKNOWN, 'Required field " << field->get_name() << " is unset!')"; + if (field->get_type()->is_bool()) { + out << " if @" << field->get_name() << ".nil?"; + } else { + out << " unless @" << field->get_name(); + } + out << endl; + } + } + + // if field is an enum, check that its value is valid + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + t_field* field = (*f_iter); + + if (field->get_type()->is_enum()){ + indent(out) << "unless @" << field->get_name() << ".nil? || " << field->get_type()->get_name() << "::VALID_VALUES.include?(@" << field->get_name() << ")" << endl; + indent_up(); + indent(out) << "raise ::Thrift::ProtocolException.new(::Thrift::ProtocolException::UNKNOWN, 'Invalid value of field " << field->get_name() << "!')" << endl; + indent_down(); + indent(out) << "end" << endl; + } + } + + indent_down(); + indent(out) << "end" << endl << endl; + +} + +THRIFT_REGISTER_GENERATOR(rb, "Ruby", ""); diff --git a/compiler/cpp/src/generate/t_st_generator.cc b/compiler/cpp/src/generate/t_st_generator.cc new file mode 100644 index 00000000..3600a3b8 --- /dev/null +++ b/compiler/cpp/src/generate/t_st_generator.cc @@ -0,0 +1,1071 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "platform.h" +#include "t_oop_generator.h" +using namespace std; + + +/** + * Smalltalk code generator. + * + */ +class t_st_generator : public t_oop_generator { + public: + t_st_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_oop_generator(program) + { + out_dir_base_ = "gen-st"; + } + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef (t_typedef* ttypedef); + void generate_enum (t_enum* tenum); + void generate_const (t_const* tconst); + void generate_struct (t_struct* tstruct); + void generate_xception (t_struct* txception); + void generate_service (t_service* tservice); + void generate_class_side_definition (); + void generate_force_consts (); + + + std::string render_const_value(t_type* type, t_const_value* value); + + /** + * Struct generation code + */ + + void generate_st_struct (std::ofstream& out, t_struct* tstruct, bool is_exception); + void generate_accessors (std::ofstream& out, t_struct* tstruct); + + /** + * Service-level generation functions + */ + + void generate_service_client (t_service* tservice); + + void generate_send_method (t_function* tfunction); + void generate_recv_method (t_function* tfunction); + + std::string map_reader (t_map *tmap); + std::string list_reader (t_list *tlist); + std::string set_reader (t_set *tset); + std::string struct_reader (t_struct *tstruct, std::string clsName); + + std::string map_writer (t_map *tmap, std::string name); + std::string list_writer (t_list *tlist, std::string name); + std::string set_writer (t_set *tset, std::string name); + std::string struct_writer (t_struct *tstruct, std::string fname); + + std::string write_val (t_type *t, std::string fname); + std::string read_val (t_type *t); + + /** + * Helper rendering functions + */ + + std::string st_autogen_comment(); + + void st_class_def(std::ofstream &out, std::string name); + void st_method(std::ofstream &out, std::string cls, std::string name); + void st_method(std::ofstream &out, std::string cls, std::string name, std::string category); + void st_close_method(std::ofstream &out); + void st_class_method(std::ofstream &out, std::string cls, std::string name); + void st_class_method(std::ofstream &out, std::string cls, std::string name, std::string category); + void st_setter(std::ofstream &out, std::string cls, std::string name, std::string type); + void st_getter(std::ofstream &out, std::string cls, std::string name); + void st_accessors(std::ofstream &out, std::string cls, std::string name, std::string type); + + std::string class_name(); + std::string client_class_name(); + std::string prefix(std::string name); + std::string declare_field(t_field* tfield); + std::string sanitize(std::string s); + std::string type_name(t_type* ttype); + + std::string function_signature(t_function* tfunction); + std::string argument_list(t_struct* tstruct); + std::string function_types_comment(t_function* fn); + + std::string type_to_enum(t_type* ttype); + std::string a_type(t_type* type); + bool is_vowel(char c); + std::string temp_name(); + std::string generated_category(); + + private: + + /** + * File streams + */ + int temporary_var; + std::ofstream f_; + +}; + + +/** + * Prepares for file generation by opening up the necessary file output + * streams. + * + * @param tprogram The program to generate + */ +void t_st_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + temporary_var = 0; + + // Make output file + string f_name = get_out_dir()+"/"+program_name_+".st"; + f_.open(f_name.c_str()); + + // Print header + f_ << st_autogen_comment() << endl; + + st_class_def(f_, program_name_); + generate_class_side_definition(); + + //Generate enums + vector enums = program_->get_enums(); + vector::iterator en_iter; + for (en_iter = enums.begin(); en_iter != enums.end(); ++en_iter) { + generate_enum(*en_iter); + } +} + +string t_st_generator::class_name() { + return capitalize(program_name_); +} + +string t_st_generator::prefix(string class_name) { + string prefix = program_->get_namespace("smalltalk.prefix"); + string name = capitalize(class_name); + name = prefix.empty() ? name : (prefix + name); + return name; +} + +string t_st_generator::client_class_name() { + return capitalize(service_name_) + "Client"; +} + +/** + * Autogen'd comment + */ +string t_st_generator::st_autogen_comment() { + return + std::string("'") + + "Autogenerated by Thrift\n" + + "\n" + + "DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" + + "'!\n"; +} + +void t_st_generator::generate_force_consts() { + f_ << prefix(class_name()) << " enums keysAndValuesDo: [:k :v | " << + prefix(class_name()) << " enums at: k put: v value].!" << endl; + + f_ << prefix(class_name()) << " constants keysAndValuesDo: [:k :v | " << + prefix(class_name()) << " constants at: k put: v value].!" << endl; + +} + +void t_st_generator::close_generator() { + generate_force_consts(); + f_.close(); +} + +string t_st_generator::generated_category() { + string cat = program_->get_namespace("smalltalk.category"); + // For compatibility with the Thrift grammar, the category must + // be punctuated by dots. Replaces them with dashes here. + for (string::iterator iter = cat.begin(); iter != cat.end(); ++iter) { + if (*iter == '.') { + *iter = '-'; + } + } + return cat.size() ? cat : "Generated-" + class_name(); +} + +/** + * Generates a typedef. This is not done in Smalltalk, types are all implicit. + * + * @param ttypedef The type definition + */ +void t_st_generator::generate_typedef(t_typedef* ttypedef) {} + +void t_st_generator::st_class_def(std::ofstream &out, string name) { + out << "Object subclass: #" << prefix(name) << endl; + indent_up(); + out << indent() << "instanceVariableNames: ''" << endl << + indent() << "classVariableNames: ''" << endl << + indent() << "poolDictionaries: ''" << endl << + indent() << "category: '" << generated_category() << "'!" << endl << endl; +} + +void t_st_generator::st_method(std::ofstream &out, string cls, string name) { + st_method(out, cls, name, "as yet uncategorized"); +} + +void t_st_generator::st_class_method(std::ofstream &out, string cls, string name) { + st_method(out, cls + " class", name); +} + +void t_st_generator::st_class_method(std::ofstream &out, string cls, string name, string category) { + st_method(out, cls, name, category); +} + +void t_st_generator::st_method(std::ofstream &out, string cls, string name, string category) { + char timestr[50]; + time_t rawtime; + struct tm *tinfo; + + time(&rawtime); + tinfo = localtime(&rawtime); + strftime(timestr, 50, "%m/%d/%Y %H:%M", tinfo); + + out << "!" << prefix(cls) << + " methodsFor: '"+category+"' stamp: 'thrift " << timestr << "'!\n" << + name << endl; + + indent_up(); + out << indent(); +} + +void t_st_generator::st_close_method(std::ofstream &out) { + out << "! !" << endl << endl; + indent_down(); +} + +void t_st_generator::st_setter(std::ofstream &out, string cls, string name, string type = "anObject") { + st_method(out, cls, name + ": " + type); + out << name << " := " + type; + st_close_method(out); +} + +void t_st_generator::st_getter(std::ofstream &out, string cls, string name) { + st_method(out, cls, name + ""); + out << "^ " << name; + st_close_method(out); +} + +void t_st_generator::st_accessors(std::ofstream &out, string cls, string name, string type = "anObject") { + st_setter(out, cls, name, type); + st_getter(out, cls, name); +} + +void t_st_generator::generate_class_side_definition() { + f_ << prefix(class_name()) << " class" << endl << + "\tinstanceVariableNames: 'constants enums'!" << endl << endl; + + st_accessors(f_, class_name() + " class", "enums"); + st_accessors(f_, class_name() + " class", "constants"); + + f_ << prefix(class_name()) << " enums: Dictionary new!" << endl; + f_ << prefix(class_name()) << " constants: Dictionary new!" << endl; + + f_ << endl; +} + +/** + * Generates code for an enumerated type. Done using a class to scope + * the values. + * + * @param tenum The enumeration + */ +void t_st_generator::generate_enum(t_enum* tenum) { + string cls_name = program_name_ + capitalize(tenum->get_name()); + + f_ << prefix(class_name()) << " enums at: '" << tenum->get_name() << "' put: [" << + "(Dictionary new " << endl; + + vector constants = tenum->get_constants(); + vector::iterator c_iter; + int value = -1; + for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { + if ((*c_iter)->has_value()) { + value = (*c_iter)->get_value(); + } else { + ++value; + } + + f_ << "\tat: '" << (*c_iter)->get_name() << "' put: " << value << ";" << endl; + } + + f_ << "\tyourself)]!" << endl << endl; +} + +/** + * Generate a constant value + */ +void t_st_generator::generate_const(t_const* tconst) { + t_type* type = tconst->get_type(); + string name = tconst->get_name(); + t_const_value* value = tconst->get_value(); + + f_ << prefix(class_name()) << " constants at: '" << name << "' put: [" << + render_const_value(type, value) << "]!" << endl << endl; +} + +/** + * Prints the value of a constant with the given type. Note that type checking + * is NOT performed in this function as it is always run beforehand using the + * validate_types method in main.cc + */ +string t_st_generator::render_const_value(t_type* type, t_const_value* value) { + type = get_true_type(type); + std::ostringstream out; + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + out << '"' << get_escaped_string(value) << '"'; + break; + case t_base_type::TYPE_BOOL: + out << (value->get_integer() > 0 ? "true" : "false"); + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + out << value->get_integer(); + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() == t_const_value::CV_INTEGER) { + out << value->get_integer(); + } else { + out << value->get_double(); + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase); + } + } else if (type->is_enum()) { + indent(out) << value->get_integer(); + } else if (type->is_struct() || type->is_xception()) { + out << "(" << capitalize(type->get_name()) << " new " << endl; + indent_up(); + + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + const map& val = value->get_map(); + map::const_iterator v_iter; + + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + + out << indent() << v_iter->first->get_string() << ": " << + render_const_value(field_type, v_iter->second) << ";" << endl; + } + out << indent() << "yourself)"; + + indent_down(); + } else if (type->is_map()) { + t_type* ktype = ((t_map*)type)->get_key_type(); + t_type* vtype = ((t_map*)type)->get_val_type(); + out << "(Dictionary new" << endl; + indent_up(); + indent_up(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent() << indent(); + out << "at: " << render_const_value(ktype, v_iter->first); + out << " put: "; + out << render_const_value(vtype, v_iter->second); + out << ";" << endl; + } + out << indent() << indent() << "yourself)"; + indent_down(); + indent_down(); + } else if (type->is_list() || type->is_set()) { + t_type* etype; + if (type->is_list()) { + etype = ((t_list*)type)->get_elem_type(); + } else { + etype = ((t_set*)type)->get_elem_type(); + } + if (type->is_set()) { + out << "(Set new" << endl; + } else { + out << "(OrderedCollection new" << endl; + } + indent_up(); + indent_up(); + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + out << indent() << indent(); + out << "add: " << render_const_value(etype, *v_iter); + out << ";" << endl; + } + out << indent() << indent() << "yourself)"; + indent_down(); + indent_down(); + } else { + throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); + } + return out.str(); +} + +/** + * Generates a Smalltalk struct + */ +void t_st_generator::generate_struct(t_struct* tstruct) { + generate_st_struct(f_, tstruct, false); +} + +/** + * Generates a struct definition for a thrift exception. Basically the same + * as a struct but extends the Exception class. + * + * @param txception The struct definition + */ +void t_st_generator::generate_xception(t_struct* txception) { + generate_st_struct(f_, txception, true); +} + +/** + * Generates a smalltalk class to represent a struct + */ +void t_st_generator::generate_st_struct(std::ofstream& out, t_struct* tstruct, bool is_exception = false) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + + if (is_exception) + out << "Error"; + else + out << "Object"; + + out << " subclass: #" << prefix(type_name(tstruct)) << endl << + "\tinstanceVariableNames: '"; + + if (members.size() > 0) { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (m_iter != members.begin()) out << " "; + out << sanitize((*m_iter)->get_name()); + } + } + + out << "'\n" << + "\tclassVariableNames: ''\n" << + "\tpoolDictionaries: ''\n" << + "\tcategory: '" << generated_category() << "'!\n\n"; + + generate_accessors(out, tstruct); +} + +bool t_st_generator::is_vowel(char c) { + switch(tolower(c)) { + case 'a': case 'e': case 'i': case 'o': case 'u': + return true; + } + return false; +} + +string t_st_generator::a_type(t_type* type) { + string prefix; + + if (is_vowel(type_name(type)[0])) + prefix = "an"; + else + prefix = "a"; + + return prefix + capitalize(type_name(type)); +} + +void t_st_generator::generate_accessors(std::ofstream& out, t_struct* tstruct) { + const vector& members = tstruct->get_members(); + vector::const_iterator m_iter; + string type; + string prefix; + + if (members.size() > 0) { + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + st_accessors(out, + capitalize(type_name(tstruct)), + sanitize((*m_iter)->get_name()), + a_type((*m_iter)->get_type())); + } + out << endl; + } +} + +/** + * Generates a thrift service. + * + * @param tservice The service definition + */ +void t_st_generator::generate_service(t_service* tservice) { + generate_service_client(tservice); + // generate_service_server(tservice); +} + +string t_st_generator::temp_name() { + std::ostringstream out; + out << "temp" << temporary_var++; + return out.str(); +} + +string t_st_generator::map_writer(t_map *tmap, string fname) { + std::ostringstream out; + string key = temp_name(); + string val = temp_name(); + + out << "[oprot writeMapBegin: (TMap new keyType: " << type_to_enum(tmap->get_key_type()) << + "; valueType: " << type_to_enum(tmap->get_val_type()) << "; size: " << fname << " size)." << endl; + indent_up(); + + out << indent() << fname << " keysAndValuesDo: [:" << key << " :" << val << " |" << endl; + indent_up(); + + out << indent() << write_val(tmap->get_key_type(), key) << "." << endl << + indent() << write_val(tmap->get_val_type(), val); + indent_down(); + + out << "]." << endl << + indent() << "oprot writeMapEnd] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::map_reader(t_map *tmap) { + std::ostringstream out; + string desc = temp_name(); + string val = temp_name(); + + out << "[|" << desc << " " << val << "| " << endl; + indent_up(); + + out << indent() << desc << " := iprot readMapBegin." << endl << + indent() << val << " := Dictionary new." << endl << + indent() << desc << " size timesRepeat: [" << endl; + + indent_up(); + out << indent() << val << " at: " << read_val(tmap->get_key_type()) << + " put: " << read_val(tmap->get_val_type()); + indent_down(); + + out << "]." << endl << + indent() << "iprot readMapEnd." << endl << + indent() << val << "] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::list_writer(t_list *tlist, string fname) { + std::ostringstream out; + string val = temp_name(); + + out << "[oprot writeListBegin: (TList new elemType: " << + type_to_enum(tlist->get_elem_type()) << "; size: " << fname << " size)." << endl; + indent_up(); + + out << indent() << fname << " do: [:" << val << "|" << endl; + indent_up(); + + out << indent() << write_val(tlist->get_elem_type(), val) << endl; + indent_down(); + + out << "]." << endl << + indent() << "oprot writeListEnd] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::list_reader(t_list *tlist) { + std::ostringstream out; + string desc = temp_name(); + string val = temp_name(); + + out << "[|" << desc << " " << val << "| " << desc << " := iprot readListBegin." << endl; + indent_up(); + + out << indent() << val << " := OrderedCollection new." << endl << + indent() << desc << " size timesRepeat: [" << endl; + + indent_up(); + out << indent() << val << " add: " << read_val(tlist->get_elem_type()); + indent_down(); + + out << "]." << endl << + indent() << "iprot readListEnd." << endl << + indent() << val << "] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::set_writer(t_set *tset, string fname) { + std::ostringstream out; + string val = temp_name(); + + out << "[oprot writeSetBegin: (TSet new elemType: " << type_to_enum(tset->get_elem_type()) << + "; size: " << fname << " size)." << endl; + indent_up(); + + out << indent() << fname << " do: [:" << val << "|" << endl; + indent_up(); + + out << indent() << write_val(tset->get_elem_type(), val) << endl; + indent_down(); + + out << "]." << endl << + indent() << "oprot writeSetEnd] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::set_reader(t_set *tset) { + std::ostringstream out; + string desc = temp_name(); + string val = temp_name(); + + out << "[|" << desc << " " << val << "| " << desc << " := iprot readSetBegin." << endl; + indent_up(); + + out << indent() << val << " := Set new." << endl << + indent() << desc << " size timesRepeat: [" << endl; + + indent_up(); + out << indent() << val << " add: " << read_val(tset->get_elem_type()); + indent_down(); + + out << "]." << endl << + indent() << "iprot readSetEnd." << endl << + indent() << val << "] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::struct_writer(t_struct *tstruct, string sname) { + std::ostringstream out; + const vector& fields = tstruct->get_sorted_members(); + vector::const_iterator fld_iter; + + out << "[oprot writeStructBegin: " << + "(TStruct new name: '" + tstruct->get_name() +"')." << endl; + indent_up(); + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + bool optional = (*fld_iter)->get_req() == t_field::T_OPTIONAL; + string fname = (*fld_iter)->get_name(); + string accessor = sname + " " + sanitize(fname); + + if (optional) { + out << indent() << accessor << " ifNotNil: [" << endl; + indent_up(); + } + + out << indent() << "oprot writeFieldBegin: (TField new name: '" << fname << + "'; type: " << type_to_enum((*fld_iter)->get_type()) << + "; id: " << (*fld_iter)->get_key() << ")." << endl; + + out << indent() << write_val((*fld_iter)->get_type(), accessor) << "." << endl << + indent() << "oprot writeFieldEnd"; + + if (optional) { + out << "]"; + indent_down(); + } + + out << "." << endl; + } + + out << indent() << "oprot writeFieldStop; writeStructEnd] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::struct_reader(t_struct *tstruct, string clsName = "") { + std::ostringstream out; + const vector& fields = tstruct->get_members(); + vector::const_iterator fld_iter; + string val = temp_name(); + string desc = temp_name(); + string found = temp_name(); + + if (clsName.size() == 0) { + clsName = tstruct->get_name(); + } + + out << "[|" << desc << " " << val << "|" << endl; + indent_up(); + + //This is nasty, but without it we'll break things by prefixing TResult. + string name = ((capitalize(clsName) == "TResult") ? capitalize(clsName) : prefix(clsName)); + out << indent() << val << " := " << name << " new." << endl; + + out << indent() << "iprot readStructBegin." << endl << + indent() << "[" << desc << " := iprot readFieldBegin." << endl << + indent() << desc << " type = TType stop] whileFalse: [|" << found << "|" << endl; + indent_up(); + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + out << indent() << desc << " id = " << (*fld_iter)->get_key() << + " ifTrue: [" << endl; + indent_up(); + + out << indent() << found << " := true." << endl << + indent() << val << " " << sanitize((*fld_iter)->get_name()) << ": " << + read_val((*fld_iter)->get_type()); + indent_down(); + + out << "]." << endl; + } + + out << indent() << found << " ifNil: [iprot skip: " << desc << " type]]." << endl; + indent_down(); + + out << indent() << "oprot readStructEnd." << endl << + indent() << val << "] value"; + indent_down(); + + return out.str(); +} + +string t_st_generator::write_val(t_type *t, string fname) { + t = get_true_type(t); + + if (t->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*) t)->get_base(); + switch(tbase) { + case t_base_type::TYPE_DOUBLE: + return "iprot writeDouble: " + fname + " asFloat"; + break; + case t_base_type::TYPE_BYTE: + case t_base_type::TYPE_I16: + case t_base_type::TYPE_I32: + case t_base_type::TYPE_I64: + return "iprot write" + capitalize(type_name(t)) + ": " + fname + " asInteger"; + default: + return "iprot write" + capitalize(type_name(t)) + ": " + fname; + } + } else if (t->is_map()) { + return map_writer((t_map*) t, fname); + } else if (t->is_struct() || t->is_xception()) { + return struct_writer((t_struct*) t, fname); + } else if (t->is_list()) { + return list_writer((t_list*) t, fname); + } else if (t->is_set()) { + return set_writer((t_set*) t, fname); + } else if (t->is_enum()) { + return "iprot writeI32: " + fname; + } else { + throw "Sorry, I don't know how to write this: " + type_name(t); + } +} + +string t_st_generator::read_val(t_type *t) { + t = get_true_type(t); + + if (t->is_base_type()) { + return "iprot read" + capitalize(type_name(t)); + } else if (t->is_map()) { + return map_reader((t_map*) t); + } else if (t->is_struct() || t->is_xception()) { + return struct_reader((t_struct*) t); + } else if (t->is_list()) { + return list_reader((t_list*) t); + } else if (t->is_set()) { + return set_reader((t_set*) t); + } else if (t->is_enum()) { + return "iprot readI32"; + } else { + throw "Sorry, I don't know how to read this: " + type_name(t); + } +} + +void t_st_generator::generate_send_method(t_function* function) { + string funname = function->get_name(); + string signature = function_signature(function); + t_struct* arg_struct = function->get_arglist(); + const vector& fields = arg_struct->get_members(); + vector::const_iterator fld_iter; + + st_method(f_, client_class_name(), "send" + capitalize(signature)); + f_ << "oprot writeMessageBegin:" << endl; + indent_up(); + + f_ << indent() << "(TCallMessage new" << endl; + indent_up(); + + f_ << indent() << "name: '" << funname << "'; " << endl << + indent() << "seqid: self nextSeqid)." << endl; + indent_down(); + indent_down(); + + f_ << indent() << "oprot writeStructBegin: " << + "(TStruct new name: '" + capitalize(function->get_name()) + "_args')." << endl; + + for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) { + string fname = (*fld_iter)->get_name(); + + f_ << indent() << "oprot writeFieldBegin: (TField new name: '" << fname << + "'; type: " << type_to_enum((*fld_iter)->get_type()) << + "; id: " << (*fld_iter)->get_key() << ")." << endl; + + f_ << indent() << write_val((*fld_iter)->get_type(), fname) << "." << endl << + indent() << "oprot writeFieldEnd." << endl; + } + + f_ << indent() << "oprot writeFieldStop; writeStructEnd; writeMessageEnd." << endl; + f_ << indent() << "oprot transport flush"; + + st_close_method(f_); +} + +// We only support receiving TResult structures (so this won't work on the server side) +void t_st_generator::generate_recv_method(t_function* function) { + string funname = function->get_name(); + string signature = function_signature(function); + + t_struct result(program_, "TResult"); + t_field success(function->get_returntype(), "success", 0); + result.append(&success); + + t_struct* xs = function->get_xceptions(); + const vector& fields = xs->get_members(); + vector::const_iterator f_iter; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + // duplicate the field, but call it "exception"... we don't need a dynamic name + t_field *exception = new t_field((*f_iter)->get_type(), "exception", (*f_iter)->get_key()); + result.append(exception); + } + + st_method(f_, client_class_name(), "recv" + capitalize(funname)); + f_ << "| f msg res | " << endl << + indent() << "msg := oprot readMessageBegin." << endl << + indent() << "self validateRemoteMessage: msg." << endl << + indent() << "res := " << struct_reader(&result) << "." << endl << + indent() << "oprot readMessageEnd." << endl << + indent() << "oprot transport flush." << endl << + indent() << "res exception ifNotNil: [res exception signal]." << endl << + indent() << "^ res"; + st_close_method(f_); +} + +string t_st_generator::function_types_comment(t_function* fn) { + std::ostringstream out; + const vector& fields = fn->get_arglist()->get_members(); + vector::const_iterator f_iter; + + out << "\""; + + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + out << (*f_iter)->get_name() << ": " << type_name((*f_iter)->get_type()); + if ((f_iter + 1) != fields.end()) { + out << ", "; + } + } + + out << "\""; + + return out.str(); +} + +/** + * Generates a service client definition. + * + * @param tservice The service to generate a server for. + */ +void t_st_generator::generate_service_client(t_service* tservice) { + string extends = ""; + string extends_client = "TClient"; + vector functions = tservice->get_functions(); + vector::iterator f_iter; + + if (tservice->get_extends() != NULL) { + extends = type_name(tservice->get_extends()); + extends_client = extends + "Client"; + } + + f_ << extends_client << " subclass: #" << prefix(client_class_name()) << endl << + "\tinstanceVariableNames: ''\n" << + "\tclassVariableNames: ''\n" << + "\tpoolDictionaries: ''\n" << + "\tcategory: '" << generated_category() << "'!\n\n"; + + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string funname = (*f_iter)->get_name(); + string signature = function_signature(*f_iter); + + st_method(f_, client_class_name(), signature); + f_ << function_types_comment(*f_iter) << endl << + indent() << "self send" << capitalize(signature) << "." << endl; + + if (!(*f_iter)->is_oneway()) { + f_ << indent() << "^ self recv" << capitalize(funname) << " success " << endl; + } + + st_close_method(f_); + + generate_send_method(*f_iter); + if (!(*f_iter)->is_oneway()) { + generate_recv_method(*f_iter); + } + } +} + +string t_st_generator::sanitize(string s) { + std::ostringstream out; + bool underscore = false; + + for (unsigned int i = 0; i < s.size(); i++) { + if (s[i] == '_') { + underscore = true; + continue; + } + if (underscore) { + out << (char) toupper(s[i]); + underscore = false; + continue; + } + out << s[i]; + } + + return out.str(); +} + +/** + * Renders a function signature of the form 'type name(args)' + * + * @param tfunction Function definition + * @return String of rendered function definition + */ +string t_st_generator::function_signature(t_function* tfunction) { + return tfunction->get_name() + capitalize(argument_list(tfunction->get_arglist())); +} + +/** + * Renders a field list + */ +string t_st_generator::argument_list(t_struct* tstruct) { + string result = ""; + + const vector& fields = tstruct->get_members(); + vector::const_iterator f_iter; + bool first = true; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if (first) { + first = false; + } else { + result += " "; + } + result += (*f_iter)->get_name() + ": " + (*f_iter)->get_name(); + } + return result; +} + +string t_st_generator::type_name(t_type* ttype) { + string prefix = ""; + t_program* program = ttype->get_program(); + if (program != NULL && program != program_) { + if (!ttype->is_service()) { + prefix = program->get_name() + "_types."; + } + } + + string name = ttype->get_name(); + if (ttype->is_struct() || ttype->is_xception()) { + name = capitalize(ttype->get_name()); + } + + return prefix + name; +} + +/* Convert t_type to Smalltalk type code */ +string t_st_generator::type_to_enum(t_type* type) { + type = get_true_type(type); + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_VOID: + throw "NO T_VOID CONSTRUCT"; + case t_base_type::TYPE_STRING: + return "TType string"; + case t_base_type::TYPE_BOOL: + return "TType bool"; + case t_base_type::TYPE_BYTE: + return "TType byte"; + case t_base_type::TYPE_I16: + return "TType i16"; + case t_base_type::TYPE_I32: + return "TType i32"; + case t_base_type::TYPE_I64: + return "TType i64"; + case t_base_type::TYPE_DOUBLE: + return "TType double"; + } + } else if (type->is_enum()) { + return "TType i32"; + } else if (type->is_struct() || type->is_xception()) { + return "TType struct"; + } else if (type->is_map()) { + return "TType map"; + } else if (type->is_set()) { + return "TType set"; + } else if (type->is_list()) { + return "TType list"; + } + + throw "INVALID TYPE IN type_to_enum: " + type->get_name(); +} + + +THRIFT_REGISTER_GENERATOR(st, "Smalltalk", ""); diff --git a/compiler/cpp/src/generate/t_xsd_generator.cc b/compiler/cpp/src/generate/t_xsd_generator.cc new file mode 100644 index 00000000..729a91ae --- /dev/null +++ b/compiler/cpp/src/generate/t_xsd_generator.cc @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include +#include +#include "t_generator.h" +#include "platform.h" +using namespace std; + + +/** + * XSD generator, creates an XSD for the base types etc. + * + */ +class t_xsd_generator : public t_generator { + public: + t_xsd_generator( + t_program* program, + const std::map& parsed_options, + const std::string& option_string) + : t_generator(program) + { + out_dir_base_ = "gen-xsd"; + } + + virtual ~t_xsd_generator() {} + + /** + * Init and close methods + */ + + void init_generator(); + void close_generator(); + + /** + * Program-level generation functions + */ + + void generate_typedef(t_typedef* ttypedef); + void generate_enum(t_enum* tenum) {} + + void generate_service(t_service* tservice); + void generate_struct(t_struct* tstruct); + + private: + + void generate_element(std::ostream& out, std::string name, t_type* ttype, t_struct* attrs=NULL, bool optional=false, bool nillable=false, bool list_element=false); + + std::string ns(std::string in, std::string ns) { + return ns + ":" + in; + } + + std::string xsd(std::string in) { + return ns(in, "xsd"); + } + + std::string type_name(t_type* ttype); + std::string base_type_name(t_base_type::t_base tbase); + + /** + * Output xsd/php file + */ + std::ofstream f_xsd_; + std::ofstream f_php_; + + /** + * Output string stream + */ + std::ostringstream s_xsd_types_; + +}; + + +void t_xsd_generator::init_generator() { + // Make output directory + MKDIR(get_out_dir().c_str()); + + // Make output file + string f_php_name = get_out_dir()+program_->get_name()+"_xsd.php"; + f_php_.open(f_php_name.c_str()); + + f_php_ << + "" << endl; + f_php_.close(); +} + +void t_xsd_generator::generate_typedef(t_typedef* ttypedef) { + indent(s_xsd_types_) << + "get_name() << "\">" << endl; + indent_up(); + if (ttypedef->get_type()->is_string() && ((t_base_type*)ttypedef->get_type())->is_string_enum()) { + indent(s_xsd_types_) << + "get_type()) << "\">" << endl; + indent_up(); + const vector& values = ((t_base_type*)ttypedef->get_type())->get_string_enum_vals(); + vector::const_iterator v_iter; + for (v_iter = values.begin(); v_iter != values.end(); ++v_iter) { + indent(s_xsd_types_) << + "" << endl; + } + indent_down(); + indent(s_xsd_types_) << + "" << endl; + } else { + indent(s_xsd_types_) << + "get_type()) << "\" />" << endl; + } + indent_down(); + indent(s_xsd_types_) << + "" << endl << endl; +} + +void t_xsd_generator::generate_struct(t_struct* tstruct) { + vector::const_iterator m_iter; + const vector& members = tstruct->get_members(); + bool xsd_all = tstruct->get_xsd_all(); + + indent(s_xsd_types_) << "get_name() << "\">" << endl; + indent_up(); + if (xsd_all) { + indent(s_xsd_types_) << "" << endl; + } else { + indent(s_xsd_types_) << "" << endl; + } + indent_up(); + + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + generate_element(s_xsd_types_, (*m_iter)->get_name(), (*m_iter)->get_type(), (*m_iter)->get_xsd_attrs(), (*m_iter)->get_xsd_optional() || xsd_all, (*m_iter)->get_xsd_nillable()); + } + + indent_down(); + if (xsd_all) { + indent(s_xsd_types_) << "" << endl; + } else { + indent(s_xsd_types_) << "" << endl; + } + indent_down(); + indent(s_xsd_types_) << + "" << endl << + endl; +} + +void t_xsd_generator::generate_element(ostream& out, + string name, + t_type* ttype, + t_struct* attrs, + bool optional, + bool nillable, + bool list_element) { + string sminOccurs = (optional || list_element) ? " minOccurs=\"0\"" : ""; + string smaxOccurs = list_element ? " maxOccurs=\"unbounded\"" : ""; + string soptional = sminOccurs + smaxOccurs; + string snillable = nillable ? " nillable=\"true\"" : ""; + + if (ttype->is_void() || ttype->is_list()) { + indent(out) << + "" << endl; + indent_up(); + if (attrs == NULL && ttype->is_void()) { + indent(out) << + "" << endl; + } else { + indent(out) << + "" << endl; + indent_up(); + if (ttype->is_list()) { + indent(out) << "" << endl; + indent_up(); + string subname; + t_type* subtype = ((t_list*)ttype)->get_elem_type(); + if (subtype->is_base_type() || subtype->is_container()) { + subname = name + "_elt"; + } else { + subname = type_name(subtype); + } + f_php_ << "$GLOBALS['" << program_->get_name() << "_xsd_elt_" << name << "'] = '" << subname << "';" << endl; + generate_element(out, subname, subtype, NULL, false, false, true); + indent_down(); + indent(out) << "" << endl; + indent(out) << "" << endl; + } + if (attrs != NULL) { + const vector& members = attrs->get_members(); + vector::const_iterator a_iter; + for (a_iter = members.begin(); a_iter != members.end(); ++a_iter) { + indent(out) << "get_name() << "\" type=\"" << type_name((*a_iter)->get_type()) << "\" />" << endl; + } + } + indent_down(); + indent(out) << + "" << endl; + } + indent_down(); + indent(out) << + "" << endl; + } else { + if (attrs == NULL) { + indent(out) << + "" << endl; + } else { + // Wow, all this work for a SIMPLE TYPE with attributes?!?!?! + indent(out) << "" << endl; + indent_up(); + indent(out) << "" << endl; + indent_up(); + indent(out) << "" << endl; + indent_up(); + indent(out) << "" << endl; + indent_up(); + const vector& members = attrs->get_members(); + vector::const_iterator a_iter; + for (a_iter = members.begin(); a_iter != members.end(); ++a_iter) { + indent(out) << "get_name() << "\" type=\"" << type_name((*a_iter)->get_type()) << "\" />" << endl; + } + indent_down(); + indent(out) << "" << endl; + indent_down(); + indent(out) << "" << endl; + indent_down(); + indent(out) << "" << endl; + indent_down(); + indent(out) << "" << endl; + } + } +} + +void t_xsd_generator::generate_service(t_service* tservice) { + // Make output file + string f_xsd_name = get_out_dir()+tservice->get_name()+".xsd"; + f_xsd_.open(f_xsd_name.c_str()); + + string ns = program_->get_namespace("xsd"); + if (ns.size() > 0) { + ns = " targetNamespace=\"" + ns + "\" xmlns=\"" + ns + "\" " + + "elementFormDefault=\"qualified\""; + } + + // Print the XSD header + f_xsd_ << + "" << endl << + "" << endl << + endl << + "" << endl << + endl; + + // Print out the type definitions + indent(f_xsd_) << s_xsd_types_.str(); + + // Keep a list of all the possible exceptions that might get thrown + map all_xceptions; + + // List the elements that you might actually get + vector functions = tservice->get_functions(); + vector::iterator f_iter; + for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) { + string elemname = (*f_iter)->get_name() + "_response"; + t_type* returntype = (*f_iter)->get_returntype(); + generate_element(f_xsd_, elemname, returntype); + f_xsd_ << endl; + + t_struct* xs = (*f_iter)->get_xceptions(); + const std::vector& xceptions = xs->get_members(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + all_xceptions[(*x_iter)->get_name()] = (t_struct*)((*x_iter)->get_type()); + } + } + + map::iterator ax_iter; + for (ax_iter = all_xceptions.begin(); ax_iter != all_xceptions.end(); ++ax_iter) { + generate_element(f_xsd_, ax_iter->first, ax_iter->second); + } + + // Close the XSD document + f_xsd_ << endl << "" << endl; + f_xsd_.close(); +} + +string t_xsd_generator::type_name(t_type* ttype) { + if (ttype->is_typedef()) { + return ttype->get_name(); + } + + if (ttype->is_base_type()) { + return xsd(base_type_name(((t_base_type*)ttype)->get_base())); + } + + if (ttype->is_enum()) { + return xsd("int"); + } + + if (ttype->is_struct() || ttype->is_xception()) { + return ttype->get_name(); + } + + return "container"; +} + +/** + * Returns the XSD type that corresponds to the thrift type. + * + * @param tbase The base type + * @return Explicit XSD type, i.e. xsd:string + */ +string t_xsd_generator::base_type_name(t_base_type::t_base tbase) { + switch (tbase) { + case t_base_type::TYPE_VOID: + return "void"; + case t_base_type::TYPE_STRING: + return "string"; + case t_base_type::TYPE_BOOL: + return "boolean"; + case t_base_type::TYPE_BYTE: + return "byte"; + case t_base_type::TYPE_I16: + return "short"; + case t_base_type::TYPE_I32: + return "int"; + case t_base_type::TYPE_I64: + return "long"; + case t_base_type::TYPE_DOUBLE: + return "decimal"; + default: + throw "compiler error: no C++ base type name for base type " + t_base_type::t_base_name(tbase); + } +} + +THRIFT_REGISTER_GENERATOR(xsd, "XSD", ""); diff --git a/compiler/cpp/src/globals.h b/compiler/cpp/src/globals.h new file mode 100644 index 00000000..b2041436 --- /dev/null +++ b/compiler/cpp/src/globals.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_GLOBALS_H +#define T_GLOBALS_H + +#include +#include +#include +#include +#include + +/** + * This module contains all the global variables (slap on the wrist) that are + * shared throughout the program. The reason for this is to facilitate simple + * interaction between the parser and the rest of the program. Before calling + * yyparse(), the main.cc program will make necessary adjustments to these + * global variables such that the parser does the right thing and puts entries + * into the right containers, etc. + * + */ + +/** + * Hooray for forward declaration of types! + */ + +class t_program; +class t_scope; +class t_type; + +/** + * Parsing mode, two passes up in this gin rummy! + */ + +enum PARSE_MODE { + INCLUDES = 1, + PROGRAM = 2 +}; + +/** + * Strictness level + */ +extern int g_strict; + +/** + * The master program parse tree. This is accessed from within the parser code + * to build up the program elements. + */ +extern t_program* g_program; + +/** + * Global types for the parser to be able to reference + */ + +extern t_type* g_type_void; +extern t_type* g_type_string; +extern t_type* g_type_binary; +extern t_type* g_type_slist; +extern t_type* g_type_bool; +extern t_type* g_type_byte; +extern t_type* g_type_i16; +extern t_type* g_type_i32; +extern t_type* g_type_i64; +extern t_type* g_type_double; + +/** + * The scope that we are currently parsing into + */ +extern t_scope* g_scope; + +/** + * The parent scope to also load symbols into + */ +extern t_scope* g_parent_scope; + +/** + * The prefix for the parent scope entries + */ +extern std::string g_parent_prefix; + +/** + * The parsing pass that we are on. We do different things on each pass. + */ +extern PARSE_MODE g_parse_mode; + +/** + * Global time string, used in formatting error messages etc. + */ +extern char* g_time_str; + +/** + * The last parsed doctext comment. + */ +extern char* g_doctext; + +/** + * The location of the last parsed doctext comment. + */ +extern int g_doctext_lineno; + +#endif diff --git a/compiler/cpp/src/main.cc b/compiler/cpp/src/main.cc new file mode 100644 index 00000000..7a5d2d49 --- /dev/null +++ b/compiler/cpp/src/main.cc @@ -0,0 +1,1207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * thrift - a lightweight cross-language rpc/serialization tool + * + * This file contains the main compiler engine for Thrift, which invokes the + * scanner/parser to build the thrift object tree. The interface generation + * code for each language lives in a file by the language name under the + * generate/ folder, and all parse structures live in parse/ + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef MINGW +# include /* for GetFullPathName */ +#endif + +// Careful: must include globals first for extern definitions +#include "globals.h" + +#include "main.h" +#include "parse/t_program.h" +#include "parse/t_scope.h" +#include "generate/t_generator.h" + +#include "version.h" + +using namespace std; + +/** + * Global program tree + */ +t_program* g_program; + +/** + * Global types + */ + +t_type* g_type_void; +t_type* g_type_string; +t_type* g_type_binary; +t_type* g_type_slist; +t_type* g_type_bool; +t_type* g_type_byte; +t_type* g_type_i16; +t_type* g_type_i32; +t_type* g_type_i64; +t_type* g_type_double; + +/** + * Global scope + */ +t_scope* g_scope; + +/** + * Parent scope to also parse types + */ +t_scope* g_parent_scope; + +/** + * Prefix for putting types in parent scope + */ +string g_parent_prefix; + +/** + * Parsing pass + */ +PARSE_MODE g_parse_mode; + +/** + * Current directory of file being parsed + */ +string g_curdir; + +/** + * Current file being parsed + */ +string g_curpath; + +/** + * Search path for inclusions + */ +vector g_incl_searchpath; + +/** + * Should C++ include statements use path prefixes for other thrift-generated + * header files + */ +bool g_cpp_use_include_prefix = false; + +/** + * Global debug state + */ +int g_debug = 0; + +/** + * Strictness level + */ +int g_strict = 127; + +/** + * Warning level + */ +int g_warn = 1; + +/** + * Verbose output + */ +int g_verbose = 0; + +/** + * Global time string + */ +char* g_time_str; + +/** + * The last parsed doctext comment. + */ +char* g_doctext; + +/** + * The location of the last parsed doctext comment. + */ +int g_doctext_lineno; + +/** + * Flags to control code generation + */ +bool gen_cpp = false; +bool gen_dense = false; +bool gen_java = false; +bool gen_javabean = false; +bool gen_rb = false; +bool gen_py = false; +bool gen_py_newstyle = false; +bool gen_xsd = false; +bool gen_php = false; +bool gen_phpi = false; +bool gen_phps = true; +bool gen_phpa = false; +bool gen_phpo = false; +bool gen_rest = false; +bool gen_perl = false; +bool gen_erl = false; +bool gen_ocaml = false; +bool gen_hs = false; +bool gen_cocoa = false; +bool gen_csharp = false; +bool gen_st = false; +bool gen_recurse = false; + +/** + * MinGW doesn't have realpath, so use fallback implementation in that case, + * otherwise this just calls through to realpath + */ +char *saferealpath(const char *path, char *resolved_path) { +#ifdef MINGW + char buf[MAX_PATH]; + char* basename; + DWORD len = GetFullPathName(path, MAX_PATH, buf, &basename); + if (len == 0 || len > MAX_PATH - 1){ + strcpy(resolved_path, path); + } else { + CharLowerBuff(buf, len); + strcpy(resolved_path, buf); + } + return resolved_path; +#else + return realpath(path, resolved_path); +#endif +} + + +/** + * Report an error to the user. This is called yyerror for historical + * reasons (lex and yacc expect the error reporting routine to be called + * this). Call this function to report any errors to the user. + * yyerror takes printf style arguments. + * + * @param fmt C format string followed by additional arguments + */ +void yyerror(const char* fmt, ...) { + va_list args; + fprintf(stderr, + "[ERROR:%s:%d] (last token was '%s')\n", + g_curpath.c_str(), + yylineno, + yytext); + + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + + fprintf(stderr, "\n"); +} + +/** + * Prints a debug message from the parser. + * + * @param fmt C format string followed by additional arguments + */ +void pdebug(const char* fmt, ...) { + if (g_debug == 0) { + return; + } + va_list args; + printf("[PARSE:%d] ", yylineno); + va_start(args, fmt); + vprintf(fmt, args); + va_end(args); + printf("\n"); +} + +/** + * Prints a verbose output mode message + * + * @param fmt C format string followed by additional arguments + */ +void pverbose(const char* fmt, ...) { + if (g_verbose == 0) { + return; + } + va_list args; + va_start(args, fmt); + vprintf(fmt, args); + va_end(args); +} + +/** + * Prints a warning message + * + * @param fmt C format string followed by additional arguments + */ +void pwarning(int level, const char* fmt, ...) { + if (g_warn < level) { + return; + } + va_list args; + printf("[WARNING:%s:%d] ", g_curpath.c_str(), yylineno); + va_start(args, fmt); + vprintf(fmt, args); + va_end(args); + printf("\n"); +} + +/** + * Prints a failure message and exits + * + * @param fmt C format string followed by additional arguments + */ +void failure(const char* fmt, ...) { + va_list args; + fprintf(stderr, "[FAILURE:%s:%d] ", g_curpath.c_str(), yylineno); + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + printf("\n"); + exit(1); +} + +/** + * Converts a string filename into a thrift program name + */ +string program_name(string filename) { + string::size_type slash = filename.rfind("/"); + if (slash != string::npos) { + filename = filename.substr(slash+1); + } + string::size_type dot = filename.rfind("."); + if (dot != string::npos) { + filename = filename.substr(0, dot); + } + return filename; +} + +/** + * Gets the directory path of a filename + */ +string directory_name(string filename) { + string::size_type slash = filename.rfind("/"); + // No slash, just use the current directory + if (slash == string::npos) { + return "."; + } + return filename.substr(0, slash); +} + +/** + * Finds the appropriate file path for the given filename + */ +string include_file(string filename) { + // Absolute path? Just try that + if (filename[0] == '/') { + // Realpath! + char rp[PATH_MAX]; + if (saferealpath(filename.c_str(), rp) == NULL) { + pwarning(0, "Cannot open include file %s\n", filename.c_str()); + return std::string(); + } + + // Stat this file + struct stat finfo; + if (stat(rp, &finfo) == 0) { + return rp; + } + } else { // relative path, start searching + // new search path with current dir global + vector sp = g_incl_searchpath; + sp.insert(sp.begin(), g_curdir); + + // iterate through paths + vector::iterator it; + for (it = sp.begin(); it != sp.end(); it++) { + string sfilename = *(it) + "/" + filename; + + // Realpath! + char rp[PATH_MAX]; + if (saferealpath(sfilename.c_str(), rp) == NULL) { + continue; + } + + // Stat this files + struct stat finfo; + if (stat(rp, &finfo) == 0) { + return rp; + } + } + } + + // Uh oh + pwarning(0, "Could not find include file %s\n", filename.c_str()); + return std::string(); +} + +/** + * Clears any previously stored doctext string. + * Also prints a warning if we are discarding information. + */ +void clear_doctext() { + if (g_doctext != NULL) { + pwarning(2, "Uncaptured doctext at on line %d.", g_doctext_lineno); + } + free(g_doctext); + g_doctext = NULL; +} + +/** + * Cleans up text commonly found in doxygen-like comments + * + * Warning: if you mix tabs and spaces in a non-uniform way, + * you will get what you deserve. + */ +char* clean_up_doctext(char* doctext) { + // Convert to C++ string, and remove Windows's carriage returns. + string docstring = doctext; + docstring.erase( + remove(docstring.begin(), docstring.end(), '\r'), + docstring.end()); + + // Separate into lines. + vector lines; + string::size_type pos = string::npos; + string::size_type last; + while (true) { + last = (pos == string::npos) ? 0 : pos+1; + pos = docstring.find('\n', last); + if (pos == string::npos) { + // First bit of cleaning. If the last line is only whitespace, drop it. + string::size_type nonwhite = docstring.find_first_not_of(" \t", last); + if (nonwhite != string::npos) { + lines.push_back(docstring.substr(last)); + } + break; + } + lines.push_back(docstring.substr(last, pos-last)); + } + + // A very profound docstring. + if (lines.empty()) { + return NULL; + } + + // Clear leading whitespace from the first line. + pos = lines.front().find_first_not_of(" \t"); + lines.front().erase(0, pos); + + // If every nonblank line after the first has the same number of spaces/tabs, + // then a star, remove them. + bool have_prefix = true; + bool found_prefix = false; + string::size_type prefix_len = 0; + vector::iterator l_iter; + for (l_iter = lines.begin()+1; l_iter != lines.end(); ++l_iter) { + if (l_iter->empty()) { + continue; + } + + pos = l_iter->find_first_not_of(" \t"); + if (!found_prefix) { + if (pos != string::npos) { + if (l_iter->at(pos) == '*') { + found_prefix = true; + prefix_len = pos; + } else { + have_prefix = false; + break; + } + } else { + // Whitespace-only line. Truncate it. + l_iter->clear(); + } + } else if (l_iter->size() > pos + && l_iter->at(pos) == '*' + && pos == prefix_len) { + // Business as usual. + } else if (pos == string::npos) { + // Whitespace-only line. Let's truncate it for them. + l_iter->clear(); + } else { + // The pattern has been broken. + have_prefix = false; + break; + } + } + + // If our prefix survived, delete it from every line. + if (have_prefix) { + // Get the star too. + prefix_len++; + for (l_iter = lines.begin()+1; l_iter != lines.end(); ++l_iter) { + l_iter->erase(0, prefix_len); + } + } + + // Now delete the minimum amount of leading whitespace from each line. + prefix_len = string::npos; + for (l_iter = lines.begin()+1; l_iter != lines.end(); ++l_iter) { + if (l_iter->empty()) { + continue; + } + pos = l_iter->find_first_not_of(" \t"); + if (pos != string::npos + && (prefix_len == string::npos || pos < prefix_len)) { + prefix_len = pos; + } + } + + // If our prefix survived, delete it from every line. + if (prefix_len != string::npos) { + for (l_iter = lines.begin()+1; l_iter != lines.end(); ++l_iter) { + l_iter->erase(0, prefix_len); + } + } + + // Remove trailing whitespace from every line. + for (l_iter = lines.begin(); l_iter != lines.end(); ++l_iter) { + pos = l_iter->find_last_not_of(" \t"); + if (pos != string::npos && pos != l_iter->length()-1) { + l_iter->erase(pos+1); + } + } + + // If the first line is empty, remove it. + // Don't do this earlier because a lot of steps skip the first line. + if (lines.front().empty()) { + lines.erase(lines.begin()); + } + + // Now rejoin the lines and copy them back into doctext. + docstring.clear(); + for (l_iter = lines.begin(); l_iter != lines.end(); ++l_iter) { + docstring += *l_iter; + docstring += '\n'; + } + + assert(docstring.length() <= strlen(doctext)); + strcpy(doctext, docstring.c_str()); + return doctext; +} + +/** Set to true to debug docstring parsing */ +static bool dump_docs = false; + +/** + * Dumps docstrings to stdout + * Only works for top-level definitions and the whole program doc + * (i.e., not enum constants, struct fields, or functions. + */ +void dump_docstrings(t_program* program) { + string progdoc = program->get_doc(); + if (!progdoc.empty()) { + printf("Whole program doc:\n%s\n", progdoc.c_str()); + } + const vector& typedefs = program->get_typedefs(); + vector::const_iterator t_iter; + for (t_iter = typedefs.begin(); t_iter != typedefs.end(); ++t_iter) { + t_typedef* td = *t_iter; + if (td->has_doc()) { + printf("typedef %s:\n%s\n", td->get_name().c_str(), td->get_doc().c_str()); + } + } + const vector& enums = program->get_enums(); + vector::const_iterator e_iter; + for (e_iter = enums.begin(); e_iter != enums.end(); ++e_iter) { + t_enum* en = *e_iter; + if (en->has_doc()) { + printf("enum %s:\n%s\n", en->get_name().c_str(), en->get_doc().c_str()); + } + } + const vector& consts = program->get_consts(); + vector::const_iterator c_iter; + for (c_iter = consts.begin(); c_iter != consts.end(); ++c_iter) { + t_const* co = *c_iter; + if (co->has_doc()) { + printf("const %s:\n%s\n", co->get_name().c_str(), co->get_doc().c_str()); + } + } + const vector& structs = program->get_structs(); + vector::const_iterator s_iter; + for (s_iter = structs.begin(); s_iter != structs.end(); ++s_iter) { + t_struct* st = *s_iter; + if (st->has_doc()) { + printf("struct %s:\n%s\n", st->get_name().c_str(), st->get_doc().c_str()); + } + } + const vector& xceptions = program->get_xceptions(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + t_struct* xn = *x_iter; + if (xn->has_doc()) { + printf("xception %s:\n%s\n", xn->get_name().c_str(), xn->get_doc().c_str()); + } + } + const vector& services = program->get_services(); + vector::const_iterator v_iter; + for (v_iter = services.begin(); v_iter != services.end(); ++v_iter) { + t_service* sv = *v_iter; + if (sv->has_doc()) { + printf("service %s:\n%s\n", sv->get_name().c_str(), sv->get_doc().c_str()); + } + } +} + +/** + * Call generate_fingerprint for every structure and enum. + */ +void generate_all_fingerprints(t_program* program) { + const vector& structs = program->get_structs(); + vector::const_iterator s_iter; + for (s_iter = structs.begin(); s_iter != structs.end(); ++s_iter) { + t_struct* st = *s_iter; + st->generate_fingerprint(); + } + + const vector& xceptions = program->get_xceptions(); + vector::const_iterator x_iter; + for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) { + t_struct* st = *x_iter; + st->generate_fingerprint(); + } + + const vector& enums = program->get_enums(); + vector::const_iterator e_iter; + for (e_iter = enums.begin(); e_iter != enums.end(); ++e_iter) { + t_enum* e = *e_iter; + e->generate_fingerprint(); + } + + g_type_void->generate_fingerprint(); + + // If you want to generate fingerprints for implicit structures, start here. + /* + const vector& services = program->get_services(); + vector::const_iterator v_iter; + for (v_iter = services.begin(); v_iter != services.end(); ++v_iter) { + t_service* sv = *v_iter; + } + */ +} + +/** + * Prints the version number + */ +void version() { + printf("Thrift version %s-%s\n", THRIFT_VERSION, THRIFT_REVISION); +} + +/** + * Diplays the usage message and then exits with an error code. + */ +void usage() { + fprintf(stderr, "Usage: thrift [options] file\n"); + fprintf(stderr, "Options:\n"); + fprintf(stderr, " -version Print the compiler version\n"); + fprintf(stderr, " -o dir Set the output directory for gen-* packages\n"); + fprintf(stderr, " (default: current directory)\n"); + fprintf(stderr, " -I dir Add a directory to the list of directories\n"); + fprintf(stderr, " searched for include directives\n"); + fprintf(stderr, " -nowarn Suppress all compiler warnings (BAD!)\n"); + fprintf(stderr, " -strict Strict compiler warnings on\n"); + fprintf(stderr, " -v[erbose] Verbose mode\n"); + fprintf(stderr, " -r[ecurse] Also generate included files\n"); + fprintf(stderr, " -debug Parse debug trace to stdout\n"); + fprintf(stderr, " --gen STR Generate code with a dynamically-registered generator.\n"); + fprintf(stderr, " STR has the form language[:key1=val1[,key2,[key3=val3]]].\n"); + fprintf(stderr, " Keys and values are options passed to the generator.\n"); + fprintf(stderr, " Many options will not require values.\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "Available generators (and options):\n"); + + t_generator_registry::gen_map_t gen_map = t_generator_registry::get_generator_map(); + t_generator_registry::gen_map_t::iterator iter; + for (iter = gen_map.begin(); iter != gen_map.end(); ++iter) { + fprintf(stderr, " %s (%s):\n", + iter->second->get_short_name().c_str(), + iter->second->get_long_name().c_str()); + fprintf(stderr, "%s", iter->second->get_documentation().c_str()); + } + exit(1); +} + +/** + * You know, when I started working on Thrift I really thought it wasn't going + * to become a programming language because it was just a generator and it + * wouldn't need runtime type information and all that jazz. But then we + * decided to add constants, and all of a sudden that means runtime type + * validation and inference, except the "runtime" is the code generator + * runtime. Shit. I've been had. + */ +void validate_const_rec(std::string name, t_type* type, t_const_value* value) { + if (type->is_void()) { + throw "type error: cannot declare a void const: " + name; + } + + if (type->is_base_type()) { + t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); + switch (tbase) { + case t_base_type::TYPE_STRING: + if (value->get_type() != t_const_value::CV_STRING) { + throw "type error: const \"" + name + "\" was declared as string"; + } + break; + case t_base_type::TYPE_BOOL: + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as bool"; + } + break; + case t_base_type::TYPE_BYTE: + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as byte"; + } + break; + case t_base_type::TYPE_I16: + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as i16"; + } + break; + case t_base_type::TYPE_I32: + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as i32"; + } + break; + case t_base_type::TYPE_I64: + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as i64"; + } + break; + case t_base_type::TYPE_DOUBLE: + if (value->get_type() != t_const_value::CV_INTEGER && + value->get_type() != t_const_value::CV_DOUBLE) { + throw "type error: const \"" + name + "\" was declared as double"; + } + break; + default: + throw "compiler error: no const of base type " + t_base_type::t_base_name(tbase) + name; + } + } else if (type->is_enum()) { + if (value->get_type() != t_const_value::CV_INTEGER) { + throw "type error: const \"" + name + "\" was declared as enum"; + } + } else if (type->is_struct() || type->is_xception()) { + if (value->get_type() != t_const_value::CV_MAP) { + throw "type error: const \"" + name + "\" was declared as struct/xception"; + } + const vector& fields = ((t_struct*)type)->get_members(); + vector::const_iterator f_iter; + + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + if (v_iter->first->get_type() != t_const_value::CV_STRING) { + throw "type error: " + name + " struct key must be string"; + } + t_type* field_type = NULL; + for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { + if ((*f_iter)->get_name() == v_iter->first->get_string()) { + field_type = (*f_iter)->get_type(); + } + } + if (field_type == NULL) { + throw "type error: " + type->get_name() + " has no field " + v_iter->first->get_string(); + } + + validate_const_rec(name + "." + v_iter->first->get_string(), field_type, v_iter->second); + } + } else if (type->is_map()) { + t_type* k_type = ((t_map*)type)->get_key_type(); + t_type* v_type = ((t_map*)type)->get_val_type(); + const map& val = value->get_map(); + map::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + validate_const_rec(name + "", k_type, v_iter->first); + validate_const_rec(name + "", v_type, v_iter->second); + } + } else if (type->is_list() || type->is_set()) { + t_type* e_type; + if (type->is_list()) { + e_type = ((t_list*)type)->get_elem_type(); + } else { + e_type = ((t_set*)type)->get_elem_type(); + } + const vector& val = value->get_list(); + vector::const_iterator v_iter; + for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { + validate_const_rec(name + "", e_type, *v_iter); + } + } +} + +/** + * Check the type of the parsed const information against its declared type + */ +void validate_const_type(t_const* c) { + validate_const_rec(c->get_name(), c->get_type(), c->get_value()); +} + +/** + * Check the type of a default value assigned to a field. + */ +void validate_field_value(t_field* field, t_const_value* cv) { + validate_const_rec(field->get_name(), field->get_type(), cv); +} + +/** + * Check that all the elements of a throws block are actually exceptions. + */ +bool validate_throws(t_struct* throws) { + const vector& members = throws->get_members(); + vector::const_iterator m_iter; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + if (!(*m_iter)->get_type()->is_xception()) { + return false; + } + } + return true; +} + +/** + * Parses a program + */ +void parse(t_program* program, t_program* parent_program) { + // Get scope file path + string path = program->get_path(); + + // Set current dir global, which is used in the include_file function + g_curdir = directory_name(path); + g_curpath = path; + + // Open the file + yyin = fopen(path.c_str(), "r"); + if (yyin == 0) { + failure("Could not open input file: \"%s\"", path.c_str()); + } + + // Create new scope and scan for includes + pverbose("Scanning %s for includes\n", path.c_str()); + g_parse_mode = INCLUDES; + g_program = program; + g_scope = program->scope(); + try { + yylineno = 1; + if (yyparse() != 0) { + failure("Parser error during include pass."); + } + } catch (string x) { + failure(x.c_str()); + } + fclose(yyin); + + // Recursively parse all the include programs + vector& includes = program->get_includes(); + vector::iterator iter; + for (iter = includes.begin(); iter != includes.end(); ++iter) { + parse(*iter, program); + } + + // Parse the program file + g_parse_mode = PROGRAM; + g_program = program; + g_scope = program->scope(); + g_parent_scope = (parent_program != NULL) ? parent_program->scope() : NULL; + g_parent_prefix = program->get_name() + "."; + g_curpath = path; + yyin = fopen(path.c_str(), "r"); + if (yyin == 0) { + failure("Could not open input file: \"%s\"", path.c_str()); + } + pverbose("Parsing %s for types\n", path.c_str()); + yylineno = 1; + try { + if (yyparse() != 0) { + failure("Parser error during types pass."); + } + } catch (string x) { + failure(x.c_str()); + } + fclose(yyin); +} + +/** + * Generate code + */ +void generate(t_program* program, const vector& generator_strings) { + // Oooohh, recursive code generation, hot!! + if (gen_recurse) { + const vector& includes = program->get_includes(); + for (size_t i = 0; i < includes.size(); ++i) { + // Propogate output path from parent to child programs + includes[i]->set_out_path(program->get_out_path()); + + generate(includes[i], generator_strings); + } + } + + // Generate code! + try { + pverbose("Program: %s\n", program->get_path().c_str()); + + // Compute fingerprints. + generate_all_fingerprints(program); + + if (dump_docs) { + dump_docstrings(program); + } + + vector::const_iterator iter; + for (iter = generator_strings.begin(); iter != generator_strings.end(); ++iter) { + t_generator* generator = t_generator_registry::get_generator(program, *iter); + + if (generator == NULL) { + pwarning(1, "Unable to get a generator for \"%s\".\n", iter->c_str()); + } else { + pverbose("Generating \"%s\"\n", iter->c_str()); + generator->generate_program(); + delete generator; + } + } + + } catch (string s) { + printf("Error: %s\n", s.c_str()); + } catch (const char* exc) { + printf("Error: %s\n", exc); + } + +} + +/** + * Parse it up.. then spit it back out, in pretty much every language. Alright + * not that many languages, but the cool ones that we care about. + */ +int main(int argc, char** argv) { + int i; + std::string out_path; + + // Setup time string + time_t now = time(NULL); + g_time_str = ctime(&now); + + // Check for necessary arguments, you gotta have at least a filename and + // an output language flag + if (argc < 2) { + usage(); + } + + vector generator_strings; + + // Set the current path to a dummy value to make warning messages clearer. + g_curpath = "arguments"; + + // Hacky parameter handling... I didn't feel like using a library sorry! + for (i = 1; i < argc-1; i++) { + char* arg; + + arg = strtok(argv[i], " "); + while (arg != NULL) { + // Treat double dashes as single dashes + if (arg[0] == '-' && arg[1] == '-') { + ++arg; + } + + if (strcmp(arg, "-version") == 0) { + version(); + exit(1); + } else if (strcmp(arg, "-debug") == 0) { + g_debug = 1; + } else if (strcmp(arg, "-nowarn") == 0) { + g_warn = 0; + } else if (strcmp(arg, "-strict") == 0) { + g_strict = 255; + g_warn = 2; + } else if (strcmp(arg, "-v") == 0 || strcmp(arg, "-verbose") == 0 ) { + g_verbose = 1; + } else if (strcmp(arg, "-r") == 0 || strcmp(arg, "-recurse") == 0 ) { + gen_recurse = true; + } else if (strcmp(arg, "-gen") == 0) { + arg = argv[++i]; + if (arg == NULL) { + fprintf(stderr, "!!! Missing generator specification\n"); + usage(); + } + generator_strings.push_back(arg); + } else if (strcmp(arg, "-dense") == 0) { + gen_dense = true; + } else if (strcmp(arg, "-cpp") == 0) { + gen_cpp = true; + } else if (strcmp(arg, "-javabean") == 0) { + gen_javabean = true; + } else if (strcmp(arg, "-java") == 0) { + gen_java = true; + } else if (strcmp(arg, "-php") == 0) { + gen_php = true; + } else if (strcmp(arg, "-phpi") == 0) { + gen_phpi = true; + } else if (strcmp(arg, "-phps") == 0) { + gen_php = true; + gen_phps = true; + } else if (strcmp(arg, "-phpl") == 0) { + gen_php = true; + gen_phps = false; + } else if (strcmp(arg, "-phpa") == 0) { + gen_php = true; + gen_phps = false; + gen_phpa = true; + } else if (strcmp(arg, "-phpo") == 0) { + gen_php = true; + gen_phpo = true; + } else if (strcmp(arg, "-rest") == 0) { + gen_rest = true; + } else if (strcmp(arg, "-py") == 0) { + gen_py = true; + } else if (strcmp(arg, "-pyns") == 0) { + gen_py = true; + gen_py_newstyle = true; + } else if (strcmp(arg, "-rb") == 0) { + gen_rb = true; + } else if (strcmp(arg, "-xsd") == 0) { + gen_xsd = true; + } else if (strcmp(arg, "-perl") == 0) { + gen_perl = true; + } else if (strcmp(arg, "-erl") == 0) { + gen_erl = true; + } else if (strcmp(arg, "-ocaml") == 0) { + gen_ocaml = true; + } else if (strcmp(arg, "-hs") == 0) { + gen_hs = true; + } else if (strcmp(arg, "-cocoa") == 0) { + gen_cocoa = true; + } else if (strcmp(arg, "-st") == 0) { + gen_st = true; + } else if (strcmp(arg, "-csharp") == 0) { + gen_csharp = true; + } else if (strcmp(arg, "-cpp_use_include_prefix") == 0) { + g_cpp_use_include_prefix = true; + } else if (strcmp(arg, "-I") == 0) { + // An argument of "-I\ asdf" is invalid and has unknown results + arg = argv[++i]; + + if (arg == NULL) { + fprintf(stderr, "!!! Missing Include directory\n"); + usage(); + } + g_incl_searchpath.push_back(arg); + } else if (strcmp(arg, "-o") == 0) { + arg = argv[++i]; + if (arg == NULL) { + fprintf(stderr, "-o: missing output directory\n"); + usage(); + } + out_path = arg; + +#ifdef MINGW + //strip out trailing \ on Windows + int last = out_path.length()-1; + if (out_path[last] == '\\') + { + out_path.erase(last); + } +#endif + + struct stat sb; + if (stat(out_path.c_str(), &sb) < 0) { + fprintf(stderr, "Output directory %s is unusable: %s\n", out_path.c_str(), strerror(errno)); + return -1; + } + if (! S_ISDIR(sb.st_mode)) { + fprintf(stderr, "Output directory %s exists but is not a directory\n", out_path.c_str()); + return -1; + } + } else { + fprintf(stderr, "!!! Unrecognized option: %s\n", arg); + usage(); + } + + // Tokenize more + arg = strtok(NULL, " "); + } + } + + // if you're asking for version, you have a right not to pass a file + if (strcmp(argv[argc-1], "-version") == 0) { + version(); + exit(1); + } + + // TODO(dreiss): Delete these when everyone is using the new hotness. + if (gen_cpp) { + pwarning(1, "-cpp is deprecated. Use --gen cpp"); + string gen_string = "cpp:"; + if (gen_dense) { + gen_string.append("dense,"); + } + if (g_cpp_use_include_prefix) { + gen_string.append("include_prefix,"); + } + generator_strings.push_back(gen_string); + } + if (gen_java) { + pwarning(1, "-java is deprecated. Use --gen java"); + generator_strings.push_back("java"); + } + if (gen_javabean) { + pwarning(1, "-javabean is deprecated. Use --gen java:beans"); + generator_strings.push_back("java:beans"); + } + if (gen_csharp) { + pwarning(1, "-csharp is deprecated. Use --gen csharp"); + generator_strings.push_back("csharp"); + } + if (gen_py) { + pwarning(1, "-py is deprecated. Use --gen py"); + generator_strings.push_back("py"); + } + if (gen_rb) { + pwarning(1, "-rb is deprecated. Use --gen rb"); + generator_strings.push_back("rb"); + } + if (gen_perl) { + pwarning(1, "-perl is deprecated. Use --gen perl"); + generator_strings.push_back("perl"); + } + if (gen_php || gen_phpi) { + pwarning(1, "-php is deprecated. Use --gen php"); + string gen_string = "php:"; + if (gen_phpi) { + gen_string.append("inlined,"); + } else if(gen_phps) { + gen_string.append("server,"); + } else if(gen_phpa) { + gen_string.append("autoload,"); + } else if(gen_phpo) { + gen_string.append("oop,"); + } else if(gen_rest) { + gen_string.append("rest,"); + } + generator_strings.push_back(gen_string); + } + if (gen_cocoa) { + pwarning(1, "-cocoa is deprecated. Use --gen cocoa"); + generator_strings.push_back("cocoa"); + } + if (gen_erl) { + pwarning(1, "-erl is deprecated. Use --gen erl"); + generator_strings.push_back("erl"); + } + if (gen_st) { + pwarning(1, "-st is deprecated. Use --gen st"); + generator_strings.push_back("st"); + } + if (gen_ocaml) { + pwarning(1, "-ocaml is deprecated. Use --gen ocaml"); + generator_strings.push_back("ocaml"); + } + if (gen_hs) { + pwarning(1, "-hs is deprecated. Use --gen hs"); + generator_strings.push_back("hs"); + } + if (gen_xsd) { + pwarning(1, "-xsd is deprecated. Use --gen xsd"); + generator_strings.push_back("xsd"); + } + + // You gotta generate something! + if (generator_strings.empty()) { + fprintf(stderr, "!!! No output language(s) specified\n\n"); + usage(); + } + + // Real-pathify it + char rp[PATH_MAX]; + if (argv[i] == NULL) { + fprintf(stderr, "!!! Missing file name\n"); + usage(); + } + if (saferealpath(argv[i], rp) == NULL) { + failure("Could not open input file with realpath: %s", argv[i]); + } + string input_file(rp); + + // Instance of the global parse tree + t_program* program = new t_program(input_file); + if (out_path.size()) { + program->set_out_path(out_path); + } + + // Compute the cpp include prefix. + // infer this from the filename passed in + string input_filename = argv[i]; + string include_prefix; + + string::size_type last_slash = string::npos; + if ((last_slash = input_filename.rfind("/")) != string::npos) { + include_prefix = input_filename.substr(0, last_slash); + } + + program->set_include_prefix(include_prefix); + + // Initialize global types + g_type_void = new t_base_type("void", t_base_type::TYPE_VOID); + g_type_string = new t_base_type("string", t_base_type::TYPE_STRING); + g_type_binary = new t_base_type("string", t_base_type::TYPE_STRING); + ((t_base_type*)g_type_binary)->set_binary(true); + g_type_slist = new t_base_type("string", t_base_type::TYPE_STRING); + ((t_base_type*)g_type_slist)->set_string_list(true); + g_type_bool = new t_base_type("bool", t_base_type::TYPE_BOOL); + g_type_byte = new t_base_type("byte", t_base_type::TYPE_BYTE); + g_type_i16 = new t_base_type("i16", t_base_type::TYPE_I16); + g_type_i32 = new t_base_type("i32", t_base_type::TYPE_I32); + g_type_i64 = new t_base_type("i64", t_base_type::TYPE_I64); + g_type_double = new t_base_type("double", t_base_type::TYPE_DOUBLE); + + // Parse it! + parse(program, NULL); + + // The current path is not really relevant when we are doing generation. + // Reset the variable to make warning messages clearer. + g_curpath = "generation"; + // Reset yylineno for the heck of it. Use 1 instead of 0 because + // That is what shows up during argument parsing. + yylineno = 1; + + // Generate it! + generate(program, generator_strings); + + // Clean up. Who am I kidding... this program probably orphans heap memory + // all over the place, but who cares because it is about to exit and it is + // all referenced and used by this wacky parse tree up until now anyways. + + delete program; + delete g_type_void; + delete g_type_string; + delete g_type_bool; + delete g_type_byte; + delete g_type_i16; + delete g_type_i32; + delete g_type_i64; + delete g_type_double; + + // Finished + return 0; +} diff --git a/compiler/cpp/src/main.h b/compiler/cpp/src/main.h new file mode 100644 index 00000000..9b7d82d7 --- /dev/null +++ b/compiler/cpp/src/main.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_MAIN_H +#define T_MAIN_H + +#include +#include "parse/t_const.h" +#include "parse/t_field.h" + +/** + * Defined in the flex library + */ + +int yylex(void); + +int yyparse(void); + +/** + * Expected to be defined by Flex/Bison + */ +void yyerror(const char* fmt, ...); + +/** + * Parse debugging output, used to print helpful info + */ +void pdebug(const char* fmt, ...); + +/** + * Parser warning + */ +void pwarning(int level, const char* fmt, ...); + +/** + * Failure! + */ +void failure(const char* fmt, ...); + +/** + * Check constant types + */ +void validate_const_type(t_const* c); + +/** + * Check constant types + */ +void validate_field_value(t_field* field, t_const_value* cv); + +/** + * Check members of a throws block + */ +bool validate_throws(t_struct* throws); + +/** + * Converts a string filename into a thrift program name + */ +std::string program_name(std::string filename); + +/** + * Gets the directory path of a filename + */ +std::string directory_name(std::string filename); + +/** + * Get the absolute path for an include file + */ +std::string include_file(std::string filename); + +/** + * Clears any previously stored doctext string. + */ +void clear_doctext(); + +/** + * Cleans up text commonly found in doxygen-like comments + */ +char* clean_up_doctext(char* doctext); + +/** + * Flex utilities + */ + +extern int yylineno; +extern char yytext[]; +extern FILE* yyin; + +#endif diff --git a/compiler/cpp/src/md5.c b/compiler/cpp/src/md5.c new file mode 100644 index 00000000..c35d96c5 --- /dev/null +++ b/compiler/cpp/src/md5.c @@ -0,0 +1,381 @@ +/* + Copyright (C) 1999, 2000, 2002 Aladdin Enterprises. All rights reserved. + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + L. Peter Deutsch + ghost@aladdin.com + + */ +/* $Id: md5.c,v 1.6 2002/04/13 19:20:28 lpd Exp $ */ +/* + Independent implementation of MD5 (RFC 1321). + + This code implements the MD5 Algorithm defined in RFC 1321, whose + text is available at + http://www.ietf.org/rfc/rfc1321.txt + The code is derived from the text of the RFC, including the test suite + (section A.5) but excluding the rest of Appendix A. It does not include + any code or documentation that is identified in the RFC as being + copyrighted. + + The original and principal author of md5.c is L. Peter Deutsch + . Other authors are noted in the change history + that follows (in reverse chronological order): + + 2002-04-13 lpd Clarified derivation from RFC 1321; now handles byte order + either statically or dynamically; added missing #include + in library. + 2002-03-11 lpd Corrected argument list for main(), and added int return + type, in test program and T value program. + 2002-02-21 lpd Added missing #include in test program. + 2000-07-03 lpd Patched to eliminate warnings about "constant is + unsigned in ANSI C, signed in traditional"; made test program + self-checking. + 1999-11-04 lpd Edited comments slightly for automatic TOC extraction. + 1999-10-18 lpd Fixed typo in header comment (ansi2knr rather than md5). + 1999-05-03 lpd Original version. + */ + +#include "md5.h" +#include + +#undef BYTE_ORDER /* 1 = big-endian, -1 = little-endian, 0 = unknown */ +#ifdef ARCH_IS_BIG_ENDIAN +# define BYTE_ORDER (ARCH_IS_BIG_ENDIAN ? 1 : -1) +#else +# define BYTE_ORDER 0 +#endif + +#define T_MASK ((md5_word_t)~0) +#define T1 /* 0xd76aa478 */ (T_MASK ^ 0x28955b87) +#define T2 /* 0xe8c7b756 */ (T_MASK ^ 0x173848a9) +#define T3 0x242070db +#define T4 /* 0xc1bdceee */ (T_MASK ^ 0x3e423111) +#define T5 /* 0xf57c0faf */ (T_MASK ^ 0x0a83f050) +#define T6 0x4787c62a +#define T7 /* 0xa8304613 */ (T_MASK ^ 0x57cfb9ec) +#define T8 /* 0xfd469501 */ (T_MASK ^ 0x02b96afe) +#define T9 0x698098d8 +#define T10 /* 0x8b44f7af */ (T_MASK ^ 0x74bb0850) +#define T11 /* 0xffff5bb1 */ (T_MASK ^ 0x0000a44e) +#define T12 /* 0x895cd7be */ (T_MASK ^ 0x76a32841) +#define T13 0x6b901122 +#define T14 /* 0xfd987193 */ (T_MASK ^ 0x02678e6c) +#define T15 /* 0xa679438e */ (T_MASK ^ 0x5986bc71) +#define T16 0x49b40821 +#define T17 /* 0xf61e2562 */ (T_MASK ^ 0x09e1da9d) +#define T18 /* 0xc040b340 */ (T_MASK ^ 0x3fbf4cbf) +#define T19 0x265e5a51 +#define T20 /* 0xe9b6c7aa */ (T_MASK ^ 0x16493855) +#define T21 /* 0xd62f105d */ (T_MASK ^ 0x29d0efa2) +#define T22 0x02441453 +#define T23 /* 0xd8a1e681 */ (T_MASK ^ 0x275e197e) +#define T24 /* 0xe7d3fbc8 */ (T_MASK ^ 0x182c0437) +#define T25 0x21e1cde6 +#define T26 /* 0xc33707d6 */ (T_MASK ^ 0x3cc8f829) +#define T27 /* 0xf4d50d87 */ (T_MASK ^ 0x0b2af278) +#define T28 0x455a14ed +#define T29 /* 0xa9e3e905 */ (T_MASK ^ 0x561c16fa) +#define T30 /* 0xfcefa3f8 */ (T_MASK ^ 0x03105c07) +#define T31 0x676f02d9 +#define T32 /* 0x8d2a4c8a */ (T_MASK ^ 0x72d5b375) +#define T33 /* 0xfffa3942 */ (T_MASK ^ 0x0005c6bd) +#define T34 /* 0x8771f681 */ (T_MASK ^ 0x788e097e) +#define T35 0x6d9d6122 +#define T36 /* 0xfde5380c */ (T_MASK ^ 0x021ac7f3) +#define T37 /* 0xa4beea44 */ (T_MASK ^ 0x5b4115bb) +#define T38 0x4bdecfa9 +#define T39 /* 0xf6bb4b60 */ (T_MASK ^ 0x0944b49f) +#define T40 /* 0xbebfbc70 */ (T_MASK ^ 0x4140438f) +#define T41 0x289b7ec6 +#define T42 /* 0xeaa127fa */ (T_MASK ^ 0x155ed805) +#define T43 /* 0xd4ef3085 */ (T_MASK ^ 0x2b10cf7a) +#define T44 0x04881d05 +#define T45 /* 0xd9d4d039 */ (T_MASK ^ 0x262b2fc6) +#define T46 /* 0xe6db99e5 */ (T_MASK ^ 0x1924661a) +#define T47 0x1fa27cf8 +#define T48 /* 0xc4ac5665 */ (T_MASK ^ 0x3b53a99a) +#define T49 /* 0xf4292244 */ (T_MASK ^ 0x0bd6ddbb) +#define T50 0x432aff97 +#define T51 /* 0xab9423a7 */ (T_MASK ^ 0x546bdc58) +#define T52 /* 0xfc93a039 */ (T_MASK ^ 0x036c5fc6) +#define T53 0x655b59c3 +#define T54 /* 0x8f0ccc92 */ (T_MASK ^ 0x70f3336d) +#define T55 /* 0xffeff47d */ (T_MASK ^ 0x00100b82) +#define T56 /* 0x85845dd1 */ (T_MASK ^ 0x7a7ba22e) +#define T57 0x6fa87e4f +#define T58 /* 0xfe2ce6e0 */ (T_MASK ^ 0x01d3191f) +#define T59 /* 0xa3014314 */ (T_MASK ^ 0x5cfebceb) +#define T60 0x4e0811a1 +#define T61 /* 0xf7537e82 */ (T_MASK ^ 0x08ac817d) +#define T62 /* 0xbd3af235 */ (T_MASK ^ 0x42c50dca) +#define T63 0x2ad7d2bb +#define T64 /* 0xeb86d391 */ (T_MASK ^ 0x14792c6e) + + +static void +md5_process(md5_state_t *pms, const md5_byte_t *data /*[64]*/) +{ + md5_word_t + a = pms->abcd[0], b = pms->abcd[1], + c = pms->abcd[2], d = pms->abcd[3]; + md5_word_t t; +#if BYTE_ORDER > 0 + /* Define storage only for big-endian CPUs. */ + md5_word_t X[16]; +#else + /* Define storage for little-endian or both types of CPUs. */ + md5_word_t xbuf[16]; + const md5_word_t *X; +#endif + + { +#if BYTE_ORDER == 0 + /* + * Determine dynamically whether this is a big-endian or + * little-endian machine, since we can use a more efficient + * algorithm on the latter. + */ + static const int w = 1; + + if (*((const md5_byte_t *)&w)) /* dynamic little-endian */ +#endif +#if BYTE_ORDER <= 0 /* little-endian */ + { + /* + * On little-endian machines, we can process properly aligned + * data without copying it. + */ + if (!((data - (const md5_byte_t *)0) & 3)) { + /* data are properly aligned */ + X = (const md5_word_t *)data; + } else { + /* not aligned */ + memcpy(xbuf, data, 64); + X = xbuf; + } + } +#endif +#if BYTE_ORDER == 0 + else /* dynamic big-endian */ +#endif +#if BYTE_ORDER >= 0 /* big-endian */ + { + /* + * On big-endian machines, we must arrange the bytes in the + * right order. + */ + const md5_byte_t *xp = data; + int i; + +# if BYTE_ORDER == 0 + X = xbuf; /* (dynamic only) */ +# else +# define xbuf X /* (static only) */ +# endif + for (i = 0; i < 16; ++i, xp += 4) + xbuf[i] = xp[0] + (xp[1] << 8) + (xp[2] << 16) + (xp[3] << 24); + } +#endif + } + +#define ROTATE_LEFT(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) + + /* Round 1. */ + /* Let [abcd k s i] denote the operation + a = b + ((a + F(b,c,d) + X[k] + T[i]) <<< s). */ +#define F(x, y, z) (((x) & (y)) | (~(x) & (z))) +#define SET(a, b, c, d, k, s, Ti)\ + t = a + F(b,c,d) + X[k] + Ti;\ + a = ROTATE_LEFT(t, s) + b + /* Do the following 16 operations. */ + SET(a, b, c, d, 0, 7, T1); + SET(d, a, b, c, 1, 12, T2); + SET(c, d, a, b, 2, 17, T3); + SET(b, c, d, a, 3, 22, T4); + SET(a, b, c, d, 4, 7, T5); + SET(d, a, b, c, 5, 12, T6); + SET(c, d, a, b, 6, 17, T7); + SET(b, c, d, a, 7, 22, T8); + SET(a, b, c, d, 8, 7, T9); + SET(d, a, b, c, 9, 12, T10); + SET(c, d, a, b, 10, 17, T11); + SET(b, c, d, a, 11, 22, T12); + SET(a, b, c, d, 12, 7, T13); + SET(d, a, b, c, 13, 12, T14); + SET(c, d, a, b, 14, 17, T15); + SET(b, c, d, a, 15, 22, T16); +#undef SET + + /* Round 2. */ + /* Let [abcd k s i] denote the operation + a = b + ((a + G(b,c,d) + X[k] + T[i]) <<< s). */ +#define G(x, y, z) (((x) & (z)) | ((y) & ~(z))) +#define SET(a, b, c, d, k, s, Ti)\ + t = a + G(b,c,d) + X[k] + Ti;\ + a = ROTATE_LEFT(t, s) + b + /* Do the following 16 operations. */ + SET(a, b, c, d, 1, 5, T17); + SET(d, a, b, c, 6, 9, T18); + SET(c, d, a, b, 11, 14, T19); + SET(b, c, d, a, 0, 20, T20); + SET(a, b, c, d, 5, 5, T21); + SET(d, a, b, c, 10, 9, T22); + SET(c, d, a, b, 15, 14, T23); + SET(b, c, d, a, 4, 20, T24); + SET(a, b, c, d, 9, 5, T25); + SET(d, a, b, c, 14, 9, T26); + SET(c, d, a, b, 3, 14, T27); + SET(b, c, d, a, 8, 20, T28); + SET(a, b, c, d, 13, 5, T29); + SET(d, a, b, c, 2, 9, T30); + SET(c, d, a, b, 7, 14, T31); + SET(b, c, d, a, 12, 20, T32); +#undef SET + + /* Round 3. */ + /* Let [abcd k s t] denote the operation + a = b + ((a + H(b,c,d) + X[k] + T[i]) <<< s). */ +#define H(x, y, z) ((x) ^ (y) ^ (z)) +#define SET(a, b, c, d, k, s, Ti)\ + t = a + H(b,c,d) + X[k] + Ti;\ + a = ROTATE_LEFT(t, s) + b + /* Do the following 16 operations. */ + SET(a, b, c, d, 5, 4, T33); + SET(d, a, b, c, 8, 11, T34); + SET(c, d, a, b, 11, 16, T35); + SET(b, c, d, a, 14, 23, T36); + SET(a, b, c, d, 1, 4, T37); + SET(d, a, b, c, 4, 11, T38); + SET(c, d, a, b, 7, 16, T39); + SET(b, c, d, a, 10, 23, T40); + SET(a, b, c, d, 13, 4, T41); + SET(d, a, b, c, 0, 11, T42); + SET(c, d, a, b, 3, 16, T43); + SET(b, c, d, a, 6, 23, T44); + SET(a, b, c, d, 9, 4, T45); + SET(d, a, b, c, 12, 11, T46); + SET(c, d, a, b, 15, 16, T47); + SET(b, c, d, a, 2, 23, T48); +#undef SET + + /* Round 4. */ + /* Let [abcd k s t] denote the operation + a = b + ((a + I(b,c,d) + X[k] + T[i]) <<< s). */ +#define I(x, y, z) ((y) ^ ((x) | ~(z))) +#define SET(a, b, c, d, k, s, Ti)\ + t = a + I(b,c,d) + X[k] + Ti;\ + a = ROTATE_LEFT(t, s) + b + /* Do the following 16 operations. */ + SET(a, b, c, d, 0, 6, T49); + SET(d, a, b, c, 7, 10, T50); + SET(c, d, a, b, 14, 15, T51); + SET(b, c, d, a, 5, 21, T52); + SET(a, b, c, d, 12, 6, T53); + SET(d, a, b, c, 3, 10, T54); + SET(c, d, a, b, 10, 15, T55); + SET(b, c, d, a, 1, 21, T56); + SET(a, b, c, d, 8, 6, T57); + SET(d, a, b, c, 15, 10, T58); + SET(c, d, a, b, 6, 15, T59); + SET(b, c, d, a, 13, 21, T60); + SET(a, b, c, d, 4, 6, T61); + SET(d, a, b, c, 11, 10, T62); + SET(c, d, a, b, 2, 15, T63); + SET(b, c, d, a, 9, 21, T64); +#undef SET + + /* Then perform the following additions. (That is increment each + of the four registers by the value it had before this block + was started.) */ + pms->abcd[0] += a; + pms->abcd[1] += b; + pms->abcd[2] += c; + pms->abcd[3] += d; +} + +void +md5_init(md5_state_t *pms) +{ + pms->count[0] = pms->count[1] = 0; + pms->abcd[0] = 0x67452301; + pms->abcd[1] = /*0xefcdab89*/ T_MASK ^ 0x10325476; + pms->abcd[2] = /*0x98badcfe*/ T_MASK ^ 0x67452301; + pms->abcd[3] = 0x10325476; +} + +void +md5_append(md5_state_t *pms, const md5_byte_t *data, int nbytes) +{ + const md5_byte_t *p = data; + int left = nbytes; + int offset = (pms->count[0] >> 3) & 63; + md5_word_t nbits = (md5_word_t)(nbytes << 3); + + if (nbytes <= 0) + return; + + /* Update the message length. */ + pms->count[1] += nbytes >> 29; + pms->count[0] += nbits; + if (pms->count[0] < nbits) + pms->count[1]++; + + /* Process an initial partial block. */ + if (offset) { + int copy = (offset + nbytes > 64 ? 64 - offset : nbytes); + + memcpy(pms->buf + offset, p, copy); + if (offset + copy < 64) + return; + p += copy; + left -= copy; + md5_process(pms, pms->buf); + } + + /* Process full blocks. */ + for (; left >= 64; p += 64, left -= 64) + md5_process(pms, p); + + /* Process a final partial block. */ + if (left) + memcpy(pms->buf, p, left); +} + +void +md5_finish(md5_state_t *pms, md5_byte_t digest[16]) +{ + static const md5_byte_t pad[64] = { + 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + md5_byte_t data[8]; + int i; + + /* Save the length before padding. */ + for (i = 0; i < 8; ++i) + data[i] = (md5_byte_t)(pms->count[i >> 2] >> ((i & 3) << 3)); + /* Pad to 56 bytes mod 64. */ + md5_append(pms, pad, ((55 - (pms->count[0] >> 3)) & 63) + 1); + /* Append the length. */ + md5_append(pms, data, 8); + for (i = 0; i < 16; ++i) + digest[i] = (md5_byte_t)(pms->abcd[i >> 2] >> ((i & 3) << 3)); +} diff --git a/compiler/cpp/src/md5.h b/compiler/cpp/src/md5.h new file mode 100644 index 00000000..698c995d --- /dev/null +++ b/compiler/cpp/src/md5.h @@ -0,0 +1,91 @@ +/* + Copyright (C) 1999, 2002 Aladdin Enterprises. All rights reserved. + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + + L. Peter Deutsch + ghost@aladdin.com + + */ +/* $Id: md5.h,v 1.4 2002/04/13 19:20:28 lpd Exp $ */ +/* + Independent implementation of MD5 (RFC 1321). + + This code implements the MD5 Algorithm defined in RFC 1321, whose + text is available at + http://www.ietf.org/rfc/rfc1321.txt + The code is derived from the text of the RFC, including the test suite + (section A.5) but excluding the rest of Appendix A. It does not include + any code or documentation that is identified in the RFC as being + copyrighted. + + The original and principal author of md5.h is L. Peter Deutsch + . Other authors are noted in the change history + that follows (in reverse chronological order): + + 2002-04-13 lpd Removed support for non-ANSI compilers; removed + references to Ghostscript; clarified derivation from RFC 1321; + now handles byte order either statically or dynamically. + 1999-11-04 lpd Edited comments slightly for automatic TOC extraction. + 1999-10-18 lpd Fixed typo in header comment (ansi2knr rather than md5); + added conditionalization for C++ compilation from Martin + Purschke . + 1999-05-03 lpd Original version. + */ + +#ifndef md5_INCLUDED +# define md5_INCLUDED + +/* + * This package supports both compile-time and run-time determination of CPU + * byte order. If ARCH_IS_BIG_ENDIAN is defined as 0, the code will be + * compiled to run only on little-endian CPUs; if ARCH_IS_BIG_ENDIAN is + * defined as non-zero, the code will be compiled to run only on big-endian + * CPUs; if ARCH_IS_BIG_ENDIAN is not defined, the code will be compiled to + * run on either big- or little-endian CPUs, but will run slightly less + * efficiently on either one than if ARCH_IS_BIG_ENDIAN is defined. + */ + +typedef unsigned char md5_byte_t; /* 8-bit byte */ +typedef unsigned int md5_word_t; /* 32-bit word */ + +/* Define the state of the MD5 Algorithm. */ +typedef struct md5_state_s { + md5_word_t count[2]; /* message length in bits, lsw first */ + md5_word_t abcd[4]; /* digest buffer */ + md5_byte_t buf[64]; /* accumulate block */ +} md5_state_t; + +#ifdef __cplusplus +extern "C" +{ +#endif + +/* Initialize the algorithm. */ +void md5_init(md5_state_t *pms); + +/* Append a string to the message. */ +void md5_append(md5_state_t *pms, const md5_byte_t *data, int nbytes); + +/* Finish the message and return the digest. */ +void md5_finish(md5_state_t *pms, md5_byte_t digest[16]); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif /* md5_INCLUDED */ diff --git a/compiler/cpp/src/parse/t_base_type.h b/compiler/cpp/src/parse/t_base_type.h new file mode 100644 index 00000000..1751df9b --- /dev/null +++ b/compiler/cpp/src/parse/t_base_type.h @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_BASE_TYPE_H +#define T_BASE_TYPE_H + +#include +#include "t_type.h" + +/** + * A thrift base type, which must be one of the defined enumerated types inside + * this definition. + * + */ +class t_base_type : public t_type { + public: + /** + * Enumeration of thrift base types + */ + enum t_base { + TYPE_VOID, + TYPE_STRING, + TYPE_BOOL, + TYPE_BYTE, + TYPE_I16, + TYPE_I32, + TYPE_I64, + TYPE_DOUBLE + }; + + t_base_type(std::string name, t_base base) : + t_type(name), + base_(base), + string_list_(false), + binary_(false), + string_enum_(false){} + + t_base get_base() const { + return base_; + } + + bool is_void() const { + return base_ == TYPE_VOID; + } + + bool is_string() const { + return base_ == TYPE_STRING; + } + + bool is_bool() const { + return base_ == TYPE_BOOL; + } + + void set_string_list(bool val) { + string_list_ = val; + } + + bool is_string_list() const { + return (base_ == TYPE_STRING) && string_list_; + } + + void set_binary(bool val) { + binary_ = val; + } + + bool is_binary() const { + return (base_ == TYPE_STRING) && binary_; + } + + void set_string_enum(bool val) { + string_enum_ = true; + } + + bool is_string_enum() const { + return base_ == TYPE_STRING && string_enum_; + } + + void add_string_enum_val(std::string val) { + string_enum_vals_.push_back(val); + } + + const std::vector& get_string_enum_vals() const { + return string_enum_vals_; + } + + bool is_base_type() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + std::string rv = t_base_name(base_); + if (rv == "(unknown)") { + throw "BUG: Can't get fingerprint material for this base type."; + } + return rv; + } + + static std::string t_base_name(t_base tbase) { + switch (tbase) { + case TYPE_VOID : return "void"; break; + case TYPE_STRING : return "string"; break; + case TYPE_BOOL : return "bool"; break; + case TYPE_BYTE : return "byte"; break; + case TYPE_I16 : return "i16"; break; + case TYPE_I32 : return "i32"; break; + case TYPE_I64 : return "i64"; break; + case TYPE_DOUBLE : return "double"; break; + default : return "(unknown)"; break; + } + } + + private: + t_base base_; + + bool string_list_; + bool binary_; + bool string_enum_; + std::vector string_enum_vals_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_const.h b/compiler/cpp/src/parse/t_const.h new file mode 100644 index 00000000..7fd81bd1 --- /dev/null +++ b/compiler/cpp/src/parse/t_const.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_CONST_H +#define T_CONST_H + +#include "t_type.h" +#include "t_const_value.h" + +/** + * A const is a constant value defined across languages that has a type and + * a value. The trick here is that the declared type might not match the type + * of the value object, since that is not determined until after parsing the + * whole thing out. + * + */ +class t_const : public t_doc { + public: + t_const(t_type* type, std::string name, t_const_value* value) : + type_(type), + name_(name), + value_(value) {} + + t_type* get_type() const { + return type_; + } + + std::string get_name() const { + return name_; + } + + t_const_value* get_value() const { + return value_; + } + + private: + t_type* type_; + std::string name_; + t_const_value* value_; +}; + +#endif + diff --git a/compiler/cpp/src/parse/t_const_value.h b/compiler/cpp/src/parse/t_const_value.h new file mode 100644 index 00000000..a7d6e31c --- /dev/null +++ b/compiler/cpp/src/parse/t_const_value.h @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_CONST_VALUE_H +#define T_CONST_VALUE_H + +#include "t_const.h" +#include +#include +#include + +/** + * A const value is something parsed that could be a map, set, list, struct + * or whatever. + * + */ +class t_const_value { + public: + + enum t_const_value_type { + CV_INTEGER, + CV_DOUBLE, + CV_STRING, + CV_MAP, + CV_LIST + }; + + t_const_value() {} + + t_const_value(int64_t val) { + set_integer(val); + } + + t_const_value(std::string val) { + set_string(val); + } + + void set_string(std::string val) { + valType_ = CV_STRING; + stringVal_ = val; + } + + std::string get_string() const { + return stringVal_; + } + + void set_integer(int64_t val) { + valType_ = CV_INTEGER; + intVal_ = val; + } + + int64_t get_integer() const { + return intVal_; + } + + void set_double(double val) { + valType_ = CV_DOUBLE; + doubleVal_ = val; + } + + double get_double() const { + return doubleVal_; + } + + void set_map() { + valType_ = CV_MAP; + } + + void add_map(t_const_value* key, t_const_value* val) { + mapVal_[key] = val; + } + + const std::map& get_map() const { + return mapVal_; + } + + void set_list() { + valType_ = CV_LIST; + } + + void add_list(t_const_value* val) { + listVal_.push_back(val); + } + + const std::vector& get_list() const { + return listVal_; + } + + t_const_value_type get_type() const { + return valType_; + } + + private: + std::map mapVal_; + std::vector listVal_; + std::string stringVal_; + int64_t intVal_; + double doubleVal_; + + t_const_value_type valType_; + +}; + +#endif + diff --git a/compiler/cpp/src/parse/t_container.h b/compiler/cpp/src/parse/t_container.h new file mode 100644 index 00000000..6753493a --- /dev/null +++ b/compiler/cpp/src/parse/t_container.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_CONTAINER_H +#define T_CONTAINER_H + +#include "t_type.h" + +class t_container : public t_type { + public: + t_container() : + cpp_name_(), + has_cpp_name_(false) {} + + virtual ~t_container() {} + + void set_cpp_name(std::string cpp_name) { + cpp_name_ = cpp_name; + has_cpp_name_ = true; + } + + bool has_cpp_name() { + return has_cpp_name_; + } + + std::string get_cpp_name() { + return cpp_name_; + } + + bool is_container() const { + return true; + } + + private: + std::string cpp_name_; + bool has_cpp_name_; + +}; + +#endif diff --git a/compiler/cpp/src/parse/t_doc.h b/compiler/cpp/src/parse/t_doc.h new file mode 100644 index 00000000..e52068cb --- /dev/null +++ b/compiler/cpp/src/parse/t_doc.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_DOC_H +#define T_DOC_H + +/** + * Documentation stubs + * + */ +class t_doc { + + public: + t_doc() : has_doc_(false) {} + + void set_doc(const std::string& doc) { + doc_ = doc; + has_doc_ = true; + } + + const std::string& get_doc() const { + return doc_; + } + + bool has_doc() { + return has_doc_; + } + + private: + std::string doc_; + bool has_doc_; + +}; + +#endif diff --git a/compiler/cpp/src/parse/t_enum.h b/compiler/cpp/src/parse/t_enum.h new file mode 100644 index 00000000..740f95ca --- /dev/null +++ b/compiler/cpp/src/parse/t_enum.h @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_ENUM_H +#define T_ENUM_H + +#include "t_enum_value.h" +#include + +/** + * An enumerated type. A list of constant objects with a name for the type. + * + */ +class t_enum : public t_type { + public: + t_enum(t_program* program) : + t_type(program) {} + + void set_name(const std::string& name) { + name_ = name; + } + + void append(t_enum_value* constant) { + constants_.push_back(constant); + } + + const std::vector& get_constants() { + return constants_; + } + + bool is_enum() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + return "enum"; + } + + private: + std::vector constants_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_enum_value.h b/compiler/cpp/src/parse/t_enum_value.h new file mode 100644 index 00000000..68e905bd --- /dev/null +++ b/compiler/cpp/src/parse/t_enum_value.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_ENUM_VALUE_H +#define T_ENUM_VALUE_H + +#include +#include "t_doc.h" + +/** + * A constant. These are used inside of enum definitions. Constants are just + * symbol identifiers that may or may not have an explicit value associated + * with them. + * + */ +class t_enum_value : public t_doc { + public: + t_enum_value(std::string name) : + name_(name), + has_value_(false), + value_(0) {} + + t_enum_value(std::string name, int value) : + name_(name), + has_value_(true), + value_(value) {} + + ~t_enum_value() {} + + const std::string& get_name() { + return name_; + } + + bool has_value() { + return has_value_; + } + + int get_value() { + return value_; + } + + private: + std::string name_; + bool has_value_; + int value_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_field.h b/compiler/cpp/src/parse/t_field.h new file mode 100644 index 00000000..67a2125c --- /dev/null +++ b/compiler/cpp/src/parse/t_field.h @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_FIELD_H +#define T_FIELD_H + +#include +#include + +#include "t_doc.h" + +// Forward declare for xsd_attrs +class t_struct; + +/** + * Class to represent a field in a thrift structure. A field has a data type, + * a symbolic name, and a numeric identifier. + * + */ +class t_field : public t_doc { + public: + t_field(t_type* type, std::string name) : + type_(type), + name_(name), + key_(0), + value_(NULL), + xsd_optional_(false), + xsd_nillable_(false), + xsd_attrs_(NULL) {} + + t_field(t_type* type, std::string name, int32_t key) : + type_(type), + name_(name), + key_(key), + req_(T_OPT_IN_REQ_OUT), + value_(NULL), + xsd_optional_(false), + xsd_nillable_(false), + xsd_attrs_(NULL) {} + + ~t_field() {} + + t_type* get_type() const { + return type_; + } + + const std::string& get_name() const { + return name_; + } + + int32_t get_key() const { + return key_; + } + + enum e_req { + T_REQUIRED, + T_OPTIONAL, + T_OPT_IN_REQ_OUT, + }; + + void set_req(e_req req) { + req_ = req; + } + + e_req get_req() const { + return req_; + } + + void set_value(t_const_value* value) { + value_ = value; + } + + t_const_value* get_value() { + return value_; + } + + void set_xsd_optional(bool xsd_optional) { + xsd_optional_ = xsd_optional; + } + + bool get_xsd_optional() const { + return xsd_optional_; + } + + void set_xsd_nillable(bool xsd_nillable) { + xsd_nillable_ = xsd_nillable; + } + + bool get_xsd_nillable() const { + return xsd_nillable_; + } + + void set_xsd_attrs(t_struct* xsd_attrs) { + xsd_attrs_ = xsd_attrs; + } + + t_struct* get_xsd_attrs() { + return xsd_attrs_; + } + + // This is not the same function as t_type::get_fingerprint_material, + // but it does the same thing. + std::string get_fingerprint_material() const { + return boost::lexical_cast(key_) + ":" + + ((req_ == T_OPTIONAL) ? "opt-" : "") + + type_->get_fingerprint_material(); + } + + /** + * Comparator to sort fields in ascending order by key. + * Make this a functor instead of a function to help GCC inline it. + * The arguments are (const) references to const pointers to const t_fields. + */ + struct key_compare { + bool operator()(t_field const * const & a, t_field const * const & b) { + return a->get_key() < b->get_key(); + } + }; + + + private: + t_type* type_; + std::string name_; + int32_t key_; + e_req req_; + t_const_value* value_; + + bool xsd_optional_; + bool xsd_nillable_; + t_struct* xsd_attrs_; + +}; + +#endif diff --git a/compiler/cpp/src/parse/t_function.h b/compiler/cpp/src/parse/t_function.h new file mode 100644 index 00000000..a72aa6c3 --- /dev/null +++ b/compiler/cpp/src/parse/t_function.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_FUNCTION_H +#define T_FUNCTION_H + +#include +#include "t_type.h" +#include "t_struct.h" +#include "t_doc.h" + +/** + * Representation of a function. Key parts are return type, function name, + * optional modifiers, and an argument list, which is implemented as a thrift + * struct. + * + */ +class t_function : public t_doc { + public: + t_function(t_type* returntype, + std::string name, + t_struct* arglist, + bool oneway=false) : + returntype_(returntype), + name_(name), + arglist_(arglist), + oneway_(oneway) { + xceptions_ = new t_struct(NULL); + } + + t_function(t_type* returntype, + std::string name, + t_struct* arglist, + t_struct* xceptions, + bool oneway=false) : + returntype_(returntype), + name_(name), + arglist_(arglist), + xceptions_(xceptions), + oneway_(oneway) + { + if (oneway_ && !xceptions_->get_members().empty()) { + throw std::string("Oneway methods can't throw exceptions."); + } + } + + ~t_function() {} + + t_type* get_returntype() const { + return returntype_; + } + + const std::string& get_name() const { + return name_; + } + + t_struct* get_arglist() const { + return arglist_; + } + + t_struct* get_xceptions() const { + return xceptions_; + } + + bool is_oneway() const { + return oneway_; + } + + private: + t_type* returntype_; + std::string name_; + t_struct* arglist_; + t_struct* xceptions_; + bool oneway_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_list.h b/compiler/cpp/src/parse/t_list.h new file mode 100644 index 00000000..21a9625e --- /dev/null +++ b/compiler/cpp/src/parse/t_list.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_LIST_H +#define T_LIST_H + +#include "t_container.h" + +/** + * A list is a lightweight container type that just wraps another data type. + * + */ +class t_list : public t_container { + public: + t_list(t_type* elem_type) : + elem_type_(elem_type) {} + + t_type* get_elem_type() const { + return elem_type_; + } + + bool is_list() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + return "list<" + elem_type_->get_fingerprint_material() + ">"; + } + + virtual void generate_fingerprint() { + t_type::generate_fingerprint(); + elem_type_->generate_fingerprint(); + } + + private: + t_type* elem_type_; +}; + +#endif + diff --git a/compiler/cpp/src/parse/t_map.h b/compiler/cpp/src/parse/t_map.h new file mode 100644 index 00000000..c4e358fd --- /dev/null +++ b/compiler/cpp/src/parse/t_map.h @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_MAP_H +#define T_MAP_H + +#include "t_container.h" + +/** + * A map is a lightweight container type that just wraps another two data + * types. + * + */ +class t_map : public t_container { + public: + t_map(t_type* key_type, t_type* val_type) : + key_type_(key_type), + val_type_(val_type) {} + + t_type* get_key_type() const { + return key_type_; + } + + t_type* get_val_type() const { + return val_type_; + } + + bool is_map() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + return "map<" + key_type_->get_fingerprint_material() + + "," + val_type_->get_fingerprint_material() + ">"; + } + + virtual void generate_fingerprint() { + t_type::generate_fingerprint(); + key_type_->generate_fingerprint(); + val_type_->generate_fingerprint(); + } + + private: + t_type* key_type_; + t_type* val_type_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_program.h b/compiler/cpp/src/parse/t_program.h new file mode 100644 index 00000000..4e1ab6a5 --- /dev/null +++ b/compiler/cpp/src/parse/t_program.h @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_PROGRAM_H +#define T_PROGRAM_H + +#include +#include +#include + +// For program_name() +#include "main.h" + +#include "t_doc.h" +#include "t_scope.h" +#include "t_base_type.h" +#include "t_typedef.h" +#include "t_enum.h" +#include "t_const.h" +#include "t_struct.h" +#include "t_service.h" +#include "t_list.h" +#include "t_map.h" +#include "t_set.h" +//#include "t_doc.h" + +/** + * Top level class representing an entire thrift program. A program consists + * fundamentally of the following: + * + * Typedefs + * Enumerations + * Constants + * Structs + * Exceptions + * Services + * + * The program module also contains the definitions of the base types. + * + */ +class t_program : public t_doc { + public: + t_program(std::string path, std::string name) : + path_(path), + name_(name), + out_path_("./") { + scope_ = new t_scope(); + } + + t_program(std::string path) : + path_(path), + out_path_("./") { + name_ = program_name(path); + scope_ = new t_scope(); + } + + // Path accessor + const std::string& get_path() const { return path_; } + + // Output path accessor + const std::string& get_out_path() const { return out_path_; } + + // Name accessor + const std::string& get_name() const { return name_; } + + // Namespace + const std::string& get_namespace() const { return namespace_; } + + // Include prefix accessor + const std::string& get_include_prefix() const { return include_prefix_; } + + // Accessors for program elements + const std::vector& get_typedefs() const { return typedefs_; } + const std::vector& get_enums() const { return enums_; } + const std::vector& get_consts() const { return consts_; } + const std::vector& get_structs() const { return structs_; } + const std::vector& get_xceptions() const { return xceptions_; } + const std::vector& get_objects() const { return objects_; } + const std::vector& get_services() const { return services_; } + + // Program elements + void add_typedef (t_typedef* td) { typedefs_.push_back(td); } + void add_enum (t_enum* te) { enums_.push_back(te); } + void add_const (t_const* tc) { consts_.push_back(tc); } + void add_struct (t_struct* ts) { objects_.push_back(ts); + structs_.push_back(ts); } + void add_xception (t_struct* tx) { objects_.push_back(tx); + xceptions_.push_back(tx); } + void add_service (t_service* ts) { services_.push_back(ts); } + + // Programs to include + const std::vector& get_includes() const { return includes_; } + + void set_out_path(std::string out_path) { + out_path_ = out_path; + // Ensure that it ends with a trailing '/' (or '\' for windows machines) + char c = out_path_.at(out_path_.size() - 1); + if (!(c == '/' || c == '\\')) { + out_path_.push_back('/'); + } + } + + // Scoping and namespacing + void set_namespace(std::string name) { + namespace_ = name; + } + + // Scope accessor + t_scope* scope() { + return scope_; + } + + // Includes + + void add_include(std::string path, std::string include_site) { + t_program* program = new t_program(path); + + // include prefix for this program is the site at which it was included + // (minus the filename) + std::string include_prefix; + std::string::size_type last_slash = std::string::npos; + if ((last_slash = include_site.rfind("/")) != std::string::npos) { + include_prefix = include_site.substr(0, last_slash); + } + + program->set_include_prefix(include_prefix); + includes_.push_back(program); + } + + std::vector& get_includes() { + return includes_; + } + + void set_include_prefix(std::string include_prefix) { + include_prefix_ = include_prefix; + + // this is intended to be a directory; add a trailing slash if necessary + int len = include_prefix_.size(); + if (len > 0 && include_prefix_[len - 1] != '/') { + include_prefix_ += '/'; + } + } + + // Language neutral namespace / packaging + void set_namespace(std::string language, std::string name_space) { + namespaces_[language] = name_space; + } + + std::string get_namespace(std::string language) const { + std::map::const_iterator iter = namespaces_.find(language); + if (iter == namespaces_.end()) { + return std::string(); + } + return iter->second; + } + + // Language specific namespace / packaging + + void add_cpp_include(std::string path) { + cpp_includes_.push_back(path); + } + + const std::vector& get_cpp_includes() { + return cpp_includes_; + } + + private: + + // File path + std::string path_; + + // Name + std::string name_; + + // Output directory + std::string out_path_; + + // Namespace + std::string namespace_; + + // Included programs + std::vector includes_; + + // Include prefix for this program, if any + std::string include_prefix_; + + // Identifier lookup scope + t_scope* scope_; + + // Components to generate code for + std::vector typedefs_; + std::vector enums_; + std::vector consts_; + std::vector objects_; + std::vector structs_; + std::vector xceptions_; + std::vector services_; + + // Dynamic namespaces + std::map namespaces_; + + // C++ extra includes + std::vector cpp_includes_; + +}; + +#endif diff --git a/compiler/cpp/src/parse/t_scope.h b/compiler/cpp/src/parse/t_scope.h new file mode 100644 index 00000000..122e3256 --- /dev/null +++ b/compiler/cpp/src/parse/t_scope.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_SCOPE_H +#define T_SCOPE_H + +#include +#include + +#include "t_type.h" +#include "t_service.h" + +/** + * This represents a variable scope used for looking up predefined types and + * services. Typically, a scope is associated with a t_program. Scopes are not + * used to determine code generation, but rather to resolve identifiers at + * parse time. + * + */ +class t_scope { + public: + t_scope() {} + + void add_type(std::string name, t_type* type) { + types_[name] = type; + } + + t_type* get_type(std::string name) { + return types_[name]; + } + + void add_service(std::string name, t_service* service) { + services_[name] = service; + } + + t_service* get_service(std::string name) { + return services_[name]; + } + + void add_constant(std::string name, t_const* constant) { + constants_[name] = constant; + } + + t_const* get_constant(std::string name) { + return constants_[name]; + } + + void print() { + std::map::iterator iter; + for (iter = types_.begin(); iter != types_.end(); ++iter) { + printf("%s => %s\n", + iter->first.c_str(), + iter->second->get_name().c_str()); + } + } + + private: + + // Map of names to types + std::map types_; + + // Map of names to constants + std::map constants_; + + // Map of names to services + std::map services_; + +}; + +#endif diff --git a/compiler/cpp/src/parse/t_service.h b/compiler/cpp/src/parse/t_service.h new file mode 100644 index 00000000..eee2dac1 --- /dev/null +++ b/compiler/cpp/src/parse/t_service.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_SERVICE_H +#define T_SERVICE_H + +#include "t_function.h" +#include + +class t_program; + +/** + * A service consists of a set of functions. + * + */ +class t_service : public t_type { + public: + t_service(t_program* program) : + t_type(program), + extends_(NULL) {} + + bool is_service() const { + return true; + } + + void set_extends(t_service* extends) { + extends_ = extends; + } + + void add_function(t_function* func) { + functions_.push_back(func); + } + + const std::vector& get_functions() const { + return functions_; + } + + t_service* get_extends() { + return extends_; + } + + virtual std::string get_fingerprint_material() const { + // Services should never be used in fingerprints. + throw "BUG: Can't get fingerprint material for service."; + } + + private: + std::vector functions_; + t_service* extends_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_set.h b/compiler/cpp/src/parse/t_set.h new file mode 100644 index 00000000..d1983577 --- /dev/null +++ b/compiler/cpp/src/parse/t_set.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_SET_H +#define T_SET_H + +#include "t_container.h" + +/** + * A set is a lightweight container type that just wraps another data type. + * + */ +class t_set : public t_container { + public: + t_set(t_type* elem_type) : + elem_type_(elem_type) {} + + t_type* get_elem_type() const { + return elem_type_; + } + + bool is_set() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + return "set<" + elem_type_->get_fingerprint_material() + ">"; + } + + virtual void generate_fingerprint() { + t_type::generate_fingerprint(); + elem_type_->generate_fingerprint(); + } + + private: + t_type* elem_type_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_struct.h b/compiler/cpp/src/parse/t_struct.h new file mode 100644 index 00000000..7980f803 --- /dev/null +++ b/compiler/cpp/src/parse/t_struct.h @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_STRUCT_H +#define T_STRUCT_H + +#include +#include +#include +#include + +#include "t_type.h" +#include "t_field.h" + +// Forward declare that puppy +class t_program; + +/** + * A struct is a container for a set of member fields that has a name. Structs + * are also used to implement exception types. + * + */ +class t_struct : public t_type { + public: + typedef std::vector members_type; + + t_struct(t_program* program) : + t_type(program), + is_xception_(false), + xsd_all_(false) {} + + t_struct(t_program* program, const std::string& name) : + t_type(program, name), + is_xception_(false), + xsd_all_(false) {} + + void set_name(const std::string& name) { + name_ = name; + } + + void set_xception(bool is_xception) { + is_xception_ = is_xception; + } + + void set_xsd_all(bool xsd_all) { + xsd_all_ = xsd_all; + } + + bool get_xsd_all() const { + return xsd_all_; + } + + bool append(t_field* elem) { + members_.push_back(elem); + + typedef members_type::iterator iter_type; + std::pair bounds = std::equal_range( + members_in_id_order_.begin(), members_in_id_order_.end(), elem, t_field::key_compare() + ); + if (bounds.first != bounds.second) { + return false; + } + members_in_id_order_.insert(bounds.second, elem); + return true; + } + + const members_type& get_members() { + return members_; + } + + const members_type& get_sorted_members() { + return members_in_id_order_; + } + + bool is_struct() const { + return !is_xception_; + } + + bool is_xception() const { + return is_xception_; + } + + virtual std::string get_fingerprint_material() const { + std::string rv = "{"; + members_type::const_iterator m_iter; + for (m_iter = members_in_id_order_.begin(); m_iter != members_in_id_order_.end(); ++m_iter) { + rv += (*m_iter)->get_fingerprint_material(); + rv += ";"; + } + rv += "}"; + return rv; + } + + virtual void generate_fingerprint() { + t_type::generate_fingerprint(); + members_type::const_iterator m_iter; + for (m_iter = members_in_id_order_.begin(); m_iter != members_in_id_order_.end(); ++m_iter) { + (*m_iter)->get_type()->generate_fingerprint(); + } + } + + private: + + members_type members_; + members_type members_in_id_order_; + bool is_xception_; + + bool xsd_all_; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_type.h b/compiler/cpp/src/parse/t_type.h new file mode 100644 index 00000000..4ce2eda1 --- /dev/null +++ b/compiler/cpp/src/parse/t_type.h @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_TYPE_H +#define T_TYPE_H + +#include +#include +#include +#include +#include "t_doc.h" + +// What's worse? This, or making a src/parse/non_inlined.cc? +#include "md5.h" + +class t_program; + +/** + * Generic representation of a thrift type. These objects are used by the + * parser module to build up a tree of object that are all explicitly typed. + * The generic t_type class exports a variety of useful methods that are + * used by the code generator to branch based upon different handling for the + * various types. + * + */ +class t_type : public t_doc { + public: + virtual ~t_type() {} + + virtual void set_name(const std::string& name) { + name_ = name; + } + + virtual const std::string& get_name() const { + return name_; + } + + virtual bool is_void() const { return false; } + virtual bool is_base_type() const { return false; } + virtual bool is_string() const { return false; } + virtual bool is_bool() const { return false; } + virtual bool is_typedef() const { return false; } + virtual bool is_enum() const { return false; } + virtual bool is_struct() const { return false; } + virtual bool is_xception() const { return false; } + virtual bool is_container() const { return false; } + virtual bool is_list() const { return false; } + virtual bool is_set() const { return false; } + virtual bool is_map() const { return false; } + virtual bool is_service() const { return false; } + + t_program* get_program() { + return program_; + } + + + // Return a string that uniquely identifies this type + // from any other thrift type in the world, as far as + // TDenseProtocol is concerned. + // We don't cache this, which is a little sloppy, + // but the compiler is so fast that it doesn't really matter. + virtual std::string get_fingerprint_material() const = 0; + + // Fingerprint should change whenever (and only when) + // the encoding via TDenseProtocol changes. + static const int fingerprint_len = 16; + + // Call this before trying get_*_fingerprint(). + virtual void generate_fingerprint() { + std::string material = get_fingerprint_material(); + md5_state_t ctx; + md5_init(&ctx); + md5_append(&ctx, (md5_byte_t*)(material.data()), (int)material.size()); + md5_finish(&ctx, (md5_byte_t*)fingerprint_); + } + + bool has_fingerprint() const { + for (int i = 0; i < fingerprint_len; i++) { + if (fingerprint_[i] != 0) { + return true; + } + } + return false; + } + + const uint8_t* get_binary_fingerprint() const { + return fingerprint_; + } + + std::string get_ascii_fingerprint() const { + std::string rv; + const uint8_t* fp = get_binary_fingerprint(); + for (int i = 0; i < fingerprint_len; i++) { + rv += byte_to_hex(fp[i]); + } + return rv; + } + + // This function will break (maybe badly) unless 0 <= num <= 16. + static char nybble_to_xdigit(int num) { + if (num < 10) { + return '0' + num; + } else { + return 'A' + num - 10; + } + } + + static std::string byte_to_hex(uint8_t byte) { + std::string rv; + rv += nybble_to_xdigit(byte >> 4); + rv += nybble_to_xdigit(byte & 0x0f); + return rv; + } + + std::map annotations_; + + protected: + t_type() : + program_(NULL) + { + memset(fingerprint_, 0, sizeof(fingerprint_)); + } + + t_type(t_program* program) : + program_(program) + { + memset(fingerprint_, 0, sizeof(fingerprint_)); + } + + t_type(t_program* program, std::string name) : + program_(program), + name_(name) + { + memset(fingerprint_, 0, sizeof(fingerprint_)); + } + + t_type(std::string name) : + program_(NULL), + name_(name) + { + memset(fingerprint_, 0, sizeof(fingerprint_)); + } + + t_program* program_; + std::string name_; + + uint8_t fingerprint_[fingerprint_len]; +}; + + +/** + * Placeholder struct for returning the key and value of an annotation + * during parsing. + */ +struct t_annotation { + std::string key; + std::string val; +}; + +#endif diff --git a/compiler/cpp/src/parse/t_typedef.h b/compiler/cpp/src/parse/t_typedef.h new file mode 100644 index 00000000..4c77d97a --- /dev/null +++ b/compiler/cpp/src/parse/t_typedef.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef T_TYPEDEF_H +#define T_TYPEDEF_H + +#include +#include "t_type.h" + +/** + * A typedef is a mapping from a symbolic name to another type. In dymanically + * typed languages (i.e. php/python) the code generator can actually usually + * ignore typedefs and just use the underlying type directly, though in C++ + * the symbolic naming can be quite useful for code clarity. + * + */ +class t_typedef : public t_type { + public: + t_typedef(t_program* program, t_type* type, std::string symbolic) : + t_type(program, symbolic), + type_(type), + symbolic_(symbolic) {} + + ~t_typedef() {} + + t_type* get_type() const { + return type_; + } + + const std::string& get_symbolic() const { + return symbolic_; + } + + bool is_typedef() const { + return true; + } + + virtual std::string get_fingerprint_material() const { + return type_->get_fingerprint_material(); + } + + virtual void generate_fingerprint() { + t_type::generate_fingerprint(); + if (!type_->has_fingerprint()) { + type_->generate_fingerprint(); + } + } + + private: + t_type* type_; + std::string symbolic_; +}; + +#endif diff --git a/compiler/cpp/src/platform.h b/compiler/cpp/src/platform.h new file mode 100644 index 00000000..bd97f68e --- /dev/null +++ b/compiler/cpp/src/platform.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * define for mkdir,since the method signature + * is different for the non-POSIX MinGW + */ + +#ifdef MINGW +#include +#else +#include +#include +#endif + +#if defined MINGW +#define MKDIR(x) mkdir(x) +#else +#define MKDIR(x) mkdir(x, S_IRWXU | S_IRWXG | S_IRWXO) +#endif diff --git a/compiler/cpp/src/thriftl.ll b/compiler/cpp/src/thriftl.ll new file mode 100644 index 00000000..2a8ab67e --- /dev/null +++ b/compiler/cpp/src/thriftl.ll @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Thrift scanner. + * + * Tokenizes a thrift definition file. + */ + +%{ + +#include +#include + +#include "main.h" +#include "globals.h" +#include "parse/t_program.h" + +/** + * Must be included AFTER parse/t_program.h, but I can't remember why anymore + * because I wrote this a while ago. + */ +#include "thrifty.h" + +void thrift_reserved_keyword(char* keyword) { + yyerror("Cannot use reserved language keyword: \"%s\"\n", keyword); + exit(1); +} + +void integer_overflow(char* text) { + yyerror("This integer is too big: \"%s\"\n", text); + exit(1); +} + +%} + +/** + * Provides the yylineno global, useful for debugging output + */ +%option lex-compat + +/** + * Helper definitions, comments, constants, and whatnot + */ + +intconstant ([+-]?[0-9]+) +hexconstant ("0x"[0-9A-Fa-f]+) +dubconstant ([+-]?[0-9]*(\.[0-9]+)?([eE][+-]?[0-9]+)?) +identifier ([a-zA-Z_][\.a-zA-Z_0-9]*) +whitespace ([ \t\r\n]*) +sillycomm ("/*""*"*"*/") +multicomm ("/*"[^*]"/"*([^*/]|[^*]"/"|"*"[^/])*"*"*"*/") +doctext ("/**"([^*/]|[^*]"/"|"*"[^/])*"*"*"*/") +comment ("//"[^\n]*) +unixcomment ("#"[^\n]*) +symbol ([:;\,\{\}\(\)\=<>\[\]]) +st_identifier ([a-zA-Z-][\.a-zA-Z_0-9-]*) +literal_begin (['\"]) + +%% + +{whitespace} { /* do nothing */ } +{sillycomm} { /* do nothing */ } +{multicomm} { /* do nothing */ } +{comment} { /* do nothing */ } +{unixcomment} { /* do nothing */ } + +{symbol} { return yytext[0]; } + +"namespace" { return tok_namespace; } +"cpp_namespace" { return tok_cpp_namespace; } +"cpp_include" { return tok_cpp_include; } +"cpp_type" { return tok_cpp_type; } +"java_package" { return tok_java_package; } +"cocoa_prefix" { return tok_cocoa_prefix; } +"csharp_namespace" { return tok_csharp_namespace; } +"php_namespace" { return tok_php_namespace; } +"py_module" { return tok_py_module; } +"perl_package" { return tok_perl_package; } +"ruby_namespace" { return tok_ruby_namespace; } +"smalltalk_category" { return tok_smalltalk_category; } +"smalltalk_prefix" { return tok_smalltalk_prefix; } +"xsd_all" { return tok_xsd_all; } +"xsd_optional" { return tok_xsd_optional; } +"xsd_nillable" { return tok_xsd_nillable; } +"xsd_namespace" { return tok_xsd_namespace; } +"xsd_attrs" { return tok_xsd_attrs; } +"include" { return tok_include; } +"void" { return tok_void; } +"bool" { return tok_bool; } +"byte" { return tok_byte; } +"i16" { return tok_i16; } +"i32" { return tok_i32; } +"i64" { return tok_i64; } +"double" { return tok_double; } +"string" { return tok_string; } +"binary" { return tok_binary; } +"slist" { return tok_slist; } +"senum" { return tok_senum; } +"map" { return tok_map; } +"list" { return tok_list; } +"set" { return tok_set; } +"oneway" { return tok_oneway; } +"typedef" { return tok_typedef; } +"struct" { return tok_struct; } +"exception" { return tok_xception; } +"extends" { return tok_extends; } +"throws" { return tok_throws; } +"service" { return tok_service; } +"enum" { return tok_enum; } +"const" { return tok_const; } +"required" { return tok_required; } +"optional" { return tok_optional; } +"async" { + pwarning(0, "\"async\" is deprecated. It is called \"oneway\" now.\n"); + return tok_oneway; +} + + +"abstract" { thrift_reserved_keyword(yytext); } +"and" { thrift_reserved_keyword(yytext); } +"args" { thrift_reserved_keyword(yytext); } +"as" { thrift_reserved_keyword(yytext); } +"assert" { thrift_reserved_keyword(yytext); } +"break" { thrift_reserved_keyword(yytext); } +"case" { thrift_reserved_keyword(yytext); } +"class" { thrift_reserved_keyword(yytext); } +"continue" { thrift_reserved_keyword(yytext); } +"declare" { thrift_reserved_keyword(yytext); } +"def" { thrift_reserved_keyword(yytext); } +"default" { thrift_reserved_keyword(yytext); } +"del" { thrift_reserved_keyword(yytext); } +"delete" { thrift_reserved_keyword(yytext); } +"do" { thrift_reserved_keyword(yytext); } +"elif" { thrift_reserved_keyword(yytext); } +"else" { thrift_reserved_keyword(yytext); } +"elseif" { thrift_reserved_keyword(yytext); } +"except" { thrift_reserved_keyword(yytext); } +"exec" { thrift_reserved_keyword(yytext); } +"false" { thrift_reserved_keyword(yytext); } +"finally" { thrift_reserved_keyword(yytext); } +"float" { thrift_reserved_keyword(yytext); } +"for" { thrift_reserved_keyword(yytext); } +"foreach" { thrift_reserved_keyword(yytext); } +"function" { thrift_reserved_keyword(yytext); } +"global" { thrift_reserved_keyword(yytext); } +"goto" { thrift_reserved_keyword(yytext); } +"if" { thrift_reserved_keyword(yytext); } +"implements" { thrift_reserved_keyword(yytext); } +"import" { thrift_reserved_keyword(yytext); } +"in" { thrift_reserved_keyword(yytext); } +"inline" { thrift_reserved_keyword(yytext); } +"instanceof" { thrift_reserved_keyword(yytext); } +"interface" { thrift_reserved_keyword(yytext); } +"is" { thrift_reserved_keyword(yytext); } +"lambda" { thrift_reserved_keyword(yytext); } +"native" { thrift_reserved_keyword(yytext); } +"new" { thrift_reserved_keyword(yytext); } +"not" { thrift_reserved_keyword(yytext); } +"or" { thrift_reserved_keyword(yytext); } +"pass" { thrift_reserved_keyword(yytext); } +"public" { thrift_reserved_keyword(yytext); } +"print" { thrift_reserved_keyword(yytext); } +"private" { thrift_reserved_keyword(yytext); } +"protected" { thrift_reserved_keyword(yytext); } +"raise" { thrift_reserved_keyword(yytext); } +"return" { thrift_reserved_keyword(yytext); } +"sizeof" { thrift_reserved_keyword(yytext); } +"static" { thrift_reserved_keyword(yytext); } +"switch" { thrift_reserved_keyword(yytext); } +"synchronized" { thrift_reserved_keyword(yytext); } +"this" { thrift_reserved_keyword(yytext); } +"throw" { thrift_reserved_keyword(yytext); } +"transient" { thrift_reserved_keyword(yytext); } +"true" { thrift_reserved_keyword(yytext); } +"try" { thrift_reserved_keyword(yytext); } +"unsigned" { thrift_reserved_keyword(yytext); } +"var" { thrift_reserved_keyword(yytext); } +"virtual" { thrift_reserved_keyword(yytext); } +"volatile" { thrift_reserved_keyword(yytext); } +"while" { thrift_reserved_keyword(yytext); } +"with" { thrift_reserved_keyword(yytext); } +"union" { thrift_reserved_keyword(yytext); } +"yield" { thrift_reserved_keyword(yytext); } + +{intconstant} { + errno = 0; + yylval.iconst = strtoll(yytext, NULL, 10); + if (errno == ERANGE) { + integer_overflow(yytext); + } + return tok_int_constant; +} + +{hexconstant} { + errno = 0; + yylval.iconst = strtoll(yytext+2, NULL, 16); + if (errno == ERANGE) { + integer_overflow(yytext); + } + return tok_int_constant; +} + +{dubconstant} { + yylval.dconst = atof(yytext); + return tok_dub_constant; +} + +{identifier} { + yylval.id = strdup(yytext); + return tok_identifier; +} + +{st_identifier} { + yylval.id = strdup(yytext); + return tok_st_identifier; +} + +{literal_begin} { + char mark = yytext[0]; + std::string result; + for(;;) + { + int ch = yyinput(); + switch (ch) { + case EOF: + yyerror("End of file while read string at %d\n", yylineno); + exit(1); + case '\n': + yyerror("End of line while read string at %d\n", yylineno - 1); + exit(1); + case '\\': + ch = yyinput(); + switch (ch) { + case 'r': + result.push_back('\r'); + continue; + case 'n': + result.push_back('\n'); + continue; + case 't': + result.push_back('\t'); + continue; + case '"': + result.push_back('"'); + continue; + case '\'': + result.push_back('\''); + continue; + case '\\': + result.push_back('\\'); + continue; + default: + yyerror("Bad escape character\n"); + return -1; + } + break; + default: + if (ch == mark) { + yylval.id = strdup(result.c_str()); + return tok_literal; + } else { + result.push_back(ch); + } + } + } +} + + +{doctext} { + /* This does not show up in the parse tree. */ + /* Rather, the parser will grab it out of the global. */ + if (g_parse_mode == PROGRAM) { + clear_doctext(); + g_doctext = strdup(yytext + 3); + g_doctext[strlen(g_doctext) - 2] = '\0'; + g_doctext = clean_up_doctext(g_doctext); + g_doctext_lineno = yylineno; + } +} + + +%% + +/* vim: filetype=lex +*/ diff --git a/compiler/cpp/src/thrifty.yy b/compiler/cpp/src/thrifty.yy new file mode 100644 index 00000000..bf5408e3 --- /dev/null +++ b/compiler/cpp/src/thrifty.yy @@ -0,0 +1,1127 @@ +%{ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Thrift parser. + * + * This parser is used on a thrift definition file. + * + */ + +#define __STDC_LIMIT_MACROS +#define __STDC_FORMAT_MACROS +#include +#include +#include +#include "main.h" +#include "globals.h" +#include "parse/t_program.h" +#include "parse/t_scope.h" + +/** + * This global variable is used for automatic numbering of field indices etc. + * when parsing the members of a struct. Field values are automatically + * assigned starting from -1 and working their way down. + */ +int y_field_val = -1; +int g_arglist = 0; + +%} + +/** + * This structure is used by the parser to hold the data types associated with + * various parse nodes. + */ +%union { + char* id; + int64_t iconst; + double dconst; + bool tbool; + t_doc* tdoc; + t_type* ttype; + t_base_type* tbase; + t_typedef* ttypedef; + t_enum* tenum; + t_enum_value* tenumv; + t_const* tconst; + t_const_value* tconstv; + t_struct* tstruct; + t_service* tservice; + t_function* tfunction; + t_field* tfield; + char* dtext; + t_field::e_req ereq; + t_annotation* tannot; +} + +/** + * Strings identifier + */ +%token tok_identifier +%token tok_literal +%token tok_doctext +%token tok_st_identifier + +/** + * Constant values + */ +%token tok_int_constant +%token tok_dub_constant + +/** + * Header keywords + */ +%token tok_include +%token tok_namespace +%token tok_cpp_namespace +%token tok_cpp_include +%token tok_cpp_type +%token tok_php_namespace +%token tok_py_module +%token tok_perl_package +%token tok_java_package +%token tok_xsd_all +%token tok_xsd_optional +%token tok_xsd_nillable +%token tok_xsd_namespace +%token tok_xsd_attrs +%token tok_ruby_namespace +%token tok_smalltalk_category +%token tok_smalltalk_prefix +%token tok_cocoa_prefix +%token tok_csharp_namespace + +/** + * Base datatype keywords + */ +%token tok_void +%token tok_bool +%token tok_byte +%token tok_string +%token tok_binary +%token tok_slist +%token tok_senum +%token tok_i16 +%token tok_i32 +%token tok_i64 +%token tok_double + +/** + * Complex type keywords + */ +%token tok_map +%token tok_list +%token tok_set + +/** + * Function modifiers + */ +%token tok_oneway + +/** + * Thrift language keywords + */ +%token tok_typedef +%token tok_struct +%token tok_xception +%token tok_throws +%token tok_extends +%token tok_service +%token tok_enum +%token tok_const +%token tok_required +%token tok_optional + +/** + * Grammar nodes + */ + +%type BaseType +%type ContainerType +%type SimpleContainerType +%type MapType +%type SetType +%type ListType + +%type Definition +%type TypeDefinition + +%type Typedef +%type DefinitionType + +%type TypeAnnotations +%type TypeAnnotationList +%type TypeAnnotation + +%type Field +%type FieldIdentifier +%type FieldRequiredness +%type FieldType +%type FieldValue +%type FieldList + +%type Enum +%type EnumDefList +%type EnumDef + +%type Senum +%type SenumDefList +%type SenumDef + +%type Const +%type ConstValue +%type ConstList +%type ConstListContents +%type ConstMap +%type ConstMapContents + +%type Struct +%type Xception +%type Service + +%type Function +%type FunctionType +%type FunctionList + +%type Throws +%type Extends +%type Oneway +%type XsdAll +%type XsdOptional +%type XsdNillable +%type XsdAttributes +%type CppType + +%type CaptureDocText + +%% + +/** + * Thrift Grammar Implementation. + * + * For the most part this source file works its way top down from what you + * might expect to find in a typical .thrift file, i.e. type definitions and + * namespaces up top followed by service definitions using those types. + */ + +Program: + HeaderList DefinitionList + { + pdebug("Program -> Headers DefinitionList"); + /* + TODO(dreiss): Decide whether full-program doctext is worth the trouble. + if ($1 != NULL) { + g_program->set_doc($1); + } + */ + clear_doctext(); + } + +CaptureDocText: + { + if (g_parse_mode == PROGRAM) { + $$ = g_doctext; + g_doctext = NULL; + } else { + $$ = NULL; + } + } + +/* TODO(dreiss): Try to DestroyDocText in all sorts or random places. */ +DestroyDocText: + { + if (g_parse_mode == PROGRAM) { + clear_doctext(); + } + } + +/* We have to DestroyDocText here, otherwise it catches the doctext + on the first real element. */ +HeaderList: + HeaderList DestroyDocText Header + { + pdebug("HeaderList -> HeaderList Header"); + } +| + { + pdebug("HeaderList -> "); + } + +Header: + Include + { + pdebug("Header -> Include"); + } +| tok_namespace tok_identifier tok_identifier + { + pdebug("Header -> tok_namespace tok_identifier tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace($2, $3); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_cpp_namespace tok_identifier + { + pwarning(1, "'cpp_namespace' is deprecated. Use 'namespace cpp' instead"); + pdebug("Header -> tok_cpp_namespace tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("cpp", $2); + } + } +| tok_cpp_include tok_literal + { + pdebug("Header -> tok_cpp_include tok_literal"); + if (g_parse_mode == PROGRAM) { + g_program->add_cpp_include($2); + } + } +| tok_php_namespace tok_identifier + { + pwarning(1, "'php_namespace' is deprecated. Use 'namespace php' instead"); + pdebug("Header -> tok_php_namespace tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("php", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_py_module tok_identifier + { + pwarning(1, "'py_module' is deprecated. Use 'namespace py' instead"); + pdebug("Header -> tok_py_module tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("py", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_perl_package tok_identifier + { + pwarning(1, "'perl_package' is deprecated. Use 'namespace perl' instead"); + pdebug("Header -> tok_perl_namespace tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("perl", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_ruby_namespace tok_identifier + { + pwarning(1, "'ruby_namespace' is deprecated. Use 'namespace rb' instead"); + pdebug("Header -> tok_ruby_namespace tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("rb", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_smalltalk_category tok_st_identifier + { + pwarning(1, "'smalltalk_category' is deprecated. Use 'namespace smalltalk.category' instead"); + pdebug("Header -> tok_smalltalk_category tok_st_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("smalltalk.category", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_smalltalk_prefix tok_identifier + { + pwarning(1, "'smalltalk_prefix' is deprecated. Use 'namespace smalltalk.prefix' instead"); + pdebug("Header -> tok_smalltalk_prefix tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("smalltalk.prefix", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_java_package tok_identifier + { + pwarning(1, "'java_package' is deprecated. Use 'namespace java' instead"); + pdebug("Header -> tok_java_package tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("java", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_cocoa_prefix tok_identifier + { + pwarning(1, "'cocoa_prefix' is deprecated. Use 'namespace cocoa' instead"); + pdebug("Header -> tok_cocoa_prefix tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("cocoa", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_xsd_namespace tok_literal + { + pwarning(1, "'xsd_namespace' is deprecated. Use 'namespace xsd' instead"); + pdebug("Header -> tok_xsd_namespace tok_literal"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("cocoa", $2); + } + } +/* TODO(dreiss): Get rid of this once everyone is using the new hotness. */ +| tok_csharp_namespace tok_identifier + { + pwarning(1, "'csharp_namespace' is deprecated. Use 'namespace csharp' instead"); + pdebug("Header -> tok_csharp_namespace tok_identifier"); + if (g_parse_mode == PROGRAM) { + g_program->set_namespace("csharp", $2); + } + } + +Include: + tok_include tok_literal + { + pdebug("Include -> tok_include tok_literal"); + if (g_parse_mode == INCLUDES) { + std::string path = include_file(std::string($2)); + if (!path.empty()) { + g_program->add_include(path, std::string($2)); + } + } + } + +DefinitionList: + DefinitionList CaptureDocText Definition + { + pdebug("DefinitionList -> DefinitionList Definition"); + if ($2 != NULL && $3 != NULL) { + $3->set_doc($2); + } + } +| + { + pdebug("DefinitionList -> "); + } + +Definition: + Const + { + pdebug("Definition -> Const"); + if (g_parse_mode == PROGRAM) { + g_program->add_const($1); + } + $$ = $1; + } +| TypeDefinition + { + pdebug("Definition -> TypeDefinition"); + if (g_parse_mode == PROGRAM) { + g_scope->add_type($1->get_name(), $1); + if (g_parent_scope != NULL) { + g_parent_scope->add_type(g_parent_prefix + $1->get_name(), $1); + } + } + $$ = $1; + } +| Service + { + pdebug("Definition -> Service"); + if (g_parse_mode == PROGRAM) { + g_scope->add_service($1->get_name(), $1); + if (g_parent_scope != NULL) { + g_parent_scope->add_service(g_parent_prefix + $1->get_name(), $1); + } + g_program->add_service($1); + } + $$ = $1; + } + +TypeDefinition: + Typedef + { + pdebug("TypeDefinition -> Typedef"); + if (g_parse_mode == PROGRAM) { + g_program->add_typedef($1); + } + } +| Enum + { + pdebug("TypeDefinition -> Enum"); + if (g_parse_mode == PROGRAM) { + g_program->add_enum($1); + } + } +| Senum + { + pdebug("TypeDefinition -> Senum"); + if (g_parse_mode == PROGRAM) { + g_program->add_typedef($1); + } + } +| Struct + { + pdebug("TypeDefinition -> Struct"); + if (g_parse_mode == PROGRAM) { + g_program->add_struct($1); + } + } +| Xception + { + pdebug("TypeDefinition -> Xception"); + if (g_parse_mode == PROGRAM) { + g_program->add_xception($1); + } + } + +Typedef: + tok_typedef DefinitionType tok_identifier + { + pdebug("TypeDef -> tok_typedef DefinitionType tok_identifier"); + t_typedef *td = new t_typedef(g_program, $2, $3); + $$ = td; + } + +CommaOrSemicolonOptional: + ',' + {} +| ';' + {} +| + {} + +Enum: + tok_enum tok_identifier '{' EnumDefList '}' + { + pdebug("Enum -> tok_enum tok_identifier { EnumDefList }"); + $$ = $4; + $$->set_name($2); + } + +EnumDefList: + EnumDefList EnumDef + { + pdebug("EnumDefList -> EnumDefList EnumDef"); + $$ = $1; + $$->append($2); + } +| + { + pdebug("EnumDefList -> "); + $$ = new t_enum(g_program); + } + +EnumDef: + CaptureDocText tok_identifier '=' tok_int_constant CommaOrSemicolonOptional + { + pdebug("EnumDef -> tok_identifier = tok_int_constant"); + if ($4 < 0) { + pwarning(1, "Negative value supplied for enum %s.\n", $2); + } + if ($4 > INT_MAX) { + pwarning(1, "64-bit value supplied for enum %s.\n", $2); + } + $$ = new t_enum_value($2, $4); + if ($1 != NULL) { + $$->set_doc($1); + } + if (g_parse_mode == PROGRAM) { + g_scope->add_constant($2, new t_const(g_type_i32, $2, new t_const_value($4))); + if (g_parent_scope != NULL) { + g_parent_scope->add_constant(g_parent_prefix + $2, new t_const(g_type_i32, $2, new t_const_value($4))); + } + } + } +| + CaptureDocText tok_identifier CommaOrSemicolonOptional + { + pdebug("EnumDef -> tok_identifier"); + $$ = new t_enum_value($2); + if ($1 != NULL) { + $$->set_doc($1); + } + } + +Senum: + tok_senum tok_identifier '{' SenumDefList '}' + { + pdebug("Senum -> tok_senum tok_identifier { SenumDefList }"); + $$ = new t_typedef(g_program, $4, $2); + } + +SenumDefList: + SenumDefList SenumDef + { + pdebug("SenumDefList -> SenumDefList SenumDef"); + $$ = $1; + $$->add_string_enum_val($2); + } +| + { + pdebug("SenumDefList -> "); + $$ = new t_base_type("string", t_base_type::TYPE_STRING); + $$->set_string_enum(true); + } + +SenumDef: + tok_literal CommaOrSemicolonOptional + { + pdebug("SenumDef -> tok_literal"); + $$ = $1; + } + +Const: + tok_const FieldType tok_identifier '=' ConstValue CommaOrSemicolonOptional + { + pdebug("Const -> tok_const FieldType tok_identifier = ConstValue"); + if (g_parse_mode == PROGRAM) { + $$ = new t_const($2, $3, $5); + validate_const_type($$); + + g_scope->add_constant($3, $$); + if (g_parent_scope != NULL) { + g_parent_scope->add_constant(g_parent_prefix + $3, $$); + } + + } else { + $$ = NULL; + } + } + +ConstValue: + tok_int_constant + { + pdebug("ConstValue => tok_int_constant"); + $$ = new t_const_value(); + $$->set_integer($1); + if ($1 < INT32_MIN || $1 > INT32_MAX) { + pwarning(1, "64-bit constant \"%"PRIi64"\" may not work in all languages.\n", $1); + } + } +| tok_dub_constant + { + pdebug("ConstValue => tok_dub_constant"); + $$ = new t_const_value(); + $$->set_double($1); + } +| tok_literal + { + pdebug("ConstValue => tok_literal"); + $$ = new t_const_value($1); + } +| tok_identifier + { + pdebug("ConstValue => tok_identifier"); + t_const* constant = g_scope->get_constant($1); + if (constant != NULL) { + $$ = constant->get_value(); + } else { + if (g_parse_mode == PROGRAM) { + pwarning(1, "Constant strings should be quoted: %s\n", $1); + } + $$ = new t_const_value($1); + } + } +| ConstList + { + pdebug("ConstValue => ConstList"); + $$ = $1; + } +| ConstMap + { + pdebug("ConstValue => ConstMap"); + $$ = $1; + } + +ConstList: + '[' ConstListContents ']' + { + pdebug("ConstList => [ ConstListContents ]"); + $$ = $2; + } + +ConstListContents: + ConstListContents ConstValue CommaOrSemicolonOptional + { + pdebug("ConstListContents => ConstListContents ConstValue CommaOrSemicolonOptional"); + $$ = $1; + $$->add_list($2); + } +| + { + pdebug("ConstListContents =>"); + $$ = new t_const_value(); + $$->set_list(); + } + +ConstMap: + '{' ConstMapContents '}' + { + pdebug("ConstMap => { ConstMapContents }"); + $$ = $2; + } + +ConstMapContents: + ConstMapContents ConstValue ':' ConstValue CommaOrSemicolonOptional + { + pdebug("ConstMapContents => ConstMapContents ConstValue CommaOrSemicolonOptional"); + $$ = $1; + $$->add_map($2, $4); + } +| + { + pdebug("ConstMapContents =>"); + $$ = new t_const_value(); + $$->set_map(); + } + +Struct: + tok_struct tok_identifier XsdAll '{' FieldList '}' TypeAnnotations + { + pdebug("Struct -> tok_struct tok_identifier { FieldList }"); + $5->set_xsd_all($3); + $$ = $5; + $$->set_name($2); + if ($7 != NULL) { + $$->annotations_ = $7->annotations_; + delete $7; + } + } + +XsdAll: + tok_xsd_all + { + $$ = true; + } +| + { + $$ = false; + } + +XsdOptional: + tok_xsd_optional + { + $$ = true; + } +| + { + $$ = false; + } + +XsdNillable: + tok_xsd_nillable + { + $$ = true; + } +| + { + $$ = false; + } + +XsdAttributes: + tok_xsd_attrs '{' FieldList '}' + { + $$ = $3; + } +| + { + $$ = NULL; + } + +Xception: + tok_xception tok_identifier '{' FieldList '}' + { + pdebug("Xception -> tok_xception tok_identifier { FieldList }"); + $4->set_name($2); + $4->set_xception(true); + $$ = $4; + } + +Service: + tok_service tok_identifier Extends '{' FlagArgs FunctionList UnflagArgs '}' + { + pdebug("Service -> tok_service tok_identifier { FunctionList }"); + $$ = $6; + $$->set_name($2); + $$->set_extends($3); + } + +FlagArgs: + { + g_arglist = 1; + } + +UnflagArgs: + { + g_arglist = 0; + } + +Extends: + tok_extends tok_identifier + { + pdebug("Extends -> tok_extends tok_identifier"); + $$ = NULL; + if (g_parse_mode == PROGRAM) { + $$ = g_scope->get_service($2); + if ($$ == NULL) { + yyerror("Service \"%s\" has not been defined.", $2); + exit(1); + } + } + } +| + { + $$ = NULL; + } + +FunctionList: + FunctionList Function + { + pdebug("FunctionList -> FunctionList Function"); + $$ = $1; + $1->add_function($2); + } +| + { + pdebug("FunctionList -> "); + $$ = new t_service(g_program); + } + +Function: + CaptureDocText Oneway FunctionType tok_identifier '(' FieldList ')' Throws CommaOrSemicolonOptional + { + $6->set_name(std::string($4) + "_args"); + $$ = new t_function($3, $4, $6, $8, $2); + if ($1 != NULL) { + $$->set_doc($1); + } + } + +Oneway: + tok_oneway + { + $$ = true; + } +| + { + $$ = false; + } + +Throws: + tok_throws '(' FieldList ')' + { + pdebug("Throws -> tok_throws ( FieldList )"); + $$ = $3; + if (g_parse_mode == PROGRAM && !validate_throws($$)) { + yyerror("Throws clause may not contain non-exception types"); + exit(1); + } + } +| + { + $$ = new t_struct(g_program); + } + +FieldList: + FieldList Field + { + pdebug("FieldList -> FieldList , Field"); + $$ = $1; + if (!($$->append($2))) { + yyerror("Field identifier %d for \"%s\" has already been used", $2->get_key(), $2->get_name().c_str()); + exit(1); + } + } +| + { + pdebug("FieldList -> "); + y_field_val = -1; + $$ = new t_struct(g_program); + } + +Field: + CaptureDocText FieldIdentifier FieldRequiredness FieldType tok_identifier FieldValue XsdOptional XsdNillable XsdAttributes CommaOrSemicolonOptional + { + pdebug("tok_int_constant : Field -> FieldType tok_identifier"); + if ($2 < 0) { + pwarning(1, "No field key specified for %s, resulting protocol may have conflicts or not be backwards compatible!\n", $5); + if (g_strict >= 192) { + yyerror("Implicit field keys are deprecated and not allowed with -strict"); + exit(1); + } + } + $$ = new t_field($4, $5, $2); + $$->set_req($3); + if ($6 != NULL) { + validate_field_value($$, $6); + $$->set_value($6); + } + $$->set_xsd_optional($7); + $$->set_xsd_nillable($8); + if ($1 != NULL) { + $$->set_doc($1); + } + if ($9 != NULL) { + $$->set_xsd_attrs($9); + } + } + +FieldIdentifier: + tok_int_constant ':' + { + if ($1 <= 0) { + pwarning(1, "Nonpositive value (%d) not allowed as a field key.\n", $1); + $1 = y_field_val--; + } + $$ = $1; + } +| + { + $$ = y_field_val--; + } + +FieldRequiredness: + tok_required + { + if (g_arglist) { + if (g_parse_mode == PROGRAM) { + pwarning(1, "required keyword is ignored in argument lists.\n"); + } + $$ = t_field::T_OPT_IN_REQ_OUT; + } else { + $$ = t_field::T_REQUIRED; + } + } +| tok_optional + { + if (g_arglist) { + if (g_parse_mode == PROGRAM) { + pwarning(1, "optional keyword is ignored in argument lists.\n"); + } + $$ = t_field::T_OPT_IN_REQ_OUT; + } else { + $$ = t_field::T_OPTIONAL; + } + } +| + { + $$ = t_field::T_OPT_IN_REQ_OUT; + } + +FieldValue: + '=' ConstValue + { + if (g_parse_mode == PROGRAM) { + $$ = $2; + } else { + $$ = NULL; + } + } +| + { + $$ = NULL; + } + +DefinitionType: + BaseType + { + pdebug("DefinitionType -> BaseType"); + $$ = $1; + } +| ContainerType + { + pdebug("DefinitionType -> ContainerType"); + $$ = $1; + } + +FunctionType: + FieldType + { + pdebug("FunctionType -> FieldType"); + $$ = $1; + } +| tok_void + { + pdebug("FunctionType -> tok_void"); + $$ = g_type_void; + } + +FieldType: + tok_identifier + { + pdebug("FieldType -> tok_identifier"); + if (g_parse_mode == INCLUDES) { + // Ignore identifiers in include mode + $$ = NULL; + } else { + // Lookup the identifier in the current scope + $$ = g_scope->get_type($1); + if ($$ == NULL) { + yyerror("Type \"%s\" has not been defined.", $1); + exit(1); + } + } + } +| BaseType + { + pdebug("FieldType -> BaseType"); + $$ = $1; + } +| ContainerType + { + pdebug("FieldType -> ContainerType"); + $$ = $1; + } + +BaseType: + tok_string + { + pdebug("BaseType -> tok_string"); + $$ = g_type_string; + } +| tok_binary + { + pdebug("BaseType -> tok_binary"); + $$ = g_type_binary; + } +| tok_slist + { + pdebug("BaseType -> tok_slist"); + $$ = g_type_slist; + } +| tok_bool + { + pdebug("BaseType -> tok_bool"); + $$ = g_type_bool; + } +| tok_byte + { + pdebug("BaseType -> tok_byte"); + $$ = g_type_byte; + } +| tok_i16 + { + pdebug("BaseType -> tok_i16"); + $$ = g_type_i16; + } +| tok_i32 + { + pdebug("BaseType -> tok_i32"); + $$ = g_type_i32; + } +| tok_i64 + { + pdebug("BaseType -> tok_i64"); + $$ = g_type_i64; + } +| tok_double + { + pdebug("BaseType -> tok_double"); + $$ = g_type_double; + } + +ContainerType: SimpleContainerType TypeAnnotations + { + pdebug("ContainerType -> SimpleContainerType TypeAnnotations"); + $$ = $1; + if ($2 != NULL) { + $$->annotations_ = $2->annotations_; + delete $2; + } + } + +SimpleContainerType: + MapType + { + pdebug("SimpleContainerType -> MapType"); + $$ = $1; + } +| SetType + { + pdebug("SimpleContainerType -> SetType"); + $$ = $1; + } +| ListType + { + pdebug("SimpleContainerType -> ListType"); + $$ = $1; + } + +MapType: + tok_map CppType '<' FieldType ',' FieldType '>' + { + pdebug("MapType -> tok_map "); + $$ = new t_map($4, $6); + if ($2 != NULL) { + ((t_container*)$$)->set_cpp_name(std::string($2)); + } + } + +SetType: + tok_set CppType '<' FieldType '>' + { + pdebug("SetType -> tok_set"); + $$ = new t_set($4); + if ($2 != NULL) { + ((t_container*)$$)->set_cpp_name(std::string($2)); + } + } + +ListType: + tok_list '<' FieldType '>' CppType + { + pdebug("ListType -> tok_list"); + $$ = new t_list($3); + if ($5 != NULL) { + ((t_container*)$$)->set_cpp_name(std::string($5)); + } + } + +CppType: + tok_cpp_type tok_literal + { + $$ = $2; + } +| + { + $$ = NULL; + } + +TypeAnnotations: + '(' TypeAnnotationList ')' + { + pdebug("TypeAnnotations -> ( TypeAnnotationList )"); + $$ = $2; + } +| + { + $$ = NULL; + } + +TypeAnnotationList: + TypeAnnotationList TypeAnnotation + { + pdebug("TypeAnnotationList -> TypeAnnotationList , TypeAnnotation"); + $$ = $1; + $$->annotations_[$2->key] = $2->val; + delete $2; + } +| + { + /* Just use a dummy structure to hold the annotations. */ + $$ = new t_struct(g_program); + } + +TypeAnnotation: + tok_identifier '=' tok_literal CommaOrSemicolonOptional + { + pdebug("TypeAnnotation -> tok_identifier = tok_literal"); + $$ = new t_annotation; + $$->key = $1; + $$->val = $3; + } + +%% diff --git a/configure.ac b/configure.ac new file mode 100644 index 00000000..6bba72d1 --- /dev/null +++ b/configure.ac @@ -0,0 +1,226 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +AC_PREREQ(2.59) + +AC_INIT([thrift], [0.1-dev]) + +AC_CONFIG_AUX_DIR([.]) + +AM_INIT_AUTOMAKE + +AC_ARG_VAR([PY_PREFIX], [Prefix for installing Python modules. + (Normal --prefix is ignored for Python because + Python has different conventions.) + Default = "/usr"]) +AS_IF([test "x$PY_PREFIX" = x], [PY_PREFIX="/usr"]) + +AC_ARG_VAR([JAVA_PREFIX], [Prefix for installing the Java lib jar. + (Normal --prefix is ignored for Java because + Java has different conevntions.) + Default = "/usr/local/lib"]) +AS_IF([test "x$JAVA_PREFIX" = x], [JAVA_PREFIX="/usr/local/lib"]) + +AC_PROG_CC +AC_PROG_CPP +AC_PROG_CXX +AC_PROG_INSTALL +AC_PROG_LIBTOOL +AC_PROG_MAKE_SET +AC_PROG_YACC +AC_PROG_LEX +AM_PROG_LEX +AC_PROG_LN_S +AC_PROG_MKDIR_P + +AC_LANG([C++]) +AX_BOOST_BASE([1.33.1]) + +AX_LIB_EVENT([1.0]) +AM_CONDITIONAL([AMX_HAVE_LIBEVENT], [test "$success" = "yes"]) + +AX_LIB_ZLIB([1.2.3]) +AM_CONDITIONAL([AMX_HAVE_ZLIB], [test "$success" = "yes"]) + +AX_THRIFT_LIB(csharp, [C#], yes) +if test "$with_csharp" = "yes"; then + PKG_CHECK_MODULES(MONO, mono >= 2.0.0, net_3_5=yes, net_3_5=no) + PKG_CHECK_MODULES(MONO, mono >= 1.2.4, have_mono=yes, have_mono=no) +fi +AM_CONDITIONAL(WITH_MONO, [test "$have_mono" = "yes"]) +AM_CONDITIONAL(NET_2_0, [test "$net_3_5" = "no"]) + +AX_THRIFT_LIB(java, [Java], yes) +if test "$with_java" = "yes"; then + AX_JAVAC_AND_JAVA + AC_PATH_PROG([ANT], [ant]) +fi +AM_CONDITIONAL([WITH_JAVA], + [test -n "$ANT" -a -n "$JAVA" -a -n "$JAVAC"]) + +AX_THRIFT_LIB(erlang, [Erlang], yes) +if test "$with_erlang" = "yes"; then + AC_PATH_PROG([ERLC], [erlc]) +fi +AM_CONDITIONAL(WITH_ERLANG, [test -n "$ERLC"]) + +AX_THRIFT_LIB(py, [Python], yes) +if test "$with_py" = "yes"; then + AM_PATH_PYTHON(2.4,, :) +fi +AM_CONDITIONAL(WITH_PYTHON, [test -n "$PYTHON" -a "$PYTHON" != ":"]) + +AX_THRIFT_LIB(perl, [Perl], yes) +if test "$with_perl" = "yes"; then + AC_PATH_PROG([PERL], [perl]) +fi +AM_CONDITIONAL(WITH_PERL, [test -n "$PERL"]) + +AX_THRIFT_LIB(ruby, [Ruby], yes) +if test "$with_ruby" = "yes"; then + AC_PATH_PROG([RUBY], [ruby]) + AC_PATH_PROG([RSPEC], [spec]) +fi +AM_CONDITIONAL(WITH_RUBY, [test -n "$RUBY"]) +AM_CONDITIONAL(HAVE_RSPEC, [test -n "$RSPEC"]) + + +AC_C_CONST +AC_C_INLINE +AC_C_VOLATILE + +AC_HEADER_STDBOOL +AC_HEADER_STDC +AC_HEADER_TIME +AC_CHECK_HEADERS([arpa/inet.h]) +AC_CHECK_HEADERS([endian.h]) +AC_CHECK_HEADERS([fcntl.h]) +AC_CHECK_HEADERS([inttypes.h]) +AC_CHECK_HEADERS([limits.h]) +AC_CHECK_HEADERS([netdb.h]) +AC_CHECK_HEADERS([netinet/in.h]) +AC_CHECK_HEADERS([pthread.h]) +AC_CHECK_HEADERS([stddef.h]) +AC_CHECK_HEADERS([stdlib.h]) +AC_CHECK_HEADERS([sys/socket.h]) +AC_CHECK_HEADERS([sys/time.h]) +AC_CHECK_HEADERS([unistd.h]) +AC_CHECK_HEADERS([libintl.h]) +AC_CHECK_HEADERS([malloc.h]) + +AC_CHECK_LIB(pthread, pthread_create) +AC_CHECK_LIB(rt, sched_get_priority_min) + +AC_TYPE_INT16_T +AC_TYPE_INT32_T +AC_TYPE_INT64_T +AC_TYPE_INT8_T +AC_TYPE_MODE_T +AC_TYPE_OFF_T +AC_TYPE_SIZE_T +AC_TYPE_SSIZE_T +AC_TYPE_UINT16_T +AC_TYPE_UINT32_T +AC_TYPE_UINT64_T +AC_TYPE_UINT8_T +AC_CHECK_TYPES([ptrdiff_t], [], [exit 1]) + +AC_STRUCT_TM + +AC_FUNC_ALLOCA +AC_FUNC_MALLOC +AC_FUNC_MEMCMP +AC_FUNC_REALLOC +AC_FUNC_SELECT_ARGTYPES +AC_FUNC_STAT +AC_FUNC_STRERROR_R +AC_FUNC_STRFTIME +AC_FUNC_VPRINTF +AC_CHECK_FUNCS([strtoul]) +AC_CHECK_FUNCS([bzero]) +AC_CHECK_FUNCS([clock_gettime]) +AC_CHECK_FUNCS([ftruncate]) +AC_CHECK_FUNCS([gethostbyname]) +AC_CHECK_FUNCS([gettimeofday]) +AC_CHECK_FUNCS([memmove]) +AC_CHECK_FUNCS([memset]) +AC_CHECK_FUNCS([mkdir]) +AC_CHECK_FUNCS([realpath]) +AC_CHECK_FUNCS([select]) +AC_CHECK_FUNCS([socket]) +AC_CHECK_FUNCS([strchr]) +AC_CHECK_FUNCS([strdup]) +AC_CHECK_FUNCS([strerror]) +AC_CHECK_FUNCS([strstr]) +AC_CHECK_FUNCS([strtol]) +AC_CHECK_FUNCS([sqrt]) + +AX_SIGNED_RIGHT_SHIFT + +AX_THRIFT_GEN(cpp, [C++], yes) +AM_CONDITIONAL([THRIFT_GEN_cpp], [test "$ax_thrift_gen_cpp" = "yes"]) +AX_THRIFT_GEN(java, [Java], yes) +AM_CONDITIONAL([THRIFT_GEN_java], [test "$ax_thrift_gen_java" = "yes"]) +AX_THRIFT_GEN(csharp, [C#], yes) +AM_CONDITIONAL([THRIFT_GEN_csharp], [test "$ax_thrift_gen_csharp" = "yes"]) +AX_THRIFT_GEN(py, [Python], yes) +AM_CONDITIONAL([THRIFT_GEN_py], [test "$ax_thrift_gen_py" = "yes"]) +AX_THRIFT_GEN(rb, [Ruby], yes) +AM_CONDITIONAL([THRIFT_GEN_rb], [test "$ax_thrift_gen_rb" = "yes"]) +AX_THRIFT_GEN(perl, [Perl], yes) +AM_CONDITIONAL([THRIFT_GEN_perl], [test "$ax_thrift_gen_perl" = "yes"]) +AX_THRIFT_GEN(php, [PHP], yes) +AM_CONDITIONAL([THRIFT_GEN_php], [test "$ax_thrift_gen_php" = "yes"]) +AX_THRIFT_GEN(erl, [Erlang], yes) +AM_CONDITIONAL([THRIFT_GEN_erl], [test "$ax_thrift_gen_erl" = "yes"]) +AX_THRIFT_GEN(cocoa, [Cocoa], yes) +AM_CONDITIONAL([THRIFT_GEN_cocoa], [test "$ax_thrift_gen_cocoa" = "yes"]) +AX_THRIFT_GEN(st, [Smalltalk], yes) +AM_CONDITIONAL([THRIFT_GEN_st], [test "$ax_thrift_gen_st" = "yes"]) +AX_THRIFT_GEN(ocaml, [OCaml], yes) +AM_CONDITIONAL([THRIFT_GEN_ocaml], [test "$ax_thrift_gen_ocaml" = "yes"]) +AX_THRIFT_GEN(hs, [Haskell], yes) +AM_CONDITIONAL([THRIFT_GEN_hs], [test "$ax_thrift_gen_hs" = "yes"]) +AX_THRIFT_GEN(xsd, [XSD], yes) +AM_CONDITIONAL([THRIFT_GEN_xsd], [test "$ax_thrift_gen_xsd" = "yes"]) +AX_THRIFT_GEN(html, [HTML], yes) +AM_CONDITIONAL([THRIFT_GEN_html], [test "$ax_thrift_gen_html" = "yes"]) + +AC_CONFIG_HEADERS(config.h:config.hin) + +AC_CONFIG_FILES([ + Makefile + compiler/cpp/Makefile + lib/Makefile + lib/cpp/Makefile + lib/cpp/thrift.pc + lib/cpp/thrift-nb.pc + lib/cpp/thrift-z.pc + lib/csharp/Makefile + lib/java/Makefile + lib/perl/Makefile + lib/perl/test/Makefile + lib/py/Makefile + lib/rb/Makefile + test/Makefile + test/py/Makefile + test/rb/Makefile +]) + +AC_OUTPUT diff --git a/contrib/fb303/LICENSE b/contrib/fb303/LICENSE new file mode 100644 index 00000000..4eacb643 --- /dev/null +++ b/contrib/fb303/LICENSE @@ -0,0 +1,16 @@ +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. diff --git a/contrib/fb303/Makefile.am b/contrib/fb303/Makefile.am new file mode 100644 index 00000000..de7fbb60 --- /dev/null +++ b/contrib/fb303/Makefile.am @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +@GLOBAL_HEADER_MK@ + +@PRODUCT_MK@ + +SUBDIRS = . cpp py + +BUILT_SOURCES = + + +clean-local: clean-common + +@GLOBAL_FOOTER_MK@ diff --git a/contrib/fb303/README b/contrib/fb303/README new file mode 100644 index 00000000..8ade560c --- /dev/null +++ b/contrib/fb303/README @@ -0,0 +1,37 @@ +Project FB303: The Facebook Bassline +------------------------------------ + +* Curious about the 303? * +http://en.wikipedia.org/wiki/Roland_TB-303 + +* Why the name? * +The TB303 makes bass lines. +.Bass is what lies underneath any strong tune. +..fb303 is the shared root of all thrift services. +...fb303 => FacebookBase303. + +* How do I use this? * +Take a look at the examples to see how your backend project can +and should inherit from this service. + +* What does it provide? * +A standard interface to monitoring, dynamic options and configuration, +uptime reports, activity, etc. + +* I want more. * +Think carefully first about whether the functionality you are going to add +belongs here or in your application. If it can be abstracted and is generally +useful, then it probably belongs somewhere in the fb303 tree. Keep in mind, +not every product has to use ALL the functionality of fb303, but every product +CANNOT use functionality that is NOT in fb303. + +* Is this open source? * +Yes. fb303 is distributed under the Thrift Software License. See the +LICENSE file for more details. + +* Installation * +fb303 is configured/built/installed similar to Thrift. See the README +in the Thrift root directory for more information. + +* Who wrote this README? * +mcslee@facebook.com diff --git a/contrib/fb303/acinclude.m4 b/contrib/fb303/acinclude.m4 new file mode 100644 index 00000000..faafba6b --- /dev/null +++ b/contrib/fb303/acinclude.m4 @@ -0,0 +1,258 @@ +dnl Copyright (C) 2009 Facebook +dnl Copying and distribution of this file, with or without modification, +dnl are permitted in any medium without royalty provided the copyright +dnl notice and this notice are preserved. + +AC_DEFUN([FB_INITIALIZE], +[ +AM_INIT_AUTOMAKE([ foreign 1.9.5 no-define ]) +if test "x$1" = "xlocalinstall"; then +wdir=`pwd` +# To use $wdir undef quote. +# +########## +AC_PREFIX_DEFAULT([`pwd`/install]) +echo +fi +AC_PROG_CC +AC_PROG_CXX +AC_PROG_RANLIB(RANLIB, ranlib) +AC_PATH_PROGS(BASH, bash) +AC_PATH_PROGS(PERL, perl) +AC_PATH_PROGS(PYTHON, python) +AC_PATH_PROGS(AR, ar) +AC_PATH_PROGS(ANT, ant) +PRODUCT_MK="" +]) + +AC_DEFUN([FB_WITH_EXTERNAL_PATH], +[ +cdir=`pwd` +AC_MSG_CHECKING([Checking EXTERNAL_PATH set to]) +AC_ARG_WITH([externalpath], + [ --with-externalpath=DIR User specified path to external facebook components.], + [ + if test "x${EXTERNAL_PATH}" != "x"; then + echo "" + echo "ERROR: You have already set EXTERNAL_PATH in your environment" + echo "Cannot override it using --with-externalpath. Unset EXTERNAL_PATH to use this option" + exit 1 + fi + EXTERNAL_PATH=$withval + ], + [ + if test "x${EXTERNAL_PATH}" = "x"; then + EXTERNAL_PATH=$1 + fi + ] +) +if test "x${EXTERNAL_PATH}" = "x"; then + export EXTERNAL_PATH="$cdir/external" + GLOBAL_HEADER_MK="include ${EXTERNAL_PATH}/global_header.mk" + GLOBAL_FOOTER_MK="include ${EXTERNAL_PATH}/global_footer.mk" +else + export EXTERNAL_PATH + GLOBAL_HEADER_MK="include ${EXTERNAL_PATH}/global_header.mk" + GLOBAL_FOOTER_MK="include ${EXTERNAL_PATH}/global_footer.mk" +fi +AC_MSG_RESULT($EXTERNAL_PATH) +if test ! -d ${EXTERNAL_PATH}; then + echo "" + echo "ERROR: EXTERNAL_PATH set to an nonexistent directory ${EXTERNAL_PATH}" + exit 1 +fi +AC_SUBST(EXTERNAL_PATH) +AC_SUBST(GLOBAL_HEADER_MK) +AC_SUBST(GLOBAL_FOOTER_MK) +]) + +# Set option to enable shared mode. Set DEBUG and OPT for use in Makefile.am. +AC_DEFUN([FB_ENABLE_DEFAULT_OPT_BUILD], +[ +AC_MSG_CHECKING([whether to enable optimized build]) +AC_ARG_ENABLE([opt], + [ --disable-opt Set up debug mode.], + [ + ENABLED_OPT=$enableval + ], + [ + ENABLED_OPT="yes" + ] +) +if test "$ENABLED_OPT" = "yes" +then + CFLAGS="-Wall -O3" + CXXFLAGS="-Wall -O3" +else + CFLAGS="-Wall -g" + CXXFLAGS="-Wall -g" +fi +AC_MSG_RESULT($ENABLED_OPT) +AM_CONDITIONAL([OPT], [test "$ENABLED_OPT" = yes]) +AM_CONDITIONAL([DEBUG], [test "$ENABLED_OPT" = no]) +]) + +# Set option to enable debug mode. Set DEBUG and OPT for use in Makefile.am. +AC_DEFUN([FB_ENABLE_DEFAULT_DEBUG_BUILD], +[ +AC_MSG_CHECKING([whether to enable debug build]) +AC_ARG_ENABLE([debug], + [ --disable-debug Set up opt mode.], + [ + ENABLED_DEBUG=$enableval + ], + [ + ENABLED_DEBUG="yes" + ] +) +if test "$ENABLED_DEBUG" = "yes" +then + CFLAGS="-Wall -g" + CXXFLAGS="-Wall -g" +else + CFLAGS="-Wall -O3" + CXXFLAGS="-Wall -O3" +fi +AC_MSG_RESULT($ENABLED_DEBUG) +AM_CONDITIONAL([DEBUG], [test "$ENABLED_DEBUG" = yes]) +AM_CONDITIONAL([OPT], [test "$ENABLED_DEBUG" = no]) +]) + +# Set option to enable static libs. +AC_DEFUN([FB_ENABLE_DEFAULT_STATIC], +[ +SHARED="" +STATIC="" +AC_MSG_CHECKING([whether to enable static mode]) +AC_ARG_ENABLE([static], + [ --disable-static Set up shared mode.], + [ + ENABLED_STATIC=$enableval + ], + [ + ENABLED_STATIC="yes" + ] +) +if test "$ENABLED_STATIC" = "yes" +then + LTYPE=".a" +else + LTYPE=".so" + SHARED_CXXFLAGS="-fPIC" + SHARED_CFLAGS="-fPIC" + SHARED_LDFLAGS="-shared -fPIC" + AC_SUBST(SHARED_CXXFLAGS) + AC_SUBST(SHARED_CFLAGS) + AC_SUBST(SHARED_LDFLAGS) +fi +AC_MSG_RESULT($ENABLED_STATIC) +AC_SUBST(LTYPE) +AM_CONDITIONAL([STATIC], [test "$ENABLED_STATIC" = yes]) +AM_CONDITIONAL([SHARED], [test "$ENABLED_STATIC" = no]) +]) + +# Set option to enable shared libs. +AC_DEFUN([FB_ENABLE_DEFAULT_SHARED], +[ +SHARED="" +STATIC="" +AC_MSG_CHECKING([whether to enable shared mode]) +AC_ARG_ENABLE([shared], + [ --disable-shared Set up static mode.], + [ + ENABLED_SHARED=$enableval + ], + [ + ENABLED_SHARED="yes" + ] +) +if test "$ENABLED_SHARED" = "yes" +then + LTYPE=".so" + SHARED_CXXFLAGS="-fPIC" + SHARED_CFLAGS="-fPIC" + SHARED_LDFLAGS="-shared -fPIC" + AC_SUBST(SHARED_CXXFLAGS) + AC_SUBST(SHARED_CFLAGS) + AC_SUBST(SHARED_LDFLAGS) +else + LTYPE=".a" +fi +AC_MSG_RESULT($ENABLED_SHARED) +AC_SUBST(LTYPE) +AM_CONDITIONAL([SHARED], [test "$ENABLED_SHARED" = yes]) +AM_CONDITIONAL([STATIC], [test "$ENABLED_SHARED" = no]) +]) + +# Generates define flags and conditionals as specified by user. +# This gets enabled *only* if user selects --enable- otion. +AC_DEFUN([FB_ENABLE_FEATURE], +[ +ENABLE="" +flag="$1" +value="$3" +AC_MSG_CHECKING([whether to enable $1]) +AC_ARG_ENABLE([$2], + [ --enable-$2 Enable $2.], + [ + ENABLE=$enableval + ], + [ + ENABLE="no" + ] +) +AM_CONDITIONAL([$1], [test "$ENABLE" = yes]) +if test "$ENABLE" = "yes" +then + if test "x${value}" = "x" + then + AC_DEFINE([$1]) + else + AC_DEFINE_UNQUOTED([$1], [$value]) + fi +fi +AC_MSG_RESULT($ENABLE) +]) + + +# can also use eval $2=$withval;AC_SUBST($2) +AC_DEFUN([FB_WITH_PATH], +[ +USRFLAG="" +USRFLAG=$1 +AC_MSG_CHECKING([Checking $1 set to]) +AC_ARG_WITH([$2], + [ --with-$2=DIR User specified path.], + [ + LOC=$withval + eval $USRFLAG=$withval + ], + [ + LOC=$3 + eval $USRFLAG=$3 + ] +) +AC_SUBST([$1]) +AC_MSG_RESULT($LOC) +]) + +AC_DEFUN([FB_SET_FLAG_VALUE], +[ +SETFLAG="" +AC_MSG_CHECKING([Checking $1 set to]) +SETFLAG=$1 +eval $SETFLAG=\"$2\" +AC_SUBST([$SETFLAG]) +AC_MSG_RESULT($2) +]) + +# NOTES +# if using if else bourne stmt you must have more than a macro in it. +# EX1 is not correct. EX2 is correct +# EX1: if test "$XX" = "yes"; then +# AC_SUBST(xx) +# fi +# EX2: if test "$XX" = "yes"; then +# xx="foo" +# AC_SUBST(xx) +# fi diff --git a/contrib/fb303/aclocal/ax_boost_base.m4 b/contrib/fb303/aclocal/ax_boost_base.m4 new file mode 100644 index 00000000..e56bb738 --- /dev/null +++ b/contrib/fb303/aclocal/ax_boost_base.m4 @@ -0,0 +1,198 @@ +##### http://autoconf-archive.cryp.to/ax_boost_base.html +# +# SYNOPSIS +# +# AX_BOOST_BASE([MINIMUM-VERSION]) +# +# DESCRIPTION +# +# Test for the Boost C++ libraries of a particular version (or newer) +# +# If no path to the installed boost library is given the macro +# searchs under /usr, /usr/local, /opt and /opt/local and evaluates +# the $BOOST_ROOT environment variable. Further documentation is +# available at . +# +# This macro calls: +# +# AC_SUBST(BOOST_CPPFLAGS) / AC_SUBST(BOOST_LDFLAGS) +# +# And sets: +# +# HAVE_BOOST +# +# LAST MODIFICATION +# +# 2007-07-28 +# +# COPYLEFT +# +# Copyright (c) 2007 Thomas Porschberg +# +# Copying and distribution of this file, with or without +# modification, are permitted in any medium without royalty provided +# the copyright notice and this notice are preserved. + +AC_DEFUN([AX_BOOST_BASE], +[ +AC_ARG_WITH([boost], + AS_HELP_STRING([--with-boost@<:@=DIR@:>@], [use boost (default is yes) - it is possible to specify the root directory for boost (optional)]), + [ + if test "$withval" = "no"; then + want_boost="no" + elif test "$withval" = "yes"; then + want_boost="yes" + ac_boost_path="" + else + want_boost="yes" + ac_boost_path="$withval" + fi + ], + [want_boost="yes"]) + +if test "x$want_boost" = "xyes"; then + boost_lib_version_req=ifelse([$1], ,1.20.0,$1) + boost_lib_version_req_shorten=`expr $boost_lib_version_req : '\([[0-9]]*\.[[0-9]]*\)'` + boost_lib_version_req_major=`expr $boost_lib_version_req : '\([[0-9]]*\)'` + boost_lib_version_req_minor=`expr $boost_lib_version_req : '[[0-9]]*\.\([[0-9]]*\)'` + boost_lib_version_req_sub_minor=`expr $boost_lib_version_req : '[[0-9]]*\.[[0-9]]*\.\([[0-9]]*\)'` + if test "x$boost_lib_version_req_sub_minor" = "x" ; then + boost_lib_version_req_sub_minor="0" + fi + WANT_BOOST_VERSION=`expr $boost_lib_version_req_major \* 100000 \+ $boost_lib_version_req_minor \* 100 \+ $boost_lib_version_req_sub_minor` + AC_MSG_CHECKING(for boostlib >= $boost_lib_version_req) + succeeded=no + + dnl first we check the system location for boost libraries + dnl this location ist chosen if boost libraries are installed with the --layout=system option + dnl or if you install boost with RPM + if test "$ac_boost_path" != ""; then + BOOST_LDFLAGS="-L$ac_boost_path/lib" + BOOST_CPPFLAGS="-I$ac_boost_path/include" + else + for ac_boost_path_tmp in /usr /usr/local /opt /opt/local ; do + if test -d "$ac_boost_path_tmp/include/boost" && test -r "$ac_boost_path_tmp/include/boost"; then + BOOST_LDFLAGS="-L$ac_boost_path_tmp/lib" + BOOST_CPPFLAGS="-I$ac_boost_path_tmp/include" + break; + fi + done + fi + + CPPFLAGS_SAVED="$CPPFLAGS" + CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" + export CPPFLAGS + + LDFLAGS_SAVED="$LDFLAGS" + LDFLAGS="$LDFLAGS $BOOST_LDFLAGS" + export LDFLAGS + + AC_LANG_PUSH(C++) + AC_COMPILE_IFELSE([AC_LANG_PROGRAM([[ + @%:@include + ]], [[ + #if BOOST_VERSION >= $WANT_BOOST_VERSION + // Everything is okay + #else + # error Boost version is too old + #endif + ]])],[ + AC_MSG_RESULT(yes) + succeeded=yes + found_system=yes + ],[ + ]) + AC_LANG_POP([C++]) + + + + dnl if we found no boost with system layout we search for boost libraries + dnl built and installed without the --layout=system option or for a staged(not installed) version + if test "x$succeeded" != "xyes"; then + _version=0 + if test "$ac_boost_path" != ""; then + BOOST_LDFLAGS="-L$ac_boost_path/lib" + if test -d "$ac_boost_path" && test -r "$ac_boost_path"; then + for i in `ls -d $ac_boost_path/include/boost-* 2>/dev/null`; do + _version_tmp=`echo $i | sed "s#$ac_boost_path##" | sed 's/\/include\/boost-//' | sed 's/_/./'` + V_CHECK=`expr $_version_tmp \> $_version` + if test "$V_CHECK" = "1" ; then + _version=$_version_tmp + fi + VERSION_UNDERSCORE=`echo $_version | sed 's/\./_/'` + BOOST_CPPFLAGS="-I$ac_boost_path/include/boost-$VERSION_UNDERSCORE" + done + fi + else + for ac_boost_path in /usr /usr/local /opt /opt/local ; do + if test -d "$ac_boost_path" && test -r "$ac_boost_path"; then + for i in `ls -d $ac_boost_path/include/boost-* 2>/dev/null`; do + _version_tmp=`echo $i | sed "s#$ac_boost_path##" | sed 's/\/include\/boost-//' | sed 's/_/./'` + V_CHECK=`expr $_version_tmp \> $_version` + if test "$V_CHECK" = "1" ; then + _version=$_version_tmp + best_path=$ac_boost_path + fi + done + fi + done + + VERSION_UNDERSCORE=`echo $_version | sed 's/\./_/'` + BOOST_CPPFLAGS="-I$best_path/include/boost-$VERSION_UNDERSCORE" + BOOST_LDFLAGS="-L$best_path/lib" + + if test "x$BOOST_ROOT" != "x"; then + if test -d "$BOOST_ROOT" && test -r "$BOOST_ROOT" && test -d "$BOOST_ROOT/stage/lib" && test -r "$BOOST_ROOT/stage/lib"; then + version_dir=`expr //$BOOST_ROOT : '.*/\(.*\)'` + stage_version=`echo $version_dir | sed 's/boost_//' | sed 's/_/./g'` + stage_version_shorten=`expr $stage_version : '\([[0-9]]*\.[[0-9]]*\)'` + V_CHECK=`expr $stage_version_shorten \>\= $_version` + if test "$V_CHECK" = "1" ; then + AC_MSG_NOTICE(We will use a staged boost library from $BOOST_ROOT) + BOOST_CPPFLAGS="-I$BOOST_ROOT" + BOOST_LDFLAGS="-L$BOOST_ROOT/stage/lib" + fi + fi + fi + fi + + CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" + export CPPFLAGS + LDFLAGS="$LDFLAGS $BOOST_LDFLAGS" + export LDFLAGS + + AC_LANG_PUSH(C++) + AC_COMPILE_IFELSE([AC_LANG_PROGRAM([[ + @%:@include + ]], [[ + #if BOOST_VERSION >= $WANT_BOOST_VERSION + // Everything is okay + #else + # error Boost version is too old + #endif + ]])],[ + AC_MSG_RESULT(yes) + succeeded=yes + found_system=yes + ],[ + ]) + AC_LANG_POP([C++]) + fi + + if test "$succeeded" != "yes" ; then + if test "$_version" = "0" ; then + AC_MSG_ERROR([[We could not detect the boost libraries (version $boost_lib_version_req_shorten or higher). If you have a staged boost library (still not installed) please specify \$BOOST_ROOT in your environment and do not give a PATH to --with-boost option. If you are sure you have boost installed, then check your version number looking in . See http://randspringer.de/boost for more documentation.]]) + else + AC_MSG_NOTICE([Your boost libraries seems to old (version $_version).]) + fi + else + AC_SUBST(BOOST_CPPFLAGS) + AC_SUBST(BOOST_LDFLAGS) + AC_DEFINE(HAVE_BOOST,,[define if the Boost library is available]) + fi + + CPPFLAGS="$CPPFLAGS_SAVED" + LDFLAGS="$LDFLAGS_SAVED" +fi + +]) diff --git a/contrib/fb303/bootstrap.sh b/contrib/fb303/bootstrap.sh new file mode 100755 index 00000000..3cbeddb3 --- /dev/null +++ b/contrib/fb303/bootstrap.sh @@ -0,0 +1,26 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# To be safe include -I flag +aclocal -I ./aclocal +automake -a +autoconf +./configure --config-cache $* diff --git a/contrib/fb303/configure.ac b/contrib/fb303/configure.ac new file mode 100644 index 00000000..67cc1108 --- /dev/null +++ b/contrib/fb303/configure.ac @@ -0,0 +1,115 @@ +# Autoconf input file +# $Id$ + +# AC - autoconf +# FB - facebook + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +######################################################################### +# DO NOT TOUCH EXCEPT TO CHANGE REV# IN AC_INIT + +AC_PREREQ(2.52) +AC_INIT([libfb303],[20080209]) +#AC_CONFIG_AUX_DIR([/usr/share/automake-1.9]) +# To install locally +FB_INITIALIZE([localinstall]) +AC_PREFIX_DEFAULT([/usr/local]) + +############################################################################ +# User Configurable. Change With CAUTION! +# User can include custom makefile rules. Uncomment and update only in PRODUCT_MK. +# Include where appropriate in any Makefile.am as @PRODUCT_MK@ + +#PRODUCT_MK="include ${EXTERNAL_PATH}/shared/build/.mk" + +# Default path to external Facebook components and shared build toools I.e fb303 etc. +# To point to other locations set environment variable EXTERNAL_PATH. +# To change the current default you must change bootstrap.sh. +FB_WITH_EXTERNAL_PATH([`pwd`]) + +AC_ARG_VAR([PY_PREFIX], [Prefix for installing Python modules. + (Normal --prefix is ignored for Python because + Python has different conventions.) + Default = "/usr"]) +AS_IF([test "x$PY_PREFIX" = x], [PY_PREFIX="/usr"]) + +########################################################################## +# User Configurable + +# Pre-defined macro to set opt build mode. Run with --disable-shared option to turn off optimization. +FB_ENABLE_DEFAULT_OPT_BUILD + +# Predefined macro to set static library mode. Run with --disable-static option to turn off static lib mode. +FB_ENABLE_DEFAULT_STATIC + +# Personalized feature generator. Creates defines/conditionals and --enable --disable command line options. +# FB_ENABLE_FEATURE([FEATURE], [feature]) OR FB_ENABLE_FEATURE([FEATURE], [feature], [\"\"]) + +# Example: Macro supplies -DFACEBOOK at compile time and "if FACEBOOK endif" capabilities. + +# Personalized path generator Sets default paths. Provides --with-xx=DIR options. +# FB_WITH_PATH([_home], [path], [] + +# Example: sets $(thrift_home) variable with default path set to /usr/local. +FB_WITH_PATH([thrift_home], [thriftpath], [/usr/local]) + +# Require boost 1.33.1 or later +AX_BOOST_BASE([1.33.1]) + +# Generates Makefile from Makefile.am. Modify when new subdirs are added. +# Change Makefile.am also to add subdirectly. +AC_CONFIG_FILES(Makefile cpp/Makefile py/Makefile) + +############################################################################ +# DO NOT TOUCH. + +AC_SUBST(PRODUCT_MK) +AC_OUTPUT + +############################################################################# +######### FINISH ############################################################ + +echo "EXTERNAL_PATH $EXTERNAL_PATH" + + +# +# NOTES FOR USER +# Short cut to create conditional flags. +#enable_facebook="yes" +#AM_CONDITIONAL([FACEBOOK], [test "$enable_facebook" = yes]) +#enable_hdfs="yes" +#AM_CONDITIONAL([HDFS], [test "$enable_hdfs" = yes]) + +# Enable options with --enable and --disable configurable. +#AC_MSG_CHECKING([whether to enable FACEBOOK]) +#FACEBOOK="" +#AC_ARG_ENABLE([facebook], +# [ --enable-facebook Enable facebook.], +# [ +# ENABLE_FACEBOOK=$enableval +# ], +# [ +# ENABLE_FACEBOOK="no" +# ] +#) +#AM_CONDITIONAL([FACEBOOK], [test "$ENABLE_FACEBOOK" = yes]) +#AC_MSG_RESULT($ENABLE_FACEBOOK) + diff --git a/contrib/fb303/cpp/FacebookBase.cpp b/contrib/fb303/cpp/FacebookBase.cpp new file mode 100644 index 00000000..80033406 --- /dev/null +++ b/contrib/fb303/cpp/FacebookBase.cpp @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "FacebookBase.h" + +using namespace facebook::fb303; +using apache::thrift::concurrency::Guard; + +FacebookBase::FacebookBase(std::string name) : + name_(name) { + aliveSince_ = (int64_t) time(NULL); +} + +inline void FacebookBase::getName(std::string& _return) { + _return = name_; +} + +void FacebookBase::setOption(const std::string& key, const std::string& value) { + Guard g(optionsLock_); + options_[key] = value; +} + +void FacebookBase::getOption(std::string& _return, const std::string& key) { + Guard g(optionsLock_); + _return = options_[key]; +} + +void FacebookBase::getOptions(std::map & _return) { + Guard g(optionsLock_); + _return = options_; +} + +int64_t FacebookBase::incrementCounter(const std::string& key, int64_t amount) { + counters_.acquireRead(); + + // if we didn't find the key, we need to write lock the whole map to create it + ReadWriteCounterMap::iterator it = counters_.find(key); + if (it == counters_.end()) { + counters_.release(); + counters_.acquireWrite(); + + // we need to check again to make sure someone didn't create this key + // already while we released the lock + it = counters_.find(key); + if(it == counters_.end()){ + counters_[key].value = amount; + counters_.release(); + return amount; + } + } + + it->second.acquireWrite(); + int64_t count = it->second.value + amount; + it->second.value = count; + it->second.release(); + counters_.release(); + return count; +} + +int64_t FacebookBase::setCounter(const std::string& key, int64_t value) { + counters_.acquireRead(); + + // if we didn't find the key, we need to write lock the whole map to create it + ReadWriteCounterMap::iterator it = counters_.find(key); + if (it == counters_.end()) { + counters_.release(); + counters_.acquireWrite(); + counters_[key].value = value; + counters_.release(); + return value; + } + + it->second.acquireWrite(); + it->second.value = value; + it->second.release(); + counters_.release(); + return value; +} + +void FacebookBase::getCounters(std::map& _return) { + // we need to lock the whole thing and actually build the map since we don't + // want our read/write structure to go over the wire + counters_.acquireRead(); + for(ReadWriteCounterMap::iterator it = counters_.begin(); + it != counters_.end(); it++) + { + _return[it->first] = it->second.value; + } + counters_.release(); +} + +int64_t FacebookBase::getCounter(const std::string& key) { + int64_t rv = 0; + counters_.acquireRead(); + ReadWriteCounterMap::iterator it = counters_.find(key); + if (it != counters_.end()) { + it->second.acquireRead(); + rv = it->second.value; + it->second.release(); + } + counters_.release(); + return rv; +} + +inline int64_t FacebookBase::aliveSince() { + return aliveSince_; +} + diff --git a/contrib/fb303/cpp/FacebookBase.h b/contrib/fb303/cpp/FacebookBase.h new file mode 100644 index 00000000..fd169e62 --- /dev/null +++ b/contrib/fb303/cpp/FacebookBase.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _FACEBOOK_TB303_FACEBOOKBASE_H_ +#define _FACEBOOK_TB303_FACEBOOKBASE_H_ 1 + +#include "FacebookService.h" + +#include "server/TServer.h" +#include "concurrency/Mutex.h" + +#include +#include +#include + +namespace facebook { namespace fb303 { + +using apache::thrift::concurrency::Mutex; +using apache::thrift::concurrency::ReadWriteMutex; +using apache::thrift::server::TServer; + +struct ReadWriteInt : ReadWriteMutex {int64_t value;}; +struct ReadWriteCounterMap : ReadWriteMutex, + std::map {}; + +/** + * Base Facebook service implementation in C++. + * + */ +class FacebookBase : virtual public FacebookServiceIf { + protected: + FacebookBase(std::string name); + virtual ~FacebookBase() {} + + public: + void getName(std::string& _return); + virtual void getVersion(std::string& _return) { _return = ""; } + + virtual fb_status getStatus() = 0; + virtual void getStatusDetails(std::string& _return) { _return = ""; } + + void setOption(const std::string& key, const std::string& value); + void getOption(std::string& _return, const std::string& key); + void getOptions(std::map & _return); + + int64_t aliveSince(); + + virtual void reinitialize() {} + + virtual void shutdown() { + if (server_.get() != NULL) { + server_->stop(); + } + } + + int64_t incrementCounter(const std::string& key, int64_t amount = 1); + int64_t setCounter(const std::string& key, int64_t value); + + void getCounters(std::map& _return); + int64_t getCounter(const std::string& key); + + /** + * Set server handle for shutdown method + */ + void setServer(boost::shared_ptr server) { + server_ = server; + } + + void getCpuProfile(std::string& _return, int32_t durSecs) { _return = ""; } + + private: + + std::string name_; + int64_t aliveSince_; + + std::map options_; + Mutex optionsLock_; + + ReadWriteCounterMap counters_; + + boost::shared_ptr server_; + +}; + +}} // facebook::tb303 + +#endif // _FACEBOOK_TB303_FACEBOOKBASE_H_ diff --git a/contrib/fb303/cpp/Makefile.am b/contrib/fb303/cpp/Makefile.am new file mode 100644 index 00000000..e62608ce --- /dev/null +++ b/contrib/fb303/cpp/Makefile.am @@ -0,0 +1,84 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +@GLOBAL_HEADER_MK@ + +@PRODUCT_MK@ + + +# User specified path variables set in configure.ac. +# thrift_home +# +THRIFT = $(thrift_home)/bin/thrift + +# User defined conditionals and conditonal statements set up in configure.ac. +if DEBUG + DEBUG_CPPFLAGS = -DDEBUG_TIMING +endif + +# Set common flags recognized by automake. +# DO NOT USE CPPFLAGS, CXXFLAGS, CFLAGS, LDFLAGS here! Set in configure.ac and|or override on command line. +# USE flags AM_CXXFLAGS, AM_CFLAGS, AM_CPPFLAGS, AM_LDFLAGS, LDADD in this section. + +AM_CPPFLAGS = -I.. +AM_CPPFLAGS += -Igen-cpp +AM_CPPFLAGS += -I$(thrift_home)/include/thrift +AM_CPPFLAGS += $(BOOST_CPPFLAGS) +AM_CPPFLAGS += $(FB_CPPFLAGS) $(DEBUG_CPPFLAGS) + +# GENERATE BUILD RULES +# Set Program/library specific flags recognized by automake. +# Use _ to set prog / lib specific flag s +# foo_CXXFLAGS foo_CPPFLAGS foo_LDFLAGS foo_LDADD + +fb303_lib = gen-cpp/FacebookService.cpp gen-cpp/fb303_constants.cpp gen-cpp/fb303_types.cpp FacebookBase.cpp ServiceTracker.cpp + +# Static -- multiple libraries can be defined +if STATIC +lib_LIBRARIES = libfb303.a +libfb303_a_SOURCES = $(fb303_lib) +INTERNAL_LIBS = libfb303.a +endif + +# Shared -- multiple libraries can be defined +if SHARED +shareddir = lib +shared_PROGRAMS = libfb303.so +libfb303_so_SOURCES = $(fb303_lib) +libfb303_so_CXXFLAGS = $(SHARED_CXXFLAGS) +libfb303_so_LDFLAGS = $(SHARED_LDFLAGS) +INTERNAL_LIBS = libfb303.so +endif + +# Set up Thrift specific activity here. +# We assume that a +types.cpp will always be built from .thrift. +$(eval $(call thrift_template,.,../if/fb303.thrift,-I $(thrift_home)/share --gen cpp )) + +include_fb303dir = $(includedir)/thrift/fb303 +include_fb303_HEADERS = FacebookBase.h ServiceTracker.h gen-cpp/FacebookService.h gen-cpp/fb303_constants.h gen-cpp/fb303_types.h + +include_fb303ifdir = $(prefix)/share/fb303/if +include_fb303if_HEADERS = ../if/fb303.thrift + +BUILT_SOURCES = thriftstyle + +# Add to pre-existing target clean +clean-local: clean-common + +@GLOBAL_FOOTER_MK@ diff --git a/contrib/fb303/cpp/ServiceTracker.cpp b/contrib/fb303/cpp/ServiceTracker.cpp new file mode 100644 index 00000000..c20a0683 --- /dev/null +++ b/contrib/fb303/cpp/ServiceTracker.cpp @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "FacebookBase.h" +#include "ServiceTracker.h" +#include "concurrency/ThreadManager.h" + +using namespace std; +using namespace facebook::fb303; +using namespace apache::thrift::concurrency; + + +uint64_t ServiceTracker::CHECKPOINT_MINIMUM_INTERVAL_SECONDS = 60; +int ServiceTracker::LOG_LEVEL = 5; + + +ServiceTracker::ServiceTracker(facebook::fb303::FacebookBase *handler, + void (*logMethod)(int, const string &), + bool featureCheckpoint, + bool featureStatusCheck, + bool featureThreadCheck, + Stopwatch::Unit stopwatchUnit) + : handler_(handler), logMethod_(logMethod), + featureCheckpoint_(featureCheckpoint), + featureStatusCheck_(featureStatusCheck), + featureThreadCheck_(featureThreadCheck), + stopwatchUnit_(stopwatchUnit), + checkpointServices_(0) +{ + if (featureCheckpoint_) { + time_t now = time(NULL); + checkpointTime_ = now; + } else { + checkpointTime_ = 0; + } +} + +/** + * Registers the beginning of a "service method": basically, any of + * the implementations of Thrift remote procedure calls that a + * FacebookBase handler is handling. Controls concurrent + * services and reports statistics (via log and via fb303 counters). + * Throws an exception if the server is not ready to handle service + * methods yet. + * + * note: The relationship between startService() and finishService() + * is currently defined so that a call to finishService() should only + * be matched to this call to startService() if this method returns + * without exception. It wouldn't be a problem to implement things + * the other way, so that *every* start needed a finish, but this + * convention was chosen to match the way an object's constructor and + * destructor work together, i.e. to work well with ServiceMethod + * objects. + * + * @param const ServiceMethod &serviceMethod A reference to the ServiceMethod + * object instantiated at the start + * of the service method. + */ +void +ServiceTracker::startService(const ServiceMethod &serviceMethod) +{ + // note: serviceMethod.timer_ automatically starts at construction. + + // log service start + logMethod_(5, serviceMethod.signature_); + + // check handler ready + if (featureStatusCheck_ && !serviceMethod.featureLogOnly_) { + // note: Throwing exceptions before counting statistics. See note + // in method header. + // note: A STOPPING server is not accepting new connections, but it + // is still handling any already-connected threads -- so from the + // service method's point of view, a status of STOPPING is a green + // light. + facebook::fb303::fb_status status = handler_->getStatus(); + if (status != facebook::fb303::ALIVE + && status != facebook::fb303::STOPPING) { + if (status == facebook::fb303::STARTING) { + throw ServiceException("Server starting up; please try again later"); + } else { + throw ServiceException("Server not alive; please try again later"); + } + } + } + + // check server threads + if (featureThreadCheck_ && !serviceMethod.featureLogOnly_) { + // note: Might want to put these messages in reportCheckpoint() if + // log is getting spammed. + if (threadManager_ != NULL) { + size_t idle_count = threadManager_->idleWorkerCount(); + if (idle_count == 0) { + stringstream message; + message << "service " << serviceMethod.signature_ + << ": all threads (" << threadManager_->workerCount() + << ") in use"; + logMethod_(3, message.str()); + } + } + } +} + +/** + * Logs a significant step in the middle of a "service method"; see + * startService. + * + * @param const ServiceMethod &serviceMethod A reference to the ServiceMethod + * object instantiated at the start + * of the service method. + * @return int64_t Elapsed units (see stopwatchUnit_) since ServiceMethod + * instantiation. + */ +int64_t +ServiceTracker::stepService(const ServiceMethod &serviceMethod, + const string &stepName) +{ + stringstream message; + string elapsed_label; + int64_t elapsed = serviceMethod.timer_.elapsedUnits(stopwatchUnit_, + &elapsed_label); + message << serviceMethod.signature_ + << ' ' << stepName + << " [" << elapsed_label << ']'; + logMethod_(5, message.str()); + return elapsed; +} + +/** + * Registers the end of a "service method"; see startService(). + * + * @param const ServiceMethod &serviceMethod A reference to the ServiceMethod + * object instantiated at the start + * of the service method. + */ +void +ServiceTracker::finishService(const ServiceMethod &serviceMethod) +{ + // log end of service + stringstream message; + string duration_label; + int64_t duration = serviceMethod.timer_.elapsedUnits(stopwatchUnit_, + &duration_label); + message << serviceMethod.signature_ + << " finish [" << duration_label << ']'; + logMethod_(5, message.str()); + + // count, record, and maybe report service statistics + if (!serviceMethod.featureLogOnly_) { + + if (!featureCheckpoint_) { + + // lifetime counters + // (note: No need to lock statisticsMutex_ if not doing checkpoint; + // FacebookService::incrementCounter() is already thread-safe.) + handler_->incrementCounter("lifetime_services"); + + } else { + + statisticsMutex_.lock(); + // note: No exceptions expected from this code block. Wrap in a try + // just to be safe. + try { + + // lifetime counters + // note: Good to synchronize this with the increment of + // checkpoint services, even though incrementCounter() is + // already thread-safe, for the sake of checkpoint reporting + // consistency (i.e. since the last checkpoint, + // lifetime_services has incremented by checkpointServices_). + handler_->incrementCounter("lifetime_services"); + + // checkpoint counters + checkpointServices_++; + checkpointDuration_ += duration; + + // per-service timing + // note kjv: According to my tests it is very slightly faster to + // call insert() once (and detect not-found) than calling find() + // and then maybe insert (if not-found). However, the difference + // is tiny for small maps like this one, and the code for the + // faster solution is slightly less readable. Also, I wonder if + // the instantiation of the (often unused) pair to insert makes + // the first algorithm slower after all. + map >::iterator iter; + iter = checkpointServiceDuration_.find(serviceMethod.name_); + if (iter != checkpointServiceDuration_.end()) { + iter->second.first++; + iter->second.second += duration; + } else { + checkpointServiceDuration_.insert(make_pair(serviceMethod.name_, + make_pair(1, duration))); + } + + // maybe report checkpoint + // note: ...if it's been long enough since the last report. + time_t now = time(NULL); + uint64_t check_interval = now - checkpointTime_; + if (check_interval >= CHECKPOINT_MINIMUM_INTERVAL_SECONDS) { + reportCheckpoint(); + } + + } catch (...) { + statisticsMutex_.unlock(); + throw; + } + statisticsMutex_.unlock(); + + } + } +} + +/** + * Logs some statistics gathered since the last call to this method. + * + * note: Thread race conditions on this method could cause + * misreporting and/or undefined behavior; the caller must protect + * uses of the object variables (and calls to this method) with a + * mutex. + * + */ +void +ServiceTracker::reportCheckpoint() +{ + time_t now = time(NULL); + + uint64_t check_count = checkpointServices_; + uint64_t check_interval = now - checkpointTime_; + uint64_t check_duration = checkpointDuration_; + + // export counters for timing of service methods (by service name) + handler_->setCounter("checkpoint_time", check_interval); + map >::iterator iter; + uint64_t count; + for (iter = checkpointServiceDuration_.begin(); + iter != checkpointServiceDuration_.end(); + iter++) { + count = iter->second.first; + handler_->setCounter(string("checkpoint_count_") + iter->first, count); + if (count == 0) { + handler_->setCounter(string("checkpoint_speed_") + iter->first, + 0); + } else { + handler_->setCounter(string("checkpoint_speed_") + iter->first, + iter->second.second / count); + } + } + + // reset checkpoint variables + // note: Clearing the map while other threads are using it might + // cause undefined behavior. + checkpointServiceDuration_.clear(); + checkpointTime_ = now; + checkpointServices_ = 0; + checkpointDuration_ = 0; + + // get lifetime variables + uint64_t life_count = handler_->getCounter("lifetime_services"); + uint64_t life_interval = now - handler_->aliveSince(); + + // log checkpoint + stringstream message; + message << "checkpoint_time:" << check_interval + << " checkpoint_services:" << check_count + << " checkpoint_speed_sum:" << check_duration + << " lifetime_time:" << life_interval + << " lifetime_services:" << life_count; + if (featureThreadCheck_ && threadManager_ != NULL) { + size_t worker_count = threadManager_->workerCount(); + size_t idle_count = threadManager_->idleWorkerCount(); + message << " total_workers:" << worker_count + << " active_workers:" << (worker_count - idle_count); + } + logMethod_(4, message.str()); +} + +/** + * Remembers the thread manager used in the server, for monitoring thread + * activity. + * + * @param shared_ptr threadManager The server's thread manager. + */ +void +ServiceTracker::setThreadManager(boost::shared_ptr + threadManager) +{ + threadManager_ = threadManager; +} + +/** + * Logs messages to stdout; the passed message will be logged if the + * passed level is less than or equal to LOG_LEVEL. + * + * This is the default logging method used by the ServiceTracker. An + * alternate logging method (that accepts the same parameters) may be + * specified to the constructor. + * + * @param int level A level associated with the message: higher levels + * are used to indicate higher levels of detail. + * @param string message The message to log. + */ +void +ServiceTracker::defaultLogMethod(int level, const string &message) +{ + if (level <= LOG_LEVEL) { + string level_string; + time_t now = time(NULL); + char now_pretty[26]; + ctime_r(&now, now_pretty); + now_pretty[24] = '\0'; + switch (level) { + case 1: + level_string = "CRITICAL"; + break; + case 2: + level_string = "ERROR"; + break; + case 3: + level_string = "WARNING"; + break; + case 5: + level_string = "DEBUG"; + break; + case 4: + default: + level_string = "INFO"; + break; + } + cout << '[' << level_string << "] [" << now_pretty << "] " + << message << endl; + } +} + + +/** + * Creates a Stopwatch, which can report the time elapsed since its + * creation. + * + */ +Stopwatch::Stopwatch() +{ + gettimeofday(&startTime_, NULL); +} + +void +Stopwatch::reset() +{ + gettimeofday(&startTime_, NULL); +} + +uint64_t +Stopwatch::elapsedUnits(Stopwatch::Unit unit, string *label) const +{ + timeval now_time; + gettimeofday(&now_time, NULL); + time_t duration_secs = now_time.tv_sec - startTime_.tv_sec; + + uint64_t duration_units; + switch (unit) { + case UNIT_SECONDS: + duration_units = duration_secs + + (now_time.tv_usec - startTime_.tv_usec + 500000) / 1000000; + if (NULL != label) { + stringstream ss_label; + ss_label << duration_units << " secs"; + label->assign(ss_label.str()); + } + break; + case UNIT_MICROSECONDS: + duration_units = duration_secs * 1000000 + + now_time.tv_usec - startTime_.tv_usec; + if (NULL != label) { + stringstream ss_label; + ss_label << duration_units << " us"; + label->assign(ss_label.str()); + } + break; + case UNIT_MILLISECONDS: + default: + duration_units = duration_secs * 1000 + + (now_time.tv_usec - startTime_.tv_usec + 500) / 1000; + if (NULL != label) { + stringstream ss_label; + ss_label << duration_units << " ms"; + label->assign(ss_label.str()); + } + break; + } + return duration_units; +} + +/** + * Creates a ServiceMethod, used for tracking a single service method + * invocation (via the ServiceTracker). The passed name of the + * ServiceMethod is used to group statistics (e.g. counts and durations) + * for similar invocations; the passed signature is used to uniquely + * identify the particular invocation in the log. + * + * note: A version of this constructor is provided that automatically + * forms a signature the name and a passed numeric id. Silly, sure, + * but commonly used, since it often saves the caller a line or two of + * code. + * + * @param ServiceTracker *tracker The service tracker that will track this + * ServiceMethod. + * @param const string &name The service method name (usually independent + * of service method parameters). + * @param const string &signature A signature uniquely identifying the method + * invocation (usually name plus parameters). + */ +ServiceMethod::ServiceMethod(ServiceTracker *tracker, + const string &name, + const string &signature, + bool featureLogOnly) + : tracker_(tracker), name_(name), signature_(signature), + featureLogOnly_(featureLogOnly) +{ + // note: timer_ automatically starts at construction. + + // invoke tracker to start service + // note: Might throw. If it throws, then this object's destructor + // won't be called, which is according to plan: finishService() is + // only supposed to be matched to startService() if startService() + // returns without error. + tracker_->startService(*this); +} + +ServiceMethod::ServiceMethod(ServiceTracker *tracker, + const string &name, + uint64_t id, + bool featureLogOnly) + : tracker_(tracker), name_(name), featureLogOnly_(featureLogOnly) +{ + // note: timer_ automatically starts at construction. + stringstream ss_signature; + ss_signature << name << " (" << id << ')'; + signature_ = ss_signature.str(); + + // invoke tracker to start service + // note: Might throw. If it throws, then this object's destructor + // won't be called, which is according to plan: finishService() is + // only supposed to be matched to startService() if startService() + // returns without error. + tracker_->startService(*this); +} + +ServiceMethod::~ServiceMethod() +{ + // invoke tracker to finish service + // note: Not expecting an exception from this code, but + // finishService() might conceivably throw an out-of-memory + // exception. + try { + tracker_->finishService(*this); + } catch (...) { + // don't throw + } +} + +uint64_t +ServiceMethod::step(const std::string &stepName) +{ + return tracker_->stepService(*this, stepName); +} diff --git a/contrib/fb303/cpp/ServiceTracker.h b/contrib/fb303/cpp/ServiceTracker.h new file mode 100644 index 00000000..93043863 --- /dev/null +++ b/contrib/fb303/cpp/ServiceTracker.h @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * ServiceTracker is a utility class for logging and timing service + * calls to a fb303 Thrift server. Currently, ServiceTracker offers + * the following features: + * + * . Logging of service method start, end (and duration), and + * optional steps in between. + * + * . Automatic check of server status via fb303::getStatus() + * with a ServiceException thrown if server not alive + * (at method start). + * + * . A periodic logged checkpoint reporting lifetime time, lifetime + * service count, and per-method statistics since the last checkpoint + * time (at method finish). + * + * . Export of fb303 counters for lifetime and checkpoint statistics + * (at method finish). + * + * . For TThreadPoolServers, a logged warning when all server threads + * are busy (at method start). (Must call setThreadManager() after + * ServiceTracker instantiation for this feature to be enabled.) + * + * Individual features may be enabled or disabled by arguments to the + * constructor. The constructor also accepts a pointer to a logging + * method -- if no pointer is passed, the tracker will log to stdout. + * + * ServiceTracker defines private methods for service start, finish, + * and step, which are designed to be accessed by instantiating a + * friend ServiceMethod object, as in the following example: + * + * #include + * class MyServiceHandler : virtual public MyServiceIf, + * public facebook::fb303::FacebookBase + * { + * public: + * MyServiceHandler::MyServiceHandler() : mServiceTracker(this) {} + * void MyServiceHandler::myServiceMethod(int userId) { + * // note: Instantiating a ServiceMethod object starts a timer + * // and tells the ServiceTracker to log the start. Might throw + * // a ServiceException. + * ServiceMethod serviceMethod(&mServiceTracker, + * "myServiceMethod", + * userId); + * ... + * // note: Calling the step method tells the ServiceTracker to + * // log the step, with a time elapsed since start. + * serviceMethod.step("post parsing, begin processing"); + * ... + * // note: When the ServiceMethod object goes out of scope, the + * // ServiceTracker will log the total elapsed time of the method. + * } + * ... + * private: + * ServiceTracker mServiceTracker; + * } + * + * The step() method call is optional; the startService() and + * finishService() methods are handled by the object's constructor and + * destructor. + * + * The ServiceTracker is (intended to be) thread-safe. + * + * Future: + * + * . Come up with something better for logging than passing a + * function pointer to the constructor. + * + * . Add methods for tracking errors from service methods, e.g. + * ServiceTracker::reportService(). + */ + +#ifndef SERVICETRACKER_H +#define SERVICETRACKER_H + + +#include +#include +#include +#include +#include +#include + +#include "concurrency/Mutex.h" + + +namespace apache { namespace thrift { namespace concurrency { + class ThreadManager; +}}} + + +namespace facebook { namespace fb303 { + + +class FacebookBase; +class ServiceMethod; + + +class Stopwatch +{ +public: + enum Unit { UNIT_SECONDS, UNIT_MILLISECONDS, UNIT_MICROSECONDS }; + Stopwatch(); + uint64_t elapsedUnits(Unit unit, std::string *label = NULL) const; + void reset(); +private: + timeval startTime_; +}; + + +class ServiceTracker +{ + friend class ServiceMethod; + +public: + + static uint64_t CHECKPOINT_MINIMUM_INTERVAL_SECONDS; + static int LOG_LEVEL; + + ServiceTracker(facebook::fb303::FacebookBase *handler, + void (*logMethod)(int, const std::string &) + = &ServiceTracker::defaultLogMethod, + bool featureCheckpoint = true, + bool featureStatusCheck = true, + bool featureThreadCheck = true, + Stopwatch::Unit stopwatchUnit + = Stopwatch::UNIT_MILLISECONDS); + + void setThreadManager(boost::shared_ptr threadManager); + +private: + + facebook::fb303::FacebookBase *handler_; + void (*logMethod_)(int, const std::string &); + boost::shared_ptr threadManager_; + + bool featureCheckpoint_; + bool featureStatusCheck_; + bool featureThreadCheck_; + Stopwatch::Unit stopwatchUnit_; + + apache::thrift::concurrency::Mutex statisticsMutex_; + time_t checkpointTime_; + uint64_t checkpointServices_; + uint64_t checkpointDuration_; + std::map > checkpointServiceDuration_; + + void startService(const ServiceMethod &serviceMethod); + int64_t stepService(const ServiceMethod &serviceMethod, + const std::string &stepName); + void finishService(const ServiceMethod &serviceMethod); + void reportCheckpoint(); + static void defaultLogMethod(int level, const std::string &message); +}; + + +class ServiceMethod +{ + friend class ServiceTracker; +public: + ServiceMethod(ServiceTracker *tracker, + const std::string &name, + const std::string &signature, + bool featureLogOnly = false); + ServiceMethod(ServiceTracker *tracker, + const std::string &name, + uint64_t id, + bool featureLogOnly = false); + ~ServiceMethod(); + uint64_t step(const std::string &stepName); +private: + ServiceTracker *tracker_; + std::string name_; + std::string signature_; + bool featureLogOnly_; + Stopwatch timer_; +}; + + +class ServiceException : public std::exception +{ +public: + explicit ServiceException(const std::string &message, int code = 0) + : message_(message), code_(code) {} + ~ServiceException() throw() {} + virtual const char *what() const throw() { return message_.c_str(); } + int code() const throw() { return code_; } +private: + std::string message_; + int code_; +}; + + +}} // facebook::fb303 + +#endif diff --git a/contrib/fb303/global_footer.mk b/contrib/fb303/global_footer.mk new file mode 100644 index 00000000..96f82ebd --- /dev/null +++ b/contrib/fb303/global_footer.mk @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +thriftstyle : $(XBUILT_SOURCES) + diff --git a/contrib/fb303/global_header.mk b/contrib/fb303/global_header.mk new file mode 100644 index 00000000..77c9455e --- /dev/null +++ b/contrib/fb303/global_header.mk @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +#define thrift_template +# $(1) : $(2) +# $$(THRIFT) $(3) $(4) $(5) $(6) $(7) $(8) $$< +#endef + +define thrift_template +XTARGET := $(shell perl -e '@val = split("\/","$(2)"); $$last = pop(@val);split("\\.",$$last);print "$(1)/"."gen-cpp/"."@_[0]"."_types.cpp\n"' ) + +ifneq ($$(XBUILT_SOURCES),) + XBUILT_SOURCES := $$(XBUILT_SOURCES) $$(XTARGET) +else + XBUILT_SOURCES := $$(XTARGET) +endif +$$(XTARGET) : $(2) + $$(THRIFT) -o $1 $3 $$< +endef + +clean-common: + rm -rf gen-* diff --git a/contrib/fb303/if/fb303.thrift b/contrib/fb303/if/fb303.thrift new file mode 100644 index 00000000..66c83152 --- /dev/null +++ b/contrib/fb303/if/fb303.thrift @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * fb303.thrift + */ + +namespace java com.facebook.fb303 +namespace cpp facebook.fb303 +namespace perl Facebook.FB303 + +/** + * Common status reporting mechanism across all services + */ +enum fb_status { + DEAD = 0, + STARTING = 1, + ALIVE = 2, + STOPPING = 3, + STOPPED = 4, + WARNING = 5, +} + +/** + * Standard base service + */ +service FacebookService { + + /** + * Returns a descriptive name of the service + */ + string getName(), + + /** + * Returns the version of the service + */ + string getVersion(), + + /** + * Gets the status of this service + */ + fb_status getStatus(), + + /** + * User friendly description of status, such as why the service is in + * the dead or warning state, or what is being started or stopped. + */ + string getStatusDetails(), + + /** + * Gets the counters for this service + */ + map getCounters(), + + /** + * Gets the value of a single counter + */ + i64 getCounter(1: string key), + + /** + * Sets an option + */ + void setOption(1: string key, 2: string value), + + /** + * Gets an option + */ + string getOption(1: string key), + + /** + * Gets all options + */ + map getOptions(), + + /** + * Returns a CPU profile over the given time interval (client and server + * must agree on the profile format). + */ + string getCpuProfile(1: i32 profileDurationInSec), + + /** + * Returns the unix time that the server has been running since + */ + i64 aliveSince(), + + /** + * Tell the server to reload its configuration, reopen log files, etc + */ + oneway void reinitialize(), + + /** + * Suggest a shutdown to the server + */ + oneway void shutdown(), + +} diff --git a/contrib/fb303/java/FacebookBase.java b/contrib/fb303/java/FacebookBase.java new file mode 100644 index 00000000..5778cc8b --- /dev/null +++ b/contrib/fb303/java/FacebookBase.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.facebook.fb303; + +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; + +public abstract class FacebookBase implements FacebookService.Iface { + + private String name_; + + private long alive_; + + private final ConcurrentHashMap counters_ = + new ConcurrentHashMap(); + + private final ConcurrentHashMap options_ = + new ConcurrentHashMap(); + + protected FacebookBase(String name) { + name_ = name; + alive_ = System.currentTimeMillis() / 1000; + } + + public String getName() { + return name_; + } + + public abstract int getStatus(); + + public String getStatusDetails() { + return ""; + } + + public void deleteCounter(String key) { + counters_.remove(key); + } + + public void resetCounter(String key) { + counters_.put(key, 0L); + } + + public long incrementCounter(String key) { + long val = getCounter(key) + 1; + counters_.put(key, val); + return val; + } + + public AbstractMap getCounters() { + return counters_; + } + + public long getCounter(String key) { + Long val = counters_.get(key); + if (val == null) { + return 0; + } + return val.longValue(); + } + + public void setOption(String key, String value) { + options_.put(key, value); + } + + public String getOption(String key) { + return options_.get(key); + } + + public AbstractMap getOptions() { + return options_; + } + + public long aliveSince() { + return alive_; + } + + public String getCpuProfile() { + return ""; + } + + public void reinitialize() {} + + public void shutdown() {} + +} diff --git a/contrib/fb303/java/build.xml b/contrib/fb303/java/build.xml new file mode 100755 index 00000000..4ad30e53 --- /dev/null +++ b/contrib/fb303/java/build.xml @@ -0,0 +1,84 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + generating thrift fb303 files + + + + + + + + + + + + + Building libfb303.jar .... + + + + + + + + + + + + + + + + + + + + Cleaning old stuff .... + + + + + diff --git a/contrib/fb303/php/FacebookBase.php b/contrib/fb303/php/FacebookBase.php new file mode 100644 index 00000000..2ac318fb --- /dev/null +++ b/contrib/fb303/php/FacebookBase.php @@ -0,0 +1,89 @@ +name_ = $name; + } + + public function getName() { + return $this->name_; + } + + public function getVersion() { + return ''; + } + + public function getStatus() { + return null; + } + + public function getStatusDetails() { + return ''; + } + + public function getCounters() { + return array(); + } + + public function getCounter($key) { + return null; + } + + public function setOption($key, $value) { + return; + } + + public function getOption($key) { + return ''; + } + + public function getOptions() { + return array(); + } + + public function aliveSince() { + return 0; + } + + public function getCpuProfile($duration) { + return ''; + } + + public function getLimitedReflection() { + return array(); + } + + public function reinitialize() { + return; + } + + public function shutdown() { + return; + } + +} + diff --git a/contrib/fb303/py/Makefile.am b/contrib/fb303/py/Makefile.am new file mode 100644 index 00000000..060495e5 --- /dev/null +++ b/contrib/fb303/py/Makefile.am @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +DESTDIR ?= / +EXTRA_DIST = setup.py src + +all: + +all-local: + $(thrift_home)/bin/thrift --gen py $(top_srcdir)/if/fb303.thrift + mv gen-py/fb303/* fb303 + $(PYTHON) setup.py build + +# We're ignoring prefix here because site-packages seems to be +# the equivalent of /usr/local/lib in Python land. +# Old version (can't put inline because it's not portable). +#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS) +install-exec-hook: + $(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS) + + + +clean: clean-local + +clean-local: + $(RM) -r build + +check-local: all diff --git a/contrib/fb303/py/fb303/FacebookBase.py b/contrib/fb303/py/fb303/FacebookBase.py new file mode 100644 index 00000000..685ff20f --- /dev/null +++ b/contrib/fb303/py/fb303/FacebookBase.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import time +import FacebookService +import thrift.reflection.limited +from ttypes import fb_status + +class FacebookBase(FacebookService.Iface): + + def __init__(self, name): + self.name = name + self.alive = int(time.time()) + self.counters = {} + + def getName(self, ): + return self.name + + def getVersion(self, ): + return '' + + def getStatus(self, ): + return fb_status.ALIVE + + def getCounters(self): + return self.counters + + def resetCounter(self, key): + self.counters[key] = 0 + + def getCounter(self, key): + if self.counters.has_key(key): + return self.counters[key] + return 0 + + def incrementCounter(self, key): + self.counters[key] = self.getCounter(key) + 1 + + def setOption(self, key, value): + pass + + def getOption(self, key): + return "" + + def getOptions(self): + return {} + + def getOptions(self): + return {} + + def aliveSince(self): + return self.alive + + def getCpuProfile(self, duration): + return "" + + def getLimitedReflection(self): + return thrift.reflection.limited.Service() + + def reinitialize(self): + pass + + def shutdown(self): + pass diff --git a/contrib/fb303/py/fb303_scripts/__init__.py b/contrib/fb303/py/fb303_scripts/__init__.py new file mode 100644 index 00000000..f8e3a94b --- /dev/null +++ b/contrib/fb303/py/fb303_scripts/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['fb303_simple_mgmt'] diff --git a/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py new file mode 100644 index 00000000..4f8ce993 --- /dev/null +++ b/contrib/fb303/py/fb303_scripts/fb303_simple_mgmt.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, os +from optparse import OptionParser + +from thrift.Thrift import * + +from thrift.transport import TSocket +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +from fb303 import * +from fb303.ttypes import * + +def service_ctrl( + command, + port, + trans_factory = None, + prot_factory = None): + """ + service_ctrl is a generic function to execute standard fb303 functions + + @param command: one of stop, start, reload, status, counters, name, alive + @param port: service's port + @param trans_factory: TTransportFactory to use for obtaining a TTransport. Default is + TBufferedTransportFactory + @param prot_factory: TProtocolFactory to use for obtaining a TProtocol. Default is + TBinaryProtocolFactory + """ + + if command in ["status"]: + try: + status = fb303_wrapper('status', port, trans_factory, prot_factory) + status_details = fb303_wrapper('get_status_details', port, trans_factory, prot_factory) + + msg = fb_status_string(status) + if (len(status_details)): + msg += " - %s" % status_details + print msg + + if (status == fb_status.ALIVE): + return 2 + else: + return 3 + except: + print "Failed to get status" + return 3 + + # scalar commands + if command in ["version","alive","name"]: + try: + result = fb303_wrapper(command, port, trans_factory, prot_factory) + print result + return 0 + except: + print "failed to get ",command + return 3 + + # counters + if command in ["counters"]: + try: + counters = fb303_wrapper('counters', port, trans_factory, prot_factory) + for counter in counters: + print "%s: %d" % (counter, counters[counter]) + return 0 + except: + print "failed to get counters" + return 3 + + + # Only root should be able to run the following commands + if os.getuid() == 0: + # async commands + if command in ["stop","reload"] : + try: + fb303_wrapper(command, port, trans_factory, prot_factory) + return 0 + except: + print "failed to tell the service to ", command + return 3 + else: + if command in ["stop","reload"]: + print "root privileges are required to stop or reload the service." + return 4 + + print "The following commands are available:" + for command in ["counters","name","version","alive","status"]: + print "\t%s" % command + print "The following commands are available for users with root privileges:" + for command in ["stop","reload"]: + print "\t%s" % command + + + + return 0; + + +def fb303_wrapper(command, port, trans_factory = None, prot_factory = None): + sock = TSocket.TSocket('localhost', port) + + # use input transport factory if provided + if (trans_factory is None): + trans = TTransport.TBufferedTransport(sock) + else: + trans = trans_factory.getTransport(sock) + + # use input protocol factory if provided + if (prot_factory is None): + prot = TBinaryProtocol.TBinaryProtocol(trans) + else: + prot = prot_factory.getProtocol(trans) + + # initialize client and open transport + fb303_client = FacebookService.Client(prot, prot) + trans.open() + + if (command == 'reload'): + fb303_client.reinitialize() + + elif (command == 'stop'): + fb303_client.shutdown() + + elif (command == 'status'): + return fb303_client.getStatus() + + elif (command == 'version'): + return fb303_client.getVersion() + + elif (command == 'get_status_details'): + return fb303_client.getStatusDetails() + + elif (command == 'counters'): + return fb303_client.getCounters() + + elif (command == 'name'): + return fb303_client.getName() + + elif (command == 'alive'): + return fb303_client.aliveSince() + + trans.close() + + +def fb_status_string(status_enum): + if (status_enum == fb_status.DEAD): + return "DEAD" + if (status_enum == fb_status.STARTING): + return "STARTING" + if (status_enum == fb_status.ALIVE): + return "ALIVE" + if (status_enum == fb_status.STOPPING): + return "STOPPING" + if (status_enum == fb_status.STOPPED): + return "STOPPED" + if (status_enum == fb_status.WARNING): + return "WARNING" + + +def main(): + + # parse command line options + parser = OptionParser() + commands=["stop","counters","status","reload","version","name","alive"] + + parser.add_option("-c", "--command", dest="command", help="execute this API", + choices=commands, default="status") + parser.add_option("-p","--port",dest="port",help="the service's port", + default=9082) + + (options, args) = parser.parse_args() + status = service_ctrl(options.command, options.port) + sys.exit(status) + + +if __name__ == '__main__': + main() diff --git a/contrib/fb303/py/setup.py b/contrib/fb303/py/setup.py new file mode 100644 index 00000000..a29f9642 --- /dev/null +++ b/contrib/fb303/py/setup.py @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from distutils.core import setup + +setup(name='fb303', + version='1.0', + packages=['fb303', 'fb303_scripts'], + ) + + diff --git a/contrib/thrift.el b/contrib/thrift.el new file mode 100644 index 00000000..cd3e0e89 --- /dev/null +++ b/contrib/thrift.el @@ -0,0 +1,126 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one +;; or more contributor license agreements. See the NOTICE file +;; distributed with this work for additional information +;; regarding copyright ownership. The ASF licenses this file +;; to you under the Apache License, Version 2.0 (the +;; "License"); you may not use this file except in compliance +;; with the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, +;; software distributed under the License is distributed on an +;; "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +;; KIND, either express or implied. See the License for the +;; specific language governing permissions and limitations +;; under the License. +;; + +(require 'font-lock) + +(defvar thrift-mode-hook nil) +(add-to-list 'auto-mode-alist '("\\.thrift\\'" . thrift-mode)) + +(defvar thrift-indent-level 2 + "Defines 2 spaces for thrift indentation.") + +;; syntax coloring +(defconst thrift-font-lock-keywords + (list + '("#.*$" . font-lock-comment-face) ;; perl style comments + '("\\<\\(include\\|struct\\|exception\\|typedef\\|const\\|enum\\|service\\|extends\\|void\\|oneway\\|throws\\|optional\\|required\\)\\>" . font-lock-keyword-face) ;; keywords + '("\\<\\(bool\\|byte\\|i16\\|i32\\|i64\\|double\\|string\\|binary\\|map\\|list\\|set\\)\\>" . font-lock-type-face) ;; built-in types + '("\\<\\([0-9]+\\)\\>" . font-lock-variable-name-face) ;; ordinals + '("\\<\\(\\w+\\)\\s-*(" (1 font-lock-function-name-face)) ;; functions + ) + "Thrift Keywords") + +;; indentation +(defun thrift-indent-line () + "Indent current line as Thrift code." + (interactive) + (beginning-of-line) + (if (bobp) + (indent-line-to 0) + (let ((not-indented t) cur-indent) + (if (looking-at "^[ \t]*\\(}\\|throws\\)") + (if (looking-at "^[ \t]*}") + (progn + (save-excursion + (forward-line -1) + (setq cur-indent (- (current-indentation) thrift-indent-level))) + (if (< cur-indent 0) + (setq cur-indent 0))) + (progn + (save-excursion + (forward-line -1) + (if (looking-at "^[ \t]*[\\.<>[:word:]]+[ \t]+[\\.<>[:word:]]+[ \t]*(") + (setq cur-indent (+ (current-indentation) thrift-indent-level)) + (setq cur-indent (current-indentation)))))) + (save-excursion + (while not-indented + (forward-line -1) + (if (looking-at "^[ \t]*}") + (progn + (setq cur-indent (current-indentation)) + (setq not-indented nil)) + (if (looking-at "^.*{[^}]*$") + (progn + (setq cur-indent (+ (current-indentation) thrift-indent-level)) + (setq not-indented nil)) + (if (bobp) + (setq not-indented nil))) + (if (looking-at "^[ \t]*throws") + (progn + (setq cur-indent (- (current-indentation) thrift-indent-level)) + (if (< cur-indent 0) + (setq cur-indent 0)) + (setq not-indented nil)) + (if (bobp) + (setq not-indented nil))) + (if (looking-at "^[ \t]*[\\.<>[:word:]]+[ \t]+[\\.<>[:word:]]+[ \t]*([^)]*$") + (progn + (setq cur-indent (+ (current-indentation) thrift-indent-level)) + (setq not-indented nil)) + (if (bobp) + (setq not-indented nil))) + (if (looking-at "^[ \t]*\\/\\*") + (progn + (setq cur-indent (+ (current-indentation) 1)) + (setq not-indented nil)) + (if (bobp) + (setq not-indented nil))) + (if (looking-at "^[ \t]*\\*\\/") + (progn + (setq cur-indent (- (current-indentation) 1)) + (setq not-indented nil)) + (if (bobp) + (setq not-indented nil))) + )))) + (if cur-indent + (indent-line-to cur-indent) + (indent-line-to 0))))) + +;; C/C++ comments; also allowing underscore in words +(defvar thrift-mode-syntax-table + (let ((thrift-mode-syntax-table (make-syntax-table))) + (modify-syntax-entry ?_ "w" thrift-mode-syntax-table) + (modify-syntax-entry ?/ ". 1456" thrift-mode-syntax-table) + (modify-syntax-entry ?* ". 23" thrift-mode-syntax-table) + (modify-syntax-entry ?\n "> b" thrift-mode-syntax-table) + thrift-mode-syntax-table) + "Syntax table for thrift-mode") + +(defun thrift-mode () + "Mode for editing Thrift files" + (interactive) + (kill-all-local-variables) + (set-syntax-table thrift-mode-syntax-table) + (set (make-local-variable 'font-lock-defaults) '(thrift-font-lock-keywords)) + (setq major-mode 'thrift-mode) + (setq mode-name "Thrift") + (run-hooks 'thrift-mode-hook) + (set (make-local-variable 'indent-line-function) 'thrift-indent-line) + ) +(provide 'thrift-mode) diff --git a/contrib/thrift.spec b/contrib/thrift.spec new file mode 100644 index 00000000..ecee631a --- /dev/null +++ b/contrib/thrift.spec @@ -0,0 +1,206 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# TODO(dreiss): Have a Python build with and without the extension. +%{!?python_sitelib: %define python_sitelib %(%{__python} -c "from distutils.sysconfig import get_python_lib; print get_python_lib()")} +%{!?python_sitearch: %define python_sitearch %(%{__python} -c "from distutils.sysconfig import get_python_lib; print get_python_lib(1)")} +# TODO(dreiss): Where is this supposed to go? +%{!?thrift_erlang_root: %define thrift_erlang_root /opt/thrift-erl} + +Name: thrift +License: Apache License v2.0 +Group: Development +Summary: RPC and serialization framework +Version: 20080529svn +Epoch: 1 +Release: 1 +URL: http://developers.facebook.com/thrift +Packager: David Reiss +Source0: %{name}-%{version}.tar.gz + +BuildRequires: gcc >= 3.4.6 +BuildRequires: gcc-c++ + +# TODO(dreiss): Can these be moved into the individual packages? +%if %{!?without_java: 1} +BuildRequires: java-devel >= 0:1.5.0 +BuildRequires: ant >= 0:1.6.5 +%endif + +%if %{!?without_python: 1} +BuildRequires: python-devel +%endif + +%if %{!?without_erlang: 1} +BuildRequires: erlang +%endif + +BuildRoot: %{_tmppath}/%{name}-%{version}-%{release}-root-%(%{__id_u} -n) + +%description +Thrift is a software framework for scalable cross-language services +development. It combines a powerful software stack with a code generation +engine to build services that work efficiently and seamlessly between C++, +Java, C#, Python, Ruby, Perl, PHP, Objective C/Cocoa, Smalltalk, Erlang, +Objective Caml, and Haskell. + +%files +%defattr(-,root,root) +%{_bindir}/thrift + + +%package lib-cpp +Summary: Thrift C++ library +Group: Libraries + +%description lib-cpp +C++ libraries for Thrift. + +%files lib-cpp +%defattr(-,root,root) +%{_libdir}/libthrift*.so.* + + +%package lib-cpp-devel +Summary: Thrift C++ library development files +Group: Libraries +Requires: %{name} = %{version}-%{release} +Requires: boost-devel +%if %{!?without_libevent: 1} +Requires: libevent-devel >= 1.2 +%endif +%if %{!?without_zlib: 1} +Requires: zlib-devel +%endif + +%description lib-cpp-devel +C++ static libraries and headers for Thrift. + +%files lib-cpp-devel +%defattr(-,root,root) +%{_includedir}/thrift/ +%{_libdir}/libthrift*.*a +%{_libdir}/libthrift*.so +%{_libdir}/pkgconfig/thrift*.pc + + +%if %{!?without_java: 1} +%package lib-java +Summary: Thrift Java library +Group: Libraries +Requires: java >= 0:1.5.0 + +%description lib-java +Java libraries for Thrift. + +%files lib-java +%defattr(-,root,root) +%{_javadir}/* +%endif + + +%if %{!?without_python: 1} +%package lib-python +Summary: Thrift Python library +Group: Libraries + +%description lib-python +Python libraries for Thrift. + +%files lib-python +%defattr(-,root,root) +%{python_sitearch}/* +%endif + + +%if %{!?without_erlang: 1} +%package lib-erlang +Summary: Thrift Python library +Group: Libraries +Requires: erlang + +%description lib-erlang +Erlang libraries for Thrift. + +%files lib-erlang +%defattr(-,root,root) +%{thrift_erlang_root} +%endif + + +%prep +%setup -q + +%build +# TODO(dreiss): Implement a single --without-build-kludges. +%configure \ + %{?without_libevent: --without-libevent } \ + %{?without_zlib: --without-zlib } \ + --without-java \ + --without-csharp \ + --without-py \ + --without-erlang \ + +make + +%if %{!?without_java: 1} +cd lib/java +%ant +cd ../.. +%endif + +%if %{!?without_python: 1} +cd lib/py +CFLAGS="%{optflags}" %{__python} setup.py build +cd ../.. +%endif + +%if %{!?without_erlang: 1} +cd lib/erl +make +cd ../.. +%endif + +%install +%makeinstall + +%if %{!?without_java: 1} +mkdir -p $RPM_BUILD_ROOT%{_javadir} +cp -p lib/java/*.jar $RPM_BUILD_ROOT%{_javadir} +%endif + +%if %{!?without_python: 1} +cd lib/py +%{__python} setup.py install -O1 --skip-build --root $RPM_BUILD_ROOT +cd ../.. +%endif + +%if %{!?without_erlang: 1} +mkdir -p ${RPM_BUILD_ROOT}%{thrift_erlang_root} +cp -r lib/erl/ebin ${RPM_BUILD_ROOT}%{thrift_erlang_root} +%endif + + +%clean +rm -rf ${RPM_BUILD_ROOT} + + +%changelog +* Wed May 28 2008 David Reiss - 20080529svn +- Initial build, based on the work of Kevin Smith and Ben Maurer. diff --git a/contrib/thrift.vim b/contrib/thrift.vim new file mode 100644 index 00000000..79ce5472 --- /dev/null +++ b/contrib/thrift.vim @@ -0,0 +1,91 @@ +" Vim syntax file +" Language: Thrift +" Maintainer: Martin Smith +" Last Change: $Date: $ +" Copy to ~/.vim/ +" Add to ~/.vimrc +" au BufRead,BufNewFile *.thrift set filetype=thrift +" au! Syntax thrift source ~/.vim/thrift.vim +" +" $Id: $ +" +" Licensed to the Apache Software Foundation (ASF) under one +" or more contributor license agreements. See the NOTICE file +" distributed with this work for additional information +" regarding copyright ownership. The ASF licenses this file +" to you under the Apache License, Version 2.0 (the +" "License"); you may not use this file except in compliance +" with the License. You may obtain a copy of the License at +" +" http://www.apache.org/licenses/LICENSE-2.0 +" +" Unless required by applicable law or agreed to in writing, +" software distributed under the License is distributed on an +" "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +" KIND, either express or implied. See the License for the +" specific language governing permissions and limitations +" under the License. +" + +if version < 600 + syntax clear +elseif exists("b:current_syntax") + finish +endif + +" Todo +syn keyword thriftTodo TODO todo FIXME fixme XXX xxx contained + +" Comments +syn match thriftComment "#.*" contains=thriftTodo +syn region thriftComment start="/\*" end="\*/" contains=thriftTodo +syn match thriftComment "//.\{-}\(?>\|$\)\@=" + +" String +syn region thriftStringDouble matchgroup=None start=+"+ end=+"+ + +" Number +syn match thriftNumber "-\=\<\d\+\>" contained + +" Keywords +syn keyword thriftKeyword namespace +syn keyword thriftKeyword xsd_all xsd_optional xsd_nillable xsd_attrs +syn keyword thriftKeyword include cpp_include cpp_type const optional required +syn keyword thriftBasicTypes void bool byte i16 i32 i64 double string binary +syn keyword thriftStructure map list set struct typedef exception enum throws + +" Special +syn match thriftSpecial "\d\+:" + +" Structure +syn keyword thriftStructure service oneway extends +"async" { return tok_async; } +"exception" { return tok_xception; } +"extends" { return tok_extends; } +"throws" { return tok_throws; } +"service" { return tok_service; } +"enum" { return tok_enum; } +"const" { return tok_const; } + +if version >= 508 || !exists("did_thrift_syn_inits") + if version < 508 + let did_thrift_syn_inits = 1 + command! -nargs=+ HiLink hi link + else + command! -nargs=+ HiLink hi def link + endif + + HiLink thriftComment Comment + HiLink thriftKeyword Special + HiLink thriftBasicTypes Type + HiLink thriftStructure StorageClass + HiLink thriftTodo Todo + HiLink thriftString String + HiLink thriftNumber Number + HiLink thriftSpecial Special + HiLink thriftStructure Structure + + delcommand HiLink +endif + +let b:current_syntax = "thrift" diff --git a/contrib/thrift_dump.cpp b/contrib/thrift_dump.cpp new file mode 100644 index 00000000..0ddfcec3 --- /dev/null +++ b/contrib/thrift_dump.cpp @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +using namespace std; +using boost::shared_ptr; +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; + +void usage() { + fprintf(stderr, + "usage: thrift_dump {-b|-f|-s} < input > ouput\n" + " -b TBufferedTransport messages\n" + " -f TFramedTransport messages\n" + " -s Raw structures\n"); + exit(EXIT_FAILURE); +} + +int main(int argc, char *argv[]) { + if (argc != 2) { + usage(); + } + + shared_ptr stdin_trans(new TFDTransport(STDIN_FILENO)); + shared_ptr itrans; + + if (argv[1] == std::string("-b") || argv[1] == std::string("-s")) { + itrans.reset(new TBufferedTransport(stdin_trans)); + } else if (argv[1] == std::string("-f")) { + itrans.reset(new TFramedTransport(stdin_trans)); + } else { + usage(); + } + + shared_ptr iprot(new TBinaryProtocol(itrans)); + shared_ptr oprot( + new TDebugProtocol( + shared_ptr(new TBufferedTransport( + shared_ptr(new TFDTransport(STDOUT_FILENO)))))); + + TProtocolTap tap(iprot, oprot); + + try { + if (argv[1] == std::string("-s")) { + for (;;) { + tap.skip(T_STRUCT); + } + } else { + std::string name; + TMessageType messageType; + int32_t seqid; + for (;;) { + tap.readMessageBegin(name, messageType, seqid); + tap.skip(T_STRUCT); + tap.readMessageEnd(); + } + } + } catch (TProtocolException exn) { + cout << "Protocol Exception: " << exn.what() << endl; + } catch (...) { + oprot->getTransport()->flush(); + } + + cout << endl; + + return 0; +} diff --git a/doc/lgpl-2.1.txt b/doc/lgpl-2.1.txt new file mode 100644 index 00000000..5ab7695a --- /dev/null +++ b/doc/lgpl-2.1.txt @@ -0,0 +1,504 @@ + GNU LESSER GENERAL PUBLIC LICENSE + Version 2.1, February 1999 + + Copyright (C) 1991, 1999 Free Software Foundation, Inc. + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + +[This is the first released version of the Lesser GPL. It also counts + as the successor of the GNU Library Public License, version 2, hence + the version number 2.1.] + + Preamble + + The licenses for most software are designed to take away your +freedom to share and change it. By contrast, the GNU General Public +Licenses are intended to guarantee your freedom to share and change +free software--to make sure the software is free for all its users. + + This license, the Lesser General Public License, applies to some +specially designated software packages--typically libraries--of the +Free Software Foundation and other authors who decide to use it. You +can use it too, but we suggest you first think carefully about whether +this license or the ordinary General Public License is the better +strategy to use in any particular case, based on the explanations below. + + When we speak of free software, we are referring to freedom of use, +not price. Our General Public Licenses are designed to make sure that +you have the freedom to distribute copies of free software (and charge +for this service if you wish); that you receive source code or can get +it if you want it; that you can change the software and use pieces of +it in new free programs; and that you are informed that you can do +these things. + + To protect your rights, we need to make restrictions that forbid +distributors to deny you these rights or to ask you to surrender these +rights. These restrictions translate to certain responsibilities for +you if you distribute copies of the library or if you modify it. + + For example, if you distribute copies of the library, whether gratis +or for a fee, you must give the recipients all the rights that we gave +you. You must make sure that they, too, receive or can get the source +code. If you link other code with the library, you must provide +complete object files to the recipients, so that they can relink them +with the library after making changes to the library and recompiling +it. And you must show them these terms so they know their rights. + + We protect your rights with a two-step method: (1) we copyright the +library, and (2) we offer you this license, which gives you legal +permission to copy, distribute and/or modify the library. + + To protect each distributor, we want to make it very clear that +there is no warranty for the free library. Also, if the library is +modified by someone else and passed on, the recipients should know +that what they have is not the original version, so that the original +author's reputation will not be affected by problems that might be +introduced by others. + + Finally, software patents pose a constant threat to the existence of +any free program. We wish to make sure that a company cannot +effectively restrict the users of a free program by obtaining a +restrictive license from a patent holder. Therefore, we insist that +any patent license obtained for a version of the library must be +consistent with the full freedom of use specified in this license. + + Most GNU software, including some libraries, is covered by the +ordinary GNU General Public License. This license, the GNU Lesser +General Public License, applies to certain designated libraries, and +is quite different from the ordinary General Public License. We use +this license for certain libraries in order to permit linking those +libraries into non-free programs. + + When a program is linked with a library, whether statically or using +a shared library, the combination of the two is legally speaking a +combined work, a derivative of the original library. The ordinary +General Public License therefore permits such linking only if the +entire combination fits its criteria of freedom. The Lesser General +Public License permits more lax criteria for linking other code with +the library. + + We call this license the "Lesser" General Public License because it +does Less to protect the user's freedom than the ordinary General +Public License. It also provides other free software developers Less +of an advantage over competing non-free programs. These disadvantages +are the reason we use the ordinary General Public License for many +libraries. However, the Lesser license provides advantages in certain +special circumstances. + + For example, on rare occasions, there may be a special need to +encourage the widest possible use of a certain library, so that it becomes +a de-facto standard. To achieve this, non-free programs must be +allowed to use the library. A more frequent case is that a free +library does the same job as widely used non-free libraries. In this +case, there is little to gain by limiting the free library to free +software only, so we use the Lesser General Public License. + + In other cases, permission to use a particular library in non-free +programs enables a greater number of people to use a large body of +free software. For example, permission to use the GNU C Library in +non-free programs enables many more people to use the whole GNU +operating system, as well as its variant, the GNU/Linux operating +system. + + Although the Lesser General Public License is Less protective of the +users' freedom, it does ensure that the user of a program that is +linked with the Library has the freedom and the wherewithal to run +that program using a modified version of the Library. + + The precise terms and conditions for copying, distribution and +modification follow. Pay close attention to the difference between a +"work based on the library" and a "work that uses the library". The +former contains code derived from the library, whereas the latter must +be combined with the library in order to run. + + GNU LESSER GENERAL PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. This License Agreement applies to any software library or other +program which contains a notice placed by the copyright holder or +other authorized party saying it may be distributed under the terms of +this Lesser General Public License (also called "this License"). +Each licensee is addressed as "you". + + A "library" means a collection of software functions and/or data +prepared so as to be conveniently linked with application programs +(which use some of those functions and data) to form executables. + + The "Library", below, refers to any such software library or work +which has been distributed under these terms. A "work based on the +Library" means either the Library or any derivative work under +copyright law: that is to say, a work containing the Library or a +portion of it, either verbatim or with modifications and/or translated +straightforwardly into another language. (Hereinafter, translation is +included without limitation in the term "modification".) + + "Source code" for a work means the preferred form of the work for +making modifications to it. For a library, complete source code means +all the source code for all modules it contains, plus any associated +interface definition files, plus the scripts used to control compilation +and installation of the library. + + Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of +running a program using the Library is not restricted, and output from +such a program is covered only if its contents constitute a work based +on the Library (independent of the use of the Library in a tool for +writing it). Whether that is true depends on what the Library does +and what the program that uses the Library does. + + 1. You may copy and distribute verbatim copies of the Library's +complete source code as you receive it, in any medium, provided that +you conspicuously and appropriately publish on each copy an +appropriate copyright notice and disclaimer of warranty; keep intact +all the notices that refer to this License and to the absence of any +warranty; and distribute a copy of this License along with the +Library. + + You may charge a fee for the physical act of transferring a copy, +and you may at your option offer warranty protection in exchange for a +fee. + + 2. You may modify your copy or copies of the Library or any portion +of it, thus forming a work based on the Library, and copy and +distribute such modifications or work under the terms of Section 1 +above, provided that you also meet all of these conditions: + + a) The modified work must itself be a software library. + + b) You must cause the files modified to carry prominent notices + stating that you changed the files and the date of any change. + + c) You must cause the whole of the work to be licensed at no + charge to all third parties under the terms of this License. + + d) If a facility in the modified Library refers to a function or a + table of data to be supplied by an application program that uses + the facility, other than as an argument passed when the facility + is invoked, then you must make a good faith effort to ensure that, + in the event an application does not supply such function or + table, the facility still operates, and performs whatever part of + its purpose remains meaningful. + + (For example, a function in a library to compute square roots has + a purpose that is entirely well-defined independent of the + application. Therefore, Subsection 2d requires that any + application-supplied function or table used by this function must + be optional: if the application does not supply it, the square + root function must still compute square roots.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Library, +and can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based +on the Library, the distribution of the whole must be on the terms of +this License, whose permissions for other licensees extend to the +entire whole, and thus to each and every part regardless of who wrote +it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Library. + +In addition, mere aggregation of another work not based on the Library +with the Library (or with a work based on the Library) on a volume of +a storage or distribution medium does not bring the other work under +the scope of this License. + + 3. You may opt to apply the terms of the ordinary GNU General Public +License instead of this License to a given copy of the Library. To do +this, you must alter all the notices that refer to this License, so +that they refer to the ordinary GNU General Public License, version 2, +instead of to this License. (If a newer version than version 2 of the +ordinary GNU General Public License has appeared, then you can specify +that version instead if you wish.) Do not make any other change in +these notices. + + Once this change is made in a given copy, it is irreversible for +that copy, so the ordinary GNU General Public License applies to all +subsequent copies and derivative works made from that copy. + + This option is useful when you wish to copy part of the code of +the Library into a program that is not a library. + + 4. You may copy and distribute the Library (or a portion or +derivative of it, under Section 2) in object code or executable form +under the terms of Sections 1 and 2 above provided that you accompany +it with the complete corresponding machine-readable source code, which +must be distributed under the terms of Sections 1 and 2 above on a +medium customarily used for software interchange. + + If distribution of object code is made by offering access to copy +from a designated place, then offering equivalent access to copy the +source code from the same place satisfies the requirement to +distribute the source code, even though third parties are not +compelled to copy the source along with the object code. + + 5. A program that contains no derivative of any portion of the +Library, but is designed to work with the Library by being compiled or +linked with it, is called a "work that uses the Library". Such a +work, in isolation, is not a derivative work of the Library, and +therefore falls outside the scope of this License. + + However, linking a "work that uses the Library" with the Library +creates an executable that is a derivative of the Library (because it +contains portions of the Library), rather than a "work that uses the +library". The executable is therefore covered by this License. +Section 6 states terms for distribution of such executables. + + When a "work that uses the Library" uses material from a header file +that is part of the Library, the object code for the work may be a +derivative work of the Library even though the source code is not. +Whether this is true is especially significant if the work can be +linked without the Library, or if the work is itself a library. The +threshold for this to be true is not precisely defined by law. + + If such an object file uses only numerical parameters, data +structure layouts and accessors, and small macros and small inline +functions (ten lines or less in length), then the use of the object +file is unrestricted, regardless of whether it is legally a derivative +work. (Executables containing this object code plus portions of the +Library will still fall under Section 6.) + + Otherwise, if the work is a derivative of the Library, you may +distribute the object code for the work under the terms of Section 6. +Any executables containing that work also fall under Section 6, +whether or not they are linked directly with the Library itself. + + 6. As an exception to the Sections above, you may also combine or +link a "work that uses the Library" with the Library to produce a +work containing portions of the Library, and distribute that work +under terms of your choice, provided that the terms permit +modification of the work for the customer's own use and reverse +engineering for debugging such modifications. + + You must give prominent notice with each copy of the work that the +Library is used in it and that the Library and its use are covered by +this License. You must supply a copy of this License. If the work +during execution displays copyright notices, you must include the +copyright notice for the Library among them, as well as a reference +directing the user to the copy of this License. Also, you must do one +of these things: + + a) Accompany the work with the complete corresponding + machine-readable source code for the Library including whatever + changes were used in the work (which must be distributed under + Sections 1 and 2 above); and, if the work is an executable linked + with the Library, with the complete machine-readable "work that + uses the Library", as object code and/or source code, so that the + user can modify the Library and then relink to produce a modified + executable containing the modified Library. (It is understood + that the user who changes the contents of definitions files in the + Library will not necessarily be able to recompile the application + to use the modified definitions.) + + b) Use a suitable shared library mechanism for linking with the + Library. A suitable mechanism is one that (1) uses at run time a + copy of the library already present on the user's computer system, + rather than copying library functions into the executable, and (2) + will operate properly with a modified version of the library, if + the user installs one, as long as the modified version is + interface-compatible with the version that the work was made with. + + c) Accompany the work with a written offer, valid for at + least three years, to give the same user the materials + specified in Subsection 6a, above, for a charge no more + than the cost of performing this distribution. + + d) If distribution of the work is made by offering access to copy + from a designated place, offer equivalent access to copy the above + specified materials from the same place. + + e) Verify that the user has already received a copy of these + materials or that you have already sent this user a copy. + + For an executable, the required form of the "work that uses the +Library" must include any data and utility programs needed for +reproducing the executable from it. However, as a special exception, +the materials to be distributed need not include anything that is +normally distributed (in either source or binary form) with the major +components (compiler, kernel, and so on) of the operating system on +which the executable runs, unless that component itself accompanies +the executable. + + It may happen that this requirement contradicts the license +restrictions of other proprietary libraries that do not normally +accompany the operating system. Such a contradiction means you cannot +use both them and the Library together in an executable that you +distribute. + + 7. You may place library facilities that are a work based on the +Library side-by-side in a single library together with other library +facilities not covered by this License, and distribute such a combined +library, provided that the separate distribution of the work based on +the Library and of the other library facilities is otherwise +permitted, and provided that you do these two things: + + a) Accompany the combined library with a copy of the same work + based on the Library, uncombined with any other library + facilities. This must be distributed under the terms of the + Sections above. + + b) Give prominent notice with the combined library of the fact + that part of it is a work based on the Library, and explaining + where to find the accompanying uncombined form of the same work. + + 8. You may not copy, modify, sublicense, link with, or distribute +the Library except as expressly provided under this License. Any +attempt otherwise to copy, modify, sublicense, link with, or +distribute the Library is void, and will automatically terminate your +rights under this License. However, parties who have received copies, +or rights, from you under this License will not have their licenses +terminated so long as such parties remain in full compliance. + + 9. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Library or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Library (or any work based on the +Library), you indicate your acceptance of this License to do so, and +all its terms and conditions for copying, distributing or modifying +the Library or works based on it. + + 10. Each time you redistribute the Library (or any work based on the +Library), the recipient automatically receives a license from the +original licensor to copy, distribute, link with or modify the Library +subject to these terms and conditions. You may not impose any further +restrictions on the recipients' exercise of the rights granted herein. +You are not responsible for enforcing compliance by third parties with +this License. + + 11. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot +distribute so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you +may not distribute the Library at all. For example, if a patent +license would not permit royalty-free redistribution of the Library by +all those who receive copies directly or indirectly through you, then +the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Library. + +If any portion of this section is held invalid or unenforceable under any +particular circumstance, the balance of the section is intended to apply, +and the section as a whole is intended to apply in other circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system which is +implemented by public license practices. Many people have made +generous contributions to the wide range of software distributed +through that system in reliance on consistent application of that +system; it is up to the author/donor to decide if he or she is willing +to distribute software through any other system and a licensee cannot +impose that choice. + +This section is intended to make thoroughly clear what is believed to +be a consequence of the rest of this License. + + 12. If the distribution and/or use of the Library is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Library under this License may add +an explicit geographical distribution limitation excluding those countries, +so that distribution is permitted only in or among countries not thus +excluded. In such case, this License incorporates the limitation as if +written in the body of this License. + + 13. The Free Software Foundation may publish revised and/or new +versions of the Lesser General Public License from time to time. +Such new versions will be similar in spirit to the present version, +but may differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Library +specifies a version number of this License which applies to it and +"any later version", you have the option of following the terms and +conditions either of that version or of any later version published by +the Free Software Foundation. If the Library does not specify a +license version number, you may choose any version ever published by +the Free Software Foundation. + + 14. If you wish to incorporate parts of the Library into other free +programs whose distribution conditions are incompatible with these, +write to the author to ask for permission. For software which is +copyrighted by the Free Software Foundation, write to the Free +Software Foundation; we sometimes make exceptions for this. Our +decision will be guided by the two goals of preserving the free status +of all derivatives of our free software and of promoting the sharing +and reuse of software generally. + + NO WARRANTY + + 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO +WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. +EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR +OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY +KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE +LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME +THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN +WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY +AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU +FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR +CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE +LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING +RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A +FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF +SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH +DAMAGES. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Libraries + + If you develop a new library, and you want it to be of the greatest +possible use to the public, we recommend making it free software that +everyone can redistribute and change. You can do so by permitting +redistribution under these terms (or, alternatively, under the terms of the +ordinary General Public License). + + To apply these terms, attach the following notices to the library. It is +safest to attach them to the start of each source file to most effectively +convey the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This library is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 2.1 of the License, or (at your option) any later version. + + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this library; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +Also add information on how to contact you by electronic and paper mail. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the library, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the + library `Frob' (a library for tweaking knobs) written by James Random Hacker. + + , 1 April 1990 + Ty Coon, President of Vice + +That's all there is to it! + + diff --git a/doc/otp-base-license.txt b/doc/otp-base-license.txt new file mode 100644 index 00000000..8ee29920 --- /dev/null +++ b/doc/otp-base-license.txt @@ -0,0 +1,20 @@ +Tue Oct 24 12:28:44 CDT 2006 + +Copyright (c) <2006> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software (OTP Base, fslib, G.A.S) and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/doc/thrift.bnf b/doc/thrift.bnf new file mode 100644 index 00000000..24d83f68 --- /dev/null +++ b/doc/thrift.bnf @@ -0,0 +1,96 @@ +Thrift Protocol Structure + +Last Modified: 2007-Jun-29 + +-------------------------------------------------------------------- + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +-------------------------------------------------------------------- + +This document describes the structure of the Thrift protocol +without specifying the encoding. Thus, the order of elements +could in some cases be rearranged depending upon the TProtocol +implementation, but this document specifies the minimum required +structure. There are some "dumb" terminals like STRING and INT +that take the place of an actual encoding specification. + +They key point to notice is that ALL messages are just one wrapped +. Depending upon the message type, the can be +interpreted as the argument list to a function, the return value +of a function, or an exception. + +-------------------------------------------------------------------- + + ::= + + ::= + + ::= STRING + + ::= T_CALL | T_REPLY | T_EXCEPTION + + ::= I32 + + ::= * + + ::= + + ::= STRING + + ::= T_STOP + + ::= + + ::= + + ::= STRING + + ::= T_BOOL | T_BYTE | T_I8 | T_I16 | T_I32 | T_I64 | T_DOUBLE + | T_STRING | T_BINARY | T_STRUCT | T_MAP | T_SET | T_LIST + + ::= I16 + + ::= I8 | I16 | I32 | I64 | DOUBLE | STRING | BINARY + | | | + + ::= * + + ::= + + ::= + + ::= + + ::= I32 + + ::= * + + ::= + + ::= + + ::= I32 + + ::= * + + ::= + + ::= + + ::= I32 diff --git a/doc/thrift.tex b/doc/thrift.tex new file mode 100644 index 00000000..d20b7377 --- /dev/null +++ b/doc/thrift.tex @@ -0,0 +1,1057 @@ +%----------------------------------------------------------------------------- +% +% Thrift whitepaper +% +% Name: thrift.tex +% +% Authors: Mark Slee (mcslee@facebook.com) +% +% Created: 05 March 2007 +% +% You will need a copy of sigplanconf.cls to format this document. +% It is available at . +% +%----------------------------------------------------------------------------- + + +\documentclass[nocopyrightspace,blockstyle]{sigplanconf} + +\usepackage{amssymb} +\usepackage{amsfonts} +\usepackage{amsmath} +\usepackage{url} + +\begin{document} + +% \conferenceinfo{WXYZ '05}{date, City.} +% \copyrightyear{2007} +% \copyrightdata{[to be supplied]} + +% \titlebanner{banner above paper title} % These are ignored unless +% \preprintfooter{short description of paper} % 'preprint' option specified. + +\title{Thrift: Scalable Cross-Language Services Implementation} +\subtitle{} + +\authorinfo{Mark Slee, Aditya Agarwal and Marc Kwiatkowski} + {Facebook, 156 University Ave, Palo Alto, CA} + {\{mcslee,aditya,marc\}@facebook.com} + +\maketitle + +\begin{abstract} +Thrift is a software library and set of code-generation tools developed at +Facebook to expedite development and implementation of efficient and scalable +backend services. Its primary goal is to enable efficient and reliable +communication across programming languages by abstracting the portions of each +language that tend to require the most customization into a common library +that is implemented in each language. Specifically, Thrift allows developers to +define datatypes and service interfaces in a single language-neutral file +and generate all the necessary code to build RPC clients and servers. + +This paper details the motivations and design choices we made in Thrift, as +well as some of the more interesting implementation details. It is not +intended to be taken as research, but rather it is an exposition on what we did +and why. +\end{abstract} + +% \category{D.3.3}{Programming Languages}{Language constructs and features} + +%\terms +%Languages, serialization, remote procedure call + +%\keywords +%Data description language, interface definition language, remote procedure call + +\section{Introduction} +As Facebook's traffic and network structure have scaled, the resource +demands of many operations on the site (i.e. search, +ad selection and delivery, event logging) have presented technical requirements +drastically outside the scope of the LAMP framework. In our implementation of +these services, various programming languages have been selected to +optimize for the right combination of performance, ease and speed of +development, availability of existing libraries, etc. By and large, +Facebook's engineering culture has tended towards choosing the best +tools and implementations available over standardizing on any one +programming language and begrudgingly accepting its inherent limitations. + +Given this design choice, we were presented with the challenge of building +a transparent, high-performance bridge across many programming languages. +We found that most available solutions were either too limited, did not offer +sufficient datatype freedom, or suffered from subpar performance. +\footnote{See Appendix A for a discussion of alternative systems.} + +The solution that we have implemented combines a language-neutral software +stack implemented across numerous programming languages and an associated code +generation engine that transforms a simple interface and data definition +language into client and server remote procedure call libraries. +Choosing static code generation over a dynamic system allows us to create +validated code that can be run without the need for +any advanced introspective run-time type checking. It is also designed to +be as simple as possible for the developer, who can typically define all +the necessary data structures and interfaces for a complex service in a single +short file. + +Surprised that a robust open solution to these relatively common problems +did not yet exist, we committed early on to making the Thrift implementation +open source. + +In evaluating the challenges of cross-language interaction in a networked +environment, some key components were identified: + +\textit{Types.} A common type system must exist across programming languages +without requiring that the application developer use custom Thrift datatypes +or write their own serialization code. That is, +a C++ programmer should be able to transparently exchange a strongly typed +STL map for a dynamic Python dictionary. Neither +programmer should be forced to write any code below the application layer +to achieve this. Section 2 details the Thrift type system. + +\textit{Transport.} Each language must have a common interface to +bidirectional raw data transport. The specifics of how a given +transport is implemented should not matter to the service developer. +The same application code should be able to run against TCP stream sockets, +raw data in memory, or files on disk. Section 3 details the Thrift Transport +layer. + +\textit{Protocol.} Datatypes must have some way of using the Transport +layer to encode and decode themselves. Again, the application +developer need not be concerned by this layer. Whether the service uses +an XML or binary protocol is immaterial to the application code. +All that matters is that the data can be read and written in a consistent, +deterministic matter. Section 4 details the Thrift Protocol layer. + +\textit{Versioning.} For robust services, the involved datatypes must +provide a mechanism for versioning themselves. Specifically, +it should be possible to add or remove fields in an object or alter the +argument list of a function without any interruption in service (or, +worse yet, nasty segmentation faults). Section 5 details Thrift's versioning +system. + +\textit{Processors.} Finally, we generate code capable of processing data +streams to accomplish remote procedure calls. Section 6 details the generated +code and TProcessor paradigm. + +Section 7 discusses implementation details, and Section 8 describes +our conclusions. + +\section{Types} + +The goal of the Thrift type system is to enable programmers to develop using +completely natively defined types, no matter what programming language they +use. By design, the Thrift type system does not introduce any special dynamic +types or wrapper objects. It also does not require that the developer write +any code for object serialization or transport. The Thrift IDL (Interface +Definition Language) file is +logically a way for developers to annotate their data structures with the +minimal amount of extra information necessary to tell a code generator +how to safely transport the objects across languages. + +\subsection{Base Types} + +The type system rests upon a few base types. In considering which types to +support, we aimed for clarity and simplicity over abundance, focusing +on the key types available in all programming languages, ommitting any +niche types available only in specific languages. + +The base types supported by Thrift are: +\begin{itemize} +\item \texttt{bool} A boolean value, true or false +\item \texttt{byte} A signed byte +\item \texttt{i16} A 16-bit signed integer +\item \texttt{i32} A 32-bit signed integer +\item \texttt{i64} A 64-bit signed integer +\item \texttt{double} A 64-bit floating point number +\item \texttt{string} An encoding-agnostic text or binary string +\item \texttt{binary} A byte array representation for blobs +\end{itemize} + +Of particular note is the absence of unsigned integer types. Because these +types have no direct translation to native primitive types in many languages, +the advantages they afford are lost. Further, there is no way to prevent the +application developer in a language like Python from assigning a negative value +to an integer variable, leading to unpredictable behavior. From a design +standpoint, we observed that unsigned integers were very rarely, if ever, used +for arithmetic purposes, but in practice were much more often used as keys or +identifiers. In this case, the sign is irrelevant. Signed integers serve this +same purpose and can be safely cast to their unsigned counterparts (most +commonly in C++) when absolutely necessary. + +\subsection{Structs} + +A Thrift struct defines a common object to be used across languages. A struct +is essentially equivalent to a class in object oriented programming +languages. A struct has a set of strongly typed fields, each with a unique +name identifier. The basic syntax for defining a Thrift struct looks very +similar to a C struct definition. Fields may be annotated with an integer field +identifier (unique to the scope of that struct) and optional default values. +Field identifiers will be automatically assigned if omitted, though they are +strongly encouraged for versioning reasons discussed later. + +\subsection{Containers} + +Thrift containers are strongly typed containers that map to the most commonly +used containers in common programming languages. They are annotated using +the C++ template (or Java Generics) style. There are three types available: +\begin{itemize} +\item \texttt{list} An ordered list of elements. Translates directly into +an STL \texttt{vector}, Java \texttt{ArrayList}, or native array in scripting languages. May +contain duplicates. +\item \texttt{set} An unordered set of unique elements. Translates into +an STL \texttt{set}, Java \texttt{HashSet}, \texttt{set} in Python, or native +dictionary in PHP/Ruby. +\item \texttt{map} A map of strictly unique keys to values +Translates into an STL \texttt{map}, Java \texttt{HashMap}, PHP associative +array, or Python/Ruby dictionary. +\end{itemize} + +While defaults are provided, the type mappings are not explicitly fixed. Custom +code generator directives have been added to substitute custom types in +destination languages (i.e. +\texttt{hash\_map} or Google's sparse hash map can be used in C++). The +only requirement is that the custom types support all the necessary iteration +primitives. Container elements may be of any valid Thrift type, including other +containers or structs. + +\begin{verbatim} +struct Example { + 1:i32 number=10, + 2:i64 bigNumber, + 3:double decimals, + 4:string name="thrifty" +}\end{verbatim} + +In the target language, each definition generates a type with two methods, +\texttt{read} and \texttt{write}, which perform serialization and transport +of the objects using a Thrift TProtocol object. + +\subsection{Exceptions} + +Exceptions are syntactically and functionally equivalent to structs except +that they are declared using the \texttt{exception} keyword instead of the +\texttt{struct} keyword. + +The generated objects inherit from an exception base class as appropriate +in each target programming language, in order to seamlessly +integrate with native exception handling in any given +language. Again, the design emphasis is on making the code familiar to the +application developer. + +\subsection{Services} + +Services are defined using Thrift types. Definition of a service is +semantically equivalent to defining an interface (or a pure virtual abstract +class) in object oriented +programming. The Thrift compiler generates fully functional client and +server stubs that implement the interface. Services are defined as follows: + +\begin{verbatim} +service { + () + [throws ()] + ... +}\end{verbatim} + +An example: + +\begin{verbatim} +service StringCache { + void set(1:i32 key, 2:string value), + string get(1:i32 key) throws (1:KeyNotFound knf), + void delete(1:i32 key) +} +\end{verbatim} + +Note that \texttt{void} is a valid type for a function return, in addition to +all other defined Thrift types. Additionally, an \texttt{async} modifier +keyword may be added to a \texttt{void} function, which will generate code that does +not wait for a response from the server. Note that a pure \texttt{void} +function will return a response to the client which guarantees that the +operation has completed on the server side. With \texttt{async} method calls +the client will only be guaranteed that the request succeeded at the +transport layer. (In many transport scenarios this is inherently unreliable +due to the Byzantine Generals' Problem. Therefore, application developers +should take care only to use the async optimization in cases where dropped +method calls are acceptable or the transport is known to be reliable.) + +Also of note is the fact that argument lists and exception lists for functions +are implemented as Thrift structs. All three constructs are identical in both +notation and behavior. + +\section{Transport} + +The transport layer is used by the generated code to facilitate data transfer. + +\subsection{Interface} + +A key design choice in the implementation of Thrift was to decouple the +transport layer from the code generation layer. Though Thrift is typically +used on top of the TCP/IP stack with streaming sockets as the base layer of +communication, there was no compelling reason to build that constraint into +the system. The performance tradeoff incurred by an abstracted I/O layer +(roughly one virtual method lookup / function call per operation) was +immaterial compared to the cost of actual I/O operations (typically invoking +system calls). + +Fundamentally, generated Thrift code only needs to know how to read and +write data. The origin and destination of the data are irrelevant; it may be a +socket, a segment of shared memory, or a file on the local disk. The Thrift +transport interface supports the following methods: + +\begin{itemize} +\item \texttt{open} Opens the tranpsort +\item \texttt{close} Closes the tranport +\item \texttt{isOpen} Indicates whether the transport is open +\item \texttt{read} Reads from the transport +\item \texttt{write} Writes to the transport +\item \texttt{flush} Forces any pending writes +\end{itemize} + +There are a few additional methods not documented here which are used to aid +in batching reads and optionally signaling the completion of a read or +write operation from the generated code. + +In addition to the above +\texttt{TTransport} interface, there is a\\ +\texttt{TServerTransport} interface +used to accept or create primitive transport objects. Its interface is as +follows: + +\begin{itemize} +\item \texttt{open} Opens the transport +\item \texttt{listen} Begins listening for connections +\item \texttt{accept} Returns a new client transport +\item \texttt{close} Closes the transport +\end{itemize} + +\subsection{Implementation} + +The transport interface is designed for simple implementation in any +programming language. New transport mechanisms can be easily defined as needed +by application developers. + +\subsubsection{TSocket} + +The \texttt{TSocket} class is implemented across all target languages. It +provides a common, simple interface to a TCP/IP stream socket. + +\subsubsection{TFileTransport} + +The \texttt{TFileTransport} is an abstraction of an on-disk file to a data +stream. It can be used to write out a set of incoming Thrift requests to a file +on disk. The on-disk data can then be replayed from the log, either for +post-processing or for reproduction and/or simulation of past events. + +\subsubsection{Utilities} + +The Transport interface is designed to support easy extension using common +OOP techniques, such as composition. Some simple utilites include the +\texttt{TBufferedTransport}, which buffers the writes and reads on an +underlying transport, the \texttt{TFramedTransport}, which transmits data with frame +size headers for chunking optimization or nonblocking operation, and the +\texttt{TMemoryBuffer}, which allows reading and writing directly from the heap +or stack memory owned by the process. + +\section{Protocol} + +A second major abstraction in Thrift is the separation of data structure from +transport representation. Thrift enforces a certain messaging structure when +transporting data, but it is agnostic to the protocol encoding in use. That is, +it does not matter whether data is encoded as XML, human-readable ASCII, or a +dense binary format as long as the data supports a fixed set of operations +that allow it to be deterministically read and written by generated code. + +\subsection{Interface} + +The Thrift Protocol interface is very straightforward. It fundamentally +supports two things: 1) bidirectional sequenced messaging, and +2) encoding of base types, containers, and structs. + +\begin{verbatim} +writeMessageBegin(name, type, seq) +writeMessageEnd() +writeStructBegin(name) +writeStructEnd() +writeFieldBegin(name, type, id) +writeFieldEnd() +writeFieldStop() +writeMapBegin(ktype, vtype, size) +writeMapEnd() +writeListBegin(etype, size) +writeListEnd() +writeSetBegin(etype, size) +writeSetEnd() +writeBool(bool) +writeByte(byte) +writeI16(i16) +writeI32(i32) +writeI64(i64) +writeDouble(double) +writeString(string) + +name, type, seq = readMessageBegin() + readMessageEnd() +name = readStructBegin() + readStructEnd() +name, type, id = readFieldBegin() + readFieldEnd() +k, v, size = readMapBegin() + readMapEnd() +etype, size = readListBegin() + readListEnd() +etype, size = readSetBegin() + readSetEnd() +bool = readBool() +byte = readByte() +i16 = readI16() +i32 = readI32() +i64 = readI64() +double = readDouble() +string = readString() +\end{verbatim} + +Note that every \texttt{write} function has exactly one \texttt{read} counterpart, with +the exception of \texttt{writeFieldStop()}. This is a special method +that signals the end of a struct. The procedure for reading a struct is to +\texttt{readFieldBegin()} until the stop field is encountered, and then to +\texttt{readStructEnd()}. The +generated code relies upon this call sequence to ensure that everything written by +a protocol encoder can be read by a matching protocol decoder. Further note +that this set of functions is by design more robust than necessary. +For example, \texttt{writeStructEnd()} is not strictly necessary, as the end of +a struct may be implied by the stop field. This method is a convenience for +verbose protocols in which it is cleaner to separate these calls (e.g. a closing +\texttt{} tag in XML). + +\subsection{Structure} + +Thrift structures are designed to support encoding into a streaming +protocol. The implementation should never need to frame or compute the +entire data length of a structure prior to encoding it. This is critical to +performance in many scenarios. Consider a long list of relatively large +strings. If the protocol interface required reading or writing a list to be an +atomic operation, then the implementation would need to perform a linear pass over the +entire list before encoding any data. However, if the list can be written +as iteration is performed, the corresponding read may begin in parallel, +theoretically offering an end-to-end speedup of $(kN - C)$, where $N$ is the size +of the list, $k$ the cost factor associated with serializing a single +element, and $C$ is fixed offset for the delay between data being written +and becoming available to read. + +Similarly, structs do not encode their data lengths a priori. Instead, they are +encoded as a sequence of fields, with each field having a type specifier and a +unique field identifier. Note that the inclusion of type specifiers allows +the protocol to be safely parsed and decoded without any generated code +or access to the original IDL file. Structs are terminated by a field header +with a special \texttt{STOP} type. Because all the basic types can be read +deterministically, all structs (even those containing other structs) can be +read deterministically. The Thrift protocol is self-delimiting without any +framing and regardless of the encoding format. + +In situations where streaming is unnecessary or framing is advantageous, it +can be very simply added into the transport layer, using the +\texttt{TFramedTransport} abstraction. + +\subsection{Implementation} + +Facebook has implemented and deployed a space-efficient binary protocol which +is used by most backend services. Essentially, it writes all data +in a flat binary format. Integer types are converted to network byte order, +strings are prepended with their byte length, and all message and field headers +are written using the primitive integer serialization constructs. String names +for fields are omitted - when using generated code, field identifiers are +sufficient. + +We decided against some extreme storage optimizations (i.e. packing +small integers into ASCII or using a 7-bit continuation format) for the sake +of simplicity and clarity in the code. These alterations can easily be made +if and when we encounter a performance-critical use case that demands them. + +\section{Versioning} + +Thrift is robust in the face of versioning and data definition changes. This +is critical to enable staged rollouts of changes to deployed services. The +system must be able to support reading of old data from log files, as well as +requests from out-of-date clients to new servers, and vice versa. + +\subsection{Field Identifiers} + +Versioning in Thrift is implemented via field identifiers. The field header +for every member of a struct in Thrift is encoded with a unique field +identifier. The combination of this field identifier and its type specifier +is used to uniquely identify the field. The Thrift definition language +supports automatic assignment of field identifiers, but it is good +programming practice to always explicitly specify field identifiers. +Identifiers are specified as follows: + +\begin{verbatim} +struct Example { + 1:i32 number=10, + 2:i64 bigNumber, + 3:double decimals, + 4:string name="thrifty" +}\end{verbatim} + +To avoid conflicts between manually and automatically assigned identifiers, +fields with identifiers omitted are assigned identifiers +decrementing from -1, and the language only supports the manual assignment of +positive identifiers. + +When data is being deserialized, the generated code can use these identifiers +to properly identify the field and determine whether it aligns with a field in +its definition file. If a field identifier is not recognized, the generated +code can use the type specifier to skip the unknown field without any error. +Again, this is possible due to the fact that all datatypes are self +delimiting. + +Field identifiers can (and should) also be specified in function argument +lists. In fact, argument lists are not only represented as structs on the +backend, but actually share the same code in the compiler frontend. This +allows for version-safe modification of method parameters + +\begin{verbatim} +service StringCache { + void set(1:i32 key, 2:string value), + string get(1:i32 key) throws (1:KeyNotFound knf), + void delete(1:i32 key) +} +\end{verbatim} + +The syntax for specifying field identifiers was chosen to echo their structure. +Structs can be thought of as a dictionary where the identifiers are keys, and +the values are strongly-typed named fields. + +Field identifiers internally use the \texttt{i16} Thrift type. Note, however, +that the \texttt{TProtocol} abstraction may encode identifiers in any format. + +\subsection{Isset} + +When an unexpected field is encountered, it can be safely ignored and +discarded. When an expected field is not found, there must be some way to +signal to the developer that it was not present. This is implemented via an +inner \texttt{isset} structure inside the defined objects. (Isset functionality +is implicit with a \texttt{null} value in PHP, \texttt{None} in Python +and \texttt{nil} in Ruby.) Essentially, +the inner \texttt{isset} object of each Thrift struct contains a boolean value +for each field which denotes whether or not that field is present in the +struct. When a reader receives a struct, it should check for a field being set +before operating directly on it. + +\begin{verbatim} +class Example { + public: + Example() : + number(10), + bigNumber(0), + decimals(0), + name("thrifty") {} + + int32_t number; + int64_t bigNumber; + double decimals; + std::string name; + + struct __isset { + __isset() : + number(false), + bigNumber(false), + decimals(false), + name(false) {} + bool number; + bool bigNumber; + bool decimals; + bool name; + } __isset; +... +} +\end{verbatim} + +\subsection{Case Analysis} + +There are four cases in which version mismatches may occur. + +\begin{enumerate} +\item \textit{Added field, old client, new server.} In this case, the old +client does not send the new field. The new server recognizes that the field +is not set, and implements default behavior for out-of-date requests. +\item \textit{Removed field, old client, new server.} In this case, the old +client sends the removed field. The new server simply ignores it. +\item \textit{Added field, new client, old server.} The new client sends a +field that the old server does not recognize. The old server simply ignores +it and processes as normal. +\item \textit{Removed field, new client, old server.} This is the most +dangerous case, as the old server is unlikely to have suitable default +behavior implemented for the missing field. It is recommended that in this +situation the new server be rolled out prior to the new clients. +\end{enumerate} + +\subsection{Protocol/Transport Versioning} +The \texttt{TProtocol} abstractions are also designed to give protocol +implementations the freedom to version themselves in whatever manner they +see fit. Specifically, any protocol implementation is free to send whatever +it likes in the \texttt{writeMessageBegin()} call. It is entirely up to the +implementor how to handle versioning at the protocol level. The key point is +that protocol encoding changes are safely isolated from interface definition +version changes. + +Note that the exact same is true of the \texttt{TTransport} interface. For +example, if we wished to add some new checksumming or error detection to the +\texttt{TFileTransport}, we could simply add a version header into the +data it writes to the file in such a way that it would still accept old +log files without the given header. + +\section{RPC Implementation} + +\subsection{TProcessor} + +The last core interface in the Thrift design is the \texttt{TProcessor}, +perhaps the most simple of the constructs. The interface is as follows: + +\begin{verbatim} +interface TProcessor { + bool process(TProtocol in, TProtocol out) + throws TException +} +\end{verbatim} + +The key design idea here is that the complex systems we build can fundamentally +be broken down into agents or services that operate on inputs and outputs. In +most cases, there is actually just one input and output (an RPC client) that +needs handling. + +\subsection{Generated Code} + +When a service is defined, we generate a +\texttt{TProcessor} instance capable of handling RPC requests to that service, +using a few helpers. The fundamental structure (illustrated in pseudo-C++) is +as follows: + +\begin{verbatim} +Service.thrift + => Service.cpp + interface ServiceIf + class ServiceClient : virtual ServiceIf + TProtocol in + TProtocol out + class ServiceProcessor : TProcessor + ServiceIf handler + +ServiceHandler.cpp + class ServiceHandler : virtual ServiceIf + +TServer.cpp + TServer(TProcessor processor, + TServerTransport transport, + TTransportFactory tfactory, + TProtocolFactory pfactory) + serve() +\end{verbatim} + +From the Thrift definition file, we generate the virtual service interface. +A client class is generated, which implements the interface and +uses two \texttt{TProtocol} instances to perform the I/O operations. The +generated processor implements the \texttt{TProcessor} interface. The generated +code has all the logic to handle RPC invocations via the \texttt{process()} +call, and takes as a parameter an instance of the service interface, as +implemented by the application developer. + +The user provides an implementation of the application interface in separate, +non-generated source code. + +\subsection{TServer} + +Finally, the Thrift core libraries provide a \texttt{TServer} abstraction. +The \texttt{TServer} object generally works as follows. + +\begin{itemize} +\item Use the \texttt{TServerTransport} to get a \texttt{TTransport} +\item Use the \texttt{TTransportFactory} to optionally convert the primitive +transport into a suitable application transport (typically the +\texttt{TBufferedTransportFactory} is used here) +\item Use the \texttt{TProtocolFactory} to create an input and output protocol +for the \texttt{TTransport} +\item Invoke the \texttt{process()} method of the \texttt{TProcessor} object +\end{itemize} + +The layers are appropriately separated such that the server code needs to know +nothing about any of the transports, encodings, or applications in play. The +server encapsulates the logic around connection handling, threading, etc. +while the processor deals with RPC. The only code written by the application +developer lives in the definitional Thrift file and the interface +implementation. + +Facebook has deployed multiple \texttt{TServer} implementations, including +the single-threaded \texttt{TSimpleServer}, thread-per-connection +\texttt{TThreadedServer}, and thread-pooling \texttt{TThreadPoolServer}. + +The \texttt{TProcessor} interface is very general by design. There is no +requirement that a \texttt{TServer} take a generated \texttt{TProcessor} +object. Thrift allows the application developer to easily write any type of +server that operates on \texttt{TProtocol} objects (for instance, a server +could simply stream a certain type of object without any actual RPC method +invocation). + +\section{Implementation Details} +\subsection{Target Languages} +Thrift currently supports five target languages: C++, Java, Python, Ruby, and +PHP. At Facebook, we have deployed servers predominantly in C++, Java, and +Python. Thrift services implemented in PHP have also been embedded into the +Apache web server, providing transparent backend access to many of our +frontend constructs using a \texttt{THttpClient} implementation of the +\texttt{TTransport} interface. + +Though Thrift was explicitly designed to be much more efficient and robust +than typical web technologies, as we were designing our XML-based REST web +services API we noticed that Thrift could be easily used to define our +service interface. Though we do not currently employ SOAP envelopes (in the +authors' opinions there is already far too much repetitive enterprise Java +software to do that sort of thing), we were able to quickly extend Thrift to +generate XML Schema Definition files for our service, as well as a framework +for versioning different implementations of our web service. Though public +web services are admittedly tangential to Thrift's core use case and design, +Thrift facilitated rapid iteration and affords us the ability to quickly +migrate our entire XML-based web service onto a higher performance system +should the need arise. + +\subsection{Generated Structs} +We made a conscious decision to make our generated structs as transparent as +possible. All fields are publicly accessible; there are no \texttt{set()} and +\texttt{get()} methods. Similarly, use of the \texttt{isset} object is not +enforced. We do not include any \texttt{FieldNotSetException} construct. +Developers have the option to use these fields to write more robust code, but +the system is robust to the developer ignoring the \texttt{isset} construct +entirely and will provide suitable default behavior in all cases. + +This choice was motivated by the desire to ease application development. Our stated +goal is not to make developers learn a rich new library in their language of +choice, but rather to generate code that allow them to work with the constructs +that are most familiar in each language. + +We also made the \texttt{read()} and \texttt{write()} methods of the generated +objects public so that the objects can be used outside of the context +of RPC clients and servers. Thrift is a useful tool simply for generating +objects that are easily serializable across programming languages. + +\subsection{RPC Method Identification} +Method calls in RPC are implemented by sending the method name as a string. One +issue with this approach is that longer method names require more bandwidth. +We experimented with using fixed-size hashes to identify methods, but in the +end concluded that the savings were not worth the headaches incurred. Reliably +dealing with conflicts across versions of an interface definition file is +impossible without a meta-storage system (i.e. to generate non-conflicting +hashes for the current version of a file, we would have to know about all +conflicts that ever existed in any previous version of the file). + +We wanted to avoid too many unnecessary string comparisons upon +method invocation. To deal with this, we generate maps from strings to function +pointers, so that invocation is effectively accomplished via a constant-time +hash lookup in the common case. This requires the use of a couple interesting +code constructs. Because Java does not have function pointers, process +functions are all private member classes implementing a common interface. + +\begin{verbatim} +private class ping implements ProcessFunction { + public void process(int seqid, + TProtocol iprot, + TProtocol oprot) + throws TException + { ...} +} + +HashMap processMap_ = + new HashMap(); +\end{verbatim} + +In C++, we use a relatively esoteric language construct: member function +pointers. + +\begin{verbatim} +std::map + processMap_; +\end{verbatim} + +Using these techniques, the cost of string processing is minimized, and we +reap the benefit of being able to easily debug corrupt or misunderstood data by +inspecting it for known string method names. + +\subsection{Servers and Multithreading} +Thrift services require basic multithreading to handle simultaneous +requests from multiple clients. For the Python and Java implementations of +Thrift server logic, the standard threading libraries distributed with the +languages provide adequate support. For the C++ implementation, no standard multithread runtime +library exists. Specifically, robust, lightweight, and portable +thread manager and timer class implementations do not exist. We investigated +existing implementations, namely \texttt{boost::thread}, +\texttt{boost::threadpool}, \texttt{ACE\_Thread\_Manager} and +\texttt{ACE\_Timer}. + +While \texttt{boost::threads}\cite{boost.threads} provides clean, +lightweight and robust implementations of multi-thread primitives (mutexes, +conditions, threads) it does not provide a thread manager or timer +implementation. + +\texttt{boost::threadpool}\cite{boost.threadpool} also looked promising but +was not far enough along for our purposes. We wanted to limit the dependency on +third-party libraries as much as possible. Because\\ +\texttt{boost::threadpool} is +not a pure template library and requires runtime libraries and because it is +not yet part of the official Boost distribution we felt it was not ready for +use in Thrift. As \texttt{boost::threadpool} evolves and especially if it is +added to the Boost distribution we may reconsider our decision to not use it. + +ACE has both a thread manager and timer class in addition to multi-thread +primitives. The biggest problem with ACE is that it is ACE. Unlike Boost, ACE +API quality is poor. Everything in ACE has large numbers of dependencies on +everything else in ACE - thus forcing developers to throw out standard +classes, such as STL collections, in favor of ACE's homebrewed implementations. In +addition, unlike Boost, ACE implementations demonstrate little understanding +of the power and pitfalls of C++ programming and take no advantage of modern +templating techniques to ensure compile time safety and reasonable compiler +error messages. For all these reasons, ACE was rejected. Instead, we chose +to implement our own library, described in the following sections. + +\subsection{Thread Primitives} + +The Thrift thread libraries are implemented in the namespace\\ +\texttt{facebook::thrift::concurrency} and have three components: +\begin{itemize} +\item primitives +\item thread pool manager +\item timer manager +\end{itemize} + +As mentioned above, we were hesitant to introduce any additional dependencies +on Thrift. We decided to use \texttt{boost::shared\_ptr} because it is so +useful for multithreaded application, it requires no link-time or +runtime libraries (i.e. it is a pure template library) and it is due +to become part of the C++0x standard. + +We implement standard \texttt{Mutex} and \texttt{Condition} classes, and a + \texttt{Monitor} class. The latter is simply a combination of a mutex and +condition variable and is analogous to the \texttt{Monitor} implementation provided for +the Java \texttt{Object} class. This is also sometimes referred to as a barrier. We +provide a \texttt{Synchronized} guard class to allow Java-like synchronized blocks. +This is just a bit of syntactic sugar, but, like its Java counterpart, clearly +delimits critical sections of code. Unlike its Java counterpart, we still +have the ability to programmatically lock, unlock, block, and signal monitors. + +\begin{verbatim} +void run() { + {Synchronized s(manager->monitor); + if (manager->state == TimerManager::STARTING) { + manager->state = TimerManager::STARTED; + manager->monitor.notifyAll(); + } + } +} +\end{verbatim} + +We again borrowed from Java the distinction between a thread and a runnable +class. A \texttt{Thread} is the actual schedulable object. The +\texttt{Runnable} is the logic to execute within the thread. +The \texttt{Thread} implementation deals with all the platform-specific thread +creation and destruction issues, while the \texttt{Runnable} implementation deals +with the application-specific per-thread logic. The benefit of this approach +is that developers can easily subclass the Runnable class without pulling in +platform-specific super-classes. + +\subsection{Thread, Runnable, and shared\_ptr} +We use \texttt{boost::shared\_ptr} throughout the \texttt{ThreadManager} and +\texttt{TimerManager} implementations to guarantee cleanup of dead objects that can +be accessed by multiple threads. For \texttt{Thread} class implementations, +\texttt{boost::shared\_ptr} usage requires particular attention to make sure +\texttt{Thread} objects are neither leaked nor dereferenced prematurely while +creating and shutting down threads. + +Thread creation requires calling into a C library. (In our case the POSIX +thread library, \texttt{libpthread}, but the same would be true for WIN32 threads). +Typically, the OS makes few, if any, guarantees about when \texttt{ThreadMain}, a C thread's entry-point function, will be called. Therefore, it is +possible that our thread create call, +\texttt{ThreadFactory::newThread()} could return to the caller +well before that time. To ensure that the returned \texttt{Thread} object is not +prematurely cleaned up if the caller gives up its reference prior to the +\texttt{ThreadMain} call, the \texttt{Thread} object makes a weak referenence to +itself in its \texttt{start} method. + +With the weak reference in hand the \texttt{ThreadMain} function can attempt to get +a strong reference before entering the \texttt{Runnable::run} method of the +\texttt{Runnable} object bound to the \texttt{Thread}. If no strong references to the +thread are obtained between exiting \texttt{Thread::start} and entering \texttt{ThreadMain}, the weak reference returns \texttt{null} and the function +exits immediately. + +The need for the \texttt{Thread} to make a weak reference to itself has a +significant impact on the API. Since references are managed through the +\texttt{boost::shared\_ptr} templates, the \texttt{Thread} object must have a reference +to itself wrapped by the same \texttt{boost::shared\_ptr} envelope that is returned +to the caller. This necessitated the use of the factory pattern. +\texttt{ThreadFactory} creates the raw \texttt{Thread} object and a +\texttt{boost::shared\_ptr} wrapper, and calls a private helper method of the class +implementing the \texttt{Thread} interface (in this case, \texttt{PosixThread::weakRef}) + to allow it to make add weak reference to itself through the + \texttt{boost::shared\_ptr} envelope. + +\texttt{Thread} and \texttt{Runnable} objects reference each other. A \texttt{Runnable} +object may need to know about the thread in which it is executing, and a Thread, obviously, +needs to know what \texttt{Runnable} object it is hosting. This interdependency is +further complicated because the lifecycle of each object is independent of the +other. An application may create a set of \texttt{Runnable} object to be reused in different threads, or it may create and forget a \texttt{Runnable} object +once a thread has been created and started for it. + +The \texttt{Thread} class takes a \texttt{boost::shared\_ptr} reference to the hosted +\texttt{Runnable} object in its constructor, while the \texttt{Runnable} class has an +explicit \texttt{thread} method to allow explicit binding of the hosted thread. +\texttt{ThreadFactory::newThread} binds the objects to each other. + +\subsection{ThreadManager} + +\texttt{ThreadManager} creates a pool of worker threads and +allows applications to schedule tasks for execution as free worker threads +become available. The \texttt{ThreadManager} does not implement dynamic +thread pool resizing, but provides primitives so that applications can add +and remove threads based on load. This approach was chosen because +implementing load metrics and thread pool size is very application +specific. For example some applications may want to adjust pool size based +on running-average of work arrival rates that are measured via polled +samples. Others may simply wish to react immediately to work-queue +depth high and low water marks. Rather than trying to create a complex +API abstract enough to capture these different approaches, we +simply leave it up to the particular application and provide the +primitives to enact the desired policy and sample current status. + +\subsection{TimerManager} + +\texttt{TimerManager} allows applications to schedule + \texttt{Runnable} objects for execution at some point in the future. Its specific task +is to allows applications to sample \texttt{ThreadManager} load at regular +intervals and make changes to the thread pool size based on application policy. +Of course, it can be used to generate any number of timer or alarm events. + +The default implementation of \texttt{TimerManager} uses a single thread to +execute expired \texttt{Runnable} objects. Thus, if a timer operation needs to +do a large amount of work and especially if it needs to do blocking I/O, +that should be done in a separate thread. + +\subsection{Nonblocking Operation} +Though the Thrift transport interfaces map more directly to a blocking I/O +model, we have implemented a high performance \texttt{TNonBlockingServer} +in C++ based on \texttt{libevent} and the \texttt{TFramedTransport}. We +implemented this by moving all I/O into one tight event loop using a +state machine. Essentially, the event loop reads framed requests into +\texttt{TMemoryBuffer} objects. Once entire requests are ready, they are +dispatched to the \texttt{TProcessor} object which can read directly from +the data in memory. + +\subsection{Compiler} +The Thrift compiler is implemented in C++ using standard \texttt{lex}/\texttt{yacc} +lexing and parsing. Though it could have been implemented with fewer +lines of code in another language (i.e. Python Lex-Yacc (PLY) or \texttt{ocamlyacc}), using C++ +forces explicit definition of the language constructs. Strongly typing the +parse tree elements (debatably) makes the code more approachable for new +developers. + +Code generation is done using two passes. The first pass looks only for +include files and type definitions. Type definitions are not checked during +this phase, since they may depend upon include files. All included files +are sequentially scanned in a first pass. Once the include tree has been +resolved, a second pass over all files is taken that inserts type definitions +into the parse tree and raises an error on any undefined types. The program is +then generated against the parse tree. + +Due to inherent complexities and potential for circular dependencies, +we explicitly disallow forward declaration. Two Thrift structs cannot +each contain an instance of the other. (Since we do not allow \texttt{null} +struct instances in the generated C++ code, this would actually be impossible.) + +\subsection{TFileTransport} +The \texttt{TFileTransport} logs Thrift requests/structs by +framing incoming data with its length and writing it out to disk. +Using a framed on-disk format allows for better error checking and +helps with the processing of a finite number of discrete events. The\\ +\texttt{TFileWriterTransport} uses a system of swapping in-memory buffers +to ensure good performance while logging large amounts of data. +A Thrift log file is split up into chunks of a specified size; logged messages +are not allowed to cross chunk boundaries. A message that would cross a chunk +boundary will cause padding to be added until the end of the chunk and the +first byte of the message are aligned to the beginning of the next chunk. +Partitioning the file into chunks makes it possible to read and interpret data +from a particular point in the file. + +\section{Facebook Thrift Services} +Thrift has been employed in a large number of applications at Facebook, including +search, logging, mobile, ads and the developer platform. Two specific usages are discussed below. + +\subsection{Search} +Thrift is used as the underlying protocol and transport layer for the Facebook Search service. +The multi-language code generation is well suited for search because it allows for application +development in an efficient server side language (C++) and allows the Facebook PHP-based web application +to make calls to the search service using Thrift PHP libraries. There is also a large +variety of search stats, deployment and testing functionality that is built on top +of generated Python code. Additionally, the Thrift log file format is +used as a redo log for providing real-time search index updates. Thrift has allowed the +search team to leverage each language for its strengths and to develop code at a rapid pace. + +\subsection{Logging} +The Thrift \texttt{TFileTransport} functionality is used for structured logging. Each +service function definition along with its parameters can be considered to be +a structured log entry identified by the function name. This log can then be used for +a variety of purposes, including inline and offline processing, stats aggregation and as a redo log. + +\section{Conclusions} +Thrift has enabled Facebook to build scalable backend +services efficiently by enabling engineers to divide and conquer. Application +developers can focus on application code without worrying about the +sockets layer. We avoid duplicated work by writing buffering and I/O logic +in one place, rather than interspersing it in each application. + +Thrift has been employed in a wide variety of applications at Facebook, +including search, logging, mobile, ads, and the developer platform. We have +found that the marginal performance cost incurred by an extra layer of +software abstraction is far eclipsed by the gains in developer efficiency and +systems reliability. + +\appendix + +\section{Similar Systems} +The following are software systems similar to Thrift. Each is (very!) briefly +described: + +\begin{itemize} +\item \textit{SOAP.} XML-based. Designed for web services via HTTP, excessive +XML parsing overhead. +\item \textit{CORBA.} Relatively comprehensive, debatably overdesigned and +heavyweight. Comparably cumbersome software installation. +\item \textit{COM.} Embraced mainly in Windows client softare. Not an entirely +open solution. +\item \textit{Pillar.} Lightweight and high-performance, but missing versioning +and abstraction. +\item \textit{Protocol Buffers.} Closed-source, owned by Google. Described in +Sawzall paper. +\end{itemize} + +\acks + +Many thanks for feedback on Thrift (and extreme trial by fire) are due to +Martin Smith, Karl Voskuil and Yishan Wong. + +Thrift is a successor to Pillar, a similar system developed +by Adam D'Angelo, first while at Caltech and continued later at Facebook. +Thrift simply would not have happened without Adam's insights. + +\begin{thebibliography}{} + +\bibitem{boost.threads} +Kempf, William, +``Boost.Threads'', +\url{http://www.boost.org/doc/html/threads.html} + +\bibitem{boost.threadpool} +Henkel, Philipp, +``threadpool'', +\url{http://threadpool.sourceforge.net} + +\end{thebibliography} + +\end{document} diff --git a/lib/Makefile.am b/lib/Makefile.am new file mode 100644 index 00000000..3558dd80 --- /dev/null +++ b/lib/Makefile.am @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = \ + cpp + +if WITH_MONO +SUBDIRS += csharp +endif + +if WITH_JAVA +SUBDIRS += java +endif + +if WITH_PYTHON +SUBDIRS += py +endif + +if WITH_ERLANG +SUBDIRS += erl +endif + +if WITH_RUBY +SUBDIRS += rb +endif + +if WITH_PERL +SUBDIRS += perl +endif + +# All of the libs that don't use Automake need to go in here +# so they will end up in our release tarballs. +EXTRA_DIST = \ + cocoa \ + hs \ + ocaml \ + php \ + erl \ + st diff --git a/lib/cocoa/README b/lib/cocoa/README new file mode 100644 index 00000000..bbe3c934 --- /dev/null +++ b/lib/cocoa/README @@ -0,0 +1,21 @@ +Thrift Cocoa Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. diff --git a/lib/cocoa/src/TApplicationException.h b/lib/cocoa/src/TApplicationException.h new file mode 100644 index 00000000..cf1641d9 --- /dev/null +++ b/lib/cocoa/src/TApplicationException.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" +#import "TProtocol.h" + +enum { + TApplicationException_UNKNOWN = 0, + TApplicationException_UNKNOWN_METHOD = 1, + TApplicationException_INVALID_MESSAGE_TYPE = 2, + TApplicationException_WRONG_METHOD_NAME = 3, + TApplicationException_BAD_SEQUENCE_ID = 4, + TApplicationException_MISSING_RESULT = 5 +}; + +// FIXME +@interface TApplicationException : TException { + int mType; +} + ++ (TApplicationException *) read: (id ) protocol; + +- (void) write: (id ) protocol; + ++ (TApplicationException *) exceptionWithType: (int) type + reason: (NSString *) message; + +@end diff --git a/lib/cocoa/src/TApplicationException.m b/lib/cocoa/src/TApplicationException.m new file mode 100644 index 00000000..70687537 --- /dev/null +++ b/lib/cocoa/src/TApplicationException.m @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TApplicationException.h" +#import "TProtocolUtil.h" + +@implementation TApplicationException + +- (id) initWithType: (int) type + reason: (NSString *) reason +{ + mType = type; + + NSString * name; + switch (type) { + case TApplicationException_UNKNOWN_METHOD: + name = @"Unknown method"; + break; + case TApplicationException_INVALID_MESSAGE_TYPE: + name = @"Invalid message type"; + break; + case TApplicationException_WRONG_METHOD_NAME: + name = @"Wrong method name"; + break; + case TApplicationException_BAD_SEQUENCE_ID: + name = @"Bad sequence ID"; + break; + case TApplicationException_MISSING_RESULT: + name = @"Missing result"; + break; + default: + name = @"Unknown"; + break; + } + + self = [super initWithName: name reason: reason userInfo: nil]; + return self; +} + + ++ (TApplicationException *) read: (id ) protocol +{ + NSString * reason = nil; + int type = TApplicationException_UNKNOWN; + int fieldType; + int fieldID; + + [protocol readStructBeginReturningName: NULL]; + + while (true) { + [protocol readFieldBeginReturningName: NULL + type: &fieldType + fieldID: &fieldID]; + if (fieldType == TType_STOP) { + break; + } + switch (fieldID) { + case 1: + if (fieldType == TType_STRING) { + reason = [protocol readString]; + } else { + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + } + break; + case 2: + if (fieldType == TType_I32) { + type = [protocol readI32]; + } else { + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + } + break; + default: + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + break; + } + [protocol readFieldEnd]; + } + [protocol readStructEnd]; + + return [TApplicationException exceptionWithType: type reason: reason]; +} + + +- (void) write: (id ) protocol +{ + [protocol writeStructBeginWithName: @"TApplicationException"]; + + if ([self reason] != nil) { + [protocol writeFieldBeginWithName: @"message" + type: TType_STRING + fieldID: 1]; + [protocol writeString: [self reason]]; + [protocol writeFieldEnd]; + } + + [protocol writeFieldBeginWithName: @"type" + type: TType_I32 + fieldID: 2]; + [protocol writeI32: mType]; + [protocol writeFieldEnd]; + + [protocol writeFieldStop]; + [protocol writeStructEnd]; +} + + ++ (TApplicationException *) exceptionWithType: (int) type + reason: (NSString *) reason +{ + return [[[TApplicationException alloc] initWithType: type + reason: reason] autorelease]; +} + +@end diff --git a/lib/cocoa/src/TException.h b/lib/cocoa/src/TException.h new file mode 100644 index 00000000..b069a868 --- /dev/null +++ b/lib/cocoa/src/TException.h @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import + +@interface TException : NSException { +} + ++ (id) exceptionWithName: (NSString *) name; + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason; + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason + error: (NSError *) error; + +@end diff --git a/lib/cocoa/src/TException.m b/lib/cocoa/src/TException.m new file mode 100644 index 00000000..7c84199d --- /dev/null +++ b/lib/cocoa/src/TException.m @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@implementation TException + ++ (id) exceptionWithName: (NSString *) name +{ + return [self exceptionWithName: name reason: @"unknown" error: nil]; +} + + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason +{ + return [self exceptionWithName: name reason: reason error: nil]; +} + + ++ (id) exceptionWithName: (NSString *) name + reason: (NSString *) reason + error: (NSError *) error +{ + NSDictionary * userInfo = nil; + if (error != nil) { + userInfo = [NSDictionary dictionaryWithObject: error forKey: @"error"]; + } + + return [super exceptionWithName: name + reason: reason + userInfo: userInfo]; +} + + +- (NSString *) description +{ + NSMutableString * result = [NSMutableString stringWithString: [self name]]; + [result appendFormat: @": %@", [self reason]]; + if ([self userInfo] != nil) { + [result appendFormat: @"\n userInfo = %@", [self userInfo]]; + } + + return result; +} + + +@end diff --git a/lib/cocoa/src/TProcessor.h b/lib/cocoa/src/TProcessor.h new file mode 100644 index 00000000..f8df225e --- /dev/null +++ b/lib/cocoa/src/TProcessor.h @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import + + +@protocol TProcessor + +- (BOOL) processOnInputProtocol: (id ) inProtocol + outputProtocol: (id ) outProtocol; + +@end diff --git a/lib/cocoa/src/protocol/TBinaryProtocol.h b/lib/cocoa/src/protocol/TBinaryProtocol.h new file mode 100644 index 00000000..52cf2669 --- /dev/null +++ b/lib/cocoa/src/protocol/TBinaryProtocol.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocol.h" +#import "TTransport.h" +#import "TProtocolFactory.h" + + +@interface TBinaryProtocol : NSObject { + id mTransport; + BOOL mStrictRead; + BOOL mStrictWrite; + int32_t mMessageSizeLimit; +} + +- (id) initWithTransport: (id ) transport; + +- (id) initWithTransport: (id ) transport + strictRead: (BOOL) strictRead + strictWrite: (BOOL) strictWrite; + +- (int32_t) messageSizeLimit; +- (void) setMessageSizeLimit: (int32_t) sizeLimit; + +@end; + + +@interface TBinaryProtocolFactory : NSObject { +} + ++ (TBinaryProtocolFactory *) sharedFactory; + +- (TBinaryProtocol *) newProtocolOnTransport: (id ) transport; + +@end diff --git a/lib/cocoa/src/protocol/TBinaryProtocol.m b/lib/cocoa/src/protocol/TBinaryProtocol.m new file mode 100644 index 00000000..ba7f4629 --- /dev/null +++ b/lib/cocoa/src/protocol/TBinaryProtocol.m @@ -0,0 +1,469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TBinaryProtocol.h" +#import "TProtocolException.h" + +int32_t VERSION_1 = 0x80010000; +int32_t VERSION_MASK = 0xffff0000; + + +static TBinaryProtocolFactory * gSharedFactory = nil; + +@implementation TBinaryProtocolFactory + ++ (TBinaryProtocolFactory *) sharedFactory { + if (gSharedFactory == nil) { + gSharedFactory = [[TBinaryProtocolFactory alloc] init]; + } + + return gSharedFactory; +} + +- (TBinaryProtocol *) newProtocolOnTransport: (id ) transport { + return [[[TBinaryProtocol alloc] initWithTransport: transport] autorelease]; +} + +@end + + + +@implementation TBinaryProtocol + +- (id) initWithTransport: (id ) transport +{ + return [self initWithTransport: transport strictRead: NO strictWrite: NO]; +} + +- (id) initWithTransport: (id ) transport + strictRead: (BOOL) strictRead + strictWrite: (BOOL) strictWrite +{ + self = [super init]; + mTransport = [transport retain]; + mStrictRead = strictRead; + mStrictWrite = strictWrite; + return self; +} + + +- (int32_t) messageSizeLimit +{ + return mMessageSizeLimit; +} + + +- (void) setMessageSizeLimit: (int32_t) sizeLimit +{ + mMessageSizeLimit = sizeLimit; +} + + +- (void) dealloc +{ + [mTransport release]; + [super dealloc]; +} + + +- (id ) transport +{ + return mTransport; +} + + +- (NSString *) readStringBody: (int) size +{ + char buff[size+1]; + [mTransport readAll: (uint8_t *) buff offset: 0 length: size]; + buff[size] = 0; + return [NSString stringWithUTF8String: buff]; +} + + +- (void) readMessageBeginReturningName: (NSString **) name + type: (int *) type + sequenceID: (int *) sequenceID +{ + int32_t size = [self readI32]; + if (size < 0) { + int version = size & VERSION_MASK; + if (version != VERSION_1) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: @"Bad version in readMessageBegin"]; + } + if (type != NULL) { + *type = version & 0x00FF; + } + NSString * messageName = [self readString]; + if (name != NULL) { + *name = messageName; + } + int seqID = [self readI32]; + if (sequenceID != NULL) { + *sequenceID = seqID; + } + } else { + if (mStrictRead) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: @"Missing version in readMessageBegin, old client?"]; + } + if ([self messageSizeLimit] > 0 && size > [self messageSizeLimit]) { + @throw [TProtocolException exceptionWithName: @"TProtocolException" + reason: [NSString stringWithFormat: @"Message too big. Size limit is: %d Message size is: %d", + mMessageSizeLimit, + size]]; + } + NSString * messageName = [self readStringBody: size]; + if (name != NULL) { + *name = messageName; + } + int messageType = [self readByte]; + if (type != NULL) { + *type = messageType; + } + int seqID = [self readI32]; + if (sequenceID != NULL) { + *sequenceID = seqID; + } + } +} + + +- (void) readMessageEnd {} + + +- (void) readStructBeginReturningName: (NSString **) name +{ + if (name != NULL) { + *name = nil; + } +} + + +- (void) readStructEnd {} + + +- (void) readFieldBeginReturningName: (NSString **) name + type: (int *) fieldType + fieldID: (int *) fieldID +{ + if (name != NULL) { + *name = nil; + } + int ft = [self readByte]; + if (fieldType != NULL) { + *fieldType = ft; + } + if (ft != TType_STOP) { + int fid = [self readI16]; + if (fieldID != NULL) { + *fieldID = fid; + } + } +} + + +- (void) readFieldEnd {} + + +- (int32_t) readI32 +{ + uint8_t i32rd[4]; + [mTransport readAll: i32rd offset: 0 length: 4]; + return + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); +} + + +- (NSString *) readString +{ + int size = [self readI32]; + return [self readStringBody: size]; +} + + +- (BOOL) readBool +{ + return [self readByte] == 1; +} + +- (uint8_t) readByte +{ + uint8_t myByte; + [mTransport readAll: &myByte offset: 0 length: 1]; + return myByte; +} + +- (short) readI16 +{ + uint8_t buff[2]; + [mTransport readAll: buff offset: 0 length: 2]; + return (short) + (((buff[0] & 0xff) << 8) | + ((buff[1] & 0xff))); + return 0; +} + +- (int64_t) readI64; +{ + uint8_t i64rd[8]; + [mTransport readAll: i64rd offset: 0 length: 8]; + return + ((int64_t)(i64rd[0] & 0xff) << 56) | + ((int64_t)(i64rd[1] & 0xff) << 48) | + ((int64_t)(i64rd[2] & 0xff) << 40) | + ((int64_t)(i64rd[3] & 0xff) << 32) | + ((int64_t)(i64rd[4] & 0xff) << 24) | + ((int64_t)(i64rd[5] & 0xff) << 16) | + ((int64_t)(i64rd[6] & 0xff) << 8) | + ((int64_t)(i64rd[7] & 0xff)); +} + +- (double) readDouble; +{ + // FIXME - will this get us into trouble on PowerPC? + int64_t ieee754 = [self readI64]; + return *((double *) &ieee754); +} + + +- (NSData *) readBinary +{ + int32_t size = [self readI32]; + uint8_t * buff = malloc(size); + if (buff == NULL) { + @throw [TProtocolException + exceptionWithName: @"TProtocolException" + reason: [NSString stringWithFormat: @"Out of memory. Unable to allocate %d bytes trying to read binary data.", + size]]; + } + [mTransport readAll: buff offset: 0 length: size]; + return [NSData dataWithBytesNoCopy: buff length: size]; +} + + +- (void) readMapBeginReturningKeyType: (int *) keyType + valueType: (int *) valueType + size: (int *) size +{ + int kt = [self readByte]; + int vt = [self readByte]; + int s = [self readI32]; + if (keyType != NULL) { + *keyType = kt; + } + if (valueType != NULL) { + *valueType = vt; + } + if (size != NULL) { + *size = s; + } +} + +- (void) readMapEnd {} + + +- (void) readSetBeginReturningElementType: (int *) elementType + size: (int *) size +{ + int et = [self readByte]; + int s = [self readI32]; + if (elementType != NULL) { + *elementType = et; + } + if (size != NULL) { + *size = s; + } +} + + +- (void) readSetEnd {} + + +- (void) readListBeginReturningElementType: (int *) elementType + size: (int *) size +{ + int et = [self readByte]; + int s = [self readI32]; + if (elementType != NULL) { + *elementType = et; + } + if (size != NULL) { + *size = s; + } +} + + +- (void) readListEnd {} + + +- (void) writeByte: (uint8_t) value +{ + [mTransport write: &value offset: 0 length: 1]; +} + + +- (void) writeMessageBeginWithName: (NSString *) name + type: (int) messageType + sequenceID: (int) sequenceID +{ + if (mStrictWrite) { + int version = VERSION_1 | messageType; + [self writeI32: version]; + [self writeString: name]; + [self writeI32: sequenceID]; + } else { + [self writeString: name]; + [self writeByte: messageType]; + [self writeI32: sequenceID]; + } +} + + +- (void) writeMessageEnd {} + + +- (void) writeStructBeginWithName: (NSString *) name {} + + +- (void) writeStructEnd {} + + +- (void) writeFieldBeginWithName: (NSString *) name + type: (int) fieldType + fieldID: (int) fieldID +{ + [self writeByte: fieldType]; + [self writeI16: fieldID]; +} + + +- (void) writeI32: (int32_t) value +{ + uint8_t buff[4]; + buff[0] = 0xFF & (value >> 24); + buff[1] = 0xFF & (value >> 16); + buff[2] = 0xFF & (value >> 8); + buff[3] = 0xFF & value; + [mTransport write: buff offset: 0 length: 4]; +} + +- (void) writeI16: (short) value +{ + uint8_t buff[2]; + buff[0] = 0xff & (value >> 8); + buff[1] = 0xff & value; + [mTransport write: buff offset: 0 length: 2]; +} + + +- (void) writeI64: (int64_t) value +{ + uint8_t buff[8]; + buff[0] = 0xFF & (value >> 56); + buff[1] = 0xFF & (value >> 48); + buff[2] = 0xFF & (value >> 40); + buff[3] = 0xFF & (value >> 32); + buff[4] = 0xFF & (value >> 24); + buff[5] = 0xFF & (value >> 16); + buff[6] = 0xFF & (value >> 8); + buff[7] = 0xFF & value; + [mTransport write: buff offset: 0 length: 8]; +} + +- (void) writeDouble: (double) value +{ + // spit out IEEE 754 bits - FIXME - will this get us in trouble on + // PowerPC? + [self writeI64: *((int64_t *) &value)]; +} + + +- (void) writeString: (NSString *) value +{ + if (value != nil) { + const char * utf8Bytes = [value UTF8String]; + size_t length = strlen(utf8Bytes); + [self writeI32: length]; + [mTransport write: (uint8_t *) utf8Bytes offset: 0 length: length]; + } else { + // instead of crashing when we get null, let's write out a zero + // length string + [self writeI32: 0]; + } +} + + +- (void) writeBinary: (NSData *) data +{ + [self writeI32: [data length]]; + [mTransport write: [data bytes] offset: 0 length: [data length]]; +} + +- (void) writeFieldStop +{ + [self writeByte: TType_STOP]; +} + + +- (void) writeFieldEnd {} + + +- (void) writeMapBeginWithKeyType: (int) keyType + valueType: (int) valueType + size: (int) size +{ + [self writeByte: keyType]; + [self writeByte: valueType]; + [self writeI32: size]; +} + +- (void) writeMapEnd {} + + +- (void) writeSetBeginWithElementType: (int) elementType + size: (int) size +{ + [self writeByte: elementType]; + [self writeI32: size]; +} + +- (void) writeSetEnd {} + + +- (void) writeListBeginWithElementType: (int) elementType + size: (int) size +{ + [self writeByte: elementType]; + [self writeI32: size]; +} + +- (void) writeListEnd {} + + +- (void) writeBool: (BOOL) value +{ + [self writeByte: (value ? 1 : 0)]; +} + +@end diff --git a/lib/cocoa/src/protocol/TProtocol.h b/lib/cocoa/src/protocol/TProtocol.h new file mode 100644 index 00000000..cc8cdb4b --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocol.h @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import + +#import "TTransport.h" + + +enum { + TMessageType_CALL = 1, + TMessageType_REPLY = 2, + TMessageType_EXCEPTION = 3, + TMessageType_ONEWAY = 4 +}; + +enum { + TType_STOP = 0, + TType_VOID = 1, + TType_BOOL = 2, + TType_BYTE = 3, + TType_DOUBLE = 4, + TType_I16 = 6, + TType_I32 = 8, + TType_I64 = 10, + TType_STRING = 11, + TType_STRUCT = 12, + TType_MAP = 13, + TType_SET = 14, + TType_LIST = 15 +}; + + +@protocol TProtocol + +- (id ) transport; + +- (void) readMessageBeginReturningName: (NSString **) name + type: (int *) type + sequenceID: (int *) sequenceID; +- (void) readMessageEnd; + +- (void) readStructBeginReturningName: (NSString **) name; +- (void) readStructEnd; + +- (void) readFieldBeginReturningName: (NSString **) name + type: (int *) fieldType + fieldID: (int *) fieldID; +- (void) readFieldEnd; + +- (NSString *) readString; + +- (BOOL) readBool; + +- (unsigned char) readByte; + +- (short) readI16; + +- (int32_t) readI32; + +- (int64_t) readI64; + +- (double) readDouble; + +- (NSData *) readBinary; + +- (void) readMapBeginReturningKeyType: (int *) keyType + valueType: (int *) valueType + size: (int *) size; +- (void) readMapEnd; + + +- (void) readSetBeginReturningElementType: (int *) elementType + size: (int *) size; +- (void) readSetEnd; + + +- (void) readListBeginReturningElementType: (int *) elementType + size: (int *) size; +- (void) readListEnd; + + +- (void) writeMessageBeginWithName: (NSString *) name + type: (int) messageType + sequenceID: (int) sequenceID; +- (void) writeMessageEnd; + +- (void) writeStructBeginWithName: (NSString *) name; +- (void) writeStructEnd; + +- (void) writeFieldBeginWithName: (NSString *) name + type: (int) fieldType + fieldID: (int) fieldID; + +- (void) writeI32: (int32_t) value; + +- (void) writeI64: (int64_t) value; + +- (void) writeI16: (short) value; + +- (void) writeByte: (uint8_t) value; + +- (void) writeString: (NSString *) value; + +- (void) writeDouble: (double) value; + +- (void) writeBool: (BOOL) value; + +- (void) writeBinary: (NSData *) data; + +- (void) writeFieldStop; + +- (void) writeFieldEnd; + +- (void) writeMapBeginWithKeyType: (int) keyType + valueType: (int) valueType + size: (int) size; +- (void) writeMapEnd; + + +- (void) writeSetBeginWithElementType: (int) elementType + size: (int) size; +- (void) writeSetEnd; + + +- (void) writeListBeginWithElementType: (int) elementType + size: (int) size; + +- (void) writeListEnd; + + +@end + diff --git a/lib/cocoa/src/protocol/TProtocolException.h b/lib/cocoa/src/protocol/TProtocolException.h new file mode 100644 index 00000000..ad354fc2 --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolException.h @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@interface TProtocolException : TException { +} + +@end diff --git a/lib/cocoa/src/protocol/TProtocolException.m b/lib/cocoa/src/protocol/TProtocolException.m new file mode 100644 index 00000000..681487a4 --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolException.m @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocolException.h" + +@implementation TProtocolException +@end diff --git a/lib/cocoa/src/protocol/TProtocolFactory.h b/lib/cocoa/src/protocol/TProtocolFactory.h new file mode 100644 index 00000000..2d125e96 --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolFactory.h @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TProtocol.h" +#import "TTransport.h" + + +@protocol TProtocolFactory + +- (id ) newProtocolOnTransport: (id ) transport; + +@end diff --git a/lib/cocoa/src/protocol/TProtocolUtil.h b/lib/cocoa/src/protocol/TProtocolUtil.h new file mode 100644 index 00000000..c2d2521c --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolUtil.h @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocol.h" +#import "TTransport.h" + +@interface TProtocolUtil : NSObject { + +} + ++ (void) skipType: (int) type onProtocol: (id ) protocol; + +@end; diff --git a/lib/cocoa/src/protocol/TProtocolUtil.m b/lib/cocoa/src/protocol/TProtocolUtil.m new file mode 100644 index 00000000..13d70954 --- /dev/null +++ b/lib/cocoa/src/protocol/TProtocolUtil.m @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TProtocolUtil.h" + +@implementation TProtocolUtil + ++ (void) skipType: (int) type onProtocol: (id ) protocol +{ + switch (type) { + case TType_BOOL: + [protocol readBool]; + break; + case TType_BYTE: + [protocol readByte]; + break; + case TType_I16: + [protocol readI16]; + break; + case TType_I32: + [protocol readI32]; + break; + case TType_I64: + [protocol readI64]; + break; + case TType_DOUBLE: + [protocol readDouble]; + break; + case TType_STRING: + [protocol readString]; + break; + case TType_STRUCT: + [protocol readStructBeginReturningName: NULL]; + while (true) { + int fieldType; + [protocol readFieldBeginReturningName: nil type: &fieldType fieldID: nil]; + if (fieldType == TType_STOP) { + break; + } + [TProtocolUtil skipType: fieldType onProtocol: protocol]; + [protocol readFieldEnd]; + } + [protocol readStructEnd]; + break; + case TType_MAP: + { + int keyType; + int valueType; + int size; + [protocol readMapBeginReturningKeyType: &keyType valueType: &valueType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: keyType onProtocol: protocol]; + [TProtocolUtil skipType: valueType onProtocol: protocol]; + } + [protocol readMapEnd]; + } + break; + case TType_SET: + { + int elemType; + int size; + [protocol readSetBeginReturningElementType: &elemType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: elemType onProtocol: protocol]; + } + [protocol readSetEnd]; + } + break; + case TType_LIST: + { + int elemType; + int size; + [protocol readListBeginReturningElementType: &elemType size: &size]; + int i; + for (i = 0; i < size; i++) { + [TProtocolUtil skipType: elemType onProtocol: protocol]; + } + [protocol readListEnd]; + } + break; + default: + return; + } +} + +@end diff --git a/lib/cocoa/src/server/TSocketServer.h b/lib/cocoa/src/server/TSocketServer.h new file mode 100644 index 00000000..3d4a9e0e --- /dev/null +++ b/lib/cocoa/src/server/TSocketServer.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TProtocolFactory.h" +#import "TProcessor.h" + + +@interface TSocketServer : NSObject { + NSSocketPort * mServerSocket; + NSFileHandle * mSocketFileHandle; + id mInputProtocolFactory; + id mOutputProtocolFactory; + id mProcessor; +} + +- (id) initWithPort: (int) port + protocolFactory: (id ) protocolFactory + processor: (id ) processor; + +@end + + + diff --git a/lib/cocoa/src/server/TSocketServer.m b/lib/cocoa/src/server/TSocketServer.m new file mode 100644 index 00000000..97d8bae7 --- /dev/null +++ b/lib/cocoa/src/server/TSocketServer.m @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TSocketServer.h" +#import "TNSFileHandleTransport.h" +#import "TProtocol.h" +#import "TTransportException.h" + + +@implementation TSocketServer + +- (id) initWithPort: (int) port + protocolFactory: (id ) protocolFactory + processor: (id ) processor; +{ + self = [super init]; + + mInputProtocolFactory = [protocolFactory retain]; + mOutputProtocolFactory = [protocolFactory retain]; + mProcessor = [processor retain]; + + // create a socket + mServerSocket = [[NSSocketPort alloc] initWithTCPPort: port]; + // FIXME - move this separate start method and add method to close + // and cleanup any open ports + + if (mServerSocket == nil) { + NSLog(@"Unable to listen on TCP port %d", port); + } else { + NSLog(@"Listening on TCP port %d", port); + + // wrap it in a file handle so we can get messages from it + mSocketFileHandle = [[NSFileHandle alloc] initWithFileDescriptor: [mServerSocket socket] + closeOnDealloc: YES]; + + // register for notifications of accepted incoming connections + [[NSNotificationCenter defaultCenter] addObserver: self + selector: @selector(connectionAccepted:) + name: NSFileHandleConnectionAcceptedNotification + object: mSocketFileHandle]; + + // tell socket to listen + [mSocketFileHandle acceptConnectionInBackgroundAndNotify]; + } + + return self; +} + + +- (void) dealloc { + [mInputProtocolFactory release]; + [mOutputProtocolFactory release]; + [mProcessor release]; + [mSocketFileHandle release]; + [mServerSocket release]; + [super dealloc]; +} + + +- (void) connectionAccepted: (NSNotification *) aNotification +{ + NSFileHandle * socket = [[aNotification userInfo] objectForKey: NSFileHandleNotificationFileHandleItem]; + + // now that we have a client connected, spin off a thread to handle activity + [NSThread detachNewThreadSelector: @selector(handleClientConnection:) + toTarget: self + withObject: socket]; + + [[aNotification object] acceptConnectionInBackgroundAndNotify]; +} + + +- (void) handleClientConnection: (NSFileHandle *) clientSocket +{ + NSAutoreleasePool * pool = [[NSAutoreleasePool alloc] init]; + + TNSFileHandleTransport * transport = [[TNSFileHandleTransport alloc] initWithFileHandle: clientSocket]; + + id inProtocol = [mInputProtocolFactory newProtocolOnTransport: transport]; + id outProtocol = [mOutputProtocolFactory newProtocolOnTransport: transport]; + + @try { + while ([mProcessor processOnInputProtocol: inProtocol outputProtocol: outProtocol]); + } + @catch (TTransportException * te) { + NSLog(@"%@", te); + } + + [pool release]; +} + + + +@end + + + diff --git a/lib/cocoa/src/transport/THTTPClient.h b/lib/cocoa/src/transport/THTTPClient.h new file mode 100644 index 00000000..86f3f054 --- /dev/null +++ b/lib/cocoa/src/transport/THTTPClient.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TTransport.h" + +@interface THTTPClient : NSObject { + NSURL * mURL; + NSMutableURLRequest * mRequest; + NSMutableData * mRequestData; + NSData * mResponseData; + int mResponseDataOffset; + NSString * mUserAgent; + int mTimeout; +} + +- (id) initWithURL: (NSURL *) aURL; + +- (id) initWithURL: (NSURL *) aURL + userAgent: (NSString *) userAgent + timeout: (int) timeout; + +- (void) setURL: (NSURL *) aURL; + +@end + diff --git a/lib/cocoa/src/transport/THTTPClient.m b/lib/cocoa/src/transport/THTTPClient.m new file mode 100644 index 00000000..6391bead --- /dev/null +++ b/lib/cocoa/src/transport/THTTPClient.m @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "THTTPClient.h" +#import "TTransportException.h" + +@implementation THTTPClient + + +- (void) setupRequest +{ + if (mRequest != nil) { + [mRequest release]; + } + + // set up our request object that we'll use for each request + mRequest = [[NSMutableURLRequest alloc] initWithURL: mURL]; + [mRequest setHTTPMethod: @"POST"]; + [mRequest setValue: @"application/x-thrift" forHTTPHeaderField: @"Content-Type"]; + [mRequest setValue: @"application/x-thrift" forHTTPHeaderField: @"Accept"]; + + NSString * userAgent = mUserAgent; + if (!userAgent) { + userAgent = @"Cocoa/THTTPClient"; + } + [mRequest setValue: userAgent forHTTPHeaderField: @"User-Agent"]; + + [mRequest setCachePolicy: NSURLRequestReloadIgnoringCacheData]; + if (mTimeout) { + [mRequest setTimeoutInterval: mTimeout]; + } +} + + +- (id) initWithURL: (NSURL *) aURL +{ + return [self initWithURL: aURL + userAgent: nil + timeout: 0]; +} + + +- (id) initWithURL: (NSURL *) aURL + userAgent: (NSString *) userAgent + timeout: (int) timeout +{ + self = [super init]; + if (!self) { + return nil; + } + + mTimeout = timeout; + if (userAgent) { + mUserAgent = [userAgent retain]; + } + mURL = [aURL retain]; + + [self setupRequest]; + + // create our request data buffer + mRequestData = [[NSMutableData alloc] initWithCapacity: 1024]; + + return self; +} + + +- (void) setURL: (NSURL *) aURL +{ + [aURL retain]; + [mURL release]; + mURL = aURL; + + [self setupRequest]; +} + + +- (void) dealloc +{ + [mURL release]; + [mUserAgent release]; + [mRequest release]; + [mRequestData release]; + [mResponseData release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + NSRange r; + r.location = mResponseDataOffset; + r.length = len; + + [mResponseData getBytes: buf+off range: r]; + mResponseDataOffset += len; + + return len; +} + + +- (void) write: (const uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + [mRequestData appendBytes: data+offset length: length]; +} + + +- (void) flush +{ + [mRequest setHTTPBody: mRequestData]; // not sure if it copies the data + + // make the HTTP request + NSURLResponse * response; + NSError * error; + NSData * responseData = + [NSURLConnection sendSynchronousRequest: mRequest returningResponse: &response error: &error]; + + [mRequestData setLength: 0]; + + if (responseData == nil) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Could not make HTTP request" + error: error]; + } + if (![response isKindOfClass: [NSHTTPURLResponse class]]) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Unexpected NSURLResponse type"]; + } + + NSHTTPURLResponse * httpResponse = (NSHTTPURLResponse *) response; + if ([httpResponse statusCode] != 200) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: [NSString stringWithFormat: @"Bad response from HTTP server: %d", + [httpResponse statusCode]]]; + } + + // phew! + [mResponseData release]; + mResponseData = [responseData retain]; + mResponseDataOffset = 0; +} + + +@end diff --git a/lib/cocoa/src/transport/TNSFileHandleTransport.h b/lib/cocoa/src/transport/TNSFileHandleTransport.h new file mode 100644 index 00000000..64a6af3c --- /dev/null +++ b/lib/cocoa/src/transport/TNSFileHandleTransport.h @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#import +#import "TTransport.h" + +@interface TNSFileHandleTransport : NSObject { + NSFileHandle * mInputFileHandle; + NSFileHandle * mOutputFileHandle; +} + +- (id) initWithFileHandle: (NSFileHandle *) fileHandle; + +- (id) initWithInputFileHandle: (NSFileHandle *) inputFileHandle + outputFileHandle: (NSFileHandle *) outputFileHandle; + + +@end diff --git a/lib/cocoa/src/transport/TNSFileHandleTransport.m b/lib/cocoa/src/transport/TNSFileHandleTransport.m new file mode 100644 index 00000000..15339341 --- /dev/null +++ b/lib/cocoa/src/transport/TNSFileHandleTransport.m @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +#import "TNSFileHandleTransport.h" +#import "TTransportException.h" + + +@implementation TNSFileHandleTransport + +- (id) initWithFileHandle: (NSFileHandle *) fileHandle +{ + return [self initWithInputFileHandle: fileHandle + outputFileHandle: fileHandle]; +} + + +- (id) initWithInputFileHandle: (NSFileHandle *) inputFileHandle + outputFileHandle: (NSFileHandle *) outputFileHandle +{ + self = [super init]; + + mInputFileHandle = [inputFileHandle retain]; + mOutputFileHandle = [outputFileHandle retain]; + + return self; +} + + +- (void) dealloc { + [mInputFileHandle release]; + [mOutputFileHandle release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + int got = 0; + while (got < len) { + NSData * d = [mInputFileHandle readDataOfLength: len-got]; + if ([d length] == 0) { + @throw [TTransportException exceptionWithName: @"TTransportException" + reason: @"Cannot read. No more data."]; + } + [d getBytes: buf+got]; + got += [d length]; + } + return got; +} + + +- (void) write: (uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + NSData * dataObject = [[NSData alloc] initWithBytesNoCopy: data+offset + length: length + freeWhenDone: NO]; + + [mOutputFileHandle writeData: dataObject]; + + + [dataObject release]; +} + + +- (void) flush +{ + +} + +@end diff --git a/lib/cocoa/src/transport/TNSStreamTransport.h b/lib/cocoa/src/transport/TNSStreamTransport.h new file mode 100644 index 00000000..295a185c --- /dev/null +++ b/lib/cocoa/src/transport/TNSStreamTransport.h @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TTransport.h" + +@interface TNSStreamTransport : NSObject { + NSInputStream * mInput; + NSOutputStream * mOutput; +} + +- (id) initWithInputStream: (NSInputStream *) input + outputStream: (NSOutputStream *) output; + +- (id) initWithInputStream: (NSInputStream *) input; + +- (id) initWithOutputStream: (NSOutputStream *) output; + +@end + + + diff --git a/lib/cocoa/src/transport/TNSStreamTransport.m b/lib/cocoa/src/transport/TNSStreamTransport.m new file mode 100644 index 00000000..52a02e27 --- /dev/null +++ b/lib/cocoa/src/transport/TNSStreamTransport.m @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TNSStreamTransport.h" +#import "TTransportException.h" + + +@implementation TNSStreamTransport + +- (id) initWithInputStream: (NSInputStream *) input + outputStream: (NSOutputStream *) output +{ + [super init]; + mInput = [input retain]; + mOutput = [output retain]; + return self; +} + +- (id) initWithInputStream: (NSInputStream *) input +{ + return [self initWithInputStream: input outputStream: nil]; +} + +- (id) initWithOutputStream: (NSOutputStream *) output +{ + return [self initWithInputStream: nil outputStream: output]; +} + +- (void) dealloc +{ + [mInput release]; + [mOutput release]; + [super dealloc]; +} + + +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len +{ + int got = 0; + int ret = 0; + while (got < len) { + ret = [mInput read: buf+off+got maxLength: len-got]; + if (ret <= 0) { + @throw [TTransportException exceptionWithReason: @"Cannot read. Remote side has closed."]; + } + got += ret; + } + return got; +} + + +// FIXME:geech:20071019 - make this write all +- (void) write: (uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length +{ + int result = [mOutput write: data+offset maxLength: length]; + if (result == -1) { + @throw [TTransportException exceptionWithReason: @"Error writing to transport output stream." + error: [mOutput streamError]]; + } else if (result == 0) { + @throw [TTransportException exceptionWithReason: @"End of output stream."]; + } else if (result != length) { + @throw [TTransportException exceptionWithReason: @"Output stream did not write all of our data."]; + } +} + +- (void) flush +{ + // no flush for you! +} + +@end diff --git a/lib/cocoa/src/transport/TSocketClient.h b/lib/cocoa/src/transport/TSocketClient.h new file mode 100644 index 00000000..a883acbb --- /dev/null +++ b/lib/cocoa/src/transport/TSocketClient.h @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TNSStreamTransport.h" + +@interface TSocketClient : TNSStreamTransport { +} + +- (id) initWithHostname: (NSString *) hostname + port: (int) port; + +@end + + + diff --git a/lib/cocoa/src/transport/TSocketClient.m b/lib/cocoa/src/transport/TSocketClient.m new file mode 100644 index 00000000..7c07c561 --- /dev/null +++ b/lib/cocoa/src/transport/TSocketClient.m @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import +#import "TSocketClient.h" + +@implementation TSocketClient + +- (id) initWithHostname: (NSString *) hostname + port: (int) port +{ + NSInputStream * input = nil; + NSOutputStream * output = nil; + + [NSStream getStreamsToHost: [NSHost hostWithName: hostname] + port: port + inputStream: &input + outputStream: &output]; + + self = [super initWithInputStream: input outputStream: output]; + [input open]; + [output open]; + + return self; +} + + +@end + + + diff --git a/lib/cocoa/src/transport/TTransport.h b/lib/cocoa/src/transport/TTransport.h new file mode 100644 index 00000000..61ebbd21 --- /dev/null +++ b/lib/cocoa/src/transport/TTransport.h @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +@protocol TTransport + + /** + * Guarantees that all of len bytes are read + * + * @param buf Buffer to read into + * @param off Index in buffer to start storing bytes at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read, which must be equal to len + * @throws TTransportException if there was an error reading data + */ +- (int) readAll: (uint8_t *) buf offset: (int) off length: (int) len; + +- (void) write: (const uint8_t *) data offset: (unsigned int) offset length: (unsigned int) length; + +- (void) flush; +@end diff --git a/lib/cocoa/src/transport/TTransportException.h b/lib/cocoa/src/transport/TTransportException.h new file mode 100644 index 00000000..6749fe28 --- /dev/null +++ b/lib/cocoa/src/transport/TTransportException.h @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TException.h" + +@interface TTransportException : TException { +} + ++ (id) exceptionWithReason: (NSString *) reason + error: (NSError *) error; + ++ (id) exceptionWithReason: (NSString *) reason; + +@end diff --git a/lib/cocoa/src/transport/TTransportException.m b/lib/cocoa/src/transport/TTransportException.m new file mode 100644 index 00000000..aa67149e --- /dev/null +++ b/lib/cocoa/src/transport/TTransportException.m @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#import "TTransportException.h" + +@implementation TTransportException + ++ (id) exceptionWithReason: (NSString *) reason + error: (NSError *) error +{ + NSDictionary * userInfo = nil; + if (error != nil) { + userInfo = [NSDictionary dictionaryWithObject: error forKey: @"error"]; + } + + return [super exceptionWithName: @"TTransportException" + reason: reason + userInfo: userInfo]; +} + + ++ (id) exceptionWithReason: (NSString *) reason +{ + return [self exceptionWithReason: reason error: nil]; +} + +@end diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am new file mode 100644 index 00000000..dc0b6ae5 --- /dev/null +++ b/lib/cpp/Makefile.am @@ -0,0 +1,158 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +ACLOCAL_AMFLAGS = -I ./aclocal + +pkgconfigdir = $(libdir)/pkgconfig + +lib_LTLIBRARIES = libthrift.la +pkgconfig_DATA = thrift.pc + +## We only build the extra libraries if we have the dependencies, +## but we install all of the headers unconditionally. +if AMX_HAVE_LIBEVENT +lib_LTLIBRARIES += libthriftnb.la +pkgconfig_DATA += thrift-nb.pc +endif +if AMX_HAVE_ZLIB +lib_LTLIBRARIES += libthriftz.la +pkgconfig_DATA += thrift-z.pc +endif + +AM_CXXFLAGS = -Wall +AM_CPPFLAGS = $(BOOST_CPPFLAGS) -I$(srcdir)/src + +# Define the source files for the module + +libthrift_la_SOURCES = src/Thrift.cpp \ + src/concurrency/Mutex.cpp \ + src/concurrency/Monitor.cpp \ + src/concurrency/PosixThreadFactory.cpp \ + src/concurrency/ThreadManager.cpp \ + src/concurrency/TimerManager.cpp \ + src/concurrency/Util.cpp \ + src/protocol/TBinaryProtocol.cpp \ + src/protocol/TCompactProtocol.cpp \ + src/protocol/TDebugProtocol.cpp \ + src/protocol/TDenseProtocol.cpp \ + src/protocol/TJSONProtocol.cpp \ + src/protocol/TBase64Utils.cpp \ + src/transport/TTransportException.cpp \ + src/transport/TFDTransport.cpp \ + src/transport/TFileTransport.cpp \ + src/transport/TSimpleFileTransport.cpp \ + src/transport/THttpClient.cpp \ + src/transport/TSocket.cpp \ + src/transport/TSocketPool.cpp \ + src/transport/TServerSocket.cpp \ + src/transport/TTransportUtils.cpp \ + src/transport/TBufferTransports.cpp \ + src/server/TServer.cpp \ + src/server/TSimpleServer.cpp \ + src/server/TThreadPoolServer.cpp \ + src/server/TThreadedServer.cpp \ + src/processor/PeekProcessor.cpp + +libthriftnb_la_SOURCES = src/server/TNonblockingServer.cpp + +libthriftz_la_SOURCES = src/transport/TZlibTransport.cpp + + +# Flags for the various libraries +libthriftnb_la_CPPFLAGS = $(AM_CPPFLAGS) $(LIBEVENT_CPPFLAGS) +libthriftz_la_CPPFLAGS = $(AM_CPPFLAGS) $(ZLIB_CPPFLAGS) + + +include_thriftdir = $(includedir)/thrift +include_thrift_HEADERS = \ + $(top_builddir)/config.h \ + src/Thrift.h \ + src/TReflectionLocal.h \ + src/TProcessor.h \ + src/TLogging.h + +include_concurrencydir = $(include_thriftdir)/concurrency +include_concurrency_HEADERS = \ + src/concurrency/Exception.h \ + src/concurrency/Mutex.h \ + src/concurrency/Monitor.h \ + src/concurrency/PosixThreadFactory.h \ + src/concurrency/Thread.h \ + src/concurrency/ThreadManager.h \ + src/concurrency/TimerManager.h \ + src/concurrency/FunctionRunner.h \ + src/concurrency/Util.h + +include_protocoldir = $(include_thriftdir)/protocol +include_protocol_HEADERS = \ + src/protocol/TBinaryProtocol.h \ + src/protocol/TCompactProtocol.h \ + src/protocol/TDenseProtocol.h \ + src/protocol/TDebugProtocol.h \ + src/protocol/TOneWayProtocol.h \ + src/protocol/TBase64Utils.h \ + src/protocol/TJSONProtocol.h \ + src/protocol/TProtocolTap.h \ + src/protocol/TProtocolException.h \ + src/protocol/TProtocol.h + +include_transportdir = $(include_thriftdir)/transport +include_transport_HEADERS = \ + src/transport/TFDTransport.h \ + src/transport/TFileTransport.h \ + src/transport/TSimpleFileTransport.h \ + src/transport/TServerSocket.h \ + src/transport/TServerTransport.h \ + src/transport/THttpClient.h \ + src/transport/TSocket.h \ + src/transport/TSocketPool.h \ + src/transport/TTransport.h \ + src/transport/TTransportException.h \ + src/transport/TTransportUtils.h \ + src/transport/TBufferTransports.h \ + src/transport/TShortReadTransport.h \ + src/transport/TZlibTransport.h + +include_serverdir = $(include_thriftdir)/server +include_server_HEADERS = \ + src/server/TServer.h \ + src/server/TSimpleServer.h \ + src/server/TThreadPoolServer.h \ + src/server/TThreadedServer.h \ + src/server/TNonblockingServer.h + +include_processordir = $(include_thriftdir)/processor +include_processor_HEADERS = \ + src/processor/PeekProcessor.h \ + src/processor/StatsProcessor.h + +noinst_PROGRAMS = concurrency_test + +concurrency_test_SOURCES = src/concurrency/test/Tests.cpp \ + src/concurrency/test/ThreadFactoryTests.h \ + src/concurrency/test/ThreadManagerTests.h \ + src/concurrency/test/TimerManagerTests.h + +concurrency_test_LDADD = libthrift.la + +EXTRA_DIST = \ + README \ + thrift-nb.pc.in \ + thrift.pc.in \ + thrift-z.pc.in diff --git a/lib/cpp/README b/lib/cpp/README new file mode 100644 index 00000000..576d0170 --- /dev/null +++ b/lib/cpp/README @@ -0,0 +1,67 @@ +Thrift C++ Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with C++ +===================== + +The Thrift C++ libraries are built using the GNU tools. Follow the instructions +in the top-level README, or run bootstrap.sh in this folder to generate the +Makefiles. + +In case you do not want to open another README file, do this: + ./bootstrap.sh + ./configure (--with-boost=/usr/local) + make + sudo make install + +Thrift is divided into two libraries. + +libthrift + The core Thrift library contains all the core Thrift code. It requires + boost shared pointers, pthreads, and librt. + +libthriftnb + This library contains the Thrift nonblocking server, which uses libevent. + To link this library you will also need to link libevent. + +Linking Against Thrift +====================== + +After you build and install Thrift the libraries are installed to +/usr/local/lib by default. Make sure this is in your LDPATH. + +On Linux, the best way to do this is to ensure that /usr/local/lib is in +your /etc/ld.so.conf and then run /sbin/ldconfig. + +Depending upon whether you are linking dynamically or statically and how +your build environment it set up, you may need to include additional +libraries when linking against thrift, such as librt and/or libpthread. If +you are using libthriftnb you will also need libevent. + +Dependencies +============ + +boost shared pointers +http://www.boost.org/libs/smart_ptr/smart_ptr.htm + +libevent (for libthriftnb only) +http://monkey.org/~provos/libevent/ diff --git a/lib/cpp/src/TLogging.h b/lib/cpp/src/TLogging.h new file mode 100644 index 00000000..2df82dd7 --- /dev/null +++ b/lib/cpp/src/TLogging.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TLOGGING_H_ +#define _THRIFT_TLOGGING_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +/** + * Contains utility macros for debugging and logging. + * + */ + +#ifndef HAVE_CLOCK_GETTIME +#include +#else +#include +#endif + +#ifdef HAVE_STDINT_H +#include +#endif + +/** + * T_GLOBAL_DEBUGGING_LEVEL = 0: all debugging turned off, debug macros undefined + * T_GLOBAL_DEBUGGING_LEVEL = 1: all debugging turned on + */ +#define T_GLOBAL_DEBUGGING_LEVEL 0 + + +/** + * T_GLOBAL_LOGGING_LEVEL = 0: all logging turned off, logging macros undefined + * T_GLOBAL_LOGGING_LEVEL = 1: all logging turned on + */ +#define T_GLOBAL_LOGGING_LEVEL 1 + + +/** + * Standard wrapper around fprintf what will prefix the file name and line + * number to the line. Uses T_GLOBAL_DEBUGGING_LEVEL to control whether it is + * turned on or off. + * + * @param format_string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG(format_string,...) \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } +#else + #define T_DEBUG(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but also prints the time + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_DEBUGGING_LEVEL > 0 + #define T_DEBUG_T(format_string,...) \ + { \ + if (T_GLOBAL_DEBUGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_DEBUG_T(format_string,...) +#endif + + +/** + * analagous to T_DEBUG but uses input level to determine whether or not the string + * should be logged. + * + * @param int level: specified debug level + * @param string format_string input: format string + */ +#define T_DEBUG_L(level, format_string,...) \ + if ((level) > 0) { \ + fprintf(stderr,"[%s,%d] " #format_string " \n", __FILE__, __LINE__,##__VA_ARGS__); \ + } + + +/** + * Explicit error logging. Prints time, file name and line number + * + * @param string format_string input: printf style format string + */ +#define T_ERROR(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + } + + +/** + * Analagous to T_ERROR, additionally aborting the process. + * WARNING: macro calls abort(), ending program execution + * + * @param string format_string input: printf style format string + */ +#define T_ERROR_ABORT(format_string,...) \ + { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s,%d] [%s] ERROR: Going to abort " #format_string " \n", __FILE__, __LINE__,dbgtime,##__VA_ARGS__); \ + exit(1); \ + } + + +/** + * Log input message + * + * @param string format_string input: printf style format string + */ +#if T_GLOBAL_LOGGING_LEVEL > 0 + #define T_LOG_OPER(format_string,...) \ + { \ + if (T_GLOBAL_LOGGING_LEVEL > 0) { \ + time_t now; \ + char dbgtime[26] ; \ + time(&now); \ + ctime_r(&now, dbgtime); \ + dbgtime[24] = '\0'; \ + fprintf(stderr,"[%s] " #format_string " \n", dbgtime,##__VA_ARGS__); \ + } \ + } +#else + #define T_LOG_OPER(format_string,...) +#endif + +#endif // #ifndef _THRIFT_TLOGGING_H_ diff --git a/lib/cpp/src/TProcessor.h b/lib/cpp/src/TProcessor.h new file mode 100644 index 00000000..f2d5279a --- /dev/null +++ b/lib/cpp/src/TProcessor.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TPROCESSOR_H_ +#define _THRIFT_TPROCESSOR_H_ 1 + +#include +#include +#include + +namespace apache { namespace thrift { + +/** + * A processor is a generic object that acts upon two streams of data, one + * an input and the other an output. The definition of this object is loose, + * though the typical case is for some sort of server that either generates + * responses to an input stream or forwards data from one pipe onto another. + * + */ +class TProcessor { + public: + virtual ~TProcessor() {} + + virtual bool process(boost::shared_ptr in, + boost::shared_ptr out) = 0; + + bool process(boost::shared_ptr io) { + return process(io, io); + } + + protected: + TProcessor() {} +}; + +}} // apache::thrift + +#endif // #ifndef _THRIFT_PROCESSOR_H_ diff --git a/lib/cpp/src/TReflectionLocal.h b/lib/cpp/src/TReflectionLocal.h new file mode 100644 index 00000000..e83e4753 --- /dev/null +++ b/lib/cpp/src/TReflectionLocal.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TREFLECTIONLOCAL_H_ +#define _THRIFT_TREFLECTIONLOCAL_H_ 1 + +#include +#include +#include + +/** + * Local Reflection is a blanket term referring to the the structure + * and generation of this particular representation of Thrift types. + * (It is called local because it cannot be serialized by Thrift). + * + */ + +namespace apache { namespace thrift { namespace reflection { namespace local { + +using apache::thrift::protocol::TType; + +// We include this many bytes of the structure's fingerprint when serializing +// a top-level structure. Long enough to make collisions unlikely, short +// enough to not significantly affect the amount of memory used. +const int FP_PREFIX_LEN = 4; + +struct FieldMeta { + int16_t tag; + bool is_optional; +}; + +struct TypeSpec { + TType ttype; + uint8_t fp_prefix[FP_PREFIX_LEN]; + + // Use an anonymous union here so we can fit two TypeSpecs in one cache line. + union { + struct { + // Use parallel arrays here for denser packing (of the arrays). + FieldMeta* metas; + TypeSpec** specs; + } tstruct; + struct { + TypeSpec *subtype1; + TypeSpec *subtype2; + } tcontainer; + }; + + // Static initialization of unions isn't really possible, + // so take the plunge and use constructors. + // Hopefully they'll be evaluated at compile time. + + TypeSpec(TType ttype) : ttype(ttype) { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + } + + TypeSpec(TType ttype, + const uint8_t* fingerprint, + FieldMeta* metas, + TypeSpec** specs) : + ttype(ttype) + { + std::memcpy(fp_prefix, fingerprint, FP_PREFIX_LEN); + tstruct.metas = metas; + tstruct.specs = specs; + } + + TypeSpec(TType ttype, TypeSpec* subtype1, TypeSpec* subtype2) : + ttype(ttype) + { + std::memset(fp_prefix, 0, FP_PREFIX_LEN); + tcontainer.subtype1 = subtype1; + tcontainer.subtype2 = subtype2; + } + +}; + +}}}} // apache::thrift::reflection::local + +#endif // #ifndef _THRIFT_TREFLECTIONLOCAL_H_ diff --git a/lib/cpp/src/Thrift.cpp b/lib/cpp/src/Thrift.cpp new file mode 100644 index 00000000..ed99205b --- /dev/null +++ b/lib/cpp/src/Thrift.cpp @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { + +TOutput GlobalOutput; + +void TOutput::printf(const char *message, ...) { + // Try to reduce heap usage, even if printf is called rarely. + static const int STACK_BUF_SIZE = 256; + char stack_buf[STACK_BUF_SIZE]; + va_list ap; + + va_start(ap, message); + int need = vsnprintf(stack_buf, STACK_BUF_SIZE, message, ap); + va_end(ap); + + if (need < STACK_BUF_SIZE) { + f_(stack_buf); + return; + } + + char *heap_buf = (char*)malloc((need+1) * sizeof(char)); + if (heap_buf == NULL) { + // Malloc failed. We might as well print the stack buffer. + f_(stack_buf); + return; + } + + va_start(ap, message); + int rval = vsnprintf(heap_buf, need+1, message, ap); + va_end(ap); + // TODO(shigin): inform user + if (rval != -1) { + f_(heap_buf); + } + free(heap_buf); +} + +void TOutput::perror(const char *message, int errno_copy) { + std::string out = message + strerror_s(errno_copy); + f_(out.c_str()); +} + +std::string TOutput::strerror_s(int errno_copy) { +#ifndef HAVE_STRERROR_R + return "errno = " + boost::lexical_cast(errno_copy); +#else // HAVE_STRERROR_R + + char b_errbuf[1024] = { '\0' }; +#ifdef STRERROR_R_CHAR_P + char *b_error = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); +#else + char *b_error = b_errbuf; + int rv = strerror_r(errno_copy, b_errbuf, sizeof(b_errbuf)); + if (rv == -1) { + // strerror_r failed. omgwtfbbq. + return "XSI-compliant strerror_r() failed with errno = " + + boost::lexical_cast(errno_copy); + } +#endif + // Can anyone prove that explicit cast is probably not necessary + // to ensure that the string object is constructed before + // b_error becomes invalid? + return std::string(b_error); + +#endif // HAVE_STRERROR_R +} + +uint32_t TApplicationException::read(apache::thrift::protocol::TProtocol* iprot) { + uint32_t xfer = 0; + std::string fname; + apache::thrift::protocol::TType ftype; + int16_t fid; + + xfer += iprot->readStructBegin(fname); + + while (true) { + xfer += iprot->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + switch (fid) { + case 1: + if (ftype == apache::thrift::protocol::T_STRING) { + xfer += iprot->readString(message_); + } else { + xfer += iprot->skip(ftype); + } + break; + case 2: + if (ftype == apache::thrift::protocol::T_I32) { + int32_t type; + xfer += iprot->readI32(type); + type_ = (TApplicationExceptionType)type; + } else { + xfer += iprot->skip(ftype); + } + break; + default: + xfer += iprot->skip(ftype); + break; + } + xfer += iprot->readFieldEnd(); + } + + xfer += iprot->readStructEnd(); + return xfer; +} + +uint32_t TApplicationException::write(apache::thrift::protocol::TProtocol* oprot) const { + uint32_t xfer = 0; + xfer += oprot->writeStructBegin("TApplicationException"); + xfer += oprot->writeFieldBegin("message", apache::thrift::protocol::T_STRING, 1); + xfer += oprot->writeString(message_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldBegin("type", apache::thrift::protocol::T_I32, 2); + xfer += oprot->writeI32(type_); + xfer += oprot->writeFieldEnd(); + xfer += oprot->writeFieldStop(); + xfer += oprot->writeStructEnd(); + return xfer; +} + +}} // apache::thrift diff --git a/lib/cpp/src/Thrift.h b/lib/cpp/src/Thrift.h new file mode 100644 index 00000000..26d2b0fc --- /dev/null +++ b/lib/cpp/src/Thrift.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_THRIFT_H_ +#define _THRIFT_THRIFT_H_ 1 + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif +#include + +#include +#ifdef HAVE_INTTYPES_H +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "TLogging.h" + +namespace apache { namespace thrift { + +class TOutput { + public: + TOutput() : f_(&errorTimeWrapper) {} + + inline void setOutputFunction(void (*function)(const char *)){ + f_ = function; + } + + inline void operator()(const char *message){ + f_(message); + } + + // It is important to have a const char* overload here instead of + // just the string version, otherwise errno could be corrupted + // if there is some problem allocating memory when constructing + // the string. + void perror(const char *message, int errno_copy); + inline void perror(const std::string &message, int errno_copy) { + perror(message.c_str(), errno_copy); + } + + void printf(const char *message, ...); + + inline static void errorTimeWrapper(const char* msg) { + time_t now; + char dbgtime[25]; + time(&now); + ctime_r(&now, dbgtime); + dbgtime[24] = 0; + fprintf(stderr, "Thrift: %s %s\n", dbgtime, msg); + } + + /** Just like strerror_r but returns a C++ string object. */ + static std::string strerror_s(int errno_copy); + + private: + void (*f_)(const char *); +}; + +extern TOutput GlobalOutput; + +namespace protocol { + class TProtocol; +} + +class TException : public std::exception { + public: + TException() {} + + TException(const std::string& message) : + message_(message) {} + + virtual ~TException() throw() {} + + virtual const char* what() const throw() { + if (message_.empty()) { + return "Default TException."; + } else { + return message_.c_str(); + } + } + + protected: + std::string message_; + +}; + +class TApplicationException : public TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TApplicationExceptionType + { UNKNOWN = 0 + , UNKNOWN_METHOD = 1 + , INVALID_MESSAGE_TYPE = 2 + , WRONG_METHOD_NAME = 3 + , BAD_SEQUENCE_ID = 4 + , MISSING_RESULT = 5 + }; + + TApplicationException() : + TException(), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type) : + TException(), + type_(type) {} + + TApplicationException(const std::string& message) : + TException(message), + type_(UNKNOWN) {} + + TApplicationException(TApplicationExceptionType type, + const std::string& message) : + TException(message), + type_(type) {} + + virtual ~TApplicationException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TApplicationExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TApplicationException: Unknown application exception"; + case UNKNOWN_METHOD : return "TApplicationException: Unknown method"; + case INVALID_MESSAGE_TYPE : return "TApplicationException: Invalid message type"; + case WRONG_METHOD_NAME : return "TApplicationException: Wrong method name"; + case BAD_SEQUENCE_ID : return "TApplicationException: Bad sequence identifier"; + case MISSING_RESULT : return "TApplicationException: Missing result"; + default : return "TApplicationException: (Invalid exception type)"; + }; + } else { + return message_.c_str(); + } + } + + uint32_t read(protocol::TProtocol* iprot); + uint32_t write(protocol::TProtocol* oprot) const; + + protected: + /** + * Error code + */ + TApplicationExceptionType type_; + +}; + + +// Forward declare this structure used by TDenseProtocol +namespace reflection { namespace local { +struct TypeSpec; +}} + + +}} // apache::thrift + +#endif // #ifndef _THRIFT_THRIFT_H_ diff --git a/lib/cpp/src/concurrency/Exception.h b/lib/cpp/src/concurrency/Exception.h new file mode 100644 index 00000000..ec466297 --- /dev/null +++ b/lib/cpp/src/concurrency/Exception.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ +#define _THRIFT_CONCURRENCY_EXCEPTION_H_ 1 + +#include +#include + +namespace apache { namespace thrift { namespace concurrency { + +class NoSuchTaskException : public apache::thrift::TException {}; + +class UncancellableTaskException : public apache::thrift::TException {}; + +class InvalidArgumentException : public apache::thrift::TException {}; + +class IllegalStateException : public apache::thrift::TException {}; + +class TimedOutException : public apache::thrift::TException { +public: + TimedOutException():TException("TimedOutException"){}; + TimedOutException(const std::string& message ) : + TException(message) {} +}; + +class TooManyPendingTasksException : public apache::thrift::TException { +public: + TooManyPendingTasksException():TException("TooManyPendingTasksException"){}; + TooManyPendingTasksException(const std::string& message ) : + TException(message) {} +}; + +class SystemResourceException : public apache::thrift::TException { +public: + SystemResourceException() {} + + SystemResourceException(const std::string& message) : + TException(message) {} +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_EXCEPTION_H_ diff --git a/lib/cpp/src/concurrency/FunctionRunner.h b/lib/cpp/src/concurrency/FunctionRunner.h new file mode 100644 index 00000000..22169276 --- /dev/null +++ b/lib/cpp/src/concurrency/FunctionRunner.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H +#define _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H 1 + +#include +#include "thrift/lib/cpp/concurrency/Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Convenient implementation of Runnable that will execute arbitrary callbacks. + * Interfaces are provided to accept both a generic 'void(void)' callback, and + * a 'void* (void*)' pthread_create-style callback. + * + * Example use: + * void* my_thread_main(void* arg); + * shared_ptr factory = ...; + * shared_ptr thread = + * factory->newThread(shared_ptr( + * new FunctionRunner(my_thread_main, some_argument))); + * thread->start(); + * + * + */ + +class FunctionRunner : public Runnable { + public: + // This is the type of callback 'pthread_create()' expects. + typedef void* (*PthreadFuncPtr)(void *arg); + // This a fully-generic void(void) callback for custom bindings. + typedef std::tr1::function VoidFunc; + + /** + * Given a 'pthread_create' style callback, this FunctionRunner will + * execute the given callback. Note that the 'void*' return value is ignored. + */ + FunctionRunner(PthreadFuncPtr func, void* arg) + : func_(std::tr1::bind(func, arg)) + { } + + /** + * Given a generic callback, this FunctionRunner will execute it. + */ + FunctionRunner(const VoidFunc& cob) + : func_(cob) + { } + + + void run() { + func_(); + } + + private: + VoidFunc func_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_FUNCTION_RUNNER_H diff --git a/lib/cpp/src/concurrency/Monitor.cpp b/lib/cpp/src/concurrency/Monitor.cpp new file mode 100644 index 00000000..2055caa9 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.cpp @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Monitor.h" +#include "Exception.h" +#include "Util.h" + +#include +#include + +#include + +#include + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Monitor implementation using the POSIX pthread library + * + * @version $Id:$ + */ +class Monitor::Impl { + + public: + + Impl() : + mutexInitialized_(false), + condInitialized_(false) { + + if (pthread_mutex_init(&pthread_mutex_, NULL) == 0) { + mutexInitialized_ = true; + + if (pthread_cond_init(&pthread_cond_, NULL) == 0) { + condInitialized_ = true; + } + } + + if (!mutexInitialized_ || !condInitialized_) { + cleanup(); + throw SystemResourceException(); + } + } + + ~Impl() { cleanup(); } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + void wait(int64_t timeout) const { + + // XXX Need to assert that caller owns mutex + assert(timeout >= 0LL); + if (timeout == 0LL) { + int iret = pthread_cond_wait(&pthread_cond_, &pthread_mutex_); + assert(iret == 0); + } else { + struct timespec abstime; + int64_t now = Util::currentTime(); + Util::toTimespec(abstime, now + timeout); + int result = pthread_cond_timedwait(&pthread_cond_, + &pthread_mutex_, + &abstime); + if (result == ETIMEDOUT) { + // pthread_cond_timedwait has been observed to return early on + // various platforms, so comment out this assert. + //assert(Util::currentTime() >= (now + timeout)); + throw TimedOutException(); + } + } + } + + void notify() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_signal(&pthread_cond_); + assert(iret == 0); + } + + void notifyAll() { + // XXX Need to assert that caller owns mutex + int iret = pthread_cond_broadcast(&pthread_cond_); + assert(iret == 0); + } + + private: + + void cleanup() { + if (mutexInitialized_) { + mutexInitialized_ = false; + int iret = pthread_mutex_destroy(&pthread_mutex_); + assert(iret == 0); + } + + if (condInitialized_) { + condInitialized_ = false; + int iret = pthread_cond_destroy(&pthread_cond_); + assert(iret == 0); + } + } + + mutable pthread_mutex_t pthread_mutex_; + mutable bool mutexInitialized_; + mutable pthread_cond_t pthread_cond_; + mutable bool condInitialized_; +}; + +Monitor::Monitor() : impl_(new Monitor::Impl()) {} + +Monitor::~Monitor() { delete impl_; } + +void Monitor::lock() const { impl_->lock(); } + +void Monitor::unlock() const { impl_->unlock(); } + +void Monitor::wait(int64_t timeout) const { impl_->wait(timeout); } + +void Monitor::notify() const { impl_->notify(); } + +void Monitor::notifyAll() const { impl_->notifyAll(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Monitor.h b/lib/cpp/src/concurrency/Monitor.h new file mode 100644 index 00000000..234bf326 --- /dev/null +++ b/lib/cpp/src/concurrency/Monitor.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MONITOR_H_ +#define _THRIFT_CONCURRENCY_MONITOR_H_ 1 + +#include "Exception.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A monitor is a combination mutex and condition-event. Waiting and + * notifying condition events requires that the caller own the mutex. Mutex + * lock and unlock operations can be performed independently of condition + * events. This is more or less analogous to java.lang.Object multi-thread + * operations + * + * Note that all methods are const. Monitors implement logical constness, not + * bit constness. This allows const methods to call monitor methods without + * needing to cast away constness or change to non-const signatures. + * + * @version $Id:$ + */ +class Monitor { + + public: + + Monitor(); + + virtual ~Monitor(); + + virtual void lock() const; + + virtual void unlock() const; + + virtual void wait(int64_t timeout=0LL) const; + + virtual void notify() const; + + virtual void notifyAll() const; + + private: + + class Impl; + + Impl* impl_; +}; + +class Synchronized { + public: + + Synchronized(const Monitor& value) : + monitor_(value) { + monitor_.lock(); + } + + ~Synchronized() { + monitor_.unlock(); + } + + private: + const Monitor& monitor_; +}; + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MONITOR_H_ diff --git a/lib/cpp/src/concurrency/Mutex.cpp b/lib/cpp/src/concurrency/Mutex.cpp new file mode 100644 index 00000000..045dbdfe --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Mutex.h" + +#include +#include + +using boost::shared_ptr; + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Implementation of Mutex class using POSIX mutex + * + * @version $Id:$ + */ +class Mutex::impl { + public: + impl(Initializer init) : initialized_(false) { + init(&pthread_mutex_); + initialized_ = true; + } + + ~impl() { + if (initialized_) { + initialized_ = false; + int ret = pthread_mutex_destroy(&pthread_mutex_); + assert(ret == 0); + } + } + + void lock() const { pthread_mutex_lock(&pthread_mutex_); } + + bool trylock() const { return (0 == pthread_mutex_trylock(&pthread_mutex_)); } + + void unlock() const { pthread_mutex_unlock(&pthread_mutex_); } + + private: + mutable pthread_mutex_t pthread_mutex_; + mutable bool initialized_; +}; + +Mutex::Mutex(Initializer init) : impl_(new Mutex::impl(init)) {} + +void Mutex::lock() const { impl_->lock(); } + +bool Mutex::trylock() const { return impl_->trylock(); } + +void Mutex::unlock() const { impl_->unlock(); } + +void Mutex::DEFAULT_INITIALIZER(void* arg) { + pthread_mutex_t* pthread_mutex = (pthread_mutex_t*)arg; + int ret = pthread_mutex_init(pthread_mutex, NULL); + assert(ret == 0); +} + +static void init_with_kind(pthread_mutex_t* mutex, int kind) { + pthread_mutexattr_t mutexattr; + int ret = pthread_mutexattr_init(&mutexattr); + assert(ret == 0); + + // Apparently, this can fail. Should we really be aborting? + ret = pthread_mutexattr_settype(&mutexattr, kind); + assert(ret == 0); + + ret = pthread_mutex_init(mutex, &mutexattr); + assert(ret == 0); + + ret = pthread_mutexattr_destroy(&mutexattr); + assert(ret == 0); +} + +#ifdef PTHREAD_ADAPTIVE_MUTEX_INITIALIZER_NP +void Mutex::ADAPTIVE_INITIALIZER(void* arg) { + // From mysql source: mysys/my_thr_init.c + // Set mutex type to "fast" a.k.a "adaptive" + // + // In this case the thread may steal the mutex from some other thread + // that is waiting for the same mutex. This will save us some + // context switches but may cause a thread to 'starve forever' while + // waiting for the mutex (not likely if the code within the mutex is + // short). + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_ADAPTIVE_NP); +} +#endif + +#ifdef PTHREAD_RECURSIVE_MUTEX_INITIALIZER_NP +void Mutex::RECURSIVE_INITIALIZER(void* arg) { + init_with_kind((pthread_mutex_t*)arg, PTHREAD_MUTEX_RECURSIVE_NP); +} +#endif + + +/** + * Implementation of ReadWriteMutex class using POSIX rw lock + * + * @version $Id:$ + */ +class ReadWriteMutex::impl { +public: + impl() : initialized_(false) { + int ret = pthread_rwlock_init(&rw_lock_, NULL); + assert(ret == 0); + initialized_ = true; + } + + ~impl() { + if(initialized_) { + initialized_ = false; + int ret = pthread_rwlock_destroy(&rw_lock_); + assert(ret == 0); + } + } + + void acquireRead() const { pthread_rwlock_rdlock(&rw_lock_); } + + void acquireWrite() const { pthread_rwlock_wrlock(&rw_lock_); } + + bool attemptRead() const { return pthread_rwlock_tryrdlock(&rw_lock_); } + + bool attemptWrite() const { return pthread_rwlock_trywrlock(&rw_lock_); } + + void release() const { pthread_rwlock_unlock(&rw_lock_); } + +private: + mutable pthread_rwlock_t rw_lock_; + mutable bool initialized_; +}; + +ReadWriteMutex::ReadWriteMutex() : impl_(new ReadWriteMutex::impl()) {} + +void ReadWriteMutex::acquireRead() const { impl_->acquireRead(); } + +void ReadWriteMutex::acquireWrite() const { impl_->acquireWrite(); } + +bool ReadWriteMutex::attemptRead() const { return impl_->attemptRead(); } + +bool ReadWriteMutex::attemptWrite() const { return impl_->attemptWrite(); } + +void ReadWriteMutex::release() const { impl_->release(); } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/Mutex.h b/lib/cpp/src/concurrency/Mutex.h new file mode 100644 index 00000000..884412be --- /dev/null +++ b/lib/cpp/src/concurrency/Mutex.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_MUTEX_H_ +#define _THRIFT_CONCURRENCY_MUTEX_H_ 1 + +#include + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A simple mutex class + * + * @version $Id:$ + */ +class Mutex { + public: + typedef void (*Initializer)(void*); + + Mutex(Initializer init = DEFAULT_INITIALIZER); + virtual ~Mutex() {} + virtual void lock() const; + virtual bool trylock() const; + virtual void unlock() const; + + static void DEFAULT_INITIALIZER(void*); + static void ADAPTIVE_INITIALIZER(void*); + static void RECURSIVE_INITIALIZER(void*); + + private: + + class impl; + boost::shared_ptr impl_; +}; + +class ReadWriteMutex { +public: + ReadWriteMutex(); + virtual ~ReadWriteMutex() {} + + // these get the lock and block until it is done successfully + virtual void acquireRead() const; + virtual void acquireWrite() const; + + // these attempt to get the lock, returning false immediately if they fail + virtual bool attemptRead() const; + virtual bool attemptWrite() const; + + // this releases both read and write locks + virtual void release() const; + +private: + + class impl; + boost::shared_ptr impl_; +}; + +class Guard { + public: + Guard(const Mutex& value) : mutex_(value) { + mutex_.lock(); + } + ~Guard() { + mutex_.unlock(); + } + + private: + const Mutex& mutex_; +}; + +class RWGuard { + public: + RWGuard(const ReadWriteMutex& value, bool write = 0) : rw_mutex_(value) { + if (write) { + rw_mutex_.acquireWrite(); + } else { + rw_mutex_.acquireRead(); + } + } + ~RWGuard() { + rw_mutex_.release(); + } + private: + const ReadWriteMutex& rw_mutex_; +}; + + +// A little hack to prevent someone from trying to do "Guard(m);" +// Sorry for polluting the global namespace, but I think it's worth it. +#define Guard(m) incorrect_use_of_Guard(m) +#define RWGuard(m) incorrect_use_of_RWGuard(m) + + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_MUTEX_H_ diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.cpp b/lib/cpp/src/concurrency/PosixThreadFactory.cpp new file mode 100644 index 00000000..e48dce39 --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.cpp @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PosixThreadFactory.h" +#include "Exception.h" + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD +# include +#endif + +#include +#include + +#include + +#include + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::weak_ptr; + +/** + * The POSIX thread class. + * + * @version $Id:$ + */ +class PthreadThread: public Thread { + public: + + enum STATE { + uninitialized, + starting, + started, + stopping, + stopped + }; + + static const int MB = 1024 * 1024; + + static void* threadMain(void* arg); + + private: + pthread_t pthread_; + STATE state_; + int policy_; + int priority_; + int stackSize_; + weak_ptr self_; + bool detached_; + + public: + + PthreadThread(int policy, int priority, int stackSize, bool detached, shared_ptr runnable) : + pthread_(0), + state_(uninitialized), + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) { + + this->Thread::runnable(runnable); + } + + ~PthreadThread() { + /* Nothing references this thread, if is is not detached, do a join + now, otherwise the thread-id and, possibly, other resources will + be leaked. */ + if(!detached_) { + try { + join(); + } catch(...) { + // We're really hosed. + } + } + } + + void start() { + if (state_ != uninitialized) { + return; + } + + pthread_attr_t thread_attr; + if (pthread_attr_init(&thread_attr) != 0) { + throw SystemResourceException("pthread_attr_init failed"); + } + + if(pthread_attr_setdetachstate(&thread_attr, + detached_ ? + PTHREAD_CREATE_DETACHED : + PTHREAD_CREATE_JOINABLE) != 0) { + throw SystemResourceException("pthread_attr_setdetachstate failed"); + } + + // Set thread stack size + if (pthread_attr_setstacksize(&thread_attr, MB * stackSize_) != 0) { + throw SystemResourceException("pthread_attr_setstacksize failed"); + } + + // Set thread policy + if (pthread_attr_setschedpolicy(&thread_attr, policy_) != 0) { + throw SystemResourceException("pthread_attr_setschedpolicy failed"); + } + + struct sched_param sched_param; + sched_param.sched_priority = priority_; + + // Set thread priority + if (pthread_attr_setschedparam(&thread_attr, &sched_param) != 0) { + throw SystemResourceException("pthread_attr_setschedparam failed"); + } + + // Create reference + shared_ptr* selfRef = new shared_ptr(); + *selfRef = self_.lock(); + + state_ = starting; + + if (pthread_create(&pthread_, &thread_attr, threadMain, (void*)selfRef) != 0) { + throw SystemResourceException("pthread_create failed"); + } + } + + void join() { + if (!detached_ && state_ != uninitialized) { + void* ignore; + /* XXX + If join fails it is most likely due to the fact + that the last reference was the thread itself and cannot + join. This results in leaked threads and will eventually + cause the process to run out of thread resources. + We're beyond the point of throwing an exception. Not clear how + best to handle this. */ + detached_ = pthread_join(pthread_, &ignore) == 0; + } + } + + Thread::id_t getId() { + return (Thread::id_t)pthread_; + } + + shared_ptr runnable() const { return Thread::runnable(); } + + void runnable(shared_ptr value) { Thread::runnable(value); } + + void weakRef(shared_ptr self) { + assert(self.get() == this); + self_ = weak_ptr(self); + } +}; + +void* PthreadThread::threadMain(void* arg) { + shared_ptr thread = *(shared_ptr*)arg; + delete reinterpret_cast*>(arg); + + if (thread == NULL) { + return (void*)0; + } + + if (thread->state_ != starting) { + return (void*)0; + } + +#if GOOGLE_PERFTOOLS_REGISTER_THREAD + ProfilerRegisterThread(); +#endif + + thread->state_ = starting; + thread->runnable()->run(); + if (thread->state_ != stopping && thread->state_ != stopped) { + thread->state_ = stopping; + } + + return (void*)0; +} + +/** + * POSIX Thread factory implementation + */ +class PosixThreadFactory::Impl { + + private: + POLICY policy_; + PRIORITY priority_; + int stackSize_; + bool detached_; + + /** + * Converts generic posix thread schedule policy enums into pthread + * API values. + */ + static int toPthreadPolicy(POLICY policy) { + switch (policy) { + case OTHER: + return SCHED_OTHER; + case FIFO: + return SCHED_FIFO; + case ROUND_ROBIN: + return SCHED_RR; + } + return SCHED_OTHER; + } + + /** + * Converts relative thread priorities to absolute value based on posix + * thread scheduler policy + * + * The idea is simply to divide up the priority range for the given policy + * into the correpsonding relative priority level (lowest..highest) and + * then pro-rate accordingly. + */ + static int toPthreadPriority(POLICY policy, PRIORITY priority) { + int pthread_policy = toPthreadPolicy(policy); + int min_priority = sched_get_priority_min(pthread_policy); + int max_priority = sched_get_priority_max(pthread_policy); + int quanta = (HIGHEST - LOWEST) + 1; + float stepsperquanta = (max_priority - min_priority) / quanta; + + if (priority <= HIGHEST) { + return (int)(min_priority + stepsperquanta * priority); + } else { + // should never get here for priority increments. + assert(false); + return (int)(min_priority + stepsperquanta * NORMAL); + } + } + + public: + + Impl(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + policy_(policy), + priority_(priority), + stackSize_(stackSize), + detached_(detached) {} + + /** + * Creates a new POSIX thread to run the runnable object + * + * @param runnable A runnable object + */ + shared_ptr newThread(shared_ptr runnable) const { + shared_ptr result = shared_ptr(new PthreadThread(toPthreadPolicy(policy_), toPthreadPriority(policy_, priority_), stackSize_, detached_, runnable)); + result->weakRef(result); + runnable->thread(result); + return result; + } + + int getStackSize() const { return stackSize_; } + + void setStackSize(int value) { stackSize_ = value; } + + PRIORITY getPriority() const { return priority_; } + + /** + * Sets priority. + * + * XXX + * Need to handle incremental priorities properly. + */ + void setPriority(PRIORITY value) { priority_ = value; } + + bool isDetached() const { return detached_; } + + void setDetached(bool value) { detached_ = value; } + + Thread::id_t getCurrentThreadId() const { + // TODO(dreiss): Stop using C-style casts. + return (id_t)pthread_self(); + } + +}; + +PosixThreadFactory::PosixThreadFactory(POLICY policy, PRIORITY priority, int stackSize, bool detached) : + impl_(new PosixThreadFactory::Impl(policy, priority, stackSize, detached)) {} + +shared_ptr PosixThreadFactory::newThread(shared_ptr runnable) const { return impl_->newThread(runnable); } + +int PosixThreadFactory::getStackSize() const { return impl_->getStackSize(); } + +void PosixThreadFactory::setStackSize(int value) { impl_->setStackSize(value); } + +PosixThreadFactory::PRIORITY PosixThreadFactory::getPriority() const { return impl_->getPriority(); } + +void PosixThreadFactory::setPriority(PosixThreadFactory::PRIORITY value) { impl_->setPriority(value); } + +bool PosixThreadFactory::isDetached() const { return impl_->isDetached(); } + +void PosixThreadFactory::setDetached(bool value) { impl_->setDetached(value); } + +Thread::id_t PosixThreadFactory::getCurrentThreadId() const { return impl_->getCurrentThreadId(); } + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/PosixThreadFactory.h b/lib/cpp/src/concurrency/PosixThreadFactory.h new file mode 100644 index 00000000..d6d83a3a --- /dev/null +++ b/lib/cpp/src/concurrency/PosixThreadFactory.h @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ +#define _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ 1 + +#include "Thread.h" + +#include + +namespace apache { namespace thrift { namespace concurrency { + +/** + * A thread factory to create posix threads + * + * @version $Id:$ + */ +class PosixThreadFactory : public ThreadFactory { + + public: + + /** + * POSIX Thread scheduler policies + */ + enum POLICY { + OTHER, + FIFO, + ROUND_ROBIN + }; + + /** + * POSIX Thread scheduler relative priorities, + * + * Absolute priority is determined by scheduler policy and OS. This + * enumeration specifies relative priorities such that one can specify a + * priority withing a giving scheduler policy without knowing the absolute + * value of the priority. + */ + enum PRIORITY { + LOWEST = 0, + LOWER = 1, + LOW = 2, + NORMAL = 3, + HIGH = 4, + HIGHER = 5, + HIGHEST = 6, + INCREMENT = 7, + DECREMENT = 8 + }; + + /** + * Posix thread (pthread) factory. All threads created by a factory are reference-counted + * via boost::shared_ptr and boost::weak_ptr. The factory guarantees that threads and + * the Runnable tasks they host will be properly cleaned up once the last strong reference + * to both is given up. + * + * Threads are created with the specified policy, priority, stack-size and detachable-mode + * detached means the thread is free-running and will release all system resources the + * when it completes. A detachable thread is not joinable. The join method + * of a detachable thread will return immediately with no error. + * + * By default threads are not joinable. + */ + + PosixThreadFactory(POLICY policy=ROUND_ROBIN, PRIORITY priority=NORMAL, int stackSize=1, bool detached=true); + + // From ThreadFactory; + boost::shared_ptr newThread(boost::shared_ptr runnable) const; + + // From ThreadFactory; + Thread::id_t getCurrentThreadId() const; + + /** + * Gets stack size for created threads + * + * @return int size in megabytes + */ + virtual int getStackSize() const; + + /** + * Sets stack size for created threads + * + * @param value size in megabytes + */ + virtual void setStackSize(int value); + + /** + * Gets priority relative to current policy + */ + virtual PRIORITY getPriority() const; + + /** + * Sets priority relative to current policy + */ + virtual void setPriority(PRIORITY priority); + + /** + * Sets detached mode of threads + */ + virtual void setDetached(bool detached); + + /** + * Gets current detached mode + */ + virtual bool isDetached() const; + + private: + class Impl; + boost::shared_ptr impl_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_POSIXTHREADFACTORY_H_ diff --git a/lib/cpp/src/concurrency/Thread.h b/lib/cpp/src/concurrency/Thread.h new file mode 100644 index 00000000..d4282adb --- /dev/null +++ b/lib/cpp/src/concurrency/Thread.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREAD_H_ +#define _THRIFT_CONCURRENCY_THREAD_H_ 1 + +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { + +class Thread; + +/** + * Minimal runnable class. More or less analogous to java.lang.Runnable. + * + * @version $Id:$ + */ +class Runnable { + + public: + virtual ~Runnable() {}; + virtual void run() = 0; + + /** + * Gets the thread object that is hosting this runnable object - can return + * an empty boost::shared pointer if no references remain on thet thread object + */ + virtual boost::shared_ptr thread() { return thread_.lock(); } + + /** + * Sets the thread that is executing this object. This is only meant for + * use by concrete implementations of Thread. + */ + virtual void thread(boost::shared_ptr value) { thread_ = value; } + + private: + boost::weak_ptr thread_; +}; + +/** + * Minimal thread class. Returned by thread factory bound to a Runnable object + * and ready to start execution. More or less analogous to java.lang.Thread + * (minus all the thread group, priority, mode and other baggage, since that + * is difficult to abstract across platforms and is left for platform-specific + * ThreadFactory implemtations to deal with + * + * @see apache::thrift::concurrency::ThreadFactory) + */ +class Thread { + + public: + + typedef uint64_t id_t; + + virtual ~Thread() {}; + + /** + * Starts the thread. Does platform specific thread creation and + * configuration then invokes the run method of the Runnable object bound + * to this thread. + */ + virtual void start() = 0; + + /** + * Join this thread. Current thread blocks until this target thread + * completes. + */ + virtual void join() = 0; + + /** + * Gets the thread's platform-specific ID + */ + virtual id_t getId() = 0; + + /** + * Gets the runnable object this thread is hosting + */ + virtual boost::shared_ptr runnable() const { return _runnable; } + + protected: + virtual void runnable(boost::shared_ptr value) { _runnable = value; } + + private: + boost::shared_ptr _runnable; + +}; + +/** + * Factory to create platform-specific thread object and bind them to Runnable + * object for execution + */ +class ThreadFactory { + + public: + virtual ~ThreadFactory() {} + virtual boost::shared_ptr newThread(boost::shared_ptr runnable) const = 0; + + /** Gets the current thread id or unknown_thread_id if the current thread is not a thrift thread */ + + static const Thread::id_t unknown_thread_id; + + virtual Thread::id_t getCurrentThreadId() const = 0; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREAD_H_ diff --git a/lib/cpp/src/concurrency/ThreadManager.cpp b/lib/cpp/src/concurrency/ThreadManager.cpp new file mode 100644 index 00000000..abfcf6e7 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.cpp @@ -0,0 +1,493 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "ThreadManager.h" +#include "Exception.h" +#include "Monitor.h" + +#include + +#include +#include +#include + +#if defined(DEBUG) +#include +#endif //defined(DEBUG) + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; +using boost::dynamic_pointer_cast; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * it maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times. + * + * @version $Id:$ + */ +class ThreadManager::Impl : public ThreadManager { + + public: + Impl() : + workerCount_(0), + workerMaxCount_(0), + idleCount_(0), + pendingTaskCountMax_(0), + state_(ThreadManager::UNINITIALIZED) {} + + ~Impl() { stop(); } + + void start(); + + void stop() { stopImpl(false); } + + void join() { stopImpl(true); } + + const ThreadManager::STATE state() const { + return state_; + } + + shared_ptr threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; + } + + void threadFactory(shared_ptr value) { + Synchronized s(monitor_); + threadFactory_ = value; + } + + void addWorker(size_t value); + + void removeWorker(size_t value); + + size_t idleWorkerCount() const { + return idleCount_; + } + + size_t workerCount() const { + Synchronized s(monitor_); + return workerCount_; + } + + size_t pendingTaskCount() const { + Synchronized s(monitor_); + return tasks_.size(); + } + + size_t totalTaskCount() const { + Synchronized s(monitor_); + return tasks_.size() + workerCount_ - idleCount_; + } + + size_t pendingTaskCountMax() const { + Synchronized s(monitor_); + return pendingTaskCountMax_; + } + + void pendingTaskCountMax(const size_t value) { + Synchronized s(monitor_); + pendingTaskCountMax_ = value; + } + + bool canSleep(); + + void add(shared_ptr value, int64_t timeout); + + void remove(shared_ptr task); + +private: + void stopImpl(bool join); + + size_t workerCount_; + size_t workerMaxCount_; + size_t idleCount_; + size_t pendingTaskCountMax_; + + ThreadManager::STATE state_; + shared_ptr threadFactory_; + + + friend class ThreadManager::Task; + std::queue > tasks_; + Monitor monitor_; + Monitor workerMonitor_; + + friend class ThreadManager::Worker; + std::set > workers_; + std::set > deadWorkers_; + std::map > idMap_; +}; + +class ThreadManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() {} + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr runnable_; + friend class ThreadManager::Worker; + STATE state_; +}; + +class ThreadManager::Worker: public Runnable { + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + public: + Worker(ThreadManager::Impl* manager) : + manager_(manager), + state_(UNINITIALIZED), + idle_(false) {} + + ~Worker() {} + + private: + bool isActive() const { + return + (manager_->workerCount_ <= manager_->workerMaxCount_) || + (manager_->state_ == JOINING && !manager_->tasks_.empty()); + } + + public: + /** + * Worker entry point + * + * As long as worker thread is running, pull tasks off the task queue and + * execute. + */ + void run() { + bool active = false; + bool notifyManager = false; + + /** + * Increment worker semaphore and notify manager if worker count reached + * desired max + * + * Note: We have to release the monitor and acquire the workerMonitor + * since that is what the manager blocks on for worker add/remove + */ + { + Synchronized s(manager_->monitor_); + active = manager_->workerCount_ < manager_->workerMaxCount_; + if (active) { + manager_->workerCount_++; + notifyManager = manager_->workerCount_ == manager_->workerMaxCount_; + } + } + + if (notifyManager) { + Synchronized s(manager_->workerMonitor_); + manager_->workerMonitor_.notify(); + notifyManager = false; + } + + while (active) { + shared_ptr task; + + /** + * While holding manager monitor block for non-empty task queue (Also + * check that the thread hasn't been requested to stop). Once the queue + * is non-empty, dequeue a task, release monitor, and execute. If the + * worker max count has been decremented such that we exceed it, mark + * ourself inactive, decrement the worker count and notify the manager + * (technically we're notifying the next blocked thread but eventually + * the manager will see it. + */ + { + Synchronized s(manager_->monitor_); + active = isActive(); + + while (active && manager_->tasks_.empty()) { + manager_->idleCount_++; + idle_ = true; + manager_->monitor_.wait(); + active = isActive(); + idle_ = false; + manager_->idleCount_--; + } + + if (active) { + if (!manager_->tasks_.empty()) { + task = manager_->tasks_.front(); + manager_->tasks_.pop(); + if (task->state_ == ThreadManager::Task::WAITING) { + task->state_ = ThreadManager::Task::EXECUTING; + } + + /* If we have a pending task max and we just dropped below it, wakeup any + thread that might be blocked on add. */ + if (manager_->pendingTaskCountMax_ != 0 && + manager_->tasks_.size() == manager_->pendingTaskCountMax_ - 1) { + manager_->monitor_.notify(); + } + } + } else { + idle_ = true; + manager_->workerCount_--; + notifyManager = (manager_->workerCount_ == manager_->workerMaxCount_); + } + } + + if (task != NULL) { + if (task->state_ == ThreadManager::Task::EXECUTING) { + try { + task->run(); + } catch(...) { + // XXX need to log this + } + } + } + } + + { + Synchronized s(manager_->workerMonitor_); + manager_->deadWorkers_.insert(this->thread()); + if (notifyManager) { + manager_->workerMonitor_.notify(); + } + } + + return; + } + + private: + ThreadManager::Impl* manager_; + friend class ThreadManager::Impl; + STATE state_; + bool idle_; +}; + + + void ThreadManager::Impl::addWorker(size_t value) { + std::set > newThreads; + for (size_t ix = 0; ix < value; ix++) { + class ThreadManager::Worker; + shared_ptr worker = shared_ptr(new ThreadManager::Worker(this)); + newThreads.insert(threadFactory_->newThread(worker)); + } + + { + Synchronized s(monitor_); + workerMaxCount_ += value; + workers_.insert(newThreads.begin(), newThreads.end()); + } + + for (std::set >::iterator ix = newThreads.begin(); ix != newThreads.end(); ix++) { + shared_ptr worker = dynamic_pointer_cast((*ix)->runnable()); + worker->state_ = ThreadManager::Worker::STARTING; + (*ix)->start(); + idMap_.insert(std::pair >((*ix)->getId(), *ix)); + } + + { + Synchronized s(workerMonitor_); + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + } +} + +void ThreadManager::Impl::start() { + + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ == ThreadManager::UNINITIALIZED) { + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + state_ = ThreadManager::STARTED; + monitor_.notifyAll(); + } + + while (state_ == STARTING) { + monitor_.wait(); + } + } +} + +void ThreadManager::Impl::stopImpl(bool join) { + bool doStop = false; + if (state_ == ThreadManager::STOPPED) { + return; + } + + { + Synchronized s(monitor_); + if (state_ != ThreadManager::STOPPING && + state_ != ThreadManager::JOINING && + state_ != ThreadManager::STOPPED) { + doStop = true; + state_ = join ? ThreadManager::JOINING : ThreadManager::STOPPING; + } + } + + if (doStop) { + removeWorker(workerCount_); + } + + // XXX + // should be able to block here for transition to STOPPED since we're no + // using shared_ptrs + + { + Synchronized s(monitor_); + state_ = ThreadManager::STOPPED; + } + +} + +void ThreadManager::Impl::removeWorker(size_t value) { + std::set > removedThreads; + { + Synchronized s(monitor_); + if (value > workerMaxCount_) { + throw InvalidArgumentException(); + } + + workerMaxCount_ -= value; + + if (idleCount_ < value) { + for (size_t ix = 0; ix < idleCount_; ix++) { + monitor_.notify(); + } + } else { + monitor_.notifyAll(); + } + } + + { + Synchronized s(workerMonitor_); + + while (workerCount_ != workerMaxCount_) { + workerMonitor_.wait(); + } + + for (std::set >::iterator ix = deadWorkers_.begin(); ix != deadWorkers_.end(); ix++) { + workers_.erase(*ix); + idMap_.erase((*ix)->getId()); + } + + deadWorkers_.clear(); + } +} + + bool ThreadManager::Impl::canSleep() { + const Thread::id_t id = threadFactory_->getCurrentThreadId(); + return idMap_.find(id) == idMap_.end(); + } + + void ThreadManager::Impl::add(shared_ptr value, int64_t timeout) { + Synchronized s(monitor_); + + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } + + if (pendingTaskCountMax_ > 0 && (tasks_.size() >= pendingTaskCountMax_)) { + if (canSleep() && timeout >= 0) { + while (pendingTaskCountMax_ > 0 && tasks_.size() >= pendingTaskCountMax_) { + monitor_.wait(timeout); + } + } else { + throw TooManyPendingTasksException(); + } + } + + tasks_.push(shared_ptr(new ThreadManager::Task(value))); + + // If idle thread is available notify it, otherwise all worker threads are + // running and will get around to this task in time. + if (idleCount_ > 0) { + monitor_.notify(); + } + } + +void ThreadManager::Impl::remove(shared_ptr task) { + Synchronized s(monitor_); + if (state_ != ThreadManager::STARTED) { + throw IllegalStateException(); + } +} + +class SimpleThreadManager : public ThreadManager::Impl { + + public: + SimpleThreadManager(size_t workerCount=4, size_t pendingTaskCountMax=0) : + workerCount_(workerCount), + pendingTaskCountMax_(pendingTaskCountMax), + firstTime_(true) { + } + + void start() { + ThreadManager::Impl::pendingTaskCountMax(pendingTaskCountMax_); + ThreadManager::Impl::start(); + addWorker(workerCount_); + } + + private: + const size_t workerCount_; + const size_t pendingTaskCountMax_; + bool firstTime_; + Monitor monitor_; +}; + + +shared_ptr ThreadManager::newThreadManager() { + return shared_ptr(new ThreadManager::Impl()); +} + +shared_ptr ThreadManager::newSimpleThreadManager(size_t count, size_t pendingTaskCountMax) { + return shared_ptr(new SimpleThreadManager(count, pendingTaskCountMax)); +} + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/ThreadManager.h b/lib/cpp/src/concurrency/ThreadManager.h new file mode 100644 index 00000000..6e5a1781 --- /dev/null +++ b/lib/cpp/src/concurrency/ThreadManager.h @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ +#define _THRIFT_CONCURRENCY_THREADMANAGER_H_ 1 + +#include +#include +#include "Thread.h" + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Thread Pool Manager and related classes + * + * @version $Id:$ + */ +class ThreadManager; + +/** + * ThreadManager class + * + * This class manages a pool of threads. It uses a ThreadFactory to create + * threads. It never actually creates or destroys worker threads, rather + * It maintains statistics on number of idle threads, number of active threads, + * task backlog, and average wait and service times and informs the PoolPolicy + * object bound to instances of this manager of interesting transitions. It is + * then up the PoolPolicy object to decide if the thread pool size needs to be + * adjusted and call this object addWorker and removeWorker methods to make + * changes. + * + * This design allows different policy implementations to used this code to + * handle basic worker thread management and worker task execution and focus on + * policy issues. The simplest policy, StaticPolicy, does nothing other than + * create a fixed number of threads. + */ +class ThreadManager { + + protected: + ThreadManager() {} + + public: + virtual ~ThreadManager() {} + + /** + * Starts the thread manager. Verifies all attributes have been properly + * initialized, then allocates necessary resources to begin operation + */ + virtual void start() = 0; + + /** + * Stops the thread manager. Aborts all remaining unprocessed task, shuts + * down all created worker threads, and realeases all allocated resources. + * This method blocks for all worker threads to complete, thus it can + * potentially block forever if a worker thread is running a task that + * won't terminate. + */ + virtual void stop() = 0; + + /** + * Joins the thread manager. This is the same as stop, except that it will + * block until all the workers have finished their work. At that point + * the ThreadManager will transition into the STOPPED state. + */ + virtual void join() = 0; + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + JOINING, + STOPPING, + STOPPED + }; + + virtual const STATE state() const = 0; + + virtual boost::shared_ptr threadFactory() const = 0; + + virtual void threadFactory(boost::shared_ptr value) = 0; + + virtual void addWorker(size_t value=1) = 0; + + virtual void removeWorker(size_t value=1) = 0; + + /** + * Gets the current number of idle worker threads + */ + virtual size_t idleWorkerCount() const = 0; + + /** + * Gets the current number of total worker threads + */ + virtual size_t workerCount() const = 0; + + /** + * Gets the current number of pending tasks + */ + virtual size_t pendingTaskCount() const = 0; + + /** + * Gets the current number of pending and executing tasks + */ + virtual size_t totalTaskCount() const = 0; + + /** + * Gets the maximum pending task count. 0 indicates no maximum + */ + virtual size_t pendingTaskCountMax() const = 0; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * This method will block if pendingTaskCountMax() in not zero and pendingTaskCount() + * is greater than or equalt to pendingTaskCountMax(). If this method is called in the + * context of a ThreadManager worker thread it will throw a + * TooManyPendingTasksException + * + * @param task The task to queue for execution + * + * @param timeout Time to wait in milliseconds to add a task when a pending-task-count + * is specified. Specific cases: + * timeout = 0 : Wait forever to queue task. + * timeout = -1 : Return immediately if pending task count exceeds specified max + * + * @throws TooManyPendingTasksException Pending task count exceeds max pending task count + */ + virtual void add(boost::shared_ptrtask, int64_t timeout=0LL) = 0; + + /** + * Removes a pending task + */ + virtual void remove(boost::shared_ptr task) = 0; + + static boost::shared_ptr newThreadManager(); + + /** + * Creates a simple thread manager the uses count number of worker threads and has + * a pendingTaskCountMax maximum pending tasks. The default, 0, specified no limit + * on pending tasks + */ + static boost::shared_ptr newSimpleThreadManager(size_t count=4, size_t pendingTaskCountMax=0); + + class Task; + + class Worker; + + class Impl; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_THREADMANAGER_H_ diff --git a/lib/cpp/src/concurrency/TimerManager.cpp b/lib/cpp/src/concurrency/TimerManager.cpp new file mode 100644 index 00000000..25515dc8 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.cpp @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TimerManager.h" +#include "Exception.h" +#include "Util.h" + +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { + +using boost::shared_ptr; + +typedef std::multimap >::iterator task_iterator; +typedef std::pair task_range; + +/** + * TimerManager class + * + * @version $Id:$ + */ +class TimerManager::Task : public Runnable { + + public: + enum STATE { + WAITING, + EXECUTING, + CANCELLED, + COMPLETE + }; + + Task(shared_ptr runnable) : + runnable_(runnable), + state_(WAITING) {} + + ~Task() { + } + + void run() { + if (state_ == EXECUTING) { + runnable_->run(); + state_ = COMPLETE; + } + } + + private: + shared_ptr runnable_; + class TimerManager::Dispatcher; + friend class TimerManager::Dispatcher; + STATE state_; +}; + +class TimerManager::Dispatcher: public Runnable { + + public: + Dispatcher(TimerManager* manager) : + manager_(manager) {} + + ~Dispatcher() {} + + /** + * Dispatcher entry point + * + * As long as dispatcher thread is running, pull tasks off the task taskMap_ + * and execute. + */ + void run() { + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STARTING) { + manager_->state_ = TimerManager::STARTED; + manager_->monitor_.notifyAll(); + } + } + + do { + std::set > expiredTasks; + { + Synchronized s(manager_->monitor_); + task_iterator expiredTaskEnd; + int64_t now = Util::currentTime(); + while (manager_->state_ == TimerManager::STARTED && + (expiredTaskEnd = manager_->taskMap_.upper_bound(now)) == manager_->taskMap_.begin()) { + int64_t timeout = 0LL; + if (!manager_->taskMap_.empty()) { + timeout = manager_->taskMap_.begin()->first - now; + } + assert((timeout != 0 && manager_->taskCount_ > 0) || (timeout == 0 && manager_->taskCount_ == 0)); + try { + manager_->monitor_.wait(timeout); + } catch (TimedOutException &e) {} + now = Util::currentTime(); + } + + if (manager_->state_ == TimerManager::STARTED) { + for (task_iterator ix = manager_->taskMap_.begin(); ix != expiredTaskEnd; ix++) { + shared_ptr task = ix->second; + expiredTasks.insert(task); + if (task->state_ == TimerManager::Task::WAITING) { + task->state_ = TimerManager::Task::EXECUTING; + } + manager_->taskCount_--; + } + manager_->taskMap_.erase(manager_->taskMap_.begin(), expiredTaskEnd); + } + } + + for (std::set >::iterator ix = expiredTasks.begin(); ix != expiredTasks.end(); ix++) { + (*ix)->run(); + } + + } while (manager_->state_ == TimerManager::STARTED); + + { + Synchronized s(manager_->monitor_); + if (manager_->state_ == TimerManager::STOPPING) { + manager_->state_ = TimerManager::STOPPED; + manager_->monitor_.notify(); + } + } + return; + } + + private: + TimerManager* manager_; + friend class TimerManager; +}; + +TimerManager::TimerManager() : + taskCount_(0), + state_(TimerManager::UNINITIALIZED), + dispatcher_(shared_ptr(new Dispatcher(this))) { +} + + +TimerManager::~TimerManager() { + + // If we haven't been explicitly stopped, do so now. We don't need to grab + // the monitor here, since stop already takes care of reentrancy. + + if (state_ != STOPPED) { + try { + stop(); + } catch(...) { + throw; + // uhoh + } + } +} + +void TimerManager::start() { + bool doStart = false; + { + Synchronized s(monitor_); + if (threadFactory_ == NULL) { + throw InvalidArgumentException(); + } + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STARTING; + doStart = true; + } + } + + if (doStart) { + dispatcherThread_ = threadFactory_->newThread(dispatcher_); + dispatcherThread_->start(); + } + + { + Synchronized s(monitor_); + while (state_ == TimerManager::STARTING) { + monitor_.wait(); + } + assert(state_ != TimerManager::STARTING); + } +} + +void TimerManager::stop() { + bool doStop = false; + { + Synchronized s(monitor_); + if (state_ == TimerManager::UNINITIALIZED) { + state_ = TimerManager::STOPPED; + } else if (state_ != STOPPING && state_ != STOPPED) { + doStop = true; + state_ = STOPPING; + monitor_.notifyAll(); + } + while (state_ != STOPPED) { + monitor_.wait(); + } + } + + if (doStop) { + // Clean up any outstanding tasks + for (task_iterator ix = taskMap_.begin(); ix != taskMap_.end(); ix++) { + taskMap_.erase(ix); + } + + // Remove dispatcher's reference to us. + dispatcher_->manager_ = NULL; + } +} + +shared_ptr TimerManager::threadFactory() const { + Synchronized s(monitor_); + return threadFactory_; +} + +void TimerManager::threadFactory(shared_ptr value) { + Synchronized s(monitor_); + threadFactory_ = value; +} + +size_t TimerManager::taskCount() const { + return taskCount_; +} + +void TimerManager::add(shared_ptr task, int64_t timeout) { + int64_t now = Util::currentTime(); + timeout += now; + + { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } + + taskCount_++; + taskMap_.insert(std::pair >(timeout, shared_ptr(new Task(task)))); + + // If the task map was empty, or if we have an expiration that is earlier + // than any previously seen, kick the dispatcher so it can update its + // timeout + if (taskCount_ == 1 || timeout < taskMap_.begin()->first) { + monitor_.notify(); + } + } +} + +void TimerManager::add(shared_ptr task, const struct timespec& value) { + + int64_t expiration; + Util::toMilliseconds(expiration, value); + + int64_t now = Util::currentTime(); + + if (expiration < now) { + throw InvalidArgumentException(); + } + + add(task, expiration - now); +} + + +void TimerManager::remove(shared_ptr task) { + Synchronized s(monitor_); + if (state_ != TimerManager::STARTED) { + throw IllegalStateException(); + } +} + +const TimerManager::STATE TimerManager::state() const { return state_; } + +}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/concurrency/TimerManager.h b/lib/cpp/src/concurrency/TimerManager.h new file mode 100644 index 00000000..f3f799f9 --- /dev/null +++ b/lib/cpp/src/concurrency/TimerManager.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ +#define _THRIFT_CONCURRENCY_TIMERMANAGER_H_ 1 + +#include "Exception.h" +#include "Monitor.h" +#include "Thread.h" + +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Timer Manager + * + * This class dispatches timer tasks when they fall due. + * + * @version $Id:$ + */ +class TimerManager { + + public: + + TimerManager(); + + virtual ~TimerManager(); + + virtual boost::shared_ptr threadFactory() const; + + virtual void threadFactory(boost::shared_ptr value); + + /** + * Starts the timer manager service + * + * @throws IllegalArgumentException Missing thread factory attribute + */ + virtual void start(); + + /** + * Stops the timer manager service + */ + virtual void stop(); + + virtual size_t taskCount() const ; + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Time in milliseconds to delay before executing task + */ + virtual void add(boost::shared_ptr task, int64_t timeout); + + /** + * Adds a task to be executed at some time in the future by a worker thread. + * + * @param task The task to execute + * @param timeout Absolute time in the future to execute task. + */ + virtual void add(boost::shared_ptr task, const struct timespec& timeout); + + /** + * Removes a pending task + * + * @throws NoSuchTaskException Specified task doesn't exist. It was either + * processed already or this call was made for a + * task that was never added to this timer + * + * @throws UncancellableTaskException Specified task is already being + * executed or has completed execution. + */ + virtual void remove(boost::shared_ptr task); + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + virtual const STATE state() const; + + private: + boost::shared_ptr threadFactory_; + class Task; + friend class Task; + std::multimap > taskMap_; + size_t taskCount_; + Monitor monitor_; + STATE state_; + class Dispatcher; + friend class Dispatcher; + boost::shared_ptr dispatcher_; + boost::shared_ptr dispatcherThread_; +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_TIMERMANAGER_H_ diff --git a/lib/cpp/src/concurrency/Util.cpp b/lib/cpp/src/concurrency/Util.cpp new file mode 100644 index 00000000..1c449371 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.cpp @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "Util.h" + +#ifdef HAVE_CONFIG_H +#include +#endif + +#if defined(HAVE_CLOCK_GETTIME) +#include +#elif defined(HAVE_GETTIMEOFDAY) +#include +#endif // defined(HAVE_CLOCK_GETTIME) + +namespace apache { namespace thrift { namespace concurrency { + +const int64_t Util::currentTime() { + int64_t result; + +#if defined(HAVE_CLOCK_GETTIME) + struct timespec now; + int ret = clock_gettime(CLOCK_REALTIME, &now); + assert(ret == 0); + toMilliseconds(result, now); +#elif defined(HAVE_GETTIMEOFDAY) + struct timeval now; + int ret = gettimeofday(&now, NULL); + assert(ret == 0); + toMilliseconds(result, now); +#else +#error "No high-precision clock is available." +#endif // defined(HAVE_CLOCK_GETTIME) + + return result; +} + + +}}} // apache::thrift::concurrency diff --git a/lib/cpp/src/concurrency/Util.h b/lib/cpp/src/concurrency/Util.h new file mode 100644 index 00000000..25fcc208 --- /dev/null +++ b/lib/cpp/src/concurrency/Util.h @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_CONCURRENCY_UTIL_H_ +#define _THRIFT_CONCURRENCY_UTIL_H_ 1 + +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { + +/** + * Utility methods + * + * This class contains basic utility methods for converting time formats, + * and other common platform-dependent concurrency operations. + * It should not be included in API headers for other concurrency library + * headers, since it will, by definition, pull in all sorts of horrid + * platform dependent crap. Rather it should be inluded directly in + * concurrency library implementation source. + * + * @version $Id:$ + */ +class Util { + + static const int64_t NS_PER_S = 1000000000LL; + static const int64_t US_PER_S = 1000000LL; + static const int64_t MS_PER_S = 1000LL; + + static const int64_t NS_PER_MS = NS_PER_S / MS_PER_S; + static const int64_t US_PER_MS = US_PER_S / MS_PER_S; + + public: + + /** + * Converts millisecond timestamp into a timespec struct + * + * @param struct timespec& result + * @param time or duration in milliseconds + */ + static void toTimespec(struct timespec& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_nsec = (value % MS_PER_S) * NS_PER_MS; // ms to ns + } + + static void toTimeval(struct timeval& result, int64_t value) { + result.tv_sec = value / MS_PER_S; // ms to s + result.tv_usec = (value % MS_PER_S) * US_PER_MS; // ms to us + } + + /** + * Converts struct timespec to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timespec& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_nsec / NS_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_nsec) % NS_PER_MS >= (NS_PER_MS / 2)) { + ++result; + } + } + + /** + * Converts struct timeval to milliseconds + */ + static const void toMilliseconds(int64_t& result, const struct timeval& value) { + result = (value.tv_sec * MS_PER_S) + (value.tv_usec / US_PER_MS); + // round up -- int64_t cast is to avoid a compiler error for some GCCs + if (int64_t(value.tv_usec) % US_PER_MS >= (US_PER_MS / 2)) { + ++result; + } + } + + /** + * Get current time as milliseconds from epoch + */ + static const int64_t currentTime(); +}; + +}}} // apache::thrift::concurrency + +#endif // #ifndef _THRIFT_CONCURRENCY_UTIL_H_ diff --git a/lib/cpp/src/concurrency/test/Tests.cpp b/lib/cpp/src/concurrency/test/Tests.cpp new file mode 100644 index 00000000..c80bb883 --- /dev/null +++ b/lib/cpp/src/concurrency/test/Tests.cpp @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "ThreadFactoryTests.h" +#include "TimerManagerTests.h" +#include "ThreadManagerTests.h" + +int main(int argc, char** argv) { + + std::string arg; + + std::vector args(argc - 1 > 1 ? argc - 1 : 1); + + args[0] = "all"; + + for (int ix = 1; ix < argc; ix++) { + args[ix - 1] = std::string(argv[ix]); + } + + bool runAll = args[0].compare("all") == 0; + + if (runAll || args[0].compare("thread-factory") == 0) { + + ThreadFactoryTests threadFactoryTests; + + std::cout << "ThreadFactory tests..." << std::endl; + + size_t count = 1000; + size_t floodLoops = 1; + size_t floodCount = 100000; + + std::cout << "\t\tThreadFactory reap N threads test: N = " << count << std::endl; + + assert(threadFactoryTests.reapNThreads(count)); + + std::cout << "\t\tThreadFactory floodN threads test: N = " << floodCount << std::endl; + + assert(threadFactoryTests.floodNTest(floodLoops, floodCount)); + + std::cout << "\t\tThreadFactory synchronous start test" << std::endl; + + assert(threadFactoryTests.synchStartTest()); + + std::cout << "\t\tThreadFactory monitor timeout test" << std::endl; + + assert(threadFactoryTests.monitorTimeoutTest()); + } + + if (runAll || args[0].compare("util") == 0) { + + std::cout << "Util tests..." << std::endl; + + std::cout << "\t\tUtil minimum time" << std::endl; + + int64_t time00 = Util::currentTime(); + int64_t time01 = Util::currentTime(); + + std::cout << "\t\t\tMinimum time: " << time01 - time00 << "ms" << std::endl; + + time00 = Util::currentTime(); + time01 = time00; + size_t count = 0; + + while (time01 < time00 + 10) { + count++; + time01 = Util::currentTime(); + } + + std::cout << "\t\t\tscall per ms: " << count / (time01 - time00) << std::endl; + } + + + if (runAll || args[0].compare("timer-manager") == 0) { + + std::cout << "TimerManager tests..." << std::endl; + + std::cout << "\t\tTimerManager test00" << std::endl; + + TimerManagerTests timerManagerTests; + + assert(timerManagerTests.test00()); + } + + if (runAll || args[0].compare("thread-manager") == 0) { + + std::cout << "ThreadManager tests..." << std::endl; + + { + + size_t workerCount = 100; + + size_t taskCount = 100000; + + int64_t delay = 10LL; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + assert(threadManagerTests.loadTest(taskCount, delay, workerCount)); + + std::cout << "\t\tThreadManager block test: worker count: " << workerCount << " delay: " << delay << std::endl; + + assert(threadManagerTests.blockTest(delay, workerCount)); + + } + } + + if (runAll || args[0].compare("thread-manager-benchmark") == 0) { + + std::cout << "ThreadManager benchmark tests..." << std::endl; + + { + + size_t minWorkerCount = 2; + + size_t maxWorkerCount = 512; + + size_t tasksPerWorker = 1000; + + int64_t delay = 10LL; + + for (size_t workerCount = minWorkerCount; workerCount < maxWorkerCount; workerCount*= 2) { + + size_t taskCount = workerCount * tasksPerWorker; + + std::cout << "\t\tThreadManager load test: worker count: " << workerCount << " task count: " << taskCount << " delay: " << delay << std::endl; + + ThreadManagerTests threadManagerTests; + + threadManagerTests.loadTest(taskCount, delay, workerCount); + } + } + } +} diff --git a/lib/cpp/src/concurrency/test/ThreadFactoryTests.h b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h new file mode 100644 index 00000000..859fbaf5 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadFactoryTests.h @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using boost::shared_ptr; +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadFactoryTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task() {} + + void run() { + std::cout << "\t\t\tHello World" << std::endl; + } + }; + + /** + * Hello world test + */ + bool helloWorldTest() { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr task = shared_ptr(new ThreadFactoryTests::Task()); + + shared_ptr thread = threadFactory.newThread(task); + + thread->start(); + + thread->join(); + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + /** + * Reap N threads + */ + class ReapNTask: public Runnable { + + public: + + ReapNTask(Monitor& monitor, int& activeCount) : + _monitor(monitor), + _count(activeCount) {} + + void run() { + Synchronized s(_monitor); + + _count--; + + //std::cout << "\t\t\tthread count: " << _count << std::endl; + + if (_count == 0) { + _monitor.notify(); + } + } + + Monitor& _monitor; + + int& _count; + }; + + bool reapNThreads(int loop=1, int count=10) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + Monitor* monitor = new Monitor(); + + for(int lix = 0; lix < loop; lix++) { + + int* activeCount = new int(count); + + std::set > threads; + + int tix; + + for (tix = 0; tix < count; tix++) { + try { + threads.insert(threadFactory.newThread(shared_ptr(new ReapNTask(*monitor, *activeCount)))); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to create " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + tix = 0; + for (std::set >::const_iterator thread = threads.begin(); thread != threads.end(); tix++, ++thread) { + + try { + (*thread)->start(); + } catch(SystemResourceException& e) { + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + throw e; + } + } + + { + Synchronized s(*monitor); + while (*activeCount > 0) { + monitor->wait(1000); + } + } + + for (std::set >::const_iterator thread = threads.begin(); thread != threads.end(); thread++) { + threads.erase(*thread); + } + + std::cout << "\t\t\treaped " << lix * count << " threads" << std::endl; + } + + std::cout << "\t\t\tSuccess!" << std::endl; + + return true; + } + + class SynchStartTask: public Runnable { + + public: + + enum STATE { + UNINITIALIZED, + STARTING, + STARTED, + STOPPING, + STOPPED + }; + + SynchStartTask(Monitor& monitor, volatile STATE& state) : + _monitor(monitor), + _state(state) {} + + void run() { + { + Synchronized s(_monitor); + if (_state == SynchStartTask::STARTING) { + _state = SynchStartTask::STARTED; + _monitor.notify(); + } + } + + { + Synchronized s(_monitor); + while (_state == SynchStartTask::STARTED) { + _monitor.wait(); + } + + if (_state == SynchStartTask::STOPPING) { + _state = SynchStartTask::STOPPED; + _monitor.notifyAll(); + } + } + } + + private: + Monitor& _monitor; + volatile STATE& _state; + }; + + bool synchStartTest() { + + Monitor monitor; + + SynchStartTask::STATE state = SynchStartTask::UNINITIALIZED; + + shared_ptr task = shared_ptr(new SynchStartTask(monitor, state)); + + PosixThreadFactory threadFactory = PosixThreadFactory(); + + shared_ptr thread = threadFactory.newThread(task); + + if (state == SynchStartTask::UNINITIALIZED) { + + state = SynchStartTask::STARTING; + + thread->start(); + } + + { + Synchronized s(monitor); + while (state == SynchStartTask::STARTING) { + monitor.wait(); + } + } + + assert(state != SynchStartTask::STARTING); + + { + Synchronized s(monitor); + + try { + monitor.wait(100); + } catch(TimedOutException& e) { + } + + if (state == SynchStartTask::STARTED) { + + state = SynchStartTask::STOPPING; + + monitor.notify(); + } + + while (state == SynchStartTask::STOPPING) { + monitor.wait(); + } + } + + assert(state == SynchStartTask::STOPPED); + + bool success = true; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "!" << std::endl; + + return true; + } + + /** See how accurate monitor timeout is. */ + + bool monitorTimeoutTest(size_t count=1000, int64_t timeout=10) { + + Monitor monitor; + + int64_t startTime = Util::currentTime(); + + for (size_t ix = 0; ix < count; ix++) { + { + Synchronized s(monitor); + try { + monitor.wait(timeout); + } catch(TimedOutException& e) { + } + } + } + + int64_t endTime = Util::currentTime(); + + double error = ((endTime - startTime) - (count * timeout)) / (double)(count * timeout); + + if (error < 0.0) { + + error *= 1.0; + } + + bool success = error < ThreadFactoryTests::ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << count * timeout << "ms elapsed time: "<< endTime - startTime << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + + class FloodTask : public Runnable { + public: + + FloodTask(const size_t id) :_id(id) {} + ~FloodTask(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " done" << std::endl; + } + } + + void run(){ + if(_id % 1000 == 0) { + std::cout << "\t\tthread " << _id << " started" << std::endl; + } + + usleep(1); + } + const size_t _id; + }; + + void foo(PosixThreadFactory *tf) { + } + + bool floodNTest(size_t loop=1, size_t count=100000) { + + bool success = false; + + for(size_t lix = 0; lix < loop; lix++) { + + PosixThreadFactory threadFactory = PosixThreadFactory(); + threadFactory.setDetached(true); + + for(size_t tix = 0; tix < count; tix++) { + + try { + + shared_ptr task(new FloodTask(lix * count + tix )); + + shared_ptr thread = threadFactory.newThread(task); + + thread->start(); + + usleep(1); + + } catch (TException& e) { + + std::cout << "\t\t\tfailed to start " << lix * count + tix << " thread " << e.what() << std::endl; + + return success; + } + } + + std::cout << "\t\t\tflooded " << (lix + 1) * count << " threads" << std::endl; + + success = true; + } + + return success; + } +}; + +const double ThreadFactoryTests::ERROR = .20; + +}}}} // apache::thrift::concurrency::test + diff --git a/lib/cpp/src/concurrency/test/ThreadManagerTests.h b/lib/cpp/src/concurrency/test/ThreadManagerTests.h new file mode 100644 index 00000000..e7b51743 --- /dev/null +++ b/lib/cpp/src/concurrency/test/ThreadManagerTests.h @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class ThreadManagerTests { + +public: + + static const double ERROR; + + class Task: public Runnable { + + public: + + Task(Monitor& monitor, size_t& count, int64_t timeout) : + _monitor(monitor), + _count(count), + _timeout(timeout), + _done(false) {} + + void run() { + + _startTime = Util::currentTime(); + + { + Synchronized s(_sleep); + + try { + _sleep.wait(_timeout); + } catch(TimedOutException& e) { + ; + }catch(...) { + assert(0); + } + } + + _endTime = Util::currentTime(); + + _done = true; + + { + Synchronized s(_monitor); + + // std::cout << "Thread " << _count << " completed " << std::endl; + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + size_t& _count; + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + bool _done; + Monitor _sleep; + }; + + /** + * Dispatch count tasks, each of which blocks for timeout milliseconds then + * completes. Verify that all tasks completed and that thread manager cleans + * up properly on delete. + */ + bool loadTest(size_t count=100, int64_t timeout=100LL, size_t workerCount=4) { + + Monitor monitor; + + size_t activeCount = count; + + shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount); + + shared_ptr threadFactory = shared_ptr(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set > tasks; + + for (size_t ix = 0; ix < count; ix++) { + + tasks.insert(shared_ptr(new ThreadManagerTests::Task(monitor, activeCount, timeout))); + } + + int64_t time00 = Util::currentTime(); + + for (std::set >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + threadManager->add(*ix); + } + + { + Synchronized s(monitor); + + while(activeCount > 0) { + + monitor.wait(); + } + } + + int64_t time01 = Util::currentTime(); + + int64_t firstTime = 9223372036854775807LL; + int64_t lastTime = 0; + + double averageTime = 0; + int64_t minTime = 9223372036854775807LL; + int64_t maxTime = 0; + + for (std::set >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + + shared_ptr task = *ix; + + int64_t delta = task->_endTime - task->_startTime; + + assert(delta > 0); + + if (task->_startTime < firstTime) { + firstTime = task->_startTime; + } + + if (task->_endTime > lastTime) { + lastTime = task->_endTime; + } + + if (delta < minTime) { + minTime = delta; + } + + if (delta > maxTime) { + maxTime = delta; + } + + averageTime+= delta; + } + + averageTime /= count; + + std::cout << "\t\t\tfirst start: " << firstTime << "ms Last end: " << lastTime << "ms min: " << minTime << "ms max: " << maxTime << "ms average: " << averageTime << "ms" << std::endl; + + double expectedTime = ((count + (workerCount - 1)) / workerCount) * timeout; + + double error = ((time01 - time00) - expectedTime) / expectedTime; + + if (error < 0) { + error*= -1.0; + } + + bool success = error < ERROR; + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << "! expected time: " << expectedTime << "ms elapsed time: "<< time01 - time00 << "ms error%: " << error * 100.0 << std::endl; + + return success; + } + + class BlockTask: public Runnable { + + public: + + BlockTask(Monitor& monitor, Monitor& bmonitor, size_t& count) : + _monitor(monitor), + _bmonitor(bmonitor), + _count(count) {} + + void run() { + { + Synchronized s(_bmonitor); + + _bmonitor.wait(); + + } + + { + Synchronized s(_monitor); + + _count--; + + if (_count == 0) { + + _monitor.notify(); + } + } + } + + Monitor& _monitor; + Monitor& _bmonitor; + size_t& _count; + }; + + /** + * Block test. Create pendingTaskCountMax tasks. Verify that we block adding the + * pendingTaskCountMax + 1th task. Verify that we unblock when a task completes */ + + bool blockTest(int64_t timeout=100LL, size_t workerCount=2) { + + bool success = false; + + try { + + Monitor bmonitor; + Monitor monitor; + + size_t pendingTaskMaxCount = workerCount; + + size_t activeCounts[] = {workerCount, pendingTaskMaxCount, 1}; + + shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount, pendingTaskMaxCount); + + shared_ptr threadFactory = shared_ptr(new PosixThreadFactory()); + + threadFactory->setPriority(PosixThreadFactory::HIGHEST); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + std::set > tasks; + + for (size_t ix = 0; ix < workerCount; ix++) { + + tasks.insert(shared_ptr(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[0]))); + } + + for (size_t ix = 0; ix < pendingTaskMaxCount; ix++) { + + tasks.insert(shared_ptr(new ThreadManagerTests::BlockTask(monitor, bmonitor,activeCounts[1]))); + } + + for (std::set >::iterator ix = tasks.begin(); ix != tasks.end(); ix++) { + threadManager->add(*ix); + } + + if(!(success = (threadManager->totalTaskCount() == pendingTaskMaxCount + workerCount))) { + throw TException("Unexpected pending task count"); + } + + shared_ptr extraTask(new ThreadManagerTests::BlockTask(monitor, bmonitor, activeCounts[2])); + + try { + threadManager->add(extraTask, 1); + throw TException("Unexpected success adding task in excess of pending task count"); + } catch(TimedOutException& e) { + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[0] != 0) { + monitor.wait(); + } + } + + std::cout << "\t\t\t" << "Pending tasks " << threadManager->pendingTaskCount() << std::endl; + + try { + threadManager->add(extraTask, 1); + } catch(TimedOutException& e) { + std::cout << "\t\t\t" << "add timed out unexpectedly" << std::endl; + throw TException("Unexpected timeout adding task"); + + } catch(TooManyPendingTasksException& e) { + std::cout << "\t\t\t" << "add encountered too many pending exepctions" << std::endl; + throw TException("Unexpected timeout adding task"); + } + + // Wake up tasks that were pending before and wait for them to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[1] != 0) { + monitor.wait(); + } + } + + // Wake up the extra task and wait for it to complete + + { + Synchronized s(bmonitor); + + bmonitor.notifyAll(); + } + + { + Synchronized s(monitor); + + while(activeCounts[2] != 0) { + monitor.wait(); + } + } + + if(!(success = (threadManager->totalTaskCount() == 0))) { + throw TException("Unexpected pending task count"); + } + + } catch(TException& e) { + } + + std::cout << "\t\t\t" << (success ? "Success" : "Failure") << std::endl; + return success; + } +}; + +const double ThreadManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + +using namespace apache::thrift::concurrency::test; + diff --git a/lib/cpp/src/concurrency/test/TimerManagerTests.h b/lib/cpp/src/concurrency/test/TimerManagerTests.h new file mode 100644 index 00000000..e6fe6ce7 --- /dev/null +++ b/lib/cpp/src/concurrency/test/TimerManagerTests.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +namespace apache { namespace thrift { namespace concurrency { namespace test { + +using namespace apache::thrift::concurrency; + +/** + * ThreadManagerTests class + * + * @version $Id:$ + */ +class TimerManagerTests { + + public: + + static const double ERROR; + + class Task: public Runnable { + public: + + Task(Monitor& monitor, int64_t timeout) : + _timeout(timeout), + _startTime(Util::currentTime()), + _monitor(monitor), + _success(false), + _done(false) {} + + ~Task() { std::cerr << this << std::endl; } + + void run() { + + _endTime = Util::currentTime(); + + // Figure out error percentage + + int64_t delta = _endTime - _startTime; + + + delta = delta > _timeout ? delta - _timeout : _timeout - delta; + + float error = delta / _timeout; + + if(error < ERROR) { + _success = true; + } + + _done = true; + + std::cout << "\t\t\tTimerManagerTests::Task[" << this << "] done" << std::endl; //debug + + {Synchronized s(_monitor); + _monitor.notifyAll(); + } + } + + int64_t _timeout; + int64_t _startTime; + int64_t _endTime; + Monitor& _monitor; + bool _success; + bool _done; + }; + + /** + * This test creates two tasks and waits for the first to expire within 10% + * of the expected expiration time. It then verifies that the timer manager + * properly clean up itself and the remaining orphaned timeout task when the + * manager goes out of scope and its destructor is called. + */ + bool test00(int64_t timeout=1000LL) { + + shared_ptr orphanTask = shared_ptr(new TimerManagerTests::Task(_monitor, 10 * timeout)); + + { + + TimerManager timerManager; + + timerManager.threadFactory(shared_ptr(new PosixThreadFactory())); + + timerManager.start(); + + assert(timerManager.state() == TimerManager::STARTED); + + shared_ptr task = shared_ptr(new TimerManagerTests::Task(_monitor, timeout)); + + { + Synchronized s(_monitor); + + timerManager.add(orphanTask, 10 * timeout); + + timerManager.add(task, timeout); + + _monitor.wait(); + } + + assert(task->_done); + + + std::cout << "\t\t\t" << (task->_success ? "Success" : "Failure") << "!" << std::endl; + } + + // timerManager.stop(); This is where it happens via destructor + + assert(!orphanTask->_done); + + return true; + } + + friend class TestTask; + + Monitor _monitor; +}; + +const double TimerManagerTests::ERROR = .20; + +}}}} // apache::thrift::concurrency + diff --git a/lib/cpp/src/processor/PeekProcessor.cpp b/lib/cpp/src/processor/PeekProcessor.cpp new file mode 100644 index 00000000..c721861b --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.cpp @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "PeekProcessor.h" + +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; +using namespace apache::thrift; + +namespace apache { namespace thrift { namespace processor { + +PeekProcessor::PeekProcessor() { + memoryBuffer_.reset(new TMemoryBuffer()); + targetTransport_ = memoryBuffer_; +} +PeekProcessor::~PeekProcessor() {} + +void PeekProcessor::initialize(boost::shared_ptr actualProcessor, + boost::shared_ptr protocolFactory, + boost::shared_ptr transportFactory) { + actualProcessor_ = actualProcessor; + pipedProtocol_ = protocolFactory->getProtocol(targetTransport_); + transportFactory_ = transportFactory; + transportFactory_->initializeTargetTransport(targetTransport_); +} + +boost::shared_ptr PeekProcessor::getPipedTransport(boost::shared_ptr in) { + return transportFactory_->getTransport(in); +} + +void PeekProcessor::setTargetTransport(boost::shared_ptr targetTransport) { + targetTransport_ = targetTransport; + if (boost::dynamic_pointer_cast(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast(targetTransport); + } else if (boost::dynamic_pointer_cast(targetTransport_)) { + memoryBuffer_ = boost::dynamic_pointer_cast(boost::dynamic_pointer_cast(targetTransport_)->getTargetTransport()); + } + + if (!memoryBuffer_) { + throw TException("Target transport must be a TMemoryBuffer or a TPipedTransport with TMemoryBuffer"); + } +} + +bool PeekProcessor::process(boost::shared_ptr in, + boost::shared_ptr out) { + + std::string fname; + TMessageType mtype; + int32_t seqid; + in->readMessageBegin(fname, mtype, seqid); + + if (mtype != T_CALL) { + throw TException("Unexpected message type"); + } + + // Peek at the name + peekName(fname); + + TType ftype; + int16_t fid; + while (true) { + in->readFieldBegin(fname, ftype, fid); + if (ftype == T_STOP) { + break; + } + + // Peek at the variable + peek(in, ftype, fid); + in->readFieldEnd(); + } + in->readMessageEnd(); + in->getTransport()->readEnd(); + + // + // All the data is now in memoryBuffer_ and ready to be processed + // + + // Let's first take a peek at the full data in memory + uint8_t* buffer; + uint32_t size; + memoryBuffer_->getBuffer(&buffer, &size); + peekBuffer(buffer, size); + + // Done peeking at variables + peekEnd(); + + bool ret = actualProcessor_->process(pipedProtocol_, out); + memoryBuffer_->resetBuffer(); + return ret; +} + +void PeekProcessor::peekName(const std::string& fname) { +} + +void PeekProcessor::peekBuffer(uint8_t* buffer, uint32_t size) { +} + +void PeekProcessor::peek(boost::shared_ptr in, + TType ftype, + int16_t fid) { + in->skip(ftype); +} + +void PeekProcessor::peekEnd() {} + +}}} diff --git a/lib/cpp/src/processor/PeekProcessor.h b/lib/cpp/src/processor/PeekProcessor.h new file mode 100644 index 00000000..0f7c016a --- /dev/null +++ b/lib/cpp/src/processor/PeekProcessor.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef PEEKPROCESSOR_H +#define PEEKPROCESSOR_H + +#include +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for peeking at the raw data that is being processed by another processor + * and gives the derived class a chance to change behavior accordingly + * + */ +class PeekProcessor : public apache::thrift::TProcessor { + + public: + PeekProcessor(); + virtual ~PeekProcessor(); + + // Input here: actualProcessor - the underlying processor + // protocolFactory - the protocol factory used to wrap the memory buffer + // transportFactory - this TPipedTransportFactory is used to wrap the source transport + // via a call to getPipedTransport + void initialize(boost::shared_ptr actualProcessor, + boost::shared_ptr protocolFactory, + boost::shared_ptr transportFactory); + + boost::shared_ptr getPipedTransport(boost::shared_ptr in); + + void setTargetTransport(boost::shared_ptr targetTransport); + + virtual bool process(boost::shared_ptr in, + boost::shared_ptr out); + + // The following three functions can be overloaded by child classes to + // achieve desired peeking behavior + virtual void peekName(const std::string& fname); + virtual void peekBuffer(uint8_t* buffer, uint32_t size); + virtual void peek(boost::shared_ptr in, + apache::thrift::protocol::TType ftype, + int16_t fid); + virtual void peekEnd(); + + private: + boost::shared_ptr actualProcessor_; + boost::shared_ptr pipedProtocol_; + boost::shared_ptr transportFactory_; + boost::shared_ptr memoryBuffer_; + boost::shared_ptr targetTransport_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/processor/StatsProcessor.h b/lib/cpp/src/processor/StatsProcessor.h new file mode 100644 index 00000000..820b3ad4 --- /dev/null +++ b/lib/cpp/src/processor/StatsProcessor.h @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef STATSPROCESSOR_H +#define STATSPROCESSOR_H + +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace processor { + +/* + * Class for keeping track of function call statistics and printing them if desired + * + */ +class StatsProcessor : public apache::thrift::TProcessor { +public: + StatsProcessor(bool print, bool frequency) + : print_(print), + frequency_(frequency) + {} + virtual ~StatsProcessor() {}; + + virtual bool process(boost::shared_ptr piprot, boost::shared_ptr poprot) { + + piprot_ = piprot; + + std::string fname; + apache::thrift::protocol::TMessageType mtype; + int32_t seqid; + + piprot_->readMessageBegin(fname, mtype, seqid); + if (mtype != apache::thrift::protocol::T_CALL) { + if (print_) { + printf("Unknown message type\n"); + } + throw apache::thrift::TException("Unexpected message type"); + } + if (print_) { + printf("%s (", fname.c_str()); + } + if (frequency_) { + if (frequency_map_.find(fname) != frequency_map_.end()) { + frequency_map_[fname]++; + } else { + frequency_map_[fname] = 1; + } + } + + apache::thrift::protocol::TType ftype; + int16_t fid; + + while (true) { + piprot_->readFieldBegin(fname, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + + printAndPassToBuffer(ftype); + if (print_) { + printf(", "); + } + } + + if (print_) { + printf("\b\b)\n"); + } + return true; + } + + const std::map& get_frequency_map() { + return frequency_map_; + } + +protected: + void printAndPassToBuffer(apache::thrift::protocol::TType ftype) { + switch (ftype) { + case apache::thrift::protocol::T_BOOL: + { + bool boolv; + piprot_->readBool(boolv); + if (print_) { + printf("%d", boolv); + } + } + break; + case apache::thrift::protocol::T_BYTE: + { + int8_t bytev; + piprot_->readByte(bytev); + if (print_) { + printf("%d", bytev); + } + } + break; + case apache::thrift::protocol::T_I16: + { + int16_t i16; + piprot_->readI16(i16); + if (print_) { + printf("%d", i16); + } + } + break; + case apache::thrift::protocol::T_I32: + { + int32_t i32; + piprot_->readI32(i32); + if (print_) { + printf("%d", i32); + } + } + break; + case apache::thrift::protocol::T_I64: + { + int64_t i64; + piprot_->readI64(i64); + if (print_) { + printf("%ld", i64); + } + } + break; + case apache::thrift::protocol::T_DOUBLE: + { + double dub; + piprot_->readDouble(dub); + if (print_) { + printf("%f", dub); + } + } + break; + case apache::thrift::protocol::T_STRING: + { + std::string str; + piprot_->readString(str); + if (print_) { + printf("%s", str.c_str()); + } + } + break; + case apache::thrift::protocol::T_STRUCT: + { + std::string name; + int16_t fid; + apache::thrift::protocol::TType ftype; + piprot_->readStructBegin(name); + if (print_) { + printf("<"); + } + while (true) { + piprot_->readFieldBegin(name, ftype, fid); + if (ftype == apache::thrift::protocol::T_STOP) { + break; + } + printAndPassToBuffer(ftype); + if (print_) { + printf(","); + } + piprot_->readFieldEnd(); + } + piprot_->readStructEnd(); + if (print_) { + printf("\b>"); + } + } + break; + case apache::thrift::protocol::T_MAP: + { + apache::thrift::protocol::TType keyType; + apache::thrift::protocol::TType valType; + uint32_t i, size; + piprot_->readMapBegin(keyType, valType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(keyType); + if (print_) { + printf("=>"); + } + printAndPassToBuffer(valType); + if (print_) { + printf(","); + } + } + piprot_->readMapEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_SET: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readSetBegin(elemType, size); + if (print_) { + printf("{"); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readSetEnd(); + if (print_) { + printf("\b}"); + } + } + break; + case apache::thrift::protocol::T_LIST: + { + apache::thrift::protocol::TType elemType; + uint32_t i, size; + piprot_->readListBegin(elemType, size); + if (print_) { + printf("["); + } + for (i = 0; i < size; i++) { + printAndPassToBuffer(elemType); + if (print_) { + printf(","); + } + } + piprot_->readListEnd(); + if (print_) { + printf("\b]"); + } + } + break; + default: + break; + } + } + + boost::shared_ptr piprot_; + std::map frequency_map_; + + bool print_; + bool frequency_; +}; + +}}} // apache::thrift::processor + +#endif diff --git a/lib/cpp/src/protocol/TBase64Utils.cpp b/lib/cpp/src/protocol/TBase64Utils.cpp new file mode 100644 index 00000000..14481c49 --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBase64Utils.h" + +#include + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + + +static const uint8_t *kBase64EncodeTable = (const uint8_t *) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf) { + buf[0] = kBase64EncodeTable[(in[0] >> 2) & 0x3F]; + if (len == 3) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[((in[1] << 2) + (in[2] >> 6)) & 0x3f]; + buf[3] = kBase64EncodeTable[in[2] & 0x3f]; + } else if (len == 2) { + buf[1] = kBase64EncodeTable[((in[0] << 4) + (in[1] >> 4)) & 0x3f]; + buf[2] = kBase64EncodeTable[(in[1] << 2) & 0x3f]; + } else { // len == 1 + buf[1] = kBase64EncodeTable[(in[0] << 4) & 0x3f]; + } +} + +static const uint8_t kBase64DecodeTable[256] ={ + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63, + 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1, + -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, +}; + +void base64_decode(uint8_t *buf, uint32_t len) { + buf[0] = (kBase64DecodeTable[buf[0]] << 2) | + (kBase64DecodeTable[buf[1]] >> 4); + if (len > 2) { + buf[1] = ((kBase64DecodeTable[buf[1]] << 4) & 0xf0) | + (kBase64DecodeTable[buf[2]] >> 2); + if (len > 3) { + buf[2] = ((kBase64DecodeTable[buf[2]] << 6) & 0xc0) | + (kBase64DecodeTable[buf[3]]); + } + } +} + + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBase64Utils.h b/lib/cpp/src/protocol/TBase64Utils.h new file mode 100644 index 00000000..3def7335 --- /dev/null +++ b/lib/cpp/src/protocol/TBase64Utils.h @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBASE64UTILS_H_ +#define _THRIFT_PROTOCOL_TBASE64UTILS_H_ + +#include +#include + +namespace apache { namespace thrift { namespace protocol { + +// in must be at least len bytes +// len must be 1, 2, or 3 +// buf must be a buffer of at least 4 bytes and may not overlap in +// the data is not padded with '='; the caller can do this if desired +void base64_encode(const uint8_t *in, uint32_t len, uint8_t *buf); + +// buf must be a buffer of at least 4 bytes and contain base64 encoded values +// buf will be changed to contain output bytes +// len is number of bytes to consume from input (must be 2, 3, or 4) +// no '=' padding should be included in the input +void base64_decode(uint8_t *buf, uint32_t len); + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TBASE64UTILS_H_ diff --git a/lib/cpp/src/protocol/TBinaryProtocol.cpp b/lib/cpp/src/protocol/TBinaryProtocol.cpp new file mode 100644 index 00000000..6a4838b4 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.cpp @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TBinaryProtocol.h" + +#include + +using std::string; + +namespace apache { namespace thrift { namespace protocol { + +uint32_t TBinaryProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + if (strict_write_) { + int32_t version = (VERSION_1) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += writeI32(version); + wsize += writeString(name); + wsize += writeI32(seqid); + return wsize; + } else { + uint32_t wsize = 0; + wsize += writeString(name); + wsize += writeByte((int8_t)messageType); + wsize += writeI32(seqid); + return wsize; + } +} + +uint32_t TBinaryProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeStructBegin(const char* name) { + return 0; +} + +uint32_t TBinaryProtocol::writeStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)fieldType); + wsize += writeI16(fieldId); + return wsize; +} + +uint32_t TBinaryProtocol::writeFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeFieldStop() { + return + writeByte((int8_t)T_STOP); +} + +uint32_t TBinaryProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)keyType); + wsize += writeByte((int8_t)valType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t) elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t wsize = 0; + wsize += writeByte((int8_t)elemType); + wsize += writeI32((int32_t)size); + return wsize; +} + +uint32_t TBinaryProtocol::writeSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::writeBool(const bool value) { + uint8_t tmp = value ? 1 : 0; + trans_->write(&tmp, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +uint32_t TBinaryProtocol::writeI16(const int16_t i16) { + int16_t net = (int16_t)htons(i16); + trans_->write((uint8_t*)&net, 2); + return 2; +} + +uint32_t TBinaryProtocol::writeI32(const int32_t i32) { + int32_t net = (int32_t)htonl(i32); + trans_->write((uint8_t*)&net, 4); + return 4; +} + +uint32_t TBinaryProtocol::writeI64(const int64_t i64) { + int64_t net = (int64_t)htonll(i64); + trans_->write((uint8_t*)&net, 8); + return 8; +} + +uint32_t TBinaryProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits = bitwise_cast(dub); + bits = htonll(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + + +uint32_t TBinaryProtocol::writeString(const string& str) { + uint32_t size = str.size(); + uint32_t result = writeI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return result + size; +} + +uint32_t TBinaryProtocol::writeBinary(const string& str) { + return TBinaryProtocol::writeString(str); +} + +/** + * Reading functions + */ + +uint32_t TBinaryProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = 0; + int32_t sz; + result += readI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + result += readString(name); + result += readI32(seqid); + } else { + if (strict_read_) { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } else { + // Handle pre-versioned input + int8_t type; + result += readStringBody(name, sz); + result += readByte(type); + messageType = (TMessageType)type; + result += readI32(seqid); + } + } + return result; +} + +uint32_t TBinaryProtocol::readMessageEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readStructBegin(string& name) { + name = ""; + return 0; +} + +uint32_t TBinaryProtocol::readStructEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + int8_t type; + result += readByte(type); + fieldType = (TType)type; + if (fieldType == T_STOP) { + fieldId = 0; + return result; + } + result += readI16(fieldId); + return result; +} + +uint32_t TBinaryProtocol::readFieldEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + int8_t k, v; + uint32_t result = 0; + int32_t sizei; + result += readByte(k); + keyType = (TType)k; + result += readByte(v); + valType = (TType)v; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readMapEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readListEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + int8_t e; + uint32_t result = 0; + int32_t sizei; + result += readByte(e); + elemType = (TType)e; + result += readI32(sizei); + if (sizei < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + return result; +} + +uint32_t TBinaryProtocol::readSetEnd() { + return 0; +} + +uint32_t TBinaryProtocol::readBool(bool& value) { + uint8_t b[1]; + trans_->readAll(b, 1); + value = *(int8_t*)b != 0; + return 1; +} + +uint32_t TBinaryProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +uint32_t TBinaryProtocol::readI16(int16_t& i16) { + uint8_t b[2]; + trans_->readAll(b, 2); + i16 = *(int16_t*)b; + i16 = (int16_t)ntohs(i16); + return 2; +} + +uint32_t TBinaryProtocol::readI32(int32_t& i32) { + uint8_t b[4]; + trans_->readAll(b, 4); + i32 = *(int32_t*)b; + i32 = (int32_t)ntohl(i32); + return 4; +} + +uint32_t TBinaryProtocol::readI64(int64_t& i64) { + uint8_t b[8]; + trans_->readAll(b, 8); + i64 = *(int64_t*)b; + i64 = (int64_t)ntohll(i64); + return 8; +} + +uint32_t TBinaryProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = ntohll(bits); + dub = bitwise_cast(bits); + return 8; +} + +uint32_t TBinaryProtocol::readString(string& str) { + uint32_t result; + int32_t size; + result = readI32(size); + return result + readStringBody(str, size); +} + +uint32_t TBinaryProtocol::readBinary(string& str) { + return TBinaryProtocol::readString(str); +} + +uint32_t TBinaryProtocol::readStringBody(string& str, int32_t size) { + uint32_t result = 0; + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Catch empty string case + if (size == 0) { + str = ""; + return result; + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TBinaryProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str = string((char*)string_buf_, size); + return (uint32_t)size; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TBinaryProtocol.h b/lib/cpp/src/protocol/TBinaryProtocol.h new file mode 100644 index 00000000..7fd3de67 --- /dev/null +++ b/lib/cpp/src/protocol/TBinaryProtocol.h @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include + +namespace apache { namespace thrift { namespace protocol { + +/** + * The default binary protocol for thrift. Writes all data in a very basic + * binary format, essentially just spitting out the raw bytes. + * + */ +class TBinaryProtocol : public TProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + static const int32_t VERSION_1 = 0x80010000; + // VERSION_2 (0x80020000) is taken by TDenseProtocol. + + public: + TBinaryProtocol(boost::shared_ptr trans) : + TProtocol(trans), + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true), + string_buf_(NULL), + string_buf_size_(0) {} + + TBinaryProtocol(boost::shared_ptr trans, + int32_t string_limit, + int32_t container_limit, + bool strict_read, + bool strict_write) : + TProtocol(trans), + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write), + string_buf_(NULL), + string_buf_size_(0) {} + + ~TBinaryProtocol() { + if (string_buf_ != NULL) { + std::free(string_buf_); + string_buf_size_ = 0; + } + } + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + protected: + uint32_t readStringBody(std::string& str, int32_t sz); + + int32_t string_limit_; + int32_t container_limit_; + + // Enforce presence of version identifier + bool strict_read_; + bool strict_write_; + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + uint8_t* string_buf_; + int32_t string_buf_size_; + +}; + +/** + * Constructs binary protocol handlers + */ +class TBinaryProtocolFactory : public TProtocolFactory { + public: + TBinaryProtocolFactory() : + string_limit_(0), + container_limit_(0), + strict_read_(false), + strict_write_(true) {} + + TBinaryProtocolFactory(int32_t string_limit, int32_t container_limit, bool strict_read, bool strict_write) : + string_limit_(string_limit), + container_limit_(container_limit), + strict_read_(strict_read), + strict_write_(strict_write) {} + + virtual ~TBinaryProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + void setStrict(bool strict_read, bool strict_write) { + strict_read_ = strict_read; + strict_write_ = strict_write; + } + + boost::shared_ptr getProtocol(boost::shared_ptr trans) { + return boost::shared_ptr(new TBinaryProtocol(trans, string_limit_, container_limit_, strict_read_, strict_write_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + bool strict_read_; + bool strict_write_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TCompactProtocol.cpp b/lib/cpp/src/protocol/TCompactProtocol.cpp new file mode 100644 index 00000000..ce2ee54d --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.cpp @@ -0,0 +1,736 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TCompactProtocol.h" + +#include +#include + +/* + * TCompactProtocol::i*ToZigzag depend on the fact that the right shift + * operator on a signed integer is an arithmetic (sign-extending) shift. + * If this is not the case, the current implementation will not work. + * If anyone encounters this error, we can try to figure out the best + * way to implement an arithmetic right shift on their platform. + */ +#if !defined(SIGNED_RIGHT_SHIFT_IS) || !defined(ARITHMETIC_RIGHT_SHIFT) +# error "Unable to determine the behavior of a signed right shift" +#endif +#if SIGNED_RIGHT_SHIFT_IS != ARITHMETIC_RIGHT_SHIFT +# error "TCompactProtocol currenly only works if a signed right shift is arithmetic" +#endif + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int8_t TCompactProtocol::TTypeToCType[16] = { + CT_STOP, // T_STOP + 0, // unused + CT_BOOLEAN_TRUE, // T_BOOL + CT_BYTE, // T_BYTE + CT_DOUBLE, // T_DOUBLE + 0, // unused + CT_I16, // T_I16 + 0, // unused + CT_I32, // T_I32 + 0, // unused + CT_I64, // T_I64 + CT_BINARY, // T_STRING + CT_STRUCT, // T_STRUCT + CT_MAP, // T_MAP + CT_SET, // T_SET + CT_LIST, // T_LIST + }; + + +uint32_t TCompactProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t wsize = 0; + wsize += writeByte(PROTOCOL_ID); + wsize += writeByte((VERSION_N & VERSION_MASK) | (((int32_t)messageType << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + wsize += writeVarint32(seqid); + wsize += writeString(name); + return wsize; +} + +/** + * Write a field header containing the field id and field type. If the + * difference between the current field id and the last one is small (< 15), + * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the + * field id will follow the type header as a zigzag varint. + */ +uint32_t TCompactProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + if (fieldType == T_BOOL) { + booleanField_.name = name; + booleanField_.fieldType = fieldType; + booleanField_.fieldId = fieldId; + } else { + return writeFieldBeginInternal(name, fieldType, fieldId, -1); + } + return 0; +} + +/** + * Write the STOP symbol so we know there are no more fields in this struct. + */ +uint32_t TCompactProtocol::writeFieldStop() { + return writeByte(T_STOP); +} + +/** + * Write a struct begin. This doesn't actually put anything on the wire. We + * use it as an opportunity to put special placeholder markers on the field + * stack so we can get the field id deltas correct. + */ +uint32_t TCompactProtocol::writeStructBegin(const char* name) { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Write a struct end. This doesn't actually put anything on the wire. We use + * this as an opportunity to pop the last field from the current struct off + * of the field stack. + */ +uint32_t TCompactProtocol::writeStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Write a List header. + */ +uint32_t TCompactProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a set header. + */ +uint32_t TCompactProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + return writeCollectionBegin(elemType, size); +} + +/** + * Write a map header. If the map is empty, omit the key and value type + * headers, as we don't need any additional information to skip it. + */ +uint32_t TCompactProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t wsize = 0; + + if (size == 0) { + wsize += writeByte(0); + } else { + wsize += writeVarint32(size); + wsize += writeByte(getCompactType(keyType) << 4 | getCompactType(valType)); + } + return wsize; +} + +/** + * Write a boolean value. Potentially, this could be a boolean field, in + * which case the field header info isn't written yet. If so, decide what the + * right type header is for the value and then write the field header. + * Otherwise, write a single byte. + */ +uint32_t TCompactProtocol::writeBool(const bool value) { + uint32_t wsize = 0; + + if (booleanField_.name != NULL) { + // we haven't written the field header yet + wsize += writeFieldBeginInternal(booleanField_.name, + booleanField_.fieldType, + booleanField_.fieldId, + value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + booleanField_.name = NULL; + } else { + // we're not part of a field, so just write the value + wsize += writeByte(value ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE); + } + return wsize; +} + +uint32_t TCompactProtocol::writeByte(const int8_t byte) { + trans_->write((uint8_t*)&byte, 1); + return 1; +} + +/** + * Write an i16 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI16(const int16_t i16) { + return writeVarint32(i32ToZigzag(i16)); +} + +/** + * Write an i32 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI32(const int32_t i32) { + return writeVarint32(i32ToZigzag(i32)); +} + +/** + * Write an i64 as a zigzag varint. + */ +uint32_t TCompactProtocol::writeI64(const int64_t i64) { + return writeVarint64(i64ToZigzag(i64)); +} + +/** + * Write a double to the wire as 8 bytes. + */ +uint32_t TCompactProtocol::writeDouble(const double dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits = bitwise_cast(dub); + bits = htolell(bits); + trans_->write((uint8_t*)&bits, 8); + return 8; +} + +/** + * Write a string to the wire with a varint size preceeding. + */ +uint32_t TCompactProtocol::writeString(const std::string& str) { + return writeBinary(str); +} + +uint32_t TCompactProtocol::writeBinary(const std::string& str) { + uint32_t ssize = str.size(); + uint32_t wsize = writeVarint32(ssize) + ssize; + trans_->write((uint8_t*)str.data(), ssize); + return wsize; +} + +// +// Internal Writing methods +// + +/** + * The workhorse of writeFieldBegin. It has the option of doing a + * 'type override' of the type header. This is used specifically in the + * boolean field case. + */ +int32_t TCompactProtocol::writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride) { + uint32_t wsize = 0; + + // if there's a type override, use that. + int8_t typeToWrite = (typeOverride == -1 ? getCompactType(fieldType) : typeOverride); + + // check if we can use delta encoding for the field id + if (fieldId > lastFieldId_ && fieldId - lastFieldId_ <= 15) { + // write them together + wsize += writeByte((fieldId - lastFieldId_) << 4 | typeToWrite); + } else { + // write them separate + wsize += writeByte(typeToWrite); + wsize += writeI16(fieldId); + } + + lastFieldId_ = fieldId; + return wsize; +} + +/** + * Abstract method for writing the start of lists and sets. List and sets on + * the wire differ only by the type indicator. + */ +uint32_t TCompactProtocol::writeCollectionBegin(int8_t elemType, int32_t size) { + uint32_t wsize = 0; + if (size <= 14) { + wsize += writeByte(size << 4 | getCompactType(elemType)); + } else { + wsize += writeByte(0xf0 | getCompactType(elemType)); + wsize += writeVarint32(size); + } + return wsize; +} + +/** + * Write an i32 as a varint. Results in 1-5 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint32(uint32_t n) { + uint8_t buf[5]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7F) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ +uint32_t TCompactProtocol::writeVarint64(uint64_t n) { + uint8_t buf[10]; + uint32_t wsize = 0; + + while (true) { + if ((n & ~0x7FL) == 0) { + buf[wsize++] = (int8_t)n; + break; + } else { + buf[wsize++] = (int8_t)((n & 0x7F) | 0x80); + n >>= 7; + } + } + trans_->write(buf, wsize); + return wsize; +} + +/** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ +uint64_t TCompactProtocol::i64ToZigzag(const int64_t l) { + return (l << 1) ^ (l >> 63); +} + +/** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ +uint32_t TCompactProtocol::i32ToZigzag(const int32_t n) { + return (n << 1) ^ (n >> 31); +} + +/** + * Given a TType value, find the appropriate TCompactProtocol.Type value + */ +int8_t TCompactProtocol::getCompactType(int8_t ttype) { + return TTypeToCType[ttype]; +} + +// +// Reading Methods +// + +/** + * Read a message header. + */ +uint32_t TCompactProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rsize = 0; + int8_t protocolId; + int8_t versionAndType; + int8_t version; + + rsize += readByte(protocolId); + if (protocolId != PROTOCOL_ID) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol identifier"); + } + + rsize += readByte(versionAndType); + version = (int8_t)(versionAndType & VERSION_MASK); + if (version != VERSION_N) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad protocol version"); + } + + messageType = (TMessageType)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03); + rsize += readVarint32(seqid); + rsize += readString(name); + + return rsize; +} + +/** + * Read a struct begin. There's nothing on the wire for this, but it is our + * opportunity to push a new struct begin marker on the field stack. + */ +uint32_t TCompactProtocol::readStructBegin(std::string& name) { + name = ""; + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return 0; +} + +/** + * Doesn't actually consume any wire data, just removes the last field for + * this struct from the field stack. + */ +uint32_t TCompactProtocol::readStructEnd() { + lastFieldId_ = lastField_.top(); + lastField_.pop(); + return 0; +} + +/** + * Read a field header off the wire. + */ +uint32_t TCompactProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rsize = 0; + int8_t byte; + int8_t type; + + rsize += readByte(byte); + type = (byte & 0x0f); + + // if it's a stop, then we can return immediately, as the struct is over. + if (type == T_STOP) { + fieldType = T_STOP; + fieldId = 0; + return rsize; + } + + // mask off the 4 MSB of the type header. it could contain a field id delta. + int16_t modifier = (int16_t)(((uint8_t)byte & 0xf0) >> 4); + if (modifier == 0) { + // not a delta, look ahead for the zigzag varint field id. + rsize += readI16(fieldId); + } else { + fieldId = (int16_t)(lastFieldId_ + modifier); + } + fieldType = getTType(type); + + // if this happens to be a boolean field, the value is encoded in the type + if (type == CT_BOOLEAN_TRUE || type == CT_BOOLEAN_FALSE) { + // save the boolean value in a special instance variable. + boolValue_.hasBoolValue = true; + boolValue_.boolValue = (type == CT_BOOLEAN_TRUE ? true : false); + } + + // push the new field onto the field stack so we can keep the deltas going. + lastFieldId_ = fieldId; + return rsize; +} + +/** + * Read a map header off the wire. If the size is zero, skip reading the key + * and value type. This means that 0-length maps will yield TMaps without the + * "correct" types. + */ +uint32_t TCompactProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rsize = 0; + int8_t kvType = 0; + int32_t msize = 0; + + rsize += readVarint32(msize); + if (msize != 0) + rsize += readByte(kvType); + + if (msize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && msize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + keyType = getTType((int8_t)((uint8_t)kvType >> 4)); + valType = getTType((int8_t)((uint8_t)kvType & 0xf)); + size = (uint32_t)msize; + + return rsize; +} + +/** + * Read a list header off the wire. If the list size is 0-14, the size will + * be packed into the element type header. If it's a longer list, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readListBegin(TType& elemType, + uint32_t& size) { + int8_t size_and_type; + uint32_t rsize = 0; + int32_t lsize; + + rsize += readByte(size_and_type); + + lsize = ((uint8_t)size_and_type >> 4) & 0x0f; + if (lsize == 15) { + rsize += readVarint32(lsize); + } + + if (lsize < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && lsize > container_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + elemType = getTType((int8_t)(size_and_type & 0x0f)); + size = (uint32_t)lsize; + + return rsize; +} + +/** + * Read a set header off the wire. If the set size is 0-14, the size will + * be packed into the element type header. If it's a longer set, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ +uint32_t TCompactProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + return readListBegin(elemType, size); +} + +/** + * Read a boolean off the wire. If this is a boolean field, the value should + * already have been read during readFieldBegin, so we'll just consume the + * pre-stored value. Otherwise, read a byte. + */ +uint32_t TCompactProtocol::readBool(bool& value) { + if (boolValue_.hasBoolValue == true) { + value = boolValue_.boolValue; + boolValue_.hasBoolValue = false; + return 0; + } else { + int8_t val; + readByte(val); + value = (val == CT_BOOLEAN_TRUE); + return 1; + } +} + +/** + * Read a single byte off the wire. Nothing interesting here. + */ +uint32_t TCompactProtocol::readByte(int8_t& byte) { + uint8_t b[1]; + trans_->readAll(b, 1); + byte = *(int8_t*)b; + return 1; +} + +/** + * Read an i16 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI16(int16_t& i16) { + int32_t value; + uint32_t rsize = readVarint32(value); + i16 = (int16_t)zigzagToI32(value); + return rsize; +} + +/** + * Read an i32 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI32(int32_t& i32) { + int32_t value; + uint32_t rsize = readVarint32(value); + i32 = zigzagToI32(value); + return rsize; +} + +/** + * Read an i64 from the wire as a zigzag varint. + */ +uint32_t TCompactProtocol::readI64(int64_t& i64) { + int64_t value; + uint32_t rsize = readVarint64(value); + i64 = zigzagToI64(value); + return rsize; +} + +/** + * No magic here - just read a double off the wire. + */ +uint32_t TCompactProtocol::readDouble(double& dub) { + BOOST_STATIC_ASSERT(sizeof(double) == sizeof(uint64_t)); + BOOST_STATIC_ASSERT(std::numeric_limits::is_iec559); + + uint64_t bits; + uint8_t b[8]; + trans_->readAll(b, 8); + bits = *(uint64_t*)b; + bits = letohll(bits); + dub = bitwise_cast(bits); + return 8; +} + +uint32_t TCompactProtocol::readString(std::string& str) { + return readBinary(str); +} + +/** + * Read a byte[] from the wire. + */ +uint32_t TCompactProtocol::readBinary(std::string& str) { + int32_t rsize = 0; + int32_t size; + + rsize += readVarint32(size); + // Catch empty string case + if (size == 0) { + str = ""; + return rsize; + } + + // Catch error cases + if (size < 0) { + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } + if (string_limit_ > 0 && size > string_limit_) { + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + + // Use the heap here to prevent stack overflow for v. large strings + if (size > string_buf_size_ || string_buf_ == NULL) { + void* new_string_buf = std::realloc(string_buf_, (uint32_t)size); + if (new_string_buf == NULL) { + throw TProtocolException(TProtocolException::UNKNOWN, "Out of memory in TCompactProtocol::readString"); + } + string_buf_ = (uint8_t*)new_string_buf; + string_buf_size_ = size; + } + trans_->readAll(string_buf_, size); + str.assign((char*)string_buf_, size); + + return rsize + (uint32_t)size; +} + +/** + * Read an i32 from the wire as a varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 5 bytes. + */ +uint32_t TCompactProtocol::readVarint32(int32_t& i32) { + int64_t val; + uint32_t rsize = readVarint64(val); + i32 = (int32_t)val; + return rsize; +} + +/** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ +uint32_t TCompactProtocol::readVarint64(int64_t& i64) { + uint32_t rsize = 0; + uint64_t val = 0; + int shift = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[rsize]; + rsize++; + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + trans_->consume(rsize); + return rsize; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(rsize == sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + rsize += trans_->readAll(&byte, 1); + val |= (uint64_t)(byte & 0x7f) << shift; + shift += 7; + if (!(byte & 0x80)) { + i64 = val; + return rsize; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(rsize >= sizeof(buf))) { + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +/** + * Convert from zigzag int to int. + */ +int32_t TCompactProtocol::zigzagToI32(uint32_t n) { + return (n >> 1) ^ -(n & 1); +} + +/** + * Convert from zigzag long to long. + */ +int64_t TCompactProtocol::zigzagToI64(uint64_t n) { + return (n >> 1) ^ -(n & 1); +} + +TType TCompactProtocol::getTType(int8_t type) { + switch (type) { + case T_STOP: + return T_STOP; + case CT_BOOLEAN_FALSE: + case CT_BOOLEAN_TRUE: + return T_BOOL; + case CT_BYTE: + return T_BYTE; + case CT_I16: + return T_I16; + case CT_I32: + return T_I32; + case CT_I64: + return T_I64; + case CT_DOUBLE: + return T_DOUBLE; + case CT_BINARY: + return T_STRING; + case CT_LIST: + return T_LIST; + case CT_SET: + return T_SET; + case CT_MAP: + return T_MAP; + case CT_STRUCT: + return T_STRUCT; + default: + throw TException("don't know what type: " + type); + } + return T_STOP; +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TCompactProtocol.h b/lib/cpp/src/protocol/TCompactProtocol.h new file mode 100644 index 00000000..b4e06f0a --- /dev/null +++ b/lib/cpp/src/protocol/TCompactProtocol.h @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TCOMPACTPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include +#include + +namespace apache { namespace thrift { namespace protocol { + +/** + * C++ Implementation of the Compact Protocol as described in THRIFT-110 + */ +class TCompactProtocol : public TProtocol { + + protected: + static const int8_t PROTOCOL_ID = 0x82; + static const int8_t VERSION_N = 1; + static const int8_t VERSION_MASK = 0x1f; // 0001 1111 + static const int8_t TYPE_MASK = 0xE0; // 1110 0000 + static const int32_t TYPE_SHIFT_AMOUNT = 5; + + /** + * (Writing) If we encounter a boolean field begin, save the TField here + * so it can have the value incorporated. + */ + struct { + const char* name; + TType fieldType; + int16_t fieldId; + } booleanField_; + + /** + * (Reading) If we read a field header, and it's a boolean field, save + * the boolean value here so that readBool can use it. + */ + struct { + bool hasBoolValue; + bool boolValue; + } boolValue_; + + /** + * Used to keep track of the last field for the current and previous structs, + * so we can do the delta stuff. + */ + + std::stack lastField_; + int16_t lastFieldId_; + + enum Types { + CT_STOP = 0x00, + CT_BOOLEAN_TRUE = 0x01, + CT_BOOLEAN_FALSE = 0x02, + CT_BYTE = 0x03, + CT_I16 = 0x04, + CT_I32 = 0x05, + CT_I64 = 0x06, + CT_DOUBLE = 0x07, + CT_BINARY = 0x08, + CT_LIST = 0x09, + CT_SET = 0x0A, + CT_MAP = 0x0B, + CT_STRUCT = 0x0C, + }; + + static const int8_t TTypeToCType[16]; + + public: + TCompactProtocol(boost::shared_ptr trans) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(0), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(0) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + TCompactProtocol(boost::shared_ptr trans, + int32_t string_limit, + int32_t container_limit) : + TProtocol(trans), + lastFieldId_(0), + string_limit_(string_limit), + string_buf_(NULL), + string_buf_size_(0), + container_limit_(container_limit) { + booleanField_.name = NULL; + boolValue_.hasBoolValue = false; + } + + + + /** + * Writing functions + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldStop(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * These methods are called by structs, but don't actually have any wired + * output or purpose + */ + virtual uint32_t writeMessageEnd() { return 0; } + uint32_t writeMapEnd() { return 0; } + uint32_t writeListEnd() { return 0; } + uint32_t writeSetEnd() { return 0; } + uint32_t writeFieldEnd() { return 0; } + + protected: + int32_t writeFieldBeginInternal(const char* name, + const TType fieldType, + const int16_t fieldId, + int8_t typeOverride); + uint32_t writeCollectionBegin(int8_t elemType, int32_t size); + uint32_t writeVarint32(uint32_t n); + uint32_t writeVarint64(uint64_t n); + uint64_t i64ToZigzag(const int64_t l); + uint32_t i32ToZigzag(const int32_t n); + inline int8_t getCompactType(int8_t ttype); + + public: + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + *These methods are here for the struct to call, but don't have any wire + * encoding. + */ + uint32_t readMessageEnd() { return 0; } + uint32_t readFieldEnd() { return 0; } + uint32_t readMapEnd() { return 0; } + uint32_t readListEnd() { return 0; } + uint32_t readSetEnd() { return 0; } + + protected: + uint32_t readVarint32(int32_t& i32); + uint32_t readVarint64(int64_t& i64); + int32_t zigzagToI32(uint32_t n); + int64_t zigzagToI64(uint64_t n); + TType getTType(int8_t type); + + // Buffer for reading strings, save for the lifetime of the protocol to + // avoid memory churn allocating memory on every string read + int32_t string_limit_; + uint8_t* string_buf_; + int32_t string_buf_size_; + int32_t container_limit_; +}; + +/** + * Constructs compact protocol handlers + */ +class TCompactProtocolFactory : public TProtocolFactory { + public: + TCompactProtocolFactory() : + string_limit_(0), + container_limit_(0) {} + + TCompactProtocolFactory(int32_t string_limit, int32_t container_limit) : + string_limit_(string_limit), + container_limit_(container_limit) {} + + virtual ~TCompactProtocolFactory() {} + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setContainerSizeLimit(int32_t container_limit) { + container_limit_ = container_limit; + } + + boost::shared_ptr getProtocol(boost::shared_ptr trans) { + return boost::shared_ptr(new TCompactProtocol(trans, string_limit_, container_limit_)); + } + + private: + int32_t string_limit_; + int32_t container_limit_; + +}; + +}}} // apache::thrift::protocol + +#endif diff --git a/lib/cpp/src/protocol/TDebugProtocol.cpp b/lib/cpp/src/protocol/TDebugProtocol.cpp new file mode 100644 index 00000000..40aa36ba --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.cpp @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TDebugProtocol.h" + +#include +#include +#include +#include +#include +#include + +using std::string; + + +static string byte_to_hex(const uint8_t byte) { + char buf[3]; + int ret = std::sprintf(buf, "%02x", (int)byte); + assert(ret == 2); + assert(buf[2] == '\0'); + return buf; +} + + +namespace apache { namespace thrift { namespace protocol { + +string TDebugProtocol::fieldTypeName(TType type) { + switch (type) { + case T_STOP : return "stop" ; + case T_VOID : return "void" ; + case T_BOOL : return "bool" ; + case T_BYTE : return "byte" ; + case T_I16 : return "i16" ; + case T_I32 : return "i32" ; + case T_U64 : return "u64" ; + case T_I64 : return "i64" ; + case T_DOUBLE : return "double" ; + case T_STRING : return "string" ; + case T_STRUCT : return "struct" ; + case T_MAP : return "map" ; + case T_SET : return "set" ; + case T_LIST : return "list" ; + case T_UTF8 : return "utf8" ; + case T_UTF16 : return "utf16" ; + default: return "unknown"; + } +} + +void TDebugProtocol::indentUp() { + indent_str_ += string(indent_inc, ' '); +} + +void TDebugProtocol::indentDown() { + if (indent_str_.length() < (string::size_type)indent_inc) { + throw TProtocolException(TProtocolException::INVALID_DATA); + } + indent_str_.erase(indent_str_.length() - indent_inc); +} + +uint32_t TDebugProtocol::writePlain(const string& str) { + trans_->write((uint8_t*)str.data(), str.length()); + return str.length(); +} + +uint32_t TDebugProtocol::writeIndented(const string& str) { + trans_->write((uint8_t*)indent_str_.data(), indent_str_.length()); + trans_->write((uint8_t*)str.data(), str.length()); + return indent_str_.length() + str.length(); +} + +uint32_t TDebugProtocol::startItem() { + uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return 0; + case SET: + return writeIndented(""); + case MAP_KEY: + return writeIndented(""); + case MAP_VALUE: + return writePlain(" -> "); + case LIST: + size = writeIndented( + "[" + boost::lexical_cast(list_idx_.back()) + "] = "); + list_idx_.back()++; + return size; + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::endItem() { + //uint32_t size; + + switch (write_state_.back()) { + case UNINIT: + // XXX figure out what to do here. + //throw TProtocolException(TProtocolException::INVALID_DATA); + //return writeIndented(str); + return 0; + case STRUCT: + return writePlain(",\n"); + case SET: + return writePlain(",\n"); + case MAP_KEY: + write_state_.back() = MAP_VALUE; + return 0; + case MAP_VALUE: + write_state_.back() = MAP_KEY; + return writePlain(",\n"); + case LIST: + return writePlain(",\n"); + default: + throw std::logic_error("Invalid enum value."); + } +} + +uint32_t TDebugProtocol::writeItem(const std::string& str) { + uint32_t size = 0; + size += startItem(); + size += writePlain(str); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + string mtype; + switch (messageType) { + case T_CALL : mtype = "call" ; break; + case T_REPLY : mtype = "reply" ; break; + case T_EXCEPTION : mtype = "exn" ; break; + } + + uint32_t size = writeIndented("(" + mtype + ") " + name + "("); + indentUp(); + return size; +} + +uint32_t TDebugProtocol::writeMessageEnd() { + indentDown(); + return writeIndented(")\n"); +} + +uint32_t TDebugProtocol::writeStructBegin(const char* name) { + uint32_t size = 0; + size += startItem(); + size += writePlain(string(name) + " {\n"); + indentUp(); + write_state_.push_back(STRUCT); + return size; +} + +uint32_t TDebugProtocol::writeStructEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + // sprintf(id_str, "%02d", fieldId); + string id_str = boost::lexical_cast(fieldId); + if (id_str.length() == 1) id_str = '0' + id_str; + + return writeIndented( + id_str + ": " + + name + " (" + + fieldTypeName(fieldType) + ") = "); +} + +uint32_t TDebugProtocol::writeFieldEnd() { + assert(write_state_.back() == STRUCT); + return 0; +} + +uint32_t TDebugProtocol::writeFieldStop() { + return 0; + //writeIndented("***STOP***\n"); +} + +uint32_t TDebugProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + // TODO(dreiss): Optimize short maps? + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "map<" + fieldTypeName(keyType) + "," + fieldTypeName(valType) + ">" + "[" + boost::lexical_cast(size) + "] {\n"); + indentUp(); + write_state_.push_back(MAP_KEY); + return bsize; +} + +uint32_t TDebugProtocol::writeMapEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short arrays. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "list<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast(size) + "] {\n"); + indentUp(); + write_state_.push_back(LIST); + list_idx_.push_back(0); + return bsize; +} + +uint32_t TDebugProtocol::writeListEnd() { + indentDown(); + write_state_.pop_back(); + list_idx_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + // TODO(dreiss): Optimize short sets. + uint32_t bsize = 0; + bsize += startItem(); + bsize += writePlain( + "set<" + fieldTypeName(elemType) + ">" + "[" + boost::lexical_cast(size) + "] {\n"); + indentUp(); + write_state_.push_back(SET); + return bsize; +} + +uint32_t TDebugProtocol::writeSetEnd() { + indentDown(); + write_state_.pop_back(); + uint32_t size = 0; + size += writeIndented("}"); + size += endItem(); + return size; +} + +uint32_t TDebugProtocol::writeBool(const bool value) { + return writeItem(value ? "true" : "false"); +} + +uint32_t TDebugProtocol::writeByte(const int8_t byte) { + return writeItem("0x" + byte_to_hex(byte)); +} + +uint32_t TDebugProtocol::writeI16(const int16_t i16) { + return writeItem(boost::lexical_cast(i16)); +} + +uint32_t TDebugProtocol::writeI32(const int32_t i32) { + return writeItem(boost::lexical_cast(i32)); +} + +uint32_t TDebugProtocol::writeI64(const int64_t i64) { + return writeItem(boost::lexical_cast(i64)); +} + +uint32_t TDebugProtocol::writeDouble(const double dub) { + return writeItem(boost::lexical_cast(dub)); +} + + +uint32_t TDebugProtocol::writeString(const string& str) { + // XXX Raw/UTF-8? + + string to_show = str; + if (to_show.length() > (string::size_type)string_limit_) { + to_show = str.substr(0, string_prefix_size_); + to_show += "[...](" + boost::lexical_cast(str.length()) + ")"; + } + + string output = "\""; + + for (string::const_iterator it = to_show.begin(); it != to_show.end(); ++it) { + if (*it == '\\') { + output += "\\\\"; + } else if (*it == '"') { + output += "\\\""; + } else if (std::isprint(*it)) { + output += *it; + } else { + switch (*it) { + case '\a': output += "\\a"; break; + case '\b': output += "\\b"; break; + case '\f': output += "\\f"; break; + case '\n': output += "\\n"; break; + case '\r': output += "\\r"; break; + case '\t': output += "\\t"; break; + case '\v': output += "\\v"; break; + default: + output += "\\x"; + output += byte_to_hex(*it); + } + } + } + + output += '\"'; + return writeItem(output); +} + +uint32_t TDebugProtocol::writeBinary(const string& str) { + // XXX Hex? + return TDebugProtocol::writeString(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDebugProtocol.h b/lib/cpp/src/protocol/TDebugProtocol.h new file mode 100644 index 00000000..ab69e0ca --- /dev/null +++ b/lib/cpp/src/protocol/TDebugProtocol.h @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ 1 + +#include "TProtocol.h" +#include "TOneWayProtocol.h" + +#include + +namespace apache { namespace thrift { namespace protocol { + +/* + +!!! EXPERIMENTAL CODE !!! + +This protocol is very much a work in progress. +It doesn't handle many cases properly. +It throws exceptions in many cases. +It probably segfaults in many cases. +Bug reports and feature requests are welcome. +Complaints are not. :R + +*/ + + +/** + * Protocol that prints the payload in a nice human-readable format. + * Reading from this protocol is not supported. + * + */ +class TDebugProtocol : public TWriteOnlyProtocol { + private: + enum write_state_t + { UNINIT + , STRUCT + , LIST + , SET + , MAP_KEY + , MAP_VALUE + }; + + public: + TDebugProtocol(boost::shared_ptr trans) + : TWriteOnlyProtocol(trans, "TDebugProtocol") + , string_limit_(DEFAULT_STRING_LIMIT) + , string_prefix_size_(DEFAULT_STRING_PREFIX_SIZE) + { + write_state_.push_back(UNINIT); + } + + static const int32_t DEFAULT_STRING_LIMIT = 256; + static const int32_t DEFAULT_STRING_PREFIX_SIZE = 16; + + void setStringSizeLimit(int32_t string_limit) { + string_limit_ = string_limit; + } + + void setStringPrefixSize(int32_t string_prefix_size) { + string_prefix_size_ = string_prefix_size; + } + + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + + private: + void indentUp(); + void indentDown(); + uint32_t writePlain(const std::string& str); + uint32_t writeIndented(const std::string& str); + uint32_t startItem(); + uint32_t endItem(); + uint32_t writeItem(const std::string& str); + + static std::string fieldTypeName(TType type); + + int32_t string_limit_; + int32_t string_prefix_size_; + + std::string indent_str_; + static const int indent_inc = 2; + + std::vector write_state_; + std::vector list_idx_; +}; + +/** + * Constructs debug protocol handlers + */ +class TDebugProtocolFactory : public TProtocolFactory { + public: + TDebugProtocolFactory() {} + virtual ~TDebugProtocolFactory() {} + + boost::shared_ptr getProtocol(boost::shared_ptr trans) { + return boost::shared_ptr(new TDebugProtocol(trans)); + } + +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move (part of) ThriftDebugString into a .cpp file and remove this. +#include + +namespace apache { namespace thrift { + +template +std::string ThriftDebugString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr trans(buffer); + TDebugProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +// TODO(dreiss): This is badly broken. Don't use it unless you are me. +#if 0 +template +std::string DebugString(const std::vector& vec) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr trans(buffer); + TDebugProtocol protocol(trans); + + // I am gross! + protocol.writeStructBegin("SomeRandomVector"); + + // TODO: Fix this with a trait. + protocol.writeListBegin((TType)99, vec.size()); + typename std::vector::const_iterator it; + for (it = vec.begin(); it != vec.end(); ++it) { + it->write(&protocol); + } + protocol.writeListEnd(); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} +#endif // 0 + +}} // apache::thrift + + +#endif // #ifndef _THRIFT_PROTOCOL_TDEBUGPROTOCOL_H_ + + diff --git a/lib/cpp/src/protocol/TDenseProtocol.cpp b/lib/cpp/src/protocol/TDenseProtocol.cpp new file mode 100644 index 00000000..8e76dc47 --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.cpp @@ -0,0 +1,762 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + +IMPLEMENTATION DETAILS + +TDenseProtocol was designed to have a smaller serialized form than +TBinaryProtocol. This is accomplished using two techniques. The first is +variable-length integer encoding. We use the same technique that the Standard +MIDI File format uses for "variable-length quantities" +(http://en.wikipedia.org/wiki/Variable-length_quantity). +All integers (including i16, but not byte) are first cast to uint64_t, +then written out as variable-length quantities. This has the unfortunate side +effect that all negative numbers require 10 bytes, but negative numbers tend +to be far less common than positive ones. + +The second technique eliminating the field ids used by TBinaryProtocol. This +decision required support from the Thrift compiler and also sacrifices some of +the backward and forward compatibility of TBinaryProtocol. + +We considered implementing this technique by generating separate readers and +writers for the dense protocol (this is how Pillar, Thrift's predecessor, +worked), but this idea had a few problems: +- Our abstractions go out the window. +- We would have to maintain a second code generator. +- Preserving compatibility with old versions of the structures would be a + nightmare. + +Therefore, we chose an alternate implementation that stored the description of +the data neither in the data itself (like TBinaryProtocol) nor in the +serialization code (like Pillar), but instead in a separate data structure, +called a TypeSpec. TypeSpecs are generated by the Thrift compiler +(specifically in the t_cpp_generator), and their structure should be +documented there (TODO(dreiss): s/should be/is/). + +We maintain a stack of TypeSpecs within the protocol so it knows where the +generated code is in the reading/writing process. For example, if we are +writing an i32 contained in a struct bar, contained in a struct foo, then the +stack would look like: TOP , i32 , struct bar , struct foo , BOTTOM. +The following invariant: whenever we are about to read/write an object +(structBegin, containerBegin, or a scalar), the TypeSpec on the top of the +stack must match the type being read/written. The main reasons that this +invariant must be maintained is that if we ever start reading a structure, we +must have its exact TypeSpec in order to pass the right tags to the +deserializer. + +We use the following strategies for maintaining this invariant: + +- For structures, we have a separate stack of indexes, one for each structure + on the TypeSpec stack. These are indexes into the list of fields in the + structure's TypeSpec. When we {read,write}FieldBegin, we push on the + TypeSpec for the field. +- When we begin writing a list or set, we push on the TypeSpec for the + element type. +- For maps, we have a separate stack of booleans, one for each map on the + TypeSpec stack. The boolean is true if we are writing the key for that + map, and false if we are writing the value. Maps are the trickiest case + because the generated code does not call any protocol method between + the key and the value. As a result, we potentially have to switch + between map key state and map value state after reading/writing any object. +- This job is handled by the stateTransition method. It is called after + reading/writing every object. It pops the current TypeSpec off the stack, + then optionally pushes a new one on, depending on what the next TypeSpec is. + If it is a struct, the job is left to the next writeFieldBegin. If it is a + set or list, the just-popped typespec is pushed back on. If it is a map, + the top of the key/value stack is toggled, and the appropriate TypeSpec + is pushed. + +Optional fields are a little tricky also. We write a zero byte if they are +absent and prefix them with an 0x01 byte if they are present +*/ + +#define __STDC_LIMIT_MACROS +#include +#include "TDenseProtocol.h" +#include "TReflectionLocal.h" + +// Leaving this on for now. Disabling it will turn off asserts, which should +// give a performance boost. When we have *really* thorough test cases, +// we should drop this. +#define DEBUG_TDENSEPROTOCOL + +// NOTE: Assertions should *only* be used to detect bugs in code, +// either in TDenseProtocol itself, or in code using it. +// (For example, using the wrong TypeSpec.) +// Invalid data should NEVER cause an assertion failure, +// no matter how grossly corrupted, nor how ingeniously crafted. +#ifdef DEBUG_TDENSEPROTOCOL +#undef NDEBUG +#else +#define NDEBUG +#endif +#include + +using std::string; + +#ifdef __GNUC__ +#define UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace protocol { + +const int TDenseProtocol::FP_PREFIX_LEN = + apache::thrift::reflection::local::FP_PREFIX_LEN; + +// Top TypeSpec. TypeSpec of the structure being encoded. +#define TTS (ts_stack_.back()) // type = TypeSpec* +// InDeX. Index into TTS of the current/next field to encode. +#define IDX (idx_stack_.back()) // type = int +// Field TypeSpec. TypeSpec of the current/next field to encode. +#define FTS (TTS->tstruct.specs[IDX]) // type = TypeSpec* +// Field MeTa. Metadata of the current/next field to encode. +#define FMT (TTS->tstruct.metas[IDX]) // type = FieldMeta +// SubType 1/2. TypeSpec of the first/second subtype of this container. +#define ST1 (TTS->tcontainer.subtype1) +#define ST2 (TTS->tcontainer.subtype2) + + +/** + * Checks that @c ttype is indeed the ttype that we should be writing, + * according to our typespec. Aborts if the test fails and debugging in on. + */ +inline void TDenseProtocol::checkTType(const TType ttype) { + assert(!ts_stack_.empty()); + assert(TTS->ttype == ttype); +} + +/** + * Makes sure that the TypeSpec stack is correct for the next object. + * See top-of-file comments. + */ +inline void TDenseProtocol::stateTransition() { + TypeSpec* old_tts = ts_stack_.back(); + ts_stack_.pop_back(); + + // If this is the end of the top-level write, we should have just popped + // the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(old_tts = type_spec_); + return; + } + + switch (TTS->ttype) { + + case T_STRUCT: + assert(old_tts == FTS); + break; + + case T_LIST: + case T_SET: + assert(old_tts == ST1); + ts_stack_.push_back(old_tts); + break; + + case T_MAP: + assert(old_tts == (mkv_stack_.back() ? ST1 : ST2)); + mkv_stack_.back() = !mkv_stack_.back(); + ts_stack_.push_back(mkv_stack_.back() ? ST1 : ST2); + break; + + default: + assert(!"Invalid TType in stateTransition."); + break; + + } +} + + +/* + * Variable-length quantity functions. + */ + +inline uint32_t TDenseProtocol::vlqRead(uint64_t& vlq) { + uint32_t used = 0; + uint64_t val = 0; + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + uint32_t buf_size = sizeof(buf); + const uint8_t* borrowed = trans_->borrow(buf, &buf_size); + + // Fast path. TODO(dreiss): Make it faster. + if (borrowed != NULL) { + while (true) { + uint8_t byte = borrowed[used]; + used++; + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + trans_->consume(used); + return used; + } + // Have to check for invalid data so we don't crash. + if (UNLIKELY(used == sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } + + // Slow path. + else { + while (true) { + uint8_t byte; + used += trans_->readAll(&byte, 1); + val = (val << 7) | (byte & 0x7f); + if (!(byte & 0x80)) { + vlq = val; + return used; + } + // Might as well check for invalid data on the slow path too. + if (UNLIKELY(used >= sizeof(buf))) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, "Variable-length int over 10 bytes."); + } + } + } +} + +inline uint32_t TDenseProtocol::vlqWrite(uint64_t vlq) { + uint8_t buf[10]; // 64 bits / (7 bits/byte) = 10 bytes. + int32_t pos = sizeof(buf) - 1; + + // Write the thing from back to front. + buf[pos] = vlq & 0x7f; + vlq >>= 7; + pos--; + + while (vlq > 0) { + assert(pos >= 0); + buf[pos] = (vlq | 0x80); + vlq >>= 7; + pos--; + } + + // Back up one step before writing. + pos++; + + trans_->write(buf+pos, sizeof(buf) - pos); + return sizeof(buf) - pos; +} + + + +/* + * Writing functions. + */ + +uint32_t TDenseProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + int32_t version = (VERSION_2) | ((int32_t)messageType); + uint32_t wsize = 0; + wsize += subWriteI32(version); + wsize += subWriteString(name); + wsize += subWriteI32(seqid); + return wsize; +} + +uint32_t TDenseProtocol::writeMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::writeStructBegin(const char* name) { + uint32_t xfer = 0; + + // The TypeSpec stack should be empty if this is the top-level read/write. + // If it is, we push the TypeSpec passed to the constructor. + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + // Write out a prefix of the structure fingerprint. + trans_->write(type_spec_->fp_prefix, FP_PREFIX_LEN); + xfer += FP_PREFIX_LEN; + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::writeStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t xfer = 0; + + // Skip over optional fields. + while (FMT.tag != fieldId) { + // TODO(dreiss): Old meta here. + assert(FTS->ttype != T_STOP); + assert(FMT.is_optional); + // Write a zero byte so the reader can skip it. + xfer += subWriteBool(false); + // And advance to the next field. + IDX++; + } + + // TODO(dreiss): give a better exception. + assert(FTS->ttype == fieldType); + + if (FMT.is_optional) { + subWriteBool(true); + xfer += 1; + } + + // writeFieldStop shares all lot of logic up to this point. + // Instead of replicating it all, we just call this method from that one + // and use a gross special case here. + if (UNLIKELY(FTS->ttype != T_STOP)) { + // For normal fields, push the TypeSpec that we're about to use. + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::writeFieldEnd() { + // Just move on to the next field. + IDX++; + return 0; +} + +uint32_t TDenseProtocol::writeFieldStop() { + return TDenseProtocol::writeFieldBegin("", T_STOP, 0); +} + +uint32_t TDenseProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + checkTType(T_MAP); + + assert(keyType == ST1->ttype); + assert(valType == ST2->ttype); + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeMapEnd() { + // Pop off the value type, as well as our entry in the map key/value stack. + // stateTransition takes care of popping off our TypeSpec. + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + checkTType(T_LIST); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeListEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + checkTType(T_SET); + + assert(elemType == ST1->ttype); + ts_stack_.push_back(ST1); + return subWriteI32((int32_t)size); +} + +uint32_t TDenseProtocol::writeSetEnd() { + // Pop off the element type. stateTransition takes care of popping off ours. + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::writeBool(const bool value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::writeBool(value); +} + +uint32_t TDenseProtocol::writeByte(const int8_t byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::writeByte(byte); +} + +uint32_t TDenseProtocol::writeI16(const int16_t i16) { + checkTType(T_I16); + stateTransition(); + return vlqWrite(i16); +} + +uint32_t TDenseProtocol::writeI32(const int32_t i32) { + checkTType(T_I32); + stateTransition(); + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::writeI64(const int64_t i64) { + checkTType(T_I64); + stateTransition(); + return vlqWrite(i64); +} + +uint32_t TDenseProtocol::writeDouble(const double dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::writeDouble(dub); +} + +uint32_t TDenseProtocol::writeString(const std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subWriteString(str); +} + +uint32_t TDenseProtocol::writeBinary(const std::string& str) { + return TDenseProtocol::writeString(str); +} + +inline uint32_t TDenseProtocol::subWriteI32(const int32_t i32) { + return vlqWrite(i32); +} + +uint32_t TDenseProtocol::subWriteString(const std::string& str) { + uint32_t size = str.size(); + uint32_t xfer = subWriteI32((int32_t)size); + if (size > 0) { + trans_->write((uint8_t*)str.data(), size); + } + return xfer + size; +} + + + +/* + * Reading functions + * + * These have a lot of the same logic as the writing functions, so if + * something is confusing, look for comments in the corresponding writer. + */ + +uint32_t TDenseProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TApplicationException("TDenseProtocol doesn't work with messages (yet)."); + + uint32_t xfer = 0; + int32_t sz; + xfer += subReadI32(sz); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_2) { + throw TProtocolException(TProtocolException::BAD_VERSION, "Bad version identifier"); + } + messageType = (TMessageType)(sz & 0x000000ff); + xfer += subReadString(name); + xfer += subReadI32(seqid); + } else { + throw TProtocolException(TProtocolException::BAD_VERSION, "No version identifier... old protocol client in strict mode?"); + } + return xfer; +} + +uint32_t TDenseProtocol::readMessageEnd() { + return 0; +} + +uint32_t TDenseProtocol::readStructBegin(string& name) { + uint32_t xfer = 0; + + if (ts_stack_.empty()) { + assert(standalone_); + + if (type_spec_ == NULL) { + resetState(); + throw TApplicationException("TDenseProtocol: No type specified."); + } else { + assert(type_spec_->ttype == T_STRUCT); + ts_stack_.push_back(type_spec_); + + // Check the fingerprint prefix. + uint8_t buf[FP_PREFIX_LEN]; + xfer += trans_->read(buf, FP_PREFIX_LEN); + if (std::memcmp(buf, type_spec_->fp_prefix, FP_PREFIX_LEN) != 0) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "Fingerprint in data does not match type_spec."); + } + } + } + + // We need a new field index for this structure. + idx_stack_.push_back(0); + return 0; +} + +uint32_t TDenseProtocol::readStructEnd() { + idx_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readFieldBegin(string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t xfer = 0; + + // For optional fields, check to see if they are there. + while (FMT.is_optional) { + bool is_present; + xfer += subReadBool(is_present); + if (is_present) { + break; + } + IDX++; + } + + // Once we hit a mandatory field, or an optional field that is present, + // we know that FMT and FTS point to the appropriate field. + + fieldId = FMT.tag; + fieldType = FTS->ttype; + + // Normally, we push the TypeSpec that we are about to read, + // but no reading is done for T_STOP. + if (FTS->ttype != T_STOP) { + ts_stack_.push_back(FTS); + } + return xfer; +} + +uint32_t TDenseProtocol::readFieldEnd() { + IDX++; + return 0; +} + +uint32_t TDenseProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + checkTType(T_MAP); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + keyType = ST1->ttype; + valType = ST2->ttype; + + ts_stack_.push_back(ST1); + mkv_stack_.push_back(true); + + return xfer; +} + +uint32_t TDenseProtocol::readMapEnd() { + ts_stack_.pop_back(); + mkv_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readListBegin(TType& elemType, + uint32_t& size) { + checkTType(T_LIST); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readListEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + checkTType(T_SET); + + uint32_t xfer = 0; + int32_t sizei; + xfer += subReadI32(sizei); + if (sizei < 0) { + resetState(); + throw TProtocolException(TProtocolException::NEGATIVE_SIZE); + } else if (container_limit_ && sizei > container_limit_) { + resetState(); + throw TProtocolException(TProtocolException::SIZE_LIMIT); + } + size = (uint32_t)sizei; + + elemType = ST1->ttype; + + ts_stack_.push_back(ST1); + + return xfer; +} + +uint32_t TDenseProtocol::readSetEnd() { + ts_stack_.pop_back(); + stateTransition(); + return 0; +} + +uint32_t TDenseProtocol::readBool(bool& value) { + checkTType(T_BOOL); + stateTransition(); + return TBinaryProtocol::readBool(value); +} + +uint32_t TDenseProtocol::readByte(int8_t& byte) { + checkTType(T_BYTE); + stateTransition(); + return TBinaryProtocol::readByte(byte); +} + +uint32_t TDenseProtocol::readI16(int16_t& i16) { + checkTType(T_I16); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT16_MAX || val < INT16_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i16 out of range."); + } + i16 = (int16_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI32(int32_t& i32) { + checkTType(T_I32); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::readI64(int64_t& i64) { + checkTType(T_I64); + stateTransition(); + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT64_MAX || val < INT64_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i64 out of range."); + } + i64 = (int64_t)val; + return rv; +} + +uint32_t TDenseProtocol::readDouble(double& dub) { + checkTType(T_DOUBLE); + stateTransition(); + return TBinaryProtocol::readDouble(dub); +} + +uint32_t TDenseProtocol::readString(std::string& str) { + checkTType(T_STRING); + stateTransition(); + return subReadString(str); +} + +uint32_t TDenseProtocol::readBinary(std::string& str) { + return TDenseProtocol::readString(str); +} + +uint32_t TDenseProtocol::subReadI32(int32_t& i32) { + uint64_t u64; + uint32_t rv = vlqRead(u64); + int64_t val = (int64_t)u64; + if (UNLIKELY(val > INT32_MAX || val < INT32_MIN)) { + resetState(); + throw TProtocolException(TProtocolException::INVALID_DATA, + "i32 out of range."); + } + i32 = (int32_t)val; + return rv; +} + +uint32_t TDenseProtocol::subReadString(std::string& str) { + uint32_t xfer; + int32_t size; + xfer = subReadI32(size); + return xfer + readStringBody(str, size); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TDenseProtocol.h b/lib/cpp/src/protocol/TDenseProtocol.h new file mode 100644 index 00000000..7655a479 --- /dev/null +++ b/lib/cpp/src/protocol/TDenseProtocol.h @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ 1 + +#include "TBinaryProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * !!!WARNING!!! + * This class is still highly experimental. Incompatible changes + * WILL be made to it without notice. DO NOT USE IT YET unless + * you are coordinating your testing with the author. + * + * The dense protocol is designed to use as little space as possible. + * + * There are two types of dense protocol instances. Standalone instances + * are not used for RPC and just encoded and decode structures of + * a predetermined type. Non-standalone instances are used for RPC. + * Currently, only standalone instances exist. + * + * To use a standalone dense protocol object, you must set the type_spec + * property (either in the constructor, or with setTypeSpec) to the local + * reflection TypeSpec of the structures you will write to (or read from) the + * protocol instance. + * + * BEST PRACTICES: + * - Never use optional for primitives or containers. + * - Only use optional for structures if they are very big and very rarely set. + * - All integers are variable-length, so you can use i64 without bloating. + * - NEVER EVER change the struct definitions IN ANY WAY without either + * changing your cache keys or talking to dreiss. + * + * TODO(dreiss): New class write with old meta. + * + * We override all of TBinaryProtocol's methods. + * We inherit so that we can can explicitly call TBPs's primitive-writing + * methods within our versions. + * + */ +class TDenseProtocol : public TBinaryProtocol { + protected: + static const int32_t VERSION_MASK = 0xffff0000; + // VERSION_1 (0x80010000) is taken by TBinaryProtocol. + static const int32_t VERSION_2 = 0x80020000; + + public: + typedef apache::thrift::reflection::local::TypeSpec TypeSpec; + static const int FP_PREFIX_LEN; + + /** + * @param tran The transport to use. + * @param type_spec The TypeSpec of the structures using this protocol. + */ + TDenseProtocol(boost::shared_ptr trans, + TypeSpec* type_spec = NULL) : + TBinaryProtocol(trans), + type_spec_(type_spec), + standalone_(true) + {} + + void setTypeSpec(TypeSpec* type_spec) { + type_spec_ = type_spec; + } + TypeSpec* getTypeSpec() { + return type_spec_; + } + + + /* + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + virtual uint32_t writeMessageEnd(); + + + virtual uint32_t writeStructBegin(const char* name); + + virtual uint32_t writeStructEnd(); + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + virtual uint32_t writeFieldEnd(); + + virtual uint32_t writeFieldStop(); + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + virtual uint32_t writeMapEnd(); + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeListEnd(); + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + virtual uint32_t writeSetEnd(); + + virtual uint32_t writeBool(const bool value); + + virtual uint32_t writeByte(const int8_t byte); + + virtual uint32_t writeI16(const int16_t i16); + + virtual uint32_t writeI32(const int32_t i32); + + virtual uint32_t writeI64(const int64_t i64); + + virtual uint32_t writeDouble(const double dub); + + virtual uint32_t writeString(const std::string& str); + + virtual uint32_t writeBinary(const std::string& str); + + + /* + * Helper writing functions (don't do state transitions). + */ + inline uint32_t subWriteI32(const int32_t i32); + + inline uint32_t subWriteString(const std::string& str); + + uint32_t subWriteBool(const bool value) { + return TBinaryProtocol::writeBool(value); + } + + + /* + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + /* + * Helper reading functions (don't do state transitions). + */ + inline uint32_t subReadI32(int32_t& i32); + + inline uint32_t subReadString(std::string& str); + + uint32_t subReadBool(bool& value) { + return TBinaryProtocol::readBool(value); + } + + + private: + + // Implementation functions, documented in the .cpp. + inline void checkTType(const TType ttype); + inline void stateTransition(); + + // Read and write variable-length integers. + // Uses the same technique as the MIDI file format. + inline uint32_t vlqRead(uint64_t& vlq); + inline uint32_t vlqWrite(uint64_t vlq); + + // Called before throwing an exception to make the object reusable. + void resetState() { + ts_stack_.clear(); + idx_stack_.clear(); + mkv_stack_.clear(); + } + + // TypeSpec of the top-level structure to write, + // for standalone protocol objects. + TypeSpec* type_spec_; + + std::vector ts_stack_; // TypeSpec stack. + std::vector idx_stack_; // InDeX stack. + std::vector mkv_stack_; // Map Key/Vlue stack. + // True = key, False = value. + + // True iff this is a standalone instance (no RPC). + bool standalone_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TDENSEPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TJSONProtocol.cpp b/lib/cpp/src/protocol/TJSONProtocol.cpp new file mode 100644 index 00000000..2a9c8f0b --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.cpp @@ -0,0 +1,998 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TJSONProtocol.h" + +#include +#include +#include "TBase64Utils.h" +#include + +using namespace apache::thrift::transport; + +namespace apache { namespace thrift { namespace protocol { + + +// Static data + +static const uint8_t kJSONObjectStart = '{'; +static const uint8_t kJSONObjectEnd = '}'; +static const uint8_t kJSONArrayStart = '['; +static const uint8_t kJSONArrayEnd = ']'; +static const uint8_t kJSONNewline = '\n'; +static const uint8_t kJSONPairSeparator = ':'; +static const uint8_t kJSONElemSeparator = ','; +static const uint8_t kJSONBackslash = '\\'; +static const uint8_t kJSONStringDelimiter = '"'; +static const uint8_t kJSONZeroChar = '0'; +static const uint8_t kJSONEscapeChar = 'u'; + +static const std::string kJSONEscapePrefix("\\u00"); + +static const uint32_t kThriftVersion1 = 1; + +static const std::string kThriftNan("NaN"); +static const std::string kThriftInfinity("Infinity"); +static const std::string kThriftNegativeInfinity("-Infinity"); + +static const std::string kTypeNameBool("tf"); +static const std::string kTypeNameByte("i8"); +static const std::string kTypeNameI16("i16"); +static const std::string kTypeNameI32("i32"); +static const std::string kTypeNameI64("i64"); +static const std::string kTypeNameDouble("dbl"); +static const std::string kTypeNameStruct("rec"); +static const std::string kTypeNameString("str"); +static const std::string kTypeNameMap("map"); +static const std::string kTypeNameList("lst"); +static const std::string kTypeNameSet("set"); + +static const std::string &getTypeNameForTypeID(TType typeID) { + switch (typeID) { + case T_BOOL: + return kTypeNameBool; + case T_BYTE: + return kTypeNameByte; + case T_I16: + return kTypeNameI16; + case T_I32: + return kTypeNameI32; + case T_I64: + return kTypeNameI64; + case T_DOUBLE: + return kTypeNameDouble; + case T_STRING: + return kTypeNameString; + case T_STRUCT: + return kTypeNameStruct; + case T_MAP: + return kTypeNameMap; + case T_SET: + return kTypeNameSet; + case T_LIST: + return kTypeNameList; + default: + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } +} + +static TType getTypeIDForTypeName(const std::string &name) { + TType result = T_STOP; // Sentinel value + if (name.length() > 1) { + switch (name[0]) { + case 'd': + result = T_DOUBLE; + break; + case 'i': + switch (name[1]) { + case '8': + result = T_BYTE; + break; + case '1': + result = T_I16; + break; + case '3': + result = T_I32; + break; + case '6': + result = T_I64; + break; + } + break; + case 'l': + result = T_LIST; + break; + case 'm': + result = T_MAP; + break; + case 'r': + result = T_STRUCT; + break; + case 's': + if (name[1] == 't') { + result = T_STRING; + } + else if (name[1] == 'e') { + result = T_SET; + } + break; + case 't': + result = T_BOOL; + break; + } + } + if (result == T_STOP) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + "Unrecognized type"); + } + return result; +} + + +// This table describes the handling for the first 0x30 characters +// 0 : escape using "\u00xx" notation +// 1 : just output index +// : escape using "\" notation +static const uint8_t kJSONCharTable[0x30] = { +// 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 +}; + + +// This string's characters must match up with the elements in kEscapeCharVals. +// I don't have '/' on this list even though it appears on www.json.org -- +// it is not in the RFC +const static std::string kEscapeChars("\"\\bfnrt"); + +// The elements of this array must match up with the sequence of characters in +// kEscapeChars +const static uint8_t kEscapeCharVals[7] = { + '"', '\\', '\b', '\f', '\n', '\r', '\t', +}; + + +// Static helper functions + +// Read 1 character from the transport trans and verify that it is the +// expected character ch. +// Throw a protocol exception if it is not. +static uint32_t readSyntaxChar(TJSONProtocol::LookaheadReader &reader, + uint8_t ch) { + uint8_t ch2 = reader.read(); + if (ch2 != ch) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected \'" + std::string((char *)&ch, 1) + + "\'; got \'" + std::string((char *)&ch2, 1) + + "\'."); + } + return 1; +} + +// Return the integer value of a hex character ch. +// Throw a protocol exception if the character is not [0-9a-f]. +static uint8_t hexVal(uint8_t ch) { + if ((ch >= '0') && (ch <= '9')) { + return ch - '0'; + } + else if ((ch >= 'a') && (ch <= 'f')) { + return ch - 'a'; + } + else { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected hex val ([0-9a-f]); got \'" + + std::string((char *)&ch, 1) + "\'."); + } +} + +// Return the hex character representing the integer val. The value is masked +// to make sure it is in the correct range. +static uint8_t hexChar(uint8_t val) { + val &= 0x0F; + if (val < 10) { + return val + '0'; + } + else { + return val + 'a'; + } +} + +// Return true if the character ch is in [-+0-9.Ee]; false otherwise +static bool isJSONNumeric(uint8_t ch) { + switch (ch) { + case '+': + case '-': + case '.': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'E': + case 'e': + return true; + } + return false; +} + + +/** + * Class to serve as base JSON context and as base class for other context + * implementations + */ +class TJSONContext { + + public: + + TJSONContext() {}; + + virtual ~TJSONContext() {}; + + /** + * Write context data to the transport. Default is to do nothing. + */ + virtual uint32_t write(TTransport &trans) { + return 0; + }; + + /** + * Read context data from the transport. Default is to do nothing. + */ + virtual uint32_t read(TJSONProtocol::LookaheadReader &reader) { + return 0; + }; + + /** + * Return true if numbers need to be escaped as strings in this context. + * Default behavior is to return false. + */ + virtual bool escapeNum() { + return false; + } +}; + +// Context class for object member key-value pairs +class JSONPairContext : public TJSONContext { + +public: + + JSONPairContext() : + first_(true), + colon_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + trans.write(colon_ ? &kJSONPairSeparator : &kJSONElemSeparator, 1); + colon_ = !colon_; + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + colon_ = true; + return 0; + } + else { + uint8_t ch = (colon_ ? kJSONPairSeparator : kJSONElemSeparator); + colon_ = !colon_; + return readSyntaxChar(reader, ch); + } + } + + // Numbers must be turned into strings if they are the key part of a pair + virtual bool escapeNum() { + return colon_; + } + + private: + + bool first_; + bool colon_; +}; + +// Context class for lists +class JSONListContext : public TJSONContext { + +public: + + JSONListContext() : + first_(true) { + } + + uint32_t write(TTransport &trans) { + if (first_) { + first_ = false; + return 0; + } + else { + trans.write(&kJSONElemSeparator, 1); + return 1; + } + } + + uint32_t read(TJSONProtocol::LookaheadReader &reader) { + if (first_) { + first_ = false; + return 0; + } + else { + return readSyntaxChar(reader, kJSONElemSeparator); + } + } + + private: + bool first_; +}; + + +TJSONProtocol::TJSONProtocol(boost::shared_ptr ptrans) : + TProtocol(ptrans), + context_(new TJSONContext()), + reader_(*ptrans) { +} + +TJSONProtocol::~TJSONProtocol() {} + +void TJSONProtocol::pushContext(boost::shared_ptr c) { + contexts_.push(context_); + context_ = c; +} + +void TJSONProtocol::popContext() { + context_ = contexts_.top(); + contexts_.pop(); +} + +// Write the character ch as a JSON escape sequence ("\u00xx") +uint32_t TJSONProtocol::writeJSONEscapeChar(uint8_t ch) { + trans_->write((const uint8_t *)kJSONEscapePrefix.c_str(), + kJSONEscapePrefix.length()); + uint8_t outCh = hexChar(ch >> 4); + trans_->write(&outCh, 1); + outCh = hexChar(ch); + trans_->write(&outCh, 1); + return 6; +} + +// Write the character ch as part of a JSON string, escaping as appropriate. +uint32_t TJSONProtocol::writeJSONChar(uint8_t ch) { + if (ch >= 0x30) { + if (ch == kJSONBackslash) { // Only special character >= 0x30 is '\' + trans_->write(&kJSONBackslash, 1); + trans_->write(&kJSONBackslash, 1); + return 2; + } + else { + trans_->write(&ch, 1); + return 1; + } + } + else { + uint8_t outCh = kJSONCharTable[ch]; + // Check if regular character, backslash escaped, or JSON escaped + if (outCh == 1) { + trans_->write(&ch, 1); + return 1; + } + else if (outCh > 1) { + trans_->write(&kJSONBackslash, 1); + trans_->write(&outCh, 1); + return 2; + } + else { + return writeJSONEscapeChar(ch); + } + } +} + +// Write out the contents of the string str as a JSON string, escaping +// characters as appropriate. +uint32_t TJSONProtocol::writeJSONString(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + std::string::const_iterator iter(str.begin()); + std::string::const_iterator end(str.end()); + while (iter != end) { + result += writeJSONChar(*iter++); + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Write out the contents of the string as JSON string, base64-encoding +// the string's contents, and escaping as appropriate +uint32_t TJSONProtocol::writeJSONBase64(const std::string &str) { + uint32_t result = context_->write(*trans_); + result += 2; // For quotes + trans_->write(&kJSONStringDelimiter, 1); + uint8_t b[4]; + const uint8_t *bytes = (const uint8_t *)str.c_str(); + uint32_t len = str.length(); + while (len >= 3) { + // Encode 3 bytes at a time + base64_encode(bytes, 3, b); + trans_->write(b, 4); + result += 4; + bytes += 3; + len -=3; + } + if (len) { // Handle remainder + base64_encode(bytes, len, b); + trans_->write(b, len + 1); + result += len + 1; + } + trans_->write(&kJSONStringDelimiter, 1); + return result; +} + +// Convert the given integer type to a JSON number, or a string +// if the context requires it (eg: key in a map pair). +template +uint32_t TJSONProtocol::writeJSONInteger(NumberType num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast(num)); + bool escapeNum = context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +// Convert the given double to a JSON string, which is either the number, +// "NaN" or "Infinity" or "-Infinity". +uint32_t TJSONProtocol::writeJSONDouble(double num) { + uint32_t result = context_->write(*trans_); + std::string val(boost::lexical_cast(num)); + + // Normalize output of boost::lexical_cast for NaNs and Infinities + bool special = false; + switch (val[0]) { + case 'N': + case 'n': + val = kThriftNan; + special = true; + break; + case 'I': + case 'i': + val = kThriftInfinity; + special = true; + break; + case '-': + if ((val[1] == 'I') || (val[1] == 'i')) { + val = kThriftNegativeInfinity; + special = true; + } + break; + } + + bool escapeNum = special || context_->escapeNum(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + trans_->write((const uint8_t *)val.c_str(), val.length()); + result += val.length(); + if (escapeNum) { + trans_->write(&kJSONStringDelimiter, 1); + result += 1; + } + return result; +} + +uint32_t TJSONProtocol::writeJSONObjectStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONObjectStart, 1); + pushContext(boost::shared_ptr(new JSONPairContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONObjectEnd() { + popContext(); + trans_->write(&kJSONObjectEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeJSONArrayStart() { + uint32_t result = context_->write(*trans_); + trans_->write(&kJSONArrayStart, 1); + pushContext(boost::shared_ptr(new JSONListContext())); + return result + 1; +} + +uint32_t TJSONProtocol::writeJSONArrayEnd() { + popContext(); + trans_->write(&kJSONArrayEnd, 1); + return 1; +} + +uint32_t TJSONProtocol::writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONInteger(kThriftVersion1); + result += writeJSONString(name); + result += writeJSONInteger(messageType); + result += writeJSONInteger(seqid); + return result; +} + +uint32_t TJSONProtocol::writeMessageEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeStructBegin(const char* name) { + return writeJSONObjectStart(); +} + +uint32_t TJSONProtocol::writeStructEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + uint32_t result = writeJSONInteger(fieldId); + result += writeJSONObjectStart(); + result += writeJSONString(getTypeNameForTypeID(fieldType)); + return result; +} + +uint32_t TJSONProtocol::writeFieldEnd() { + return writeJSONObjectEnd(); +} + +uint32_t TJSONProtocol::writeFieldStop() { + return 0; +} + +uint32_t TJSONProtocol::writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(keyType)); + result += writeJSONString(getTypeNameForTypeID(valType)); + result += writeJSONInteger((int64_t)size); + result += writeJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::writeMapEnd() { + return writeJSONObjectEnd() + writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeListBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeListEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeSetBegin(const TType elemType, + const uint32_t size) { + uint32_t result = writeJSONArrayStart(); + result += writeJSONString(getTypeNameForTypeID(elemType)); + result += writeJSONInteger((int64_t)size); + return result; +} + +uint32_t TJSONProtocol::writeSetEnd() { + return writeJSONArrayEnd(); +} + +uint32_t TJSONProtocol::writeBool(const bool value) { + return writeJSONInteger(value); +} + +uint32_t TJSONProtocol::writeByte(const int8_t byte) { + // writeByte() must be handled specially becuase boost::lexical cast sees + // int8_t as a text type instead of an integer type + return writeJSONInteger((int16_t)byte); +} + +uint32_t TJSONProtocol::writeI16(const int16_t i16) { + return writeJSONInteger(i16); +} + +uint32_t TJSONProtocol::writeI32(const int32_t i32) { + return writeJSONInteger(i32); +} + +uint32_t TJSONProtocol::writeI64(const int64_t i64) { + return writeJSONInteger(i64); +} + +uint32_t TJSONProtocol::writeDouble(const double dub) { + return writeJSONDouble(dub); +} + +uint32_t TJSONProtocol::writeString(const std::string& str) { + return writeJSONString(str); +} + +uint32_t TJSONProtocol::writeBinary(const std::string& str) { + return writeJSONBase64(str); +} + + /** + * Reading functions + */ + +// Reads 1 byte and verifies that it matches ch. +uint32_t TJSONProtocol::readJSONSyntaxChar(uint8_t ch) { + return readSyntaxChar(reader_, ch); +} + +// Decodes the four hex parts of a JSON escaped string character and returns +// the character via out. The first two characters must be "00". +uint32_t TJSONProtocol::readJSONEscapeChar(uint8_t *out) { + uint8_t b[2]; + readJSONSyntaxChar(kJSONZeroChar); + readJSONSyntaxChar(kJSONZeroChar); + b[0] = reader_.read(); + b[1] = reader_.read(); + *out = (hexVal(b[0]) << 4) + hexVal(b[1]); + return 4; +} + +// Decodes a JSON string, including unescaping, and returns the string via str +uint32_t TJSONProtocol::readJSONString(std::string &str, bool skipContext) { + uint32_t result = (skipContext ? 0 : context_->read(reader_)); + result += readJSONSyntaxChar(kJSONStringDelimiter); + uint8_t ch; + str.clear(); + while (true) { + ch = reader_.read(); + ++result; + if (ch == kJSONStringDelimiter) { + break; + } + if (ch == kJSONBackslash) { + ch = reader_.read(); + ++result; + if (ch == kJSONEscapeChar) { + result += readJSONEscapeChar(&ch); + } + else { + size_t pos = kEscapeChars.find(ch); + if (pos == std::string::npos) { + throw TProtocolException(TProtocolException::INVALID_DATA, + "Expected control char, got '" + + std::string((const char *)&ch, 1) + "'."); + } + ch = kEscapeCharVals[pos]; + } + } + str += ch; + } + return result; +} + +// Reads a block of base64 characters, decoding it, and returns via str +uint32_t TJSONProtocol::readJSONBase64(std::string &str) { + std::string tmp; + uint32_t result = readJSONString(tmp); + uint8_t *b = (uint8_t *)tmp.c_str(); + uint32_t len = tmp.length(); + str.clear(); + while (len >= 4) { + base64_decode(b, 4); + str.append((const char *)b, 3); + b += 4; + len -= 4; + } + // Don't decode if we hit the end or got a single leftover byte (invalid + // base64 but legal for skip of regular string type) + if (len > 1) { + base64_decode(b, len); + str.append((const char *)b, len - 1); + } + return result; +} + +// Reads a sequence of characters, stopping at the first one that is not +// a valid JSON numeric character. +uint32_t TJSONProtocol::readJSONNumericChars(std::string &str) { + uint32_t result = 0; + str.clear(); + while (true) { + uint8_t ch = reader_.peek(); + if (!isJSONNumeric(ch)) { + break; + } + reader_.read(); + str += ch; + ++result; + } + return result; +} + +// Reads a sequence of characters and assembles them into a number, +// returning them via num +template +uint32_t TJSONProtocol::readJSONInteger(NumberType &num) { + uint32_t result = context_->read(reader_); + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + std::string str; + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + if (context_->escapeNum()) { + result += readJSONSyntaxChar(kJSONStringDelimiter); + } + return result; +} + +// Reads a JSON number or string and interprets it as a double. +uint32_t TJSONProtocol::readJSONDouble(double &num) { + uint32_t result = context_->read(reader_); + std::string str; + if (reader_.peek() == kJSONStringDelimiter) { + result += readJSONString(str, true); + // Check for NaN, Infinity and -Infinity + if (str == kThriftNan) { + num = HUGE_VAL/HUGE_VAL; // generates NaN + } + else if (str == kThriftInfinity) { + num = HUGE_VAL; + } + else if (str == kThriftNegativeInfinity) { + num = -HUGE_VAL; + } + else { + if (!context_->escapeNum()) { + // Throw exception -- we should not be in a string in this case + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Numeric data unexpectedly quoted"); + } + try { + num = boost::lexical_cast(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + } + else { + if (context_->escapeNum()) { + // This will throw - we should have had a quote if escapeNum == true + readJSONSyntaxChar(kJSONStringDelimiter); + } + result += readJSONNumericChars(str); + try { + num = boost::lexical_cast(str); + } + catch (boost::bad_lexical_cast e) { + throw new TProtocolException(TProtocolException::INVALID_DATA, + "Expected numeric value; got \"" + str + + "\""); + } + } + return result; +} + +uint32_t TJSONProtocol::readJSONObjectStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONObjectStart); + pushContext(boost::shared_ptr(new JSONPairContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONObjectEnd() { + uint32_t result = readJSONSyntaxChar(kJSONObjectEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayStart() { + uint32_t result = context_->read(reader_); + result += readJSONSyntaxChar(kJSONArrayStart); + pushContext(boost::shared_ptr(new JSONListContext())); + return result; +} + +uint32_t TJSONProtocol::readJSONArrayEnd() { + uint32_t result = readJSONSyntaxChar(kJSONArrayEnd); + popContext(); + return result; +} + +uint32_t TJSONProtocol::readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t result = readJSONArrayStart(); + uint64_t tmpVal = 0; + result += readJSONInteger(tmpVal); + if (tmpVal != kThriftVersion1) { + throw TProtocolException(TProtocolException::BAD_VERSION, + "Message contained bad version."); + } + result += readJSONString(name); + result += readJSONInteger(tmpVal); + messageType = (TMessageType)tmpVal; + result += readJSONInteger(tmpVal); + seqid = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readMessageEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readStructBegin(std::string& name) { + return readJSONObjectStart(); +} + +uint32_t TJSONProtocol::readStructEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t result = 0; + // Check if we hit the end of the list + uint8_t ch = reader_.peek(); + if (ch == kJSONObjectEnd) { + fieldType = apache::thrift::protocol::T_STOP; + } + else { + uint64_t tmpVal = 0; + std::string tmpStr; + result += readJSONInteger(tmpVal); + fieldId = tmpVal; + result += readJSONObjectStart(); + result += readJSONString(tmpStr); + fieldType = getTypeIDForTypeName(tmpStr); + } + return result; +} + +uint32_t TJSONProtocol::readFieldEnd() { + return readJSONObjectEnd(); +} + +uint32_t TJSONProtocol::readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + keyType = getTypeIDForTypeName(tmpStr); + result += readJSONString(tmpStr); + valType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + result += readJSONObjectStart(); + return result; +} + +uint32_t TJSONProtocol::readMapEnd() { + return readJSONObjectEnd() + readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readListBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readListEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readSetBegin(TType& elemType, + uint32_t& size) { + uint64_t tmpVal = 0; + std::string tmpStr; + uint32_t result = readJSONArrayStart(); + result += readJSONString(tmpStr); + elemType = getTypeIDForTypeName(tmpStr); + result += readJSONInteger(tmpVal); + size = tmpVal; + return result; +} + +uint32_t TJSONProtocol::readSetEnd() { + return readJSONArrayEnd(); +} + +uint32_t TJSONProtocol::readBool(bool& value) { + return readJSONInteger(value); +} + +// readByte() must be handled properly becuase boost::lexical cast sees int8_t +// as a text type instead of an integer type +uint32_t TJSONProtocol::readByte(int8_t& byte) { + int16_t tmp = (int16_t) byte; + uint32_t result = readJSONInteger(tmp); + assert(tmp < 256); + byte = (int8_t)tmp; + return result; +} + +uint32_t TJSONProtocol::readI16(int16_t& i16) { + return readJSONInteger(i16); +} + +uint32_t TJSONProtocol::readI32(int32_t& i32) { + return readJSONInteger(i32); +} + +uint32_t TJSONProtocol::readI64(int64_t& i64) { + return readJSONInteger(i64); +} + +uint32_t TJSONProtocol::readDouble(double& dub) { + return readJSONDouble(dub); +} + +uint32_t TJSONProtocol::readString(std::string &str) { + return readJSONString(str); +} + +uint32_t TJSONProtocol::readBinary(std::string &str) { + return readJSONBase64(str); +} + +}}} // apache::thrift::protocol diff --git a/lib/cpp/src/protocol/TJSONProtocol.h b/lib/cpp/src/protocol/TJSONProtocol.h new file mode 100644 index 00000000..2df499ac --- /dev/null +++ b/lib/cpp/src/protocol/TJSONProtocol.h @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 + +#include "TProtocol.h" + +#include + +namespace apache { namespace thrift { namespace protocol { + +// Forward declaration +class TJSONContext; + +/** + * JSON protocol for Thrift. + * + * Implements a protocol which uses JSON as the wire-format. + * + * Thrift types are represented as described below: + * + * 1. Every Thrift integer type is represented as a JSON number. + * + * 2. Thrift doubles are represented as JSON numbers. Some special values are + * represented as strings: + * a. "NaN" for not-a-number values + * b. "Infinity" for postive infinity + * c. "-Infinity" for negative infinity + * + * 3. Thrift string values are emitted as JSON strings, with appropriate + * escaping. + * + * 4. Thrift binary values are encoded into Base64 and emitted as JSON strings. + * The readBinary() method is written such that it will properly skip if + * called on a Thrift string (although it will decode garbage data). + * + * 5. Thrift structs are represented as JSON objects, with the field ID as the + * key, and the field value represented as a JSON object with a single + * key-value pair. The key is a short string identifier for that type, + * followed by the value. The valid type identifiers are: "tf" for bool, + * "i8" for byte, "i16" for 16-bit integer, "i32" for 32-bit integer, "i64" + * for 64-bit integer, "dbl" for double-precision loating point, "str" for + * string (including binary), "rec" for struct ("records"), "map" for map, + * "lst" for list, "set" for set. + * + * 6. Thrift lists and sets are represented as JSON arrays, with the first + * element of the JSON array being the string identifier for the Thrift + * element type and the second element of the JSON array being the count of + * the Thrift elements. The Thrift elements then follow. + * + * 7. Thrift maps are represented as JSON arrays, with the first two elements + * of the JSON array being the string identifiers for the Thrift key type + * and value type, followed by the count of the Thrift pairs, followed by a + * JSON object containing the key-value pairs. Note that JSON keys can only + * be strings, which means that the key type of the Thrift map should be + * restricted to numeric or string types -- in the case of numerics, they + * are serialized as strings. + * + * 8. Thrift messages are represented as JSON arrays, with the protocol + * version #, the message name, the message type, and the sequence ID as + * the first 4 elements. + * + * More discussion of the double handling is probably warranted. The aim of + * the current implementation is to match as closely as possible the behavior + * of Java's Double.toString(), which has no precision loss. Implementors in + * other languages should strive to achieve that where possible. I have not + * yet verified whether boost:lexical_cast, which is doing that work for me in + * C++, loses any precision, but I am leaving this as a future improvement. I + * may try to provide a C component for this, so that other languages could + * bind to the same underlying implementation for maximum consistency. + * + * Note further that JavaScript itself is not capable of representing + * floating point infinities -- presumably when we have a JavaScript Thrift + * client, this would mean that infinities get converted to not-a-number in + * transmission. I don't know of any work-around for this issue. + * + */ +class TJSONProtocol : public TProtocol { + public: + + TJSONProtocol(boost::shared_ptr ptrans); + + ~TJSONProtocol(); + + private: + + void pushContext(boost::shared_ptr c); + + void popContext(); + + uint32_t writeJSONEscapeChar(uint8_t ch); + + uint32_t writeJSONChar(uint8_t ch); + + uint32_t writeJSONString(const std::string &str); + + uint32_t writeJSONBase64(const std::string &str); + + template + uint32_t writeJSONInteger(NumberType num); + + uint32_t writeJSONDouble(double num); + + uint32_t writeJSONObjectStart() ; + + uint32_t writeJSONObjectEnd(); + + uint32_t writeJSONArrayStart(); + + uint32_t writeJSONArrayEnd(); + + uint32_t readJSONSyntaxChar(uint8_t ch); + + uint32_t readJSONEscapeChar(uint8_t *out); + + uint32_t readJSONString(std::string &str, bool skipContext = false); + + uint32_t readJSONBase64(std::string &str); + + uint32_t readJSONNumericChars(std::string &str); + + template + uint32_t readJSONInteger(NumberType &num); + + uint32_t readJSONDouble(double &num); + + uint32_t readJSONObjectStart(); + + uint32_t readJSONObjectEnd(); + + uint32_t readJSONArrayStart(); + + uint32_t readJSONArrayEnd(); + + public: + + /** + * Writing functions. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid); + + uint32_t writeMessageEnd(); + + uint32_t writeStructBegin(const char* name); + + uint32_t writeStructEnd(); + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId); + + uint32_t writeFieldEnd(); + + uint32_t writeFieldStop(); + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size); + + uint32_t writeMapEnd(); + + uint32_t writeListBegin(const TType elemType, + const uint32_t size); + + uint32_t writeListEnd(); + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size); + + uint32_t writeSetEnd(); + + uint32_t writeBool(const bool value); + + uint32_t writeByte(const int8_t byte); + + uint32_t writeI16(const int16_t i16); + + uint32_t writeI32(const int32_t i32); + + uint32_t writeI64(const int64_t i64); + + uint32_t writeDouble(const double dub); + + uint32_t writeString(const std::string& str); + + uint32_t writeBinary(const std::string& str); + + /** + * Reading functions + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid); + + uint32_t readMessageEnd(); + + uint32_t readStructBegin(std::string& name); + + uint32_t readStructEnd(); + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId); + + uint32_t readFieldEnd(); + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size); + + uint32_t readMapEnd(); + + uint32_t readListBegin(TType& elemType, + uint32_t& size); + + uint32_t readListEnd(); + + uint32_t readSetBegin(TType& elemType, + uint32_t& size); + + uint32_t readSetEnd(); + + uint32_t readBool(bool& value); + + uint32_t readByte(int8_t& byte); + + uint32_t readI16(int16_t& i16); + + uint32_t readI32(int32_t& i32); + + uint32_t readI64(int64_t& i64); + + uint32_t readDouble(double& dub); + + uint32_t readString(std::string& str); + + uint32_t readBinary(std::string& str); + + class LookaheadReader { + + public: + + LookaheadReader(TTransport &trans) : + trans_(&trans), + hasData_(false) { + } + + uint8_t read() { + if (hasData_) { + hasData_ = false; + } + else { + trans_->readAll(&data_, 1); + } + return data_; + } + + uint8_t peek() { + if (!hasData_) { + trans_->readAll(&data_, 1); + } + hasData_ = true; + return data_; + } + + private: + TTransport *trans_; + bool hasData_; + uint8_t data_; + }; + + private: + + std::stack > contexts_; + boost::shared_ptr context_; + LookaheadReader reader_; +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TJSONProtocolFactory : public TProtocolFactory { + public: + TJSONProtocolFactory() {} + + virtual ~TJSONProtocolFactory() {} + + boost::shared_ptr getProtocol(boost::shared_ptr trans) { + return boost::shared_ptr(new TJSONProtocol(trans)); + } +}; + +}}} // apache::thrift::protocol + + +// TODO(dreiss): Move part of ThriftJSONString into a .cpp file and remove this. +#include + +namespace apache { namespace thrift { + +template + std::string ThriftJSONString(const ThriftStruct& ts) { + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + TMemoryBuffer* buffer = new TMemoryBuffer; + boost::shared_ptr trans(buffer); + TJSONProtocol protocol(trans); + + ts.write(&protocol); + + uint8_t* buf; + uint32_t size; + buffer->getBuffer(&buf, &size); + return std::string((char*)buf, (unsigned int)size); +} + +}} // apache::thrift + +#endif // #define _THRIFT_PROTOCOL_TJSONPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TOneWayProtocol.h b/lib/cpp/src/protocol/TOneWayProtocol.h new file mode 100644 index 00000000..6f08fe1d --- /dev/null +++ b/lib/cpp/src/protocol/TOneWayProtocol.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TONEWAYPROTOCOL_H_ 1 + +#include "TProtocol.h" + +namespace apache { namespace thrift { namespace protocol { + +/** + * Abstract class for implementing a protocol that can only be written, + * not read. + * + */ +class TWriteOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TWriteOnlyProtocol(boost::shared_ptr trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All writing functions remain abstract. + + /** + * Reading functions all throw an exception. + */ + + uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructBegin(std::string& name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBool(bool& value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readByte(int8_t& byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI16(int16_t& i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI32(int32_t& i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readI64(int64_t& i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readDouble(double& dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readString(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + uint32_t readBinary(std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support reading (yet)."); + } + + private: + std::string subclass_; +}; + + +/** + * Abstract class for implementing a protocol that can only be read, + * not written. + * + */ +class TReadOnlyProtocol : public TProtocol { + public: + /** + * @param subclass_name The name of the concrete subclass. + */ + TReadOnlyProtocol(boost::shared_ptr trans, + const std::string& subclass_name) + : TProtocol(trans) + , subclass_(subclass_name) + {} + + // All reading functions remain abstract. + + /** + * Writing functions all throw an exception. + */ + + uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMessageEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + + uint32_t writeStructBegin(const char* name) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeStructEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeFieldStop() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeMapEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeListEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetBegin(const TType elemType, + const uint32_t size) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeSetEnd() { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBool(const bool value) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeByte(const int8_t byte) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI16(const int16_t i16) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI32(const int32_t i32) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeI64(const int64_t i64) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeDouble(const double dub) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeString(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + uint32_t writeBinary(const std::string& str) { + throw TProtocolException(TProtocolException::NOT_IMPLEMENTED, + subclass_ + " does not support writing (yet)."); + } + + private: + std::string subclass_; +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TBINARYPROTOCOL_H_ diff --git a/lib/cpp/src/protocol/TProtocol.h b/lib/cpp/src/protocol/TProtocol.h new file mode 100644 index 00000000..40258277 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocol.h @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOL_H_ +#define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 + +#include +#include + +#include +#include + +#include +#include +#include +#include + + +// Use this to get around strict aliasing rules. +// For example, uint64_t i = bitwise_cast(returns_double()); +// The most obvious implementation is to just cast a pointer, +// but that doesn't work. +// For a pretty in-depth explanation of the problem, see +// http://www.cellperformance.com/mike_acton/2006/06/ (...) +// understanding_strict_aliasing.html +template +static inline To bitwise_cast(From from) { + BOOST_STATIC_ASSERT(sizeof(From) == sizeof(To)); + + // BAD!!! These are all broken with -O2. + //return *reinterpret_cast(&from); // BAD!!! + //return *static_cast(static_cast(&from)); // BAD!!! + //return *(To*)(void*)&from; // BAD!!! + + // Super clean and paritally blessed by section 3.9 of the standard. + //unsigned char c[sizeof(from)]; + //memcpy(c, &from, sizeof(from)); + //To to; + //memcpy(&to, c, sizeof(c)); + //return to; + + // Slightly more questionable. + // Same code emitted by GCC. + //To to; + //memcpy(&to, &from, sizeof(from)); + //return to; + + // Technically undefined, but almost universally supported, + // and the most efficient implementation. + union { + From f; + To t; + } u; + u.f = from; + return u.t; +} + + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +#ifdef HAVE_ENDIAN_H +#include +#endif + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include +# define htolell(n) bswap_64(n) +# define letohll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define bswap_64(n) \ + ( (((n) & 0xff00000000000000ull) >> 56) \ + | (((n) & 0x00ff000000000000ull) >> 40) \ + | (((n) & 0x0000ff0000000000ull) >> 24) \ + | (((n) & 0x000000ff00000000ull) >> 8) \ + | (((n) & 0x00000000ff000000ull) << 8) \ + | (((n) & 0x0000000000ff0000ull) << 24) \ + | (((n) & 0x000000000000ff00ull) << 40) \ + | (((n) & 0x00000000000000ffull) << 56) ) +# define ntolell(n) bswap_64(n) +# define letonll(n) bswap_64(n) +# endif /* GNUC & GLIBC */ +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# define htolell(n) (n) +# define letohll(n) (n) +# if defined(__GNUC__) && defined(__GLIBC__) +# include +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +/** + * Enumerated definition of the types that the Thrift protocol supports. + * Take special note of the T_END type which is used specifically to mark + * the end of a sequence of fields. + */ +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +/** + * Enumerated definition of the message types that the Thrift protocol + * supports. + */ +enum TMessageType { + T_CALL = 1, + T_REPLY = 2, + T_EXCEPTION = 3, + T_ONEWAY = 4 +}; + +/** + * Abstract class for a thrift protocol driver. These are all the methods that + * a protocol must implement. Essentially, there must be some way of reading + * and writing all the base types, plus a mechanism for writing out structs + * with indexed fields. + * + * TProtocol objects should not be shared across multiple encoding contexts, + * as they may need to maintain internal state in some protocols (i.e. XML). + * Note that is is acceptable for the TProtocol module to do its own internal + * buffered reads/writes to the underlying TTransport where appropriate (i.e. + * when parsing an input XML stream, reading should be batched rather than + * looking ahead character by character for a close tag). + * + */ +class TProtocol { + public: + virtual ~TProtocol() {} + + /** + * Writing functions. + */ + + virtual uint32_t writeMessageBegin(const std::string& name, + const TMessageType messageType, + const int32_t seqid) = 0; + + virtual uint32_t writeMessageEnd() = 0; + + + virtual uint32_t writeStructBegin(const char* name) = 0; + + virtual uint32_t writeStructEnd() = 0; + + virtual uint32_t writeFieldBegin(const char* name, + const TType fieldType, + const int16_t fieldId) = 0; + + virtual uint32_t writeFieldEnd() = 0; + + virtual uint32_t writeFieldStop() = 0; + + virtual uint32_t writeMapBegin(const TType keyType, + const TType valType, + const uint32_t size) = 0; + + virtual uint32_t writeMapEnd() = 0; + + virtual uint32_t writeListBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeListEnd() = 0; + + virtual uint32_t writeSetBegin(const TType elemType, + const uint32_t size) = 0; + + virtual uint32_t writeSetEnd() = 0; + + virtual uint32_t writeBool(const bool value) = 0; + + virtual uint32_t writeByte(const int8_t byte) = 0; + + virtual uint32_t writeI16(const int16_t i16) = 0; + + virtual uint32_t writeI32(const int32_t i32) = 0; + + virtual uint32_t writeI64(const int64_t i64) = 0; + + virtual uint32_t writeDouble(const double dub) = 0; + + virtual uint32_t writeString(const std::string& str) = 0; + + virtual uint32_t writeBinary(const std::string& str) = 0; + + /** + * Reading functions + */ + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) = 0; + + virtual uint32_t readMessageEnd() = 0; + + virtual uint32_t readStructBegin(std::string& name) = 0; + + virtual uint32_t readStructEnd() = 0; + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) = 0; + + virtual uint32_t readFieldEnd() = 0; + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) = 0; + + virtual uint32_t readMapEnd() = 0; + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readListEnd() = 0; + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) = 0; + + virtual uint32_t readSetEnd() = 0; + + virtual uint32_t readBool(bool& value) = 0; + + virtual uint32_t readByte(int8_t& byte) = 0; + + virtual uint32_t readI16(int16_t& i16) = 0; + + virtual uint32_t readI32(int32_t& i32) = 0; + + virtual uint32_t readI64(int64_t& i64) = 0; + + virtual uint32_t readDouble(double& dub) = 0; + + virtual uint32_t readString(std::string& str) = 0; + + virtual uint32_t readBinary(std::string& str) = 0; + + uint32_t readBool(std::vector::reference ref) { + bool value; + uint32_t rv = readBool(value); + ref = value; + return rv; + } + + /** + * Method to arbitrarily skip over data. + */ + uint32_t skip(TType type) { + switch (type) { + case T_BOOL: + { + bool boolv; + return readBool(boolv); + } + case T_BYTE: + { + int8_t bytev; + return readByte(bytev); + } + case T_I16: + { + int16_t i16; + return readI16(i16); + } + case T_I32: + { + int32_t i32; + return readI32(i32); + } + case T_I64: + { + int64_t i64; + return readI64(i64); + } + case T_DOUBLE: + { + double dub; + return readDouble(dub); + } + case T_STRING: + { + std::string str; + return readBinary(str); + } + case T_STRUCT: + { + uint32_t result = 0; + std::string name; + int16_t fid; + TType ftype; + result += readStructBegin(name); + while (true) { + result += readFieldBegin(name, ftype, fid); + if (ftype == T_STOP) { + break; + } + result += skip(ftype); + result += readFieldEnd(); + } + result += readStructEnd(); + return result; + } + case T_MAP: + { + uint32_t result = 0; + TType keyType; + TType valType; + uint32_t i, size; + result += readMapBegin(keyType, valType, size); + for (i = 0; i < size; i++) { + result += skip(keyType); + result += skip(valType); + } + result += readMapEnd(); + return result; + } + case T_SET: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readSetBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readSetEnd(); + return result; + } + case T_LIST: + { + uint32_t result = 0; + TType elemType; + uint32_t i, size; + result += readListBegin(elemType, size); + for (i = 0; i < size; i++) { + result += skip(elemType); + } + result += readListEnd(); + return result; + } + default: + return 0; + } + } + + inline boost::shared_ptr getTransport() { + return ptrans_; + } + + // TODO: remove these two calls, they are for backwards + // compatibility + inline boost::shared_ptr getInputTransport() { + return ptrans_; + } + inline boost::shared_ptr getOutputTransport() { + return ptrans_; + } + + protected: + TProtocol(boost::shared_ptr ptrans): + ptrans_(ptrans) { + trans_ = ptrans.get(); + } + + boost::shared_ptr ptrans_; + TTransport* trans_; + + private: + TProtocol() {} +}; + +/** + * Constructs input and output protocol objects given transports. + */ +class TProtocolFactory { + public: + TProtocolFactory() {} + + virtual ~TProtocolFactory() {} + + virtual boost::shared_ptr getProtocol(boost::shared_ptr trans) = 0; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1 diff --git a/lib/cpp/src/protocol/TProtocolException.h b/lib/cpp/src/protocol/TProtocolException.h new file mode 100644 index 00000000..33011b37 --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolException.h @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ 1 + +#include + +namespace apache { namespace thrift { namespace protocol { + +/** + * Class to encapsulate all the possible types of protocol errors that may + * occur in various protocol systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of protocols, i.e. + * pipes etc. + * + */ +class TProtocolException : public apache::thrift::TException { + public: + + /** + * Error codes for the various types of exceptions. + */ + enum TProtocolExceptionType + { UNKNOWN = 0 + , INVALID_DATA = 1 + , NEGATIVE_SIZE = 2 + , SIZE_LIMIT = 3 + , BAD_VERSION = 4 + , NOT_IMPLEMENTED = 5 + }; + + TProtocolException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TProtocolException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TProtocolException(TProtocolExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + virtual ~TProtocolException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TProtocolExceptionType getType() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TProtocolException: Unknown protocol exception"; + case INVALID_DATA : return "TProtocolException: Invalid data"; + case NEGATIVE_SIZE : return "TProtocolException: Negative size"; + case SIZE_LIMIT : return "TProtocolException: Exceeded size limit"; + case BAD_VERSION : return "TProtocolException: Invalid version"; + case NOT_IMPLEMENTED : return "TProtocolException: Not implemented"; + default : return "TProtocolException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** + * Error code + */ + TProtocolExceptionType type_; + +}; + +}}} // apache::thrift::protocol + +#endif // #ifndef _THRIFT_PROTOCOL_TPROTOCOLEXCEPTION_H_ diff --git a/lib/cpp/src/protocol/TProtocolTap.h b/lib/cpp/src/protocol/TProtocolTap.h new file mode 100644 index 00000000..5580216a --- /dev/null +++ b/lib/cpp/src/protocol/TProtocolTap.h @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ +#define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 + +#include + +namespace apache { namespace thrift { namespace protocol { + +using apache::thrift::transport::TTransport; + +/** + * Puts a wiretap on a protocol object. Any reads to this class are passed + * through to an enclosed protocol object, but also mirrored as write to a + * second protocol object. + * + */ +class TProtocolTap : public TReadOnlyProtocol { + public: + TProtocolTap(boost::shared_ptr source, + boost::shared_ptr sink) + : TReadOnlyProtocol(source->getTransport(), "TProtocolTap") + , source_(source) + , sink_(sink) + {} + + virtual uint32_t readMessageBegin(std::string& name, + TMessageType& messageType, + int32_t& seqid) { + uint32_t rv = source_->readMessageBegin(name, messageType, seqid); + sink_->writeMessageBegin(name, messageType, seqid); + return rv; + } + + virtual uint32_t readMessageEnd() { + uint32_t rv = source_->readMessageEnd(); + sink_->writeMessageEnd(); + return rv; + } + + virtual uint32_t readStructBegin(std::string& name) { + uint32_t rv = source_->readStructBegin(name); + sink_->writeStructBegin(name.c_str()); + return rv; + } + + virtual uint32_t readStructEnd() { + uint32_t rv = source_->readStructEnd(); + sink_->writeStructEnd(); + return rv; + } + + virtual uint32_t readFieldBegin(std::string& name, + TType& fieldType, + int16_t& fieldId) { + uint32_t rv = source_->readFieldBegin(name, fieldType, fieldId); + if (fieldType == T_STOP) { + sink_->writeFieldStop(); + } else { + sink_->writeFieldBegin(name.c_str(), fieldType, fieldId); + } + return rv; + } + + + virtual uint32_t readFieldEnd() { + uint32_t rv = source_->readFieldEnd(); + sink_->writeFieldEnd(); + return rv; + } + + virtual uint32_t readMapBegin(TType& keyType, + TType& valType, + uint32_t& size) { + uint32_t rv = source_->readMapBegin(keyType, valType, size); + sink_->writeMapBegin(keyType, valType, size); + return rv; + } + + + virtual uint32_t readMapEnd() { + uint32_t rv = source_->readMapEnd(); + sink_->writeMapEnd(); + return rv; + } + + virtual uint32_t readListBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readListBegin(elemType, size); + sink_->writeListBegin(elemType, size); + return rv; + } + + + virtual uint32_t readListEnd() { + uint32_t rv = source_->readListEnd(); + sink_->writeListEnd(); + return rv; + } + + virtual uint32_t readSetBegin(TType& elemType, + uint32_t& size) { + uint32_t rv = source_->readSetBegin(elemType, size); + sink_->writeSetBegin(elemType, size); + return rv; + } + + + virtual uint32_t readSetEnd() { + uint32_t rv = source_->readSetEnd(); + sink_->writeSetEnd(); + return rv; + } + + virtual uint32_t readBool(bool& value) { + uint32_t rv = source_->readBool(value); + sink_->writeBool(value); + return rv; + } + + virtual uint32_t readByte(int8_t& byte) { + uint32_t rv = source_->readByte(byte); + sink_->writeByte(byte); + return rv; + } + + virtual uint32_t readI16(int16_t& i16) { + uint32_t rv = source_->readI16(i16); + sink_->writeI16(i16); + return rv; + } + + virtual uint32_t readI32(int32_t& i32) { + uint32_t rv = source_->readI32(i32); + sink_->writeI32(i32); + return rv; + } + + virtual uint32_t readI64(int64_t& i64) { + uint32_t rv = source_->readI64(i64); + sink_->writeI64(i64); + return rv; + } + + virtual uint32_t readDouble(double& dub) { + uint32_t rv = source_->readDouble(dub); + sink_->writeDouble(dub); + return rv; + } + + virtual uint32_t readString(std::string& str) { + uint32_t rv = source_->readString(str); + sink_->writeString(str); + return rv; + } + + virtual uint32_t readBinary(std::string& str) { + uint32_t rv = source_->readBinary(str); + sink_->writeBinary(str); + return rv; + } + + private: + boost::shared_ptr source_; + boost::shared_ptr sink_; +}; + +}}} // apache::thrift::protocol + +#endif // #define _THRIFT_PROTOCOL_TPROTOCOLTAP_H_ 1 diff --git a/lib/cpp/src/server/TNonblockingServer.cpp b/lib/cpp/src/server/TNonblockingServer.cpp new file mode 100644 index 00000000..45f635cb --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.cpp @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TNonblockingServer.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace server { + +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; +using namespace std; + +class TConnection::Task: public Runnable { + public: + Task(boost::shared_ptr processor, + boost::shared_ptr input, + boost::shared_ptr output, + int taskHandle) : + processor_(processor), + input_(input), + output_(output), + taskHandle_(taskHandle) {} + + void run() { + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TNonblockingServer client died: " << ttx.what() << endl; + } catch (TException& x) { + cerr << "TNonblockingServer exception: " << x.what() << endl; + } catch (...) { + cerr << "TNonblockingServer uncaught exception." << endl; + } + + // Signal completion back to the libevent thread via a socketpair + int8_t b = 0; + if (-1 == send(taskHandle_, &b, sizeof(int8_t), 0)) { + GlobalOutput.perror("TNonblockingServer::Task: send ", errno); + } + if (-1 == ::close(taskHandle_)) { + GlobalOutput.perror("TNonblockingServer::Task: close, possible resource leak ", errno); + } + } + + private: + boost::shared_ptr processor_; + boost::shared_ptr input_; + boost::shared_ptr output_; + int taskHandle_; +}; + +void TConnection::init(int socket, short eventFlags, TNonblockingServer* s) { + socket_ = socket; + server_ = s; + appState_ = APP_INIT; + eventFlags_ = 0; + + readBufferPos_ = 0; + readWant_ = 0; + + writeBuffer_ = NULL; + writeBufferSize_ = 0; + writeBufferPos_ = 0; + + socketState_ = SOCKET_RECV; + appState_ = APP_INIT; + + taskHandle_ = -1; + + // Set flags, which also registers the event + setFlags(eventFlags); + + // get input/transports + factoryInputTransport_ = s->getInputTransportFactory()->getTransport(inputTransport_); + factoryOutputTransport_ = s->getOutputTransportFactory()->getTransport(outputTransport_); + + // Create protocol + inputProtocol_ = s->getInputProtocolFactory()->getProtocol(factoryInputTransport_); + outputProtocol_ = s->getOutputProtocolFactory()->getProtocol(factoryOutputTransport_); +} + +void TConnection::workSocket() { + int flags=0, got=0, left=0, sent=0; + uint32_t fetch = 0; + + switch (socketState_) { + case SOCKET_RECV: + // It is an error to be in this state if we already have all the data + assert(readBufferPos_ < readWant_); + + // Double the buffer size until it is big enough + if (readWant_ > readBufferSize_) { + while (readWant_ > readBufferSize_) { + readBufferSize_ *= 2; + } + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::workSocket() realloc"); + close(); + return; + } + } + + // Read from the socket + fetch = readWant_ - readBufferPos_; + got = recv(socket_, readBuffer_ + readBufferPos_, fetch, 0); + + if (got > 0) { + // Move along in the buffer + readBufferPos_ += got; + + // Check that we did not overdo it + assert(readBufferPos_ <= readWant_); + + // We are done reading, move onto the next state + if (readBufferPos_ == readWant_) { + transition(); + } + return; + } else if (got == -1) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + + if (errno != ECONNRESET) { + GlobalOutput.perror("TConnection::workSocket() recv -1 ", errno); + } + } + + // Whenever we get down here it means a remote disconnect + close(); + + return; + + case SOCKET_SEND: + // Should never have position past size + assert(writeBufferPos_ <= writeBufferSize_); + + // If there is no data to send, then let us move on + if (writeBufferPos_ == writeBufferSize_) { + GlobalOutput("WARNING: Send state with no data to send\n"); + transition(); + return; + } + + flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + left = writeBufferSize_ - writeBufferPos_; + sent = send(socket_, writeBuffer_ + writeBufferPos_, left, flags); + + if (sent <= 0) { + // Blocking errors are okay, just move on + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return; + } + if (errno != EPIPE) { + GlobalOutput.perror("TConnection::workSocket() send -1 ", errno); + } + close(); + return; + } + + writeBufferPos_ += sent; + + // Did we overdo it? + assert(writeBufferPos_ <= writeBufferSize_); + + // We are done! + if (writeBufferPos_ == writeBufferSize_) { + transition(); + } + + return; + + default: + GlobalOutput.printf("Shit Got Ill. Socket State %d", socketState_); + assert(0); + } +} + +/** + * This is called when the application transitions from one state into + * another. This means that it has finished writing the data that it needed + * to, or finished receiving the data that it needed to. + */ +void TConnection::transition() { + + int sz = 0; + + // Switch upon the state that we are currently in and move to a new state + switch (appState_) { + + case APP_READ_REQUEST: + // We are done reading the request, package the read buffer into transport + // and get back some data from the dispatch function + // If we've used these transport buffers enough times, reset them to avoid bloating + + inputTransport_->resetBuffer(readBuffer_, readBufferPos_); + ++numReadsSinceReset_; + if (numWritesSinceReset_ < 512) { + outputTransport_->resetBuffer(); + } else { + // reset the capacity of the output transport if we used it enough times that it might be bloated + try { + outputTransport_->resetBuffer(true); + numWritesSinceReset_ = 0; + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: TMemoryBuffer::resetBuffer() %s", ttx.what()); + close(); + return; + } + } + + // Prepend four bytes of blank space to the buffer so we can + // write the frame size there later. + outputTransport_->getWritePtr(4); + outputTransport_->wroteBytes(4); + + if (server_->isThreadPoolProcessing()) { + // We are setting up a Task to do this work and we will wait on it + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TConnection::socketpair() failed ", errno); + // Now we will fall through to the APP_WAIT_TASK block with no response + } else { + // Create task and dispatch to the thread manager + boost::shared_ptr task = + boost::shared_ptr(new Task(server_->getProcessor(), + inputProtocol_, + outputProtocol_, + sv[1])); + // The application is now waiting on the task to finish + appState_ = APP_WAIT_TASK; + + // Create an event to be notified when the task finishes + event_set(&taskEvent_, + taskHandle_ = sv[0], + EV_READ, + TConnection::taskHandler, + this); + + // Attach to the base + event_base_set(server_->getEventBase(), &taskEvent_); + + // Add the event and start up the server + if (-1 == event_add(&taskEvent_, 0)) { + GlobalOutput("TNonblockingServer::serve(): coult not event_add"); + return; + } + try { + server_->addTask(task); + } catch (IllegalStateException & ise) { + // The ThreadManager is not ready to handle any more tasks (it's probably shutting down). + GlobalOutput.printf("IllegalStateException: Server::process() %s", ise.what()); + close(); + } + + // Set this connection idle so that libevent doesn't process more + // data on it while we're still waiting for the threadmanager to + // finish this task + setIdle(); + return; + } + } else { + try { + // Invoke the processor + server_->getProcessor()->process(inputProtocol_, outputProtocol_); + } catch (TTransportException &ttx) { + GlobalOutput.printf("TTransportException: Server::process() %s", ttx.what()); + close(); + return; + } catch (TException &x) { + GlobalOutput.printf("TException: Server::process() %s", x.what()); + close(); + return; + } catch (...) { + GlobalOutput.printf("Server::process() unknown exception"); + close(); + return; + } + } + + // Intentionally fall through here, the call to process has written into + // the writeBuffer_ + + case APP_WAIT_TASK: + // We have now finished processing a task and the result has been written + // into the outputTransport_, so we grab its contents and place them into + // the writeBuffer_ for actual writing by the libevent thread + + // Get the result of the operation + outputTransport_->getBuffer(&writeBuffer_, &writeBufferSize_); + + // If the function call generated return data, then move into the send + // state and get going + // 4 bytes were reserved for frame size + if (writeBufferSize_ > 4) { + + // Move into write state + writeBufferPos_ = 0; + socketState_ = SOCKET_SEND; + + // Put the frame size into the write buffer + int32_t frameSize = (int32_t)htonl(writeBufferSize_ - 4); + memcpy(writeBuffer_, &frameSize, 4); + + // Socket into write mode + appState_ = APP_SEND_RESULT; + setWrite(); + + // Try to work the socket immediately + // workSocket(); + + return; + } + + // In this case, the request was oneway and we should fall through + // right back into the read frame header state + goto LABEL_APP_INIT; + + case APP_SEND_RESULT: + + ++numWritesSinceReset_; + + // N.B.: We also intentionally fall through here into the INIT state! + + LABEL_APP_INIT: + case APP_INIT: + + // reset the input buffer if we used it enough times that it might be bloated + if (numReadsSinceReset_ > 512) + { + void * new_buffer = std::realloc(readBuffer_, 1024); + if (new_buffer == NULL) { + GlobalOutput("TConnection::transition() realloc"); + close(); + return; + } + readBuffer_ = (uint8_t*) new_buffer; + readBufferSize_ = 1024; + numReadsSinceReset_ = 0; + } + + // Clear write buffer variables + writeBuffer_ = NULL; + writeBufferPos_ = 0; + writeBufferSize_ = 0; + + // Set up read buffer for getting 4 bytes + readBufferPos_ = 0; + readWant_ = 4; + + // Into read4 state we go + socketState_ = SOCKET_RECV; + appState_ = APP_READ_FRAME_SIZE; + + // Register read event + setRead(); + + // Try to work the socket right away + // workSocket(); + + return; + + case APP_READ_FRAME_SIZE: + // We just read the request length, deserialize it + sz = *(int32_t*)readBuffer_; + sz = (int32_t)ntohl(sz); + + if (sz <= 0) { + GlobalOutput.printf("TConnection:transition() Negative frame size %d, remote side not using TFramedTransport?", sz); + close(); + return; + } + + // Reset the read buffer + readWant_ = (uint32_t)sz; + readBufferPos_= 0; + + // Move into read request state + appState_ = APP_READ_REQUEST; + + // Work the socket right away + // workSocket(); + + return; + + default: + GlobalOutput.printf("Totally Fucked. Application State %d", appState_); + assert(0); + } +} + +void TConnection::setFlags(short eventFlags) { + // Catch the do nothing case + if (eventFlags_ == eventFlags) { + return; + } + + // Delete a previously existing event + if (eventFlags_ != 0) { + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::setFlags event_del"); + return; + } + } + + // Update in memory structure + eventFlags_ = eventFlags; + + // Do not call event_set if there are no flags + if (!eventFlags_) { + return; + } + + /** + * event_set: + * + * Prepares the event structure &event to be used in future calls to + * event_add() and event_del(). The event will be prepared to call the + * eventHandler using the 'sock' file descriptor to monitor events. + * + * The events can be either EV_READ, EV_WRITE, or both, indicating + * that an application can read or write from the file respectively without + * blocking. + * + * The eventHandler will be called with the file descriptor that triggered + * the event and the type of event which will be one of: EV_TIMEOUT, + * EV_SIGNAL, EV_READ, EV_WRITE. + * + * The additional flag EV_PERSIST makes an event_add() persistent until + * event_del() has been called. + * + * Once initialized, the &event struct can be used repeatedly with + * event_add() and event_del() and does not need to be reinitialized unless + * the eventHandler and/or the argument to it are to be changed. However, + * when an ev structure has been added to libevent using event_add() the + * structure must persist until the event occurs (assuming EV_PERSIST + * is not set) or is removed using event_del(). You may not reuse the same + * ev structure for multiple monitored descriptors; each descriptor needs + * its own ev. + */ + event_set(&event_, socket_, eventFlags_, TConnection::eventHandler, this); + event_base_set(server_->getEventBase(), &event_); + + // Add the event + if (event_add(&event_, 0) == -1) { + GlobalOutput("TConnection::setFlags(): could not event_add"); + } +} + +/** + * Closes a connection + */ +void TConnection::close() { + // Delete the registered libevent + if (event_del(&event_) == -1) { + GlobalOutput("TConnection::close() event_del"); + } + + // Close the socket + if (socket_ > 0) { + ::close(socket_); + } + socket_ = 0; + + // close any factory produced transports + factoryInputTransport_->close(); + factoryOutputTransport_->close(); + + // Give this object back to the server that owns it + server_->returnConnection(this); +} + +void TConnection::checkIdleBufferMemLimit(uint32_t limit) { + if (readBufferSize_ > limit) { + readBufferSize_ = limit; + readBuffer_ = (uint8_t*)std::realloc(readBuffer_, readBufferSize_); + if (readBuffer_ == NULL) { + GlobalOutput("TConnection::checkIdleBufferMemLimit() realloc"); + close(); + } + } +} + +/** + * Creates a new connection either by reusing an object off the stack or + * by allocating a new one entirely + */ +TConnection* TNonblockingServer::createConnection(int socket, short flags) { + // Check the stack + if (connectionStack_.empty()) { + return new TConnection(socket, flags, this); + } else { + TConnection* result = connectionStack_.top(); + connectionStack_.pop(); + result->init(socket, flags, this); + return result; + } +} + +/** + * Returns a connection to the stack + */ +void TNonblockingServer::returnConnection(TConnection* connection) { + if (connectionStackLimit_ && + (connectionStack_.size() >= connectionStackLimit_)) { + delete connection; + } else { + connection->checkIdleBufferMemLimit(idleBufferMemLimit_); + connectionStack_.push(connection); + } +} + +/** + * Server socket had something happen. We accept all waiting client + * connections on fd and assign TConnection objects to handle those requests. + */ +void TNonblockingServer::handleEvent(int fd, short which) { + // Make sure that libevent didn't fuck up the socket handles + assert(fd == serverSocket_); + + // Server socket accepted a new connection + socklen_t addrLen; + struct sockaddr addr; + addrLen = sizeof(addr); + + // Going to accept a new client socket + int clientSocket; + + // Accept as many new clients as possible, even though libevent signaled only + // one, this helps us to avoid having to go back into the libevent engine so + // many times + while ((clientSocket = accept(fd, &addr, &addrLen)) != -1) { + + // Explicitly set this socket to NONBLOCK mode + int flags; + if ((flags = fcntl(clientSocket, F_GETFL, 0)) < 0 || + fcntl(clientSocket, F_SETFL, flags | O_NONBLOCK) < 0) { + GlobalOutput.perror("thriftServerEventHandler: set O_NONBLOCK (fcntl) ", errno); + close(clientSocket); + return; + } + + // Create a new TConnection for this client socket. + TConnection* clientConnection = + createConnection(clientSocket, EV_READ | EV_PERSIST); + + // Fail fast if we could not create a TConnection object + if (clientConnection == NULL) { + GlobalOutput.printf("thriftServerEventHandler: failed TConnection factory"); + close(clientSocket); + return; + } + + // Put this client connection into the proper state + clientConnection->transition(); + } + + // Done looping accept, now we have to make sure the error is due to + // blocking. Any other error is a problem + if (errno != EAGAIN && errno != EWOULDBLOCK) { + GlobalOutput.perror("thriftServerEventHandler: accept() ", errno); + } +} + +/** + * Creates a socket to listen on and binds it to the local port. + */ +void TNonblockingServer::listenSocket() { + int s; + struct addrinfo hints, *res, *res0; + int error; + + char port[sizeof("65536") + 1]; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + string errStr = "TNonblockingServer::serve() getaddrinfo " + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + return; + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + // Create the server socket + s = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (s == -1) { + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() socket() -1"); + } + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(s, IPPROTO_IPV6, IPV6_V6ONLY, &zero, sizeof(zero))) { + GlobalOutput("TServerSocket::listen() IPV6_V6ONLY"); + } + #endif // #ifdef IPV6_V6ONLY + + + int one = 1; + + // Set reuseaddr to avoid 2MSL delay on server restart + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)); + + if (bind(s, res->ai_addr, res->ai_addrlen) == -1) { + close(s); + freeaddrinfo(res0); + throw TException("TNonblockingServer::serve() bind"); + } + + // Done with the addr info + freeaddrinfo(res0); + + // Set up this file descriptor for listening + listenSocket(s); +} + +/** + * Takes a socket created by listenSocket() and sets various options on it + * to prepare for use in the server. + */ +void TNonblockingServer::listenSocket(int s) { + // Set socket to nonblocking mode + int flags; + if ((flags = fcntl(s, F_GETFL, 0)) < 0 || + fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { + close(s); + throw TException("TNonblockingServer::serve() O_NONBLOCK"); + } + + int one = 1; + struct linger ling = {0, 0}; + + // Keepalive to ensure full result flushing + setsockopt(s, SOL_SOCKET, SO_KEEPALIVE, &one, sizeof(one)); + + // Turn linger off to avoid hung sockets + setsockopt(s, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)); + + // Set TCP nodelay if available, MAC OS X Hack + // See http://lists.danga.com/pipermail/memcached/2005-March/001240.html + #ifndef TCP_NOPUSH + setsockopt(s, IPPROTO_TCP, TCP_NODELAY, &one, sizeof(one)); + #endif + + if (listen(s, LISTEN_BACKLOG) == -1) { + close(s); + throw TException("TNonblockingServer::serve() listen"); + } + + // Cool, this socket is good to go, set it as the serverSocket_ + serverSocket_ = s; +} + +/** + * Register the core libevent events onto the proper base. + */ +void TNonblockingServer::registerEvents(event_base* base) { + assert(serverSocket_ != -1); + assert(!eventBase_); + eventBase_ = base; + + // Print some libevent stats + GlobalOutput.printf("libevent %s method %s", + event_get_version(), + event_get_method()); + + // Register the server event + event_set(&serverEvent_, + serverSocket_, + EV_READ | EV_PERSIST, + TNonblockingServer::eventHandler, + this); + event_base_set(eventBase_, &serverEvent_); + + // Add the event and start up the server + if (-1 == event_add(&serverEvent_, 0)) { + throw TException("TNonblockingServer::serve(): coult not event_add"); + } +} + +/** + * Main workhorse function, starts up the server listening on a port and + * loops over the libevent handler. + */ +void TNonblockingServer::serve() { + // Init socket + listenSocket(); + + // Initialize libevent core + registerEvents(static_cast(event_init())); + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Run libevent engine, never returns, invokes calls to eventHandler + event_base_loop(eventBase_, 0); +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TNonblockingServer.h b/lib/cpp/src/server/TNonblockingServer.h new file mode 100644 index 00000000..1684b64a --- /dev/null +++ b/lib/cpp/src/server/TNonblockingServer.h @@ -0,0 +1,434 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ +#define _THRIFT_SERVER_TNONBLOCKINGSERVER_H_ 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::transport::TMemoryBuffer; +using apache::thrift::protocol::TProtocol; +using apache::thrift::concurrency::Runnable; +using apache::thrift::concurrency::ThreadManager; + +// Forward declaration of class +class TConnection; + +/** + * This is a non-blocking server in C++ for high performance that operates a + * single IO thread. It assumes that all incoming requests are framed with a + * 4 byte length indicator and writes out responses using the same framing. + * + * It does not use the TServerTransport framework, but rather has socket + * operations hardcoded for use with select. + * + */ +class TNonblockingServer : public TServer { + private: + + // Listen backlog + static const int LISTEN_BACKLOG = 1024; + + // Default limit on size of idle connection pool + static const size_t CONNECTION_STACK_LIMIT = 1024; + + // Maximum size of buffer allocated to idle connection + static const uint32_t IDLE_BUFFER_MEM_LIMIT = 8192; + + // Server socket file descriptor + int serverSocket_; + + // Port server runs on + int port_; + + // For processing via thread pool, may be NULL + boost::shared_ptr threadManager_; + + // Is thread pool processing? + bool threadPoolProcessing_; + + // The event base for libevent + event_base* eventBase_; + + // Event struct, for use with eventBase_ + struct event serverEvent_; + + // Number of TConnection object we've created + size_t numTConnections_; + + // Limit for how many TConnection objects to cache + size_t connectionStackLimit_; + + /** + * Max read buffer size for an idle connection. When we place an idle + * TConnection into connectionStack_, we insure that its read buffer is + * reduced to this size to insure that idle connections don't hog memory. + */ + uint32_t idleBufferMemLimit_; + + /** + * This is a stack of all the objects that have been created but that + * are NOT currently in use. When we close a connection, we place it on this + * stack so that the object can be reused later, rather than freeing the + * memory and reallocating a new object later. + */ + std::stack connectionStack_; + + void handleEvent(int fd, short which); + + public: + TNonblockingServer(boost::shared_ptr processor, + int port) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadPoolProcessing_(false), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) {} + + TNonblockingServer(boost::shared_ptr processor, + boost::shared_ptr protocolFactory, + int port, + boost::shared_ptr threadManager = boost::shared_ptr()) : + TServer(processor), + serverSocket_(-1), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setInputProtocolFactory(protocolFactory); + setOutputProtocolFactory(protocolFactory); + setThreadManager(threadManager); + } + + TNonblockingServer(boost::shared_ptr processor, + boost::shared_ptr inputTransportFactory, + boost::shared_ptr outputTransportFactory, + boost::shared_ptr inputProtocolFactory, + boost::shared_ptr outputProtocolFactory, + int port, + boost::shared_ptr threadManager = boost::shared_ptr()) : + TServer(processor), + serverSocket_(0), + port_(port), + threadManager_(threadManager), + eventBase_(NULL), + numTConnections_(0), + connectionStackLimit_(CONNECTION_STACK_LIMIT), + idleBufferMemLimit_(IDLE_BUFFER_MEM_LIMIT) { + setInputTransportFactory(inputTransportFactory); + setOutputTransportFactory(outputTransportFactory); + setInputProtocolFactory(inputProtocolFactory); + setOutputProtocolFactory(outputProtocolFactory); + setThreadManager(threadManager); + } + + ~TNonblockingServer() {} + + void setThreadManager(boost::shared_ptr threadManager) { + threadManager_ = threadManager; + threadPoolProcessing_ = (threadManager != NULL); + } + + boost::shared_ptr getThreadManager() { + return threadManager_; + } + + /** + * Get the maximum number of unused TConnection we will hold in reserve. + * + * @return the current limit on TConnection pool size. + */ + size_t getConnectionStackLimit() const { + return connectionStackLimit_; + } + + /** + * Set the maximum number of unused TConnection we will hold in reserve. + * + * @param sz the new limit for TConnection pool size. + */ + void setConnectionStackLimit(size_t sz) { + connectionStackLimit_ = sz; + } + + bool isThreadPoolProcessing() const { + return threadPoolProcessing_; + } + + void addTask(boost::shared_ptr task) { + threadManager_->add(task); + } + + event_base* getEventBase() const { + return eventBase_; + } + + void incrementNumConnections() { + ++numTConnections_; + } + + void decrementNumConnections() { + --numTConnections_; + } + + size_t getNumConnections() { + return numTConnections_; + } + + size_t getNumIdleConnections() { + return connectionStack_.size(); + } + + /** + * Get the maximum limit of memory allocated to idle TConnection objects. + * + * @return # bytes beyond which we will shrink buffers when idle. + */ + size_t getIdleBufferMemLimit() const { + return idleBufferMemLimit_; + } + + /** + * Set the maximum limit of memory allocated to idle TConnection objects. + * If a TConnection object goes idle with more than this much memory + * allocated to its buffer, we shrink it to this value. + * + * @param limit of bytes beyond which we will shrink buffers when idle. + */ + void setIdleBufferMemLimit(size_t limit) { + idleBufferMemLimit_ = limit; + } + + TConnection* createConnection(int socket, short flags); + + void returnConnection(TConnection* connection); + + static void eventHandler(int fd, short which, void* v) { + ((TNonblockingServer*)v)->handleEvent(fd, which); + } + + void listenSocket(); + + void listenSocket(int fd); + + void registerEvents(event_base* base); + + void serve(); +}; + +/** + * Two states for sockets, recv and send mode + */ +enum TSocketState { + SOCKET_RECV, + SOCKET_SEND +}; + +/** + * Four states for the nonblocking servr: + * 1) initialize + * 2) read 4 byte frame size + * 3) read frame of data + * 4) send back data (if any) + */ +enum TAppState { + APP_INIT, + APP_READ_FRAME_SIZE, + APP_READ_REQUEST, + APP_WAIT_TASK, + APP_SEND_RESULT +}; + +/** + * Represents a connection that is handled via libevent. This connection + * essentially encapsulates a socket that has some associated libevent state. + */ +class TConnection { + private: + + class Task; + + // Server handle + TNonblockingServer* server_; + + // Socket handle + int socket_; + + // Libevent object + struct event event_; + + // Libevent flags + short eventFlags_; + + // Socket mode + TSocketState socketState_; + + // Application state + TAppState appState_; + + // How much data needed to read + uint32_t readWant_; + + // Where in the read buffer are we + uint32_t readBufferPos_; + + // Read buffer + uint8_t* readBuffer_; + + // Read buffer size + uint32_t readBufferSize_; + + // Write buffer + uint8_t* writeBuffer_; + + // Write buffer size + uint32_t writeBufferSize_; + + // How far through writing are we? + uint32_t writeBufferPos_; + + // How many times have we read since our last buffer reset? + uint32_t numReadsSinceReset_; + + // How many times have we written since our last buffer reset? + uint32_t numWritesSinceReset_; + + // Task handle + int taskHandle_; + + // Task event + struct event taskEvent_; + + // Transport to read from + boost::shared_ptr inputTransport_; + + // Transport that processor writes to + boost::shared_ptr outputTransport_; + + // extra transport generated by transport factory (e.g. BufferedRouterTransport) + boost::shared_ptr factoryInputTransport_; + boost::shared_ptr factoryOutputTransport_; + + // Protocol decoder + boost::shared_ptr inputProtocol_; + + // Protocol encoder + boost::shared_ptr outputProtocol_; + + // Go into read mode + void setRead() { + setFlags(EV_READ | EV_PERSIST); + } + + // Go into write mode + void setWrite() { + setFlags(EV_WRITE | EV_PERSIST); + } + + // Set socket idle + void setIdle() { + setFlags(0); + } + + // Set event flags + void setFlags(short eventFlags); + + // Libevent handlers + void workSocket(); + + // Close this client and reset + void close(); + + public: + + // Constructor + TConnection(int socket, short eventFlags, TNonblockingServer *s) { + readBuffer_ = (uint8_t*)std::malloc(1024); + if (readBuffer_ == NULL) { + throw new apache::thrift::TException("Out of memory."); + } + readBufferSize_ = 1024; + + numReadsSinceReset_ = 0; + numWritesSinceReset_ = 0; + + // Allocate input and output tranpsorts + // these only need to be allocated once per TConnection (they don't need to be + // reallocated on init() call) + inputTransport_ = boost::shared_ptr(new TMemoryBuffer(readBuffer_, readBufferSize_)); + outputTransport_ = boost::shared_ptr(new TMemoryBuffer()); + + init(socket, eventFlags, s); + server_->incrementNumConnections(); + } + + ~TConnection() { + server_->decrementNumConnections(); + } + + /** + * Check read buffer against a given limit and shrink it if exceeded. + * + * @param limit we limit buffer size to. + */ + void checkIdleBufferMemLimit(uint32_t limit); + + // Initialize + void init(int socket, short eventFlags, TNonblockingServer *s); + + // Transition into a new state + void transition(); + + // Handler wrapper + static void eventHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->socket_); + ((TConnection*)v)->workSocket(); + } + + // Handler wrapper for task block + static void taskHandler(int fd, short /* which */, void* v) { + assert(fd == ((TConnection*)v)->taskHandle_); + if (-1 == ::close(((TConnection*)v)->taskHandle_)) { + GlobalOutput.perror("TConnection::taskHandler close handle failed, resource leak ", errno); + } + ((TConnection*)v)->transition(); + } + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TServer.cpp b/lib/cpp/src/server/TServer.cpp new file mode 100644 index 00000000..6b692ab0 --- /dev/null +++ b/lib/cpp/src/server/TServer.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace apache { namespace thrift { namespace server { + +int increase_max_fds(int max_fds=(1<<24)) { + struct rlimit fdmaxrl; + + for(fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds; + max_fds && (setrlimit(RLIMIT_NOFILE, &fdmaxrl) < 0); + fdmaxrl.rlim_cur = max_fds, fdmaxrl.rlim_max = max_fds) { + max_fds /= 2; + } + + return fdmaxrl.rlim_cur; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TServer.h b/lib/cpp/src/server/TServer.h new file mode 100644 index 00000000..5c4c588d --- /dev/null +++ b/lib/cpp/src/server/TServer.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSERVER_H_ +#define _THRIFT_SERVER_TSERVER_H_ 1 + +#include +#include +#include +#include + +#include + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TBinaryProtocolFactory; +using apache::thrift::protocol::TProtocol; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransport; +using apache::thrift::transport::TTransportFactory; + +/** + * Virtual interface class that can handle events from the server core. To + * use this you should subclass it and implement the methods that you care + * about. Your subclass can also store local data that you may care about, + * such as additional "arguments" to these methods (stored in the object + * instance's state). + */ +class TServerEventHandler { + public: + + virtual ~TServerEventHandler() {} + + /** + * Called before the server begins. + */ + virtual void preServe() {} + + /** + * Called when a new client has connected and is about to being processing. + */ + virtual void clientBegin(boost::shared_ptr /* input */, + boost::shared_ptr /* output */) {} + + /** + * Called when a client has finished making requests. + */ + virtual void clientEnd(boost::shared_ptr /* input */, + boost::shared_ptr /* output */) {} + + protected: + + /** + * Prevent direct instantiation. + */ + TServerEventHandler() {} + +}; + +/** + * Thrift server. + * + */ +class TServer : public concurrency::Runnable { + public: + + virtual ~TServer() {} + + virtual void serve() = 0; + + virtual void stop() {} + + // Allows running the server as a Runnable thread + virtual void run() { + serve(); + } + + boost::shared_ptr getProcessor() { + return processor_; + } + + boost::shared_ptr getServerTransport() { + return serverTransport_; + } + + boost::shared_ptr getInputTransportFactory() { + return inputTransportFactory_; + } + + boost::shared_ptr getOutputTransportFactory() { + return outputTransportFactory_; + } + + boost::shared_ptr getInputProtocolFactory() { + return inputProtocolFactory_; + } + + boost::shared_ptr getOutputProtocolFactory() { + return outputProtocolFactory_; + } + + boost::shared_ptr getEventHandler() { + return eventHandler_; + } + +protected: + TServer(boost::shared_ptr processor): + processor_(processor) { + setInputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport): + processor_(processor), + serverTransport_(serverTransport) { + setInputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setOutputTransportFactory(boost::shared_ptr(new TTransportFactory())); + setInputProtocolFactory(boost::shared_ptr(new TBinaryProtocolFactory())); + setOutputProtocolFactory(boost::shared_ptr(new TBinaryProtocolFactory())); + } + + TServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(transportFactory), + outputTransportFactory_(transportFactory), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory) {} + + TServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr inputTransportFactory, + boost::shared_ptr outputTransportFactory, + boost::shared_ptr inputProtocolFactory, + boost::shared_ptr outputProtocolFactory): + processor_(processor), + serverTransport_(serverTransport), + inputTransportFactory_(inputTransportFactory), + outputTransportFactory_(outputTransportFactory), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory) {} + + + // Class variables + boost::shared_ptr processor_; + boost::shared_ptr serverTransport_; + + boost::shared_ptr inputTransportFactory_; + boost::shared_ptr outputTransportFactory_; + + boost::shared_ptr inputProtocolFactory_; + boost::shared_ptr outputProtocolFactory_; + + boost::shared_ptr eventHandler_; + +public: + void setInputTransportFactory(boost::shared_ptr inputTransportFactory) { + inputTransportFactory_ = inputTransportFactory; + } + + void setOutputTransportFactory(boost::shared_ptr outputTransportFactory) { + outputTransportFactory_ = outputTransportFactory; + } + + void setInputProtocolFactory(boost::shared_ptr inputProtocolFactory) { + inputProtocolFactory_ = inputProtocolFactory; + } + + void setOutputProtocolFactory(boost::shared_ptr outputProtocolFactory) { + outputProtocolFactory_ = outputProtocolFactory; + } + + void setServerEventHandler(boost::shared_ptr eventHandler) { + eventHandler_ = eventHandler; + } + +}; + +/** + * Helper function to increase the max file descriptors limit + * for the current process and all of its children. + * By default, tries to increase it to as much as 2^24. + */ + int increase_max_fds(int max_fds=(1<<24)); + + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSERVER_H_ diff --git a/lib/cpp/src/server/TSimpleServer.cpp b/lib/cpp/src/server/TSimpleServer.cpp new file mode 100644 index 00000000..394ce21e --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.cpp @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TSimpleServer.h" +#include "transport/TTransportException.h" +#include +#include + +namespace apache { namespace thrift { namespace server { + +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using boost::shared_ptr; + +/** + * A simple single-threaded application server. Perfect for unit tests! + * + */ +void TSimpleServer::serve() { + + shared_ptr client; + shared_ptr inputTransport; + shared_ptr outputTransport; + shared_ptr inputProtocol; + shared_ptr outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + cerr << "TSimpleServer::run() listen(): " << ttx.what() << endl; + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + // Fetch client from server + while (!stop_) { + try { + client = serverTransport_->accept(); + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + if (eventHandler_ != NULL) { + eventHandler_->clientBegin(inputProtocol, outputProtocol); + } + try { + while (processor_->process(inputProtocol, outputProtocol)) { + // Peek ahead, is the remote side closed? + if (!inputTransport->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + cerr << "TSimpleServer client died: " << ttx.what() << endl; + } catch (TException& tx) { + cerr << "TSimpleServer exception: " << tx.what() << endl; + } + if (eventHandler_ != NULL) { + eventHandler_->clientEnd(inputProtocol, outputProtocol); + } + inputTransport->close(); + outputTransport->close(); + client->close(); + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TServerTransport died on accept: " << ttx.what() << endl; + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "Some kind of accept exception: " << tx.what() << endl; + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + cerr << "TThreadPoolServer: Unknown exception: " << s << endl; + break; + } + } + + if (stop_) { + try { + serverTransport_->close(); + } catch (TTransportException &ttx) { + cerr << "TServerTransport failed on close: " << ttx.what() << endl; + } + stop_ = false; + } +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TSimpleServer.h b/lib/cpp/src/server/TSimpleServer.h new file mode 100644 index 00000000..c4fc91c7 --- /dev/null +++ b/lib/cpp/src/server/TSimpleServer.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ +#define _THRIFT_SERVER_TSIMPLESERVER_H_ 1 + +#include "server/TServer.h" +#include "transport/TServerTransport.h" + +namespace apache { namespace thrift { namespace server { + +/** + * This is the most basic simple server. It is single-threaded and runs a + * continuous loop of accepting a single connection, processing requests on + * that connection until it closes, and then repeating. It is a good example + * of how to extend the TServer interface. + * + */ +class TSimpleServer : public TServer { + public: + TSimpleServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) {} + + TSimpleServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr inputTransportFactory, + boost::shared_ptr outputTransportFactory, + boost::shared_ptr inputProtocolFactory, + boost::shared_ptr outputProtocolFactory): + TServer(processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + stop_(false) {} + + ~TSimpleServer() {} + + void serve(); + + void stop() { + stop_ = true; + } + + protected: + bool stop_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TSIMPLESERVER_H_ diff --git a/lib/cpp/src/server/TThreadPoolServer.cpp b/lib/cpp/src/server/TThreadPoolServer.cpp new file mode 100644 index 00000000..0894cfa5 --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.cpp @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadPoolServer.h" +#include "transport/TTransportException.h" +#include "concurrency/Thread.h" +#include "concurrency/ThreadManager.h" +#include +#include + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::concurrency; +using namespace apache::thrift::protocol;; +using namespace apache::thrift::transport; + +class TThreadPoolServer::Task : public Runnable { + +public: + + Task(TThreadPoolServer &server, + shared_ptr processor, + shared_ptr input, + shared_ptr output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + // This is reasonably expected, client didn't send a full request so just + // ignore him + // string errStr = string("TThreadPoolServer client died: ") + ttx.what(); + // GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadPoolServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (std::exception &x) { + string errStr = string("TThreadPoolServer, std::exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } + + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + } + + private: + TServer& server_; + shared_ptr processor_; + shared_ptr input_; + shared_ptr output_; + +}; + +TThreadPoolServer::TThreadPoolServer(shared_ptr processor, + shared_ptr serverTransport, + shared_ptr transportFactory, + shared_ptr protocolFactory, + shared_ptr threadManager) : + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + +TThreadPoolServer::TThreadPoolServer(shared_ptr processor, + shared_ptr serverTransport, + shared_ptr inputTransportFactory, + shared_ptr outputTransportFactory, + shared_ptr inputProtocolFactory, + shared_ptr outputProtocolFactory, + shared_ptr threadManager) : + TServer(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory), + threadManager_(threadManager), + stop_(false), timeout_(0) {} + + +TThreadPoolServer::~TThreadPoolServer() {} + +void TThreadPoolServer::serve() { + shared_ptr client; + shared_ptr inputTransport; + shared_ptr outputTransport; + shared_ptr inputProtocol; + shared_ptr outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadPoolServer::run() listen(): ") + ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + // Add to threadmanager pool + threadManager_->add(shared_ptr(new TThreadPoolServer::Task(*this, processor_, inputProtocol, outputProtocol)), timeout_); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadPoolServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadPoolServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadPoolServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, join the existing threads + if (stop_) { + try { + serverTransport_->close(); + threadManager_->join(); + } catch (TException &tx) { + string errStr = string("TThreadPoolServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +int64_t TThreadPoolServer::getTimeout() const { + return timeout_; +} + +void TThreadPoolServer::setTimeout(int64_t value) { + timeout_ = value; +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadPoolServer.h b/lib/cpp/src/server/TThreadPoolServer.h new file mode 100644 index 00000000..7b7e9064 --- /dev/null +++ b/lib/cpp/src/server/TThreadPoolServer.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ +#define _THRIFT_SERVER_TTHREADPOOLSERVER_H_ 1 + +#include +#include +#include + +#include + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::concurrency::ThreadManager; +using apache::thrift::protocol::TProtocolFactory; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; + +class TThreadPoolServer : public TServer { + public: + class Task; + + TThreadPoolServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory, + boost::shared_ptr threadManager); + + TThreadPoolServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr inputTransportFactory, + boost::shared_ptr outputTransportFactory, + boost::shared_ptr inputProtocolFactory, + boost::shared_ptr outputProtocolFactory, + boost::shared_ptr threadManager); + + virtual ~TThreadPoolServer(); + + virtual void serve(); + + virtual int64_t getTimeout() const; + + virtual void setTimeout(int64_t value); + + virtual void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + + boost::shared_ptr threadManager_; + + volatile bool stop_; + + volatile int64_t timeout_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADPOOLSERVER_H_ diff --git a/lib/cpp/src/server/TThreadedServer.cpp b/lib/cpp/src/server/TThreadedServer.cpp new file mode 100644 index 00000000..cc30f8ff --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.cpp @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "server/TThreadedServer.h" +#include "transport/TTransportException.h" +#include "concurrency/PosixThreadFactory.h" + +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace server { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::concurrency; + +class TThreadedServer::Task: public Runnable { + +public: + + Task(TThreadedServer& server, + shared_ptr processor, + shared_ptr input, + shared_ptr output) : + server_(server), + processor_(processor), + input_(input), + output_(output) { + } + + ~Task() {} + + void run() { + boost::shared_ptr eventHandler = + server_.getEventHandler(); + if (eventHandler != NULL) { + eventHandler->clientBegin(input_, output_); + } + try { + while (processor_->process(input_, output_)) { + if (!input_->getTransport()->peek()) { + break; + } + } + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer client died: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } catch (TException& x) { + string errStr = string("TThreadedServer exception: ") + x.what(); + GlobalOutput(errStr.c_str()); + } catch (...) { + GlobalOutput("TThreadedServer uncaught exception."); + } + if (eventHandler != NULL) { + eventHandler->clientEnd(input_, output_); + } + + try { + input_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer input close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + try { + output_->getTransport()->close(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer output close failed: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + + // Remove this task from parent bookkeeping + { + Synchronized s(server_.tasksMonitor_); + server_.tasks_.erase(this); + if (server_.tasks_.empty()) { + server_.tasksMonitor_.notify(); + } + } + + } + + private: + TThreadedServer& server_; + friend class TThreadedServer; + + shared_ptr processor_; + shared_ptr input_; + shared_ptr output_; +}; + + +TThreadedServer::TThreadedServer(shared_ptr processor, + shared_ptr serverTransport, + shared_ptr transportFactory, + shared_ptr protocolFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + stop_(false) { + threadFactory_ = shared_ptr(new PosixThreadFactory()); +} + +TThreadedServer::TThreadedServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory, + boost::shared_ptr threadFactory): + TServer(processor, serverTransport, transportFactory, protocolFactory), + threadFactory_(threadFactory), + stop_(false) { +} + +TThreadedServer::~TThreadedServer() {} + +void TThreadedServer::serve() { + + shared_ptr client; + shared_ptr inputTransport; + shared_ptr outputTransport; + shared_ptr inputProtocol; + shared_ptr outputProtocol; + + try { + // Start the server listening + serverTransport_->listen(); + } catch (TTransportException& ttx) { + string errStr = string("TThreadedServer::run() listen(): ") +ttx.what(); + GlobalOutput(errStr.c_str()); + return; + } + + // Run the preServe event + if (eventHandler_ != NULL) { + eventHandler_->preServe(); + } + + while (!stop_) { + try { + client.reset(); + inputTransport.reset(); + outputTransport.reset(); + inputProtocol.reset(); + outputProtocol.reset(); + + // Fetch client from server + client = serverTransport_->accept(); + + // Make IO transports + inputTransport = inputTransportFactory_->getTransport(client); + outputTransport = outputTransportFactory_->getTransport(client); + inputProtocol = inputProtocolFactory_->getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_->getProtocol(outputTransport); + + TThreadedServer::Task* task = new TThreadedServer::Task(*this, + processor_, + inputProtocol, + outputProtocol); + + // Create a task + shared_ptr runnable = + shared_ptr(task); + + // Create a thread for this task + shared_ptr thread = + shared_ptr(threadFactory_->newThread(runnable)); + + // Insert thread into the set of threads + { + Synchronized s(tasksMonitor_); + tasks_.insert(task); + } + + // Start the thread! + thread->start(); + + } catch (TTransportException& ttx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + if (!stop_ || ttx.getType() != TTransportException::INTERRUPTED) { + string errStr = string("TThreadedServer: TServerTransport died on accept: ") + ttx.what(); + GlobalOutput(errStr.c_str()); + } + continue; + } catch (TException& tx) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = string("TThreadedServer: Caught TException: ") + tx.what(); + GlobalOutput(errStr.c_str()); + continue; + } catch (string s) { + if (inputTransport != NULL) { inputTransport->close(); } + if (outputTransport != NULL) { outputTransport->close(); } + if (client != NULL) { client->close(); } + string errStr = "TThreadedServer: Unknown exception: " + s; + GlobalOutput(errStr.c_str()); + break; + } + } + + // If stopped manually, make sure to close server transport + if (stop_) { + try { + serverTransport_->close(); + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception shutting down: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + try { + Synchronized s(tasksMonitor_); + while (!tasks_.empty()) { + tasksMonitor_.wait(); + } + } catch (TException &tx) { + string errStr = string("TThreadedServer: Exception joining workers: ") + tx.what(); + GlobalOutput(errStr.c_str()); + } + stop_ = false; + } + +} + +}}} // apache::thrift::server diff --git a/lib/cpp/src/server/TThreadedServer.h b/lib/cpp/src/server/TThreadedServer.h new file mode 100644 index 00000000..4d0811aa --- /dev/null +++ b/lib/cpp/src/server/TThreadedServer.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ +#define _THRIFT_SERVER_TTHREADEDSERVER_H_ 1 + +#include +#include +#include +#include + +#include + +namespace apache { namespace thrift { namespace server { + +using apache::thrift::TProcessor; +using apache::thrift::transport::TServerTransport; +using apache::thrift::transport::TTransportFactory; +using apache::thrift::concurrency::Monitor; +using apache::thrift::concurrency::ThreadFactory; + +class TThreadedServer : public TServer { + + public: + class Task; + + TThreadedServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory); + + TThreadedServer(boost::shared_ptr processor, + boost::shared_ptr serverTransport, + boost::shared_ptr transportFactory, + boost::shared_ptr protocolFactory, + boost::shared_ptr threadFactory); + + virtual ~TThreadedServer(); + + virtual void serve(); + + void stop() { + stop_ = true; + serverTransport_->interrupt(); + } + + protected: + boost::shared_ptr threadFactory_; + volatile bool stop_; + + Monitor tasksMonitor_; + std::set tasks_; + +}; + +}}} // apache::thrift::server + +#endif // #ifndef _THRIFT_SERVER_TTHREADEDSERVER_H_ diff --git a/lib/cpp/src/transport/TBufferTransports.cpp b/lib/cpp/src/transport/TBufferTransports.cpp new file mode 100644 index 00000000..7a7e5e92 --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.cpp @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include + +using std::string; + +namespace apache { namespace thrift { namespace transport { + + +uint32_t TBufferedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + // Get more from underlying transport up to buffer size. + // Note that this makes a lot of sense if len < rBufSize_ + // and almost no sense otherwise. TODO(dreiss): Fix that + // case (possibly including some readv hotness). + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TBufferedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + uint32_t have_bytes = wBase_ - wBuf_.get(); + uint32_t space = wBound_ - wBase_; + // We should only take the slow path if we can't accomodate the write + // with the free space already in the buffer. + assert(wBound_ - wBase_ < static_cast(len)); + + // Now here's the tricky question: should we copy data from buf into our + // internal buffer and write it from there, or should we just write out + // the current internal buffer in one syscall and write out buf in another. + // If our currently buffered data plus buf is at least double our buffer + // size, we will have to do two syscalls no matter what (except in the + // degenerate case when our buffer is empty), so there is no use copying. + // Otherwise, there is sort of a sliding scale. If we have N-1 bytes + // buffered and need to write 2, it would be crazy to do two syscalls. + // On the other hand, if we have 2 bytes buffered and are writing 2N-3, + // we can save a syscall in the short term by loading up our buffer, writing + // it out, and copying the rest of the bytes into our buffer. Of course, + // if we get another 2-byte write, we haven't saved any syscalls at all, + // and have just copied nearly 2N bytes for nothing. Finding a perfect + // policy would require predicting the size of future writes, so we're just + // going to always eschew syscalls if we have less than 2N bytes to write. + + // The case where we have to do two syscalls. + // This case also covers the case where the buffer is empty, + // but it is clearer (I think) to think of it as two separate cases. + if ((have_bytes + len >= 2*wBufSize_) || (have_bytes == 0)) { + // TODO(dreiss): writev + if (have_bytes > 0) { + transport_->write(wBuf_.get(), have_bytes); + } + transport_->write(buf, len); + wBase_ = wBuf_.get(); + return; + } + + // Fill up our internal buffer for a write. + memcpy(wBase_, buf, space); + buf += space; + len -= space; + transport_->write(wBuf_.get(), wBufSize_); + + // Copy the rest into our buffer. + assert(len < wBufSize_); + memcpy(wBuf_.get(), buf, len); + wBase_ = wBuf_.get() + len; + return; +} + +const uint8_t* TBufferedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // If the request is bigger than our buffer, we are hosed. + if (*len > rBufSize_) { + return NULL; + } + + // The number of bytes of data we have already. + uint32_t have = rBound_ - rBase_; + // The number of additional bytes we need from the underlying transport. + int32_t need = *len - have; + // The space from the start of the buffer to the end of our data. + uint32_t offset = rBound_ - rBuf_.get(); + assert(need > 0); + + // If we have less than half our buffer space available, shift the data + // we have down to the start. If the borrow is big compared to our buffer, + // this could be kind of a waste, but if the borrow is small, it frees up + // space at the end of our buffer to do a bigger single read from the + // underlying transport. Also, if our needs extend past the end of the + // buffer, we have to do a copy no matter what. + if ((offset > rBufSize_/2) || (offset + need > rBufSize_)) { + memmove(rBuf_.get(), rBase_, have); + setReadBuffer(rBuf_.get(), have); + } + + // First try to fill up the buffer. + uint32_t got = transport_->read(rBound_, rBufSize_ - have); + rBound_ += got; + need -= got; + + // If that fails, readAll until we get what we need. + if (need > 0) { + rBound_ += transport_->readAll(rBound_, need); + } + + *len = rBound_ - rBase_; + return rBase_; +} + +void TBufferedTransport::flush() { + // Write out any data waiting in the write buffer. + uint32_t have_bytes = wBase_ - wBuf_.get(); + if (have_bytes > 0) { + // Note that we reset wBase_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + wBase_ = wBuf_.get(); + transport_->write(wBuf_.get(), have_bytes); + } + + // Flush the underlying transport. + transport_->flush(); +} + + +uint32_t TFramedTransport::readSlow(uint8_t* buf, uint32_t len) { + uint32_t want = len; + uint32_t have = rBound_ - rBase_; + + // We should only take the slow path if we can't satisfy the read + // with the data already in the buffer. + assert(have < want); + + // Copy out whatever we have. + if (have > 0) { + memcpy(buf, rBase_, have); + want -= have; + buf += have; + } + + // Read another frame. + readFrame(); + + // TODO(dreiss): Should we warn when reads cross frames? + + // Hand over whatever we have. + uint32_t give = std::min(want, static_cast(rBound_ - rBase_)); + memcpy(buf, rBase_, give); + rBase_ += give; + want -= give; + + return (len - want); +} + +void TFramedTransport::readFrame() { + // TODO(dreiss): Think about using readv here, even though it would + // result in (gasp) read-ahead. + + // Read the size of the next frame. + int32_t sz; + transport_->readAll((uint8_t*)&sz, sizeof(sz)); + sz = ntohl(sz); + + if (sz < 0) { + throw TTransportException("Frame size has negative value"); + } + + // Read the frame payload, and reset markers. + if (sz > static_cast(rBufSize_)) { + rBuf_.reset(new uint8_t[sz]); + rBufSize_ = sz; + } + transport_->readAll(rBuf_.get(), sz); + setReadBuffer(rBuf_.get(), sz); +} + +void TFramedTransport::writeSlow(const uint8_t* buf, uint32_t len) { + // Double buffer size until sufficient. + uint32_t have = wBase_ - wBuf_.get(); + while (wBufSize_ < len + have) { + wBufSize_ *= 2; + } + + // TODO(dreiss): Consider modifying this class to use malloc/free + // so we can use realloc here. + + // Allocate new buffer. + uint8_t* new_buf = new uint8_t[wBufSize_]; + + // Copy the old buffer to the new one. + memcpy(new_buf, wBuf_.get(), have); + + // Now point buf to the new one. + wBuf_.reset(new_buf); + wBase_ = wBuf_.get() + have; + wBound_ = wBuf_.get() + wBufSize_; + + // Copy the data into the new buffer. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TFramedTransport::flush() { + int32_t sz_hbo, sz_nbo; + assert(wBufSize_ > sizeof(sz_nbo)); + + // Slip the frame size into the start of the buffer. + sz_hbo = wBase_ - (wBuf_.get() + sizeof(sz_nbo)); + sz_nbo = (int32_t)htonl((uint32_t)(sz_hbo)); + memcpy(wBuf_.get(), (uint8_t*)&sz_nbo, sizeof(sz_nbo)); + + if (sz_hbo > 0) { + // Note that we reset wBase_ (with a pad for the frame size) + // prior to the underlying write to ensure we're in a sane state + // (i.e. internal buffer cleaned) if the underlying write throws + // up an exception + wBase_ = wBuf_.get() + sizeof(sz_nbo); + + // Write size and frame body. + transport_->write(wBuf_.get(), sizeof(sz_nbo)+sz_hbo); + } + + // Flush the underlying transport. + transport_->flush(); +} + +const uint8_t* TFramedTransport::borrowSlow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If the fast path failed let the protocol use its slow path. + // Besides, who is going to try to borrow across messages? + return NULL; +} + + +void TMemoryBuffer::computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give) { + // Correct rBound_ so we can use the fast path in the future. + rBound_ = wBase_; + + // Decide how much to give. + uint32_t give = std::min(len, available_read()); + + *out_start = rBase_; + *out_give = give; + + // Preincrement rBase_ so the caller doesn't have to. + rBase_ += give; +} + +uint32_t TMemoryBuffer::readSlow(uint8_t* buf, uint32_t len) { + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Copy into the provided buffer. + memcpy(buf, start, give); + + return give; +} + +uint32_t TMemoryBuffer::readAppendToString(std::string& str, uint32_t len) { + // Don't get some stupid assertion failure. + if (buffer_ == NULL) { + return 0; + } + + uint8_t* start; + uint32_t give; + computeRead(len, &start, &give); + + // Append to the provided string. + str.append((char*)start, give); + + return give; +} + +void TMemoryBuffer::ensureCanWrite(uint32_t len) { + // Check available space + uint32_t avail = available_write(); + if (len <= avail) { + return; + } + + if (!owner_) { + throw TTransportException("Insufficient space in external MemoryBuffer"); + } + + // Grow the buffer as necessary. + while (len > avail) { + bufferSize_ *= 2; + wBound_ = buffer_ + bufferSize_; + avail = available_write(); + } + + // Allocate into a new pointer so we don't bork ours if it fails. + void* new_buffer = std::realloc(buffer_, bufferSize_); + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + ptrdiff_t offset = (uint8_t*)new_buffer - buffer_; + buffer_ += offset; + rBase_ += offset; + rBound_ += offset; + wBase_ += offset; + wBound_ += offset; +} + +void TMemoryBuffer::writeSlow(const uint8_t* buf, uint32_t len) { + ensureCanWrite(len); + + // Copy into the buffer and increment wBase_. + memcpy(wBase_, buf, len); + wBase_ += len; +} + +void TMemoryBuffer::wroteBytes(uint32_t len) { + uint32_t avail = available_write(); + if (len > avail) { + throw TTransportException("Client wrote more bytes than size of buffer."); + } + wBase_ += len; +} + +const uint8_t* TMemoryBuffer::borrowSlow(uint8_t* buf, uint32_t* len) { + rBound_ = wBase_; + if (available_read() >= *len) { + *len = available_read(); + return rBase_; + } + return NULL; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TBufferTransports.h b/lib/cpp/src/transport/TBufferTransports.h new file mode 100644 index 00000000..1908205f --- /dev/null +++ b/lib/cpp/src/transport/TBufferTransports.h @@ -0,0 +1,667 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ +#define _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ 1 + +#include +#include "boost/scoped_array.hpp" + +#include + +#ifdef __GNUC__ +#define TDB_LIKELY(val) (__builtin_expect((val), 1)) +#define TDB_UNLIKELY(val) (__builtin_expect((val), 0)) +#else +#define TDB_LIKELY(val) (val) +#define TDB_UNLIKELY(val) (val) +#endif + +namespace apache { namespace thrift { namespace transport { + + +/** + * Base class for all transports that use read/write buffers for performance. + * + * TBufferBase is designed to implement the fast-path "memcpy" style + * operations that work in the common case. It does so with small and + * (eventually) nonvirtual, inlinable methods. TBufferBase is an abstract + * class. Subclasses are expected to define the "slow path" operations + * that have to be done when the buffers are full or empty. + * + */ +class TBufferBase : public TTransport { + + public: + + /** + * Fast-path read. + * + * When we have enough data buffered to fulfill the read, we can satisfy it + * with a single memcpy, then adjust our internal pointers. If the buffer + * is empty, we call out to our slow path, implemented by a subclass. + * This method is meant to eventually be nonvirtual and inlinable. + */ + uint32_t read(uint8_t* buf, uint32_t len) { + uint8_t* new_rBase = rBase_ + len; + if (TDB_LIKELY(new_rBase <= rBound_)) { + std::memcpy(buf, rBase_, len); + rBase_ = new_rBase; + return len; + } + return readSlow(buf, len); + } + + /** + * Fast-path write. + * + * When we have enough empty space in our buffer to accomodate the write, we + * can satisfy it with a single memcpy, then adjust our internal pointers. + * If the buffer is full, we call out to our slow path, implemented by a + * subclass. This method is meant to eventually be nonvirtual and + * inlinable. + */ + void write(const uint8_t* buf, uint32_t len) { + uint8_t* new_wBase = wBase_ + len; + if (TDB_LIKELY(new_wBase <= wBound_)) { + std::memcpy(wBase_, buf, len); + wBase_ = new_wBase; + return; + } + writeSlow(buf, len); + } + + /** + * Fast-path borrow. A lot like the fast-path read. + */ + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + if (TDB_LIKELY(static_cast(*len) <= rBound_ - rBase_)) { + // With strict aliasing, writing to len shouldn't force us to + // refetch rBase_ from memory. TODO(dreiss): Verify this. + *len = rBound_ - rBase_; + return rBase_; + } + return borrowSlow(buf, len); + } + + /** + * Consume doesn't require a slow path. + */ + void consume(uint32_t len) { + if (TDB_LIKELY(static_cast(len) <= rBound_ - rBase_)) { + rBase_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } + } + + + protected: + + /// Slow path read. + virtual uint32_t readSlow(uint8_t* buf, uint32_t len) = 0; + + /// Slow path write. + virtual void writeSlow(const uint8_t* buf, uint32_t len) = 0; + + /** + * Slow path borrow. + * + * POSTCONDITION: return == NULL || rBound_ - rBase_ >= *len + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len) = 0; + + /** + * Trivial constructor. + * + * Initialize pointers safely. Constructing is not a very + * performance-sensitive operation, so it is okay to just leave it to + * the concrete class to set up pointers correctly. + */ + TBufferBase() + : rBase_(NULL) + , rBound_(NULL) + , wBase_(NULL) + , wBound_(NULL) + {} + + /// Convenience mutator for setting the read buffer. + void setReadBuffer(uint8_t* buf, uint32_t len) { + rBase_ = buf; + rBound_ = buf+len; + } + + /// Convenience mutator for setting the write buffer. + void setWriteBuffer(uint8_t* buf, uint32_t len) { + wBase_ = buf; + wBound_ = buf+len; + } + + virtual ~TBufferBase() {} + + /// Reads begin here. + uint8_t* rBase_; + /// Reads may extend to just before here. + uint8_t* rBound_; + + /// Writes begin here. + uint8_t* wBase_; + /// Writes may extend to just before here. + uint8_t* wBound_; +}; + + +/** + * Base class for all transport which wraps transport to new one. + */ +class TUnderlyingTransport : public TBufferBase { + public: + static const int DEFAULT_BUFFER_SIZE = 512; + + virtual bool peek() { + return (rBase_ < rBound_) || transport_->peek(); + } + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + void close() { + flush(); + transport_->close(); + } + + boost::shared_ptr getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr transport_; + + uint32_t rBufSize_; + uint32_t wBufSize_; + boost::scoped_array rBuf_; + boost::scoped_array wBuf_; + + TUnderlyingTransport(boost::shared_ptr transport, uint32_t sz) + : transport_(transport) + , rBufSize_(sz) + , wBufSize_(sz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr transport) + : transport_(transport) + , rBufSize_(DEFAULT_BUFFER_SIZE) + , wBufSize_(DEFAULT_BUFFER_SIZE) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} + + TUnderlyingTransport(boost::shared_ptr transport, uint32_t rsz, uint32_t wsz) + : transport_(transport) + , rBufSize_(rsz) + , wBufSize_(wsz) + , rBuf_(new uint8_t[rBufSize_]) + , wBuf_(new uint8_t[wBufSize_]) {} +}; + +/** + * Buffered transport. For reads it will read more data than is requested + * and will serve future data out of a local buffer. For writes, data is + * stored to an in memory buffer before being written out. + * + */ +class TBufferedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TBufferedTransport(boost::shared_ptr transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + /// Use specified buffer sizes. + TBufferedTransport(boost::shared_ptr transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + /// Use specified read and write buffer sizes. + TBufferedTransport(boost::shared_ptr transport, uint32_t rsz, uint32_t wsz) + : TUnderlyingTransport(transport, rsz, wsz) + { + initPointers(); + } + + virtual bool peek() { + /* shigin: see THRIFT-96 discussion */ + if (rBase_ == rBound_) { + setReadBuffer(rBuf_.get(), transport_->read(rBuf_.get(), rBufSize_)); + } + return (rBound_ > rBase_); + } + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + void flush(); + + + /** + * The following behavior is currently implemented by TBufferedTransport, + * but that may change in a future version: + * 1/ If len is at most rBufSize_, borrow will never return NULL. + * Depending on the underlying transport, it could throw an exception + * or hang forever. + * 2/ Some borrow requests may copy bytes internally. However, + * if len is at most rBufSize_/2, none of the copied bytes + * will ever have to be copied again. For optimial performance, + * stay under this limit. + */ + virtual const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + void initPointers() { + setReadBuffer(rBuf_.get(), 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + // Write size never changes. + } +}; + + +/** + * Wraps a transport into a buffered one. + * + */ +class TBufferedTransportFactory : public TTransportFactory { + public: + TBufferedTransportFactory() {} + + virtual ~TBufferedTransportFactory() {} + + /** + * Wraps the transport into a buffered one. + */ + virtual boost::shared_ptr getTransport(boost::shared_ptr trans) { + return boost::shared_ptr(new TBufferedTransport(trans)); + } + +}; + + +/** + * Framed transport. All writes go into an in-memory buffer until flush is + * called, at which point the transport writes the length of the entire + * binary chunk followed by the data payload. This allows the receiver on the + * other end to always do fixed-length reads. + * + */ +class TFramedTransport : public TUnderlyingTransport { + public: + + /// Use default buffer sizes. + TFramedTransport(boost::shared_ptr transport) + : TUnderlyingTransport(transport) + { + initPointers(); + } + + TFramedTransport(boost::shared_ptr transport, uint32_t sz) + : TUnderlyingTransport(transport, sz) + { + initPointers(); + } + + virtual uint32_t readSlow(uint8_t* buf, uint32_t len); + + virtual void writeSlow(const uint8_t* buf, uint32_t len); + + virtual void flush(); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + protected: + /** + * Reads a frame of input from the underlying stream. + */ + void readFrame(); + + void initPointers() { + setReadBuffer(NULL, 0); + setWriteBuffer(wBuf_.get(), wBufSize_); + + // Pad the buffer so we can insert the size later. + int32_t pad = 0; + this->write((uint8_t*)&pad, sizeof(pad)); + } +}; + +/** + * Wraps a transport into a framed one. + * + */ +class TFramedTransportFactory : public TTransportFactory { + public: + TFramedTransportFactory() {} + + virtual ~TFramedTransportFactory() {} + + /** + * Wraps the transport into a framed one. + */ + virtual boost::shared_ptr getTransport(boost::shared_ptr trans) { + return boost::shared_ptr(new TFramedTransport(trans)); + } + +}; + + +/** + * A memory buffer is a tranpsort that simply reads from and writes to an + * in memory buffer. Anytime you call write on it, the data is simply placed + * into a buffer, and anytime you call read, data is read from that buffer. + * + * The buffers are allocated using C constructs malloc,realloc, and the size + * doubles as necessary. We've considered using scoped + * + */ +class TMemoryBuffer : public TBufferBase { + private: + + // Common initialization done by all constructors. + void initCommon(uint8_t* buf, uint32_t size, bool owner, uint32_t wPos) { + if (buf == NULL && size != 0) { + assert(owner); + buf = (uint8_t*)std::malloc(size); + if (buf == NULL) { + throw TTransportException("Out of memory"); + } + } + + buffer_ = buf; + bufferSize_ = size; + + rBase_ = buffer_; + rBound_ = buffer_ + wPos; + // TODO(dreiss): Investigate NULL-ing this if !owner. + wBase_ = buffer_ + wPos; + wBound_ = buffer_ + bufferSize_; + + owner_ = owner; + + // rBound_ is really an artifact. In principle, it should always be + // equal to wBase_. We update it in a few places (computeRead, etc.). + } + + public: + static const uint32_t defaultSize = 1024; + + /** + * This enum specifies how a TMemoryBuffer should treat + * memory passed to it via constructors or resetBuffer. + * + * OBSERVE: + * TMemoryBuffer will simply store a pointer to the memory. + * It is the callers responsibility to ensure that the pointer + * remains valid for the lifetime of the TMemoryBuffer, + * and that it is properly cleaned up. + * Note that no data can be written to observed buffers. + * + * COPY: + * TMemoryBuffer will make an internal copy of the buffer. + * The caller has no responsibilities. + * + * TAKE_OWNERSHIP: + * TMemoryBuffer will become the "owner" of the buffer, + * and will be responsible for freeing it. + * The membory must have been allocated with malloc. + */ + enum MemoryPolicy + { OBSERVE = 1 + , COPY = 2 + , TAKE_OWNERSHIP = 3 + }; + + /** + * Construct a TMemoryBuffer with a default-sized buffer, + * owned by the TMemoryBuffer object. + */ + TMemoryBuffer() { + initCommon(NULL, defaultSize, true, 0); + } + + /** + * Construct a TMemoryBuffer with a buffer of a specified size, + * owned by the TMemoryBuffer object. + * + * @param sz The initial size of the buffer. + */ + TMemoryBuffer(uint32_t sz) { + initCommon(NULL, sz, true, 0); + } + + /** + * Construct a TMemoryBuffer with buf as its initial contents. + * + * @param buf The initial contents of the buffer. + * Note that, while buf is a non-const pointer, + * TMemoryBuffer will not write to it if policy == OBSERVE, + * so it is safe to const_cast(whatever). + * @param sz The size of @c buf. + * @param policy See @link MemoryPolicy @endlink . + */ + TMemoryBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + if (buf == NULL && sz != 0) { + throw TTransportException(TTransportException::BAD_ARGS, + "TMemoryBuffer given null buffer with non-zero size."); + } + + switch (policy) { + case OBSERVE: + case TAKE_OWNERSHIP: + initCommon(buf, sz, policy == TAKE_OWNERSHIP, sz); + break; + case COPY: + initCommon(NULL, sz, true, 0); + this->write(buf, sz); + break; + default: + throw TTransportException(TTransportException::BAD_ARGS, + "Invalid MemoryPolicy for TMemoryBuffer"); + } + } + + ~TMemoryBuffer() { + if (owner_) { + std::free(buffer_); + } + } + + bool isOpen() { + return true; + } + + bool peek() { + return (rBase_ < wBase_); + } + + void open() {} + + void close() {} + + // TODO(dreiss): Make bufPtr const. + void getBuffer(uint8_t** bufPtr, uint32_t* sz) { + *bufPtr = rBase_; + *sz = wBase_ - rBase_; + } + + std::string getBufferAsString() { + if (buffer_ == NULL) { + return ""; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + return std::string((char*)buf, (std::string::size_type)sz); + } + + void appendBufferToString(std::string& str) { + if (buffer_ == NULL) { + return; + } + uint8_t* buf; + uint32_t sz; + getBuffer(&buf, &sz); + str.append((char*)buf, sz); + } + + void resetBuffer(bool reset_capacity = false) { + if (reset_capacity) + { + assert(owner_); + + void* new_buffer = std::realloc(buffer_, defaultSize); + + if (new_buffer == NULL) { + throw TTransportException("Out of memory."); + } + + buffer_ = (uint8_t*) new_buffer; + bufferSize_ = defaultSize; + + wBound_ = buffer_ + bufferSize_; + } + + rBase_ = buffer_; + rBound_ = buffer_; + wBase_ = buffer_; + // It isn't safe to write into a buffer we don't own. + if (!owner_) { + wBound_ = wBase_; + bufferSize_ = 0; + } + } + + /// See constructor documentation. + void resetBuffer(uint8_t* buf, uint32_t sz, MemoryPolicy policy = OBSERVE) { + // Use a variant of the copy-and-swap trick for assignment operators. + // This is sub-optimal in terms of performance for two reasons: + // 1/ The constructing and swapping of the (small) values + // in the temporary object takes some time, and is not necessary. + // 2/ If policy == COPY, we allocate the new buffer before + // freeing the old one, precluding the possibility of + // reusing that memory. + // I doubt that either of these problems could be optimized away, + // but the second is probably no a common case, and the first is minor. + // I don't expect resetBuffer to be a common operation, so I'm willing to + // bite the performance bullet to make the method this simple. + + // Construct the new buffer. + TMemoryBuffer new_buffer(buf, sz, policy); + // Move it into ourself. + this->swap(new_buffer); + // Our old self gets destroyed. + } + + std::string readAsString(uint32_t len) { + std::string str; + (void)readAppendToString(str, len); + return str; + } + + uint32_t readAppendToString(std::string& str, uint32_t len); + + void readEnd() { + if (rBase_ == wBase_) { + resetBuffer(); + } + } + + uint32_t available_read() const { + // Remember, wBase_ is the real rBound_. + return wBase_ - rBase_; + } + + uint32_t available_write() const { + return wBound_ - wBase_; + } + + // Returns a pointer to where the client can write data to append to + // the TMemoryBuffer, and ensures the buffer is big enough to accomodate a + // write of the provided length. The returned pointer is very convenient for + // passing to read(), recv(), or similar. You must call wroteBytes() as soon + // as data is written or the buffer will not be aware that data has changed. + uint8_t* getWritePtr(uint32_t len) { + ensureCanWrite(len); + return wBase_; + } + + // Informs the buffer that the client has written 'len' bytes into storage + // that had been provided by getWritePtr(). + void wroteBytes(uint32_t len); + + protected: + void swap(TMemoryBuffer& that) { + using std::swap; + swap(buffer_, that.buffer_); + swap(bufferSize_, that.bufferSize_); + + swap(rBase_, that.rBase_); + swap(rBound_, that.rBound_); + swap(wBase_, that.wBase_); + swap(wBound_, that.wBound_); + + swap(owner_, that.owner_); + } + + // Make sure there's at least 'len' bytes available for writing. + void ensureCanWrite(uint32_t len); + + // Compute the position and available data for reading. + void computeRead(uint32_t len, uint8_t** out_start, uint32_t* out_give); + + uint32_t readSlow(uint8_t* buf, uint32_t len); + + void writeSlow(const uint8_t* buf, uint32_t len); + + const uint8_t* borrowSlow(uint8_t* buf, uint32_t* len); + + // Data buffer + uint8_t* buffer_; + + // Allocated buffer size + uint32_t bufferSize_; + + // Is this object the owner of the buffer? + bool owner_; + + // Don't forget to update constrctors, initCommon, and swap if + // you add new members. +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TBUFFERTRANSPORTS_H_ diff --git a/lib/cpp/src/transport/TFDTransport.cpp b/lib/cpp/src/transport/TFDTransport.cpp new file mode 100644 index 00000000..a042f8b7 --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.cpp @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include + +#include + +using namespace std; + +namespace apache { namespace thrift { namespace transport { + +void TFDTransport::close() { + if (!isOpen()) { + return; + } + + int rv = ::close(fd_); + int errno_copy = errno; + fd_ = -1; + // Have to check uncaught_exception because this is called in the destructor. + if (rv < 0 && !std::uncaught_exception()) { + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::close()", + errno_copy); + } +} + +uint32_t TFDTransport::read(uint8_t* buf, uint32_t len) { + ssize_t rv = ::read(fd_, buf, len); + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::read()", + errno_copy); + } + return rv; +} + +void TFDTransport::write(const uint8_t* buf, uint32_t len) { + while (len > 0) { + ssize_t rv = ::write(fd_, buf, len); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFDTransport::write()", + errno_copy); + } else if (rv == 0) { + throw TTransportException(TTransportException::END_OF_FILE, + "TFDTransport::write()"); + } + + buf += rv; + len -= rv; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFDTransport.h b/lib/cpp/src/transport/TFDTransport.h new file mode 100644 index 00000000..bda5d82a --- /dev/null +++ b/lib/cpp/src/transport/TFDTransport.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFDTRANSPORT_H_ 1 + +#include +#include + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file descriptor. + * + */ +class TFDTransport : public TTransport { + public: + enum ClosePolicy + { NO_CLOSE_ON_DESTROY = 0 + , CLOSE_ON_DESTROY = 1 + }; + + TFDTransport(int fd, ClosePolicy close_policy = NO_CLOSE_ON_DESTROY) + : fd_(fd) + , close_policy_(close_policy) + {} + + ~TFDTransport() { + if (close_policy_ == CLOSE_ON_DESTROY) { + close(); + } + } + + bool isOpen() { return fd_ >= 0; } + + void open() {} + + void close(); + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void setFD(int fd) { fd_ = fd; } + int getFD() { return fd_; } + + protected: + int fd_; + ClosePolicy close_policy_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TFDTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TFileTransport.cpp b/lib/cpp/src/transport/TFileTransport.cpp new file mode 100644 index 00000000..f67b9e35 --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.cpp @@ -0,0 +1,953 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "TFileTransport.h" +#include "TTransportUtils.h" + +#include +#ifdef HAVE_SYS_TIME_H +#include +#else +#include +#endif +#include +#include +#include +#ifdef HAVE_STRINGS_H +#include +#endif +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace transport { + +using boost::shared_ptr; +using namespace std; +using namespace apache::thrift::protocol; + +#ifndef HAVE_CLOCK_GETTIME + +/** + * Fake clock_gettime for systems like darwin + * + */ +#define CLOCK_REALTIME 0 +static int clock_gettime(int clk_id /*ignored*/, struct timespec *tp) { + struct timeval now; + + int rv = gettimeofday(&now, NULL); + if (rv != 0) { + return rv; + } + + tp->tv_sec = now.tv_sec; + tp->tv_nsec = now.tv_usec * 1000; + return 0; +} +#endif + +TFileTransport::TFileTransport(string path, bool readOnly) + : readState_() + , readBuff_(NULL) + , currentEvent_(NULL) + , readBuffSize_(DEFAULT_READ_BUFF_SIZE) + , readTimeout_(NO_TAIL_READ_TIMEOUT) + , chunkSize_(DEFAULT_CHUNK_SIZE) + , eventBufferSize_(DEFAULT_EVENT_BUFFER_SIZE) + , flushMaxUs_(DEFAULT_FLUSH_MAX_US) + , flushMaxBytes_(DEFAULT_FLUSH_MAX_BYTES) + , maxEventSize_(DEFAULT_MAX_EVENT_SIZE) + , maxCorruptedEvents_(DEFAULT_MAX_CORRUPTED_EVENTS) + , eofSleepTime_(DEFAULT_EOF_SLEEP_TIME_US) + , corruptedEventSleepTime_(DEFAULT_CORRUPTED_SLEEP_TIME_US) + , writerThreadId_(0) + , dequeueBuffer_(NULL) + , enqueueBuffer_(NULL) + , closing_(false) + , forceFlush_(false) + , filename_(path) + , fd_(0) + , bufferAndThreadInitialized_(false) + , offset_(0) + , lastBadChunk_(0) + , numCorruptedEventsInChunk_(0) + , readOnly_(readOnly) +{ + // initialize all the condition vars/mutexes + pthread_mutex_init(&mutex_, NULL); + pthread_cond_init(¬Full_, NULL); + pthread_cond_init(¬Empty_, NULL); + pthread_cond_init(&flushed_, NULL); + + openLogFile(); +} + +void TFileTransport::resetOutputFile(int fd, string filename, int64_t offset) { + filename_ = filename; + offset_ = offset; + + // check if current file is still open + if (fd_ > 0) { + // flush any events in the queue + flush(); + GlobalOutput.printf("error, current file (%s) not closed", filename_.c_str()); + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: resetOutputFile() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + } + + if (fd) { + fd_ = fd; + } else { + // open file if the input fd is 0 + openLogFile(); + } +} + + +TFileTransport::~TFileTransport() { + // flush the buffer if a writer thread is active + if (writerThreadId_ > 0) { + // reduce the flush timeout so that closing is quicker + setFlushMaxUs(300*1000); + + // flush output buffer + flush(); + + // set state to closing + closing_ = true; + + // TODO: make sure event queue is empty + // currently only the write buffer is flushed + // we dont actually wait until the queue is empty. This shouldn't be a big + // deal in the common case because writing is quick + + pthread_join(writerThreadId_, NULL); + writerThreadId_ = 0; + } + + if (dequeueBuffer_) { + delete dequeueBuffer_; + dequeueBuffer_ = NULL; + } + + if (enqueueBuffer_) { + delete enqueueBuffer_; + enqueueBuffer_ = NULL; + } + + if (readBuff_) { + delete[] readBuff_; + readBuff_ = NULL; + } + + if (currentEvent_) { + delete currentEvent_; + currentEvent_ = NULL; + } + + // close logfile + if (fd_ > 0) { + if(-1 == ::close(fd_)) { + GlobalOutput.perror("TFileTransport: ~TFileTransport() ::close() ", errno); + } + } +} + +bool TFileTransport::initBufferAndWriteThread() { + if (bufferAndThreadInitialized_) { + T_ERROR("Trying to double-init TFileTransport"); + return false; + } + + if (writerThreadId_ == 0) { + if (pthread_create(&writerThreadId_, NULL, startWriterThread, (void *)this) != 0) { + T_ERROR("Could not create writer thread"); + return false; + } + } + + dequeueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + enqueueBuffer_ = new TFileTransportBuffer(eventBufferSize_); + bufferAndThreadInitialized_ = true; + + return true; +} + +void TFileTransport::write(const uint8_t* buf, uint32_t len) { + if (readOnly_) { + throw TTransportException("TFileTransport: attempting to write to file opened readonly"); + } + + enqueueEvent(buf, len, false); +} + +void TFileTransport::enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush) { + // can't enqueue more events if file is going to close + if (closing_) { + return; + } + + // make sure that event size is valid + if ( (maxEventSize_ > 0) && (eventLen > maxEventSize_) ) { + T_ERROR("msg size is greater than max event size: %u > %u\n", eventLen, maxEventSize_); + return; + } + + if (eventLen == 0) { + T_ERROR("cannot enqueue an empty event"); + return; + } + + eventInfo* toEnqueue = new eventInfo(); + toEnqueue->eventBuff_ = (uint8_t *)std::malloc((sizeof(uint8_t) * eventLen) + 4); + // first 4 bytes is the event length + memcpy(toEnqueue->eventBuff_, (void*)(&eventLen), 4); + // actual event contents + memcpy(toEnqueue->eventBuff_ + 4, buf, eventLen); + toEnqueue->eventSize_ = eventLen + 4; + + // lock mutex + pthread_mutex_lock(&mutex_); + + // make sure that enqueue buffer is initialized and writer thread is running + if (!bufferAndThreadInitialized_) { + if (!initBufferAndWriteThread()) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + } + + // Can't enqueue while buffer is full + while (enqueueBuffer_->isFull()) { + pthread_cond_wait(¬Full_, &mutex_); + } + + // add to the buffer + if (!enqueueBuffer_->addEvent(toEnqueue)) { + delete toEnqueue; + pthread_mutex_unlock(&mutex_); + return; + } + + // signal anybody who's waiting for the buffer to be non-empty + pthread_cond_signal(¬Empty_); + + if (blockUntilFlush) { + pthread_cond_wait(&flushed_, &mutex_); + } + + // this really should be a loop where it makes sure it got flushed + // because condition variables can get triggered by the os for no reason + // it is probably a non-factor for the time being + pthread_mutex_unlock(&mutex_); +} + +bool TFileTransport::swapEventBuffers(struct timespec* deadline) { + pthread_mutex_lock(&mutex_); + if (deadline != NULL) { + // if we were handed a deadline time struct, do a timed wait + pthread_cond_timedwait(¬Empty_, &mutex_, deadline); + } else { + // just wait until the buffer gets an item + pthread_cond_wait(¬Empty_, &mutex_); + } + + bool swapped = false; + + // could be empty if we timed out + if (!enqueueBuffer_->isEmpty()) { + TFileTransportBuffer *temp = enqueueBuffer_; + enqueueBuffer_ = dequeueBuffer_; + dequeueBuffer_ = temp; + + swapped = true; + } + + // unlock the mutex and signal if required + pthread_mutex_unlock(&mutex_); + + if (swapped) { + pthread_cond_signal(¬Full_); + } + + return swapped; +} + + +void TFileTransport::writerThread() { + // open file if it is not open + if(!fd_) { + openLogFile(); + } + + // set the offset to the correct value (EOF) + try { + seekToEnd(); + } catch (TException &te) { + } + + // throw away any partial events + offset_ += readState_.lastDispatchPtr_; + ftruncate(fd_, offset_); + readState_.resetAllValues(); + + // Figure out the next time by which a flush must take place + + struct timespec ts_next_flush; + getNextFlushTime(&ts_next_flush); + uint32_t unflushed = 0; + + while(1) { + // this will only be true when the destructor is being invoked + if(closing_) { + // empty out both the buffers + if (enqueueBuffer_->isEmpty() && dequeueBuffer_->isEmpty()) { + if (-1 == ::close(fd_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() ::close() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error in file close", errno_copy); + } + // just be safe and sync to disk + fsync(fd_); + fd_ = 0; + pthread_exit(NULL); + return; + } + } + + if (swapEventBuffers(&ts_next_flush)) { + eventInfo* outEvent; + while (NULL != (outEvent = dequeueBuffer_->getNext())) { + if (!outEvent) { + T_DEBUG_L(1, "Got an empty event"); + return; + } + + // sanity check on event + if ((maxEventSize_ > 0) && (outEvent->eventSize_ > maxEventSize_)) { + T_ERROR("msg size is greater than max event size: %u > %u\n", outEvent->eventSize_, maxEventSize_); + continue; + } + + // If chunking is required, then make sure that msg does not cross chunk boundary + if ((outEvent->eventSize_ > 0) && (chunkSize_ != 0)) { + + // event size must be less than chunk size + if(outEvent->eventSize_ > chunkSize_) { + T_ERROR("TFileTransport: event size(%u) is greater than chunk size(%u): skipping event", + outEvent->eventSize_, chunkSize_); + continue; + } + + int64_t chunk1 = offset_/chunkSize_; + int64_t chunk2 = (offset_ + outEvent->eventSize_ - 1)/chunkSize_; + + // if adding this event will cross a chunk boundary, pad the chunk with zeros + if (chunk1 != chunk2) { + // refetch the offset to keep in sync + offset_ = lseek(fd_, 0, SEEK_CUR); + int32_t padding = (int32_t)((offset_/chunkSize_ + 1)*chunkSize_ - offset_); + + uint8_t zeros[padding]; + bzero(zeros, padding); + if (-1 == ::write(fd_, zeros, padding)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: writerThread() error while padding zeros ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while padding zeros", errno_copy); + } + unflushed += padding; + offset_ += padding; + } + } + + // write the dequeued event to the file + if (outEvent->eventSize_ > 0) { + if (-1 == ::write(fd_, outEvent->eventBuff_, outEvent->eventSize_)) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: error while writing event ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "TFileTransport: error while writing event", errno_copy); + } + + unflushed += outEvent->eventSize_; + offset_ += outEvent->eventSize_; + } + } + dequeueBuffer_->reset(); + } + + bool flushTimeElapsed = false; + struct timespec current_time; + clock_gettime(CLOCK_REALTIME, ¤t_time); + + if (current_time.tv_sec > ts_next_flush.tv_sec || + (current_time.tv_sec == ts_next_flush.tv_sec && current_time.tv_nsec > ts_next_flush.tv_nsec)) { + flushTimeElapsed = true; + getNextFlushTime(&ts_next_flush); + } + + // couple of cases from which a flush could be triggered + if ((flushTimeElapsed && unflushed > 0) || + unflushed > flushMaxBytes_ || + forceFlush_) { + + // sync (force flush) file to disk + fsync(fd_); + unflushed = 0; + + // notify anybody waiting for flush completion + forceFlush_ = false; + pthread_cond_broadcast(&flushed_); + } + } +} + +void TFileTransport::flush() { + // file must be open for writing for any flushing to take place + if (writerThreadId_ <= 0) { + return; + } + // wait for flush to take place + pthread_mutex_lock(&mutex_); + + forceFlush_ = true; + + while (forceFlush_) { + pthread_cond_wait(&flushed_, &mutex_); + } + + pthread_mutex_unlock(&mutex_); +} + + +uint32_t TFileTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +uint32_t TFileTransport::read(uint8_t* buf, uint32_t len) { + // check if there an event is ready to be read + if (!currentEvent_) { + currentEvent_ = readEvent(); + } + + // did not manage to read an event from the file. This could have happened + // if the timeout expired or there was some other error + if (!currentEvent_) { + return 0; + } + + // read as much of the current event as possible + int32_t remaining = currentEvent_->eventSize_ - currentEvent_->eventBuffPos_; + if (remaining <= (int32_t)len) { + // copy over anything thats remaining + if (remaining > 0) { + memcpy(buf, + currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, + remaining); + } + delete(currentEvent_); + currentEvent_ = NULL; + return remaining; + } + + // read as much as possible + memcpy(buf, currentEvent_->eventBuff_ + currentEvent_->eventBuffPos_, len); + currentEvent_->eventBuffPos_ += len; + return len; +} + +eventInfo* TFileTransport::readEvent() { + int readTries = 0; + + if (!readBuff_) { + readBuff_ = new uint8_t[readBuffSize_]; + } + + while (1) { + // read from the file if read buffer is exhausted + if (readState_.bufferPtr_ == readState_.bufferLen_) { + // advance the offset pointer + offset_ += readState_.bufferLen_; + readState_.bufferLen_ = ::read(fd_, readBuff_, readBuffSize_); + // if (readState_.bufferLen_) { + // T_DEBUG_L(1, "Amount read: %u (offset: %lu)", readState_.bufferLen_, offset_); + // } + readState_.bufferPtr_ = 0; + readState_.lastDispatchPtr_ = 0; + + // read error + if (readState_.bufferLen_ == -1) { + readState_.resetAllValues(); + GlobalOutput("TFileTransport: error while reading from file"); + throw TTransportException("TFileTransport: error while reading from file"); + } else if (readState_.bufferLen_ == 0) { // EOF + // wait indefinitely if there is no timeout + if (readTimeout_ == TAIL_READ_TIMEOUT) { + usleep(eofSleepTime_); + continue; + } else if (readTimeout_ == NO_TAIL_READ_TIMEOUT) { + // reset state + readState_.resetState(0); + return NULL; + } else if (readTimeout_ > 0) { + // timeout already expired once + if (readTries > 0) { + readState_.resetState(0); + return NULL; + } else { + usleep(readTimeout_ * 1000); + readTries++; + continue; + } + } + } + } + + readTries = 0; + + // attempt to read an event from the buffer + while(readState_.bufferPtr_ < readState_.bufferLen_) { + if (readState_.readingSize_) { + if(readState_.eventSizeBuffPos_ == 0) { + if ( (offset_ + readState_.bufferPtr_)/chunkSize_ != + ((offset_ + readState_.bufferPtr_ + 3)/chunkSize_)) { + // skip one byte towards chunk boundary + // T_DEBUG_L(1, "Skipping a byte"); + readState_.bufferPtr_++; + continue; + } + } + + readState_.eventSizeBuff_[readState_.eventSizeBuffPos_++] = + readBuff_[readState_.bufferPtr_++]; + if (readState_.eventSizeBuffPos_ == 4) { + // 0 length event indicates padding + if (*((uint32_t *)(readState_.eventSizeBuff_)) == 0) { + // T_DEBUG_L(1, "Got padding"); + readState_.resetState(readState_.lastDispatchPtr_); + continue; + } + // got a valid event + readState_.readingSize_ = false; + if (readState_.event_) { + delete(readState_.event_); + } + readState_.event_ = new eventInfo(); + readState_.event_->eventSize_ = *((uint32_t *)(readState_.eventSizeBuff_)); + + // check if the event is corrupted and perform recovery if required + if (isEventCorrupted()) { + performRecovery(); + // start from the top + break; + } + } + } else { + if (!readState_.event_->eventBuff_) { + readState_.event_->eventBuff_ = new uint8_t[readState_.event_->eventSize_]; + readState_.event_->eventBuffPos_ = 0; + } + // take either the entire event or the remaining bytes in the buffer + int reclaimBuffer = min((uint32_t)(readState_.bufferLen_ - readState_.bufferPtr_), + readState_.event_->eventSize_ - readState_.event_->eventBuffPos_); + + // copy data from read buffer into event buffer + memcpy(readState_.event_->eventBuff_ + readState_.event_->eventBuffPos_, + readBuff_ + readState_.bufferPtr_, + reclaimBuffer); + + // increment position ptrs + readState_.event_->eventBuffPos_ += reclaimBuffer; + readState_.bufferPtr_ += reclaimBuffer; + + // check if the event has been read in full + if (readState_.event_->eventBuffPos_ == readState_.event_->eventSize_) { + // set the completed event to the current event + eventInfo* completeEvent = readState_.event_; + completeEvent->eventBuffPos_ = 0; + + readState_.event_ = NULL; + readState_.resetState(readState_.bufferPtr_); + + // exit criteria + return completeEvent; + } + } + } + + } +} + +bool TFileTransport::isEventCorrupted() { + // an error is triggered if: + if ( (maxEventSize_ > 0) && (readState_.event_->eventSize_ > maxEventSize_)) { + // 1. Event size is larger than user-speficied max-event size + T_ERROR("Read corrupt event. Event size(%u) greater than max event size (%u)", + readState_.event_->eventSize_, maxEventSize_); + return true; + } else if (readState_.event_->eventSize_ > chunkSize_) { + // 2. Event size is larger than chunk size + T_ERROR("Read corrupt event. Event size(%u) greater than chunk size (%u)", + readState_.event_->eventSize_, chunkSize_); + return true; + } else if( ((offset_ + readState_.bufferPtr_ - 4)/chunkSize_) != + ((offset_ + readState_.bufferPtr_ + readState_.event_->eventSize_ - 1)/chunkSize_) ) { + // 3. size indicates that event crosses chunk boundary + T_ERROR("Read corrupt event. Event crosses chunk boundary. Event size:%u Offset:%ld", + readState_.event_->eventSize_, offset_ + readState_.bufferPtr_ + 4); + return true; + } + + return false; +} + +void TFileTransport::performRecovery() { + // perform some kickass recovery + uint32_t curChunk = getCurChunk(); + if (lastBadChunk_ == curChunk) { + numCorruptedEventsInChunk_++; + } else { + lastBadChunk_ = curChunk; + numCorruptedEventsInChunk_ = 1; + } + + if (numCorruptedEventsInChunk_ < maxCorruptedEvents_) { + // maybe there was an error in reading the file from disk + // seek to the beginning of chunk and try again + seekToChunk(curChunk); + } else { + + // just skip ahead to the next chunk if we not already at the last chunk + if (curChunk != (getNumChunks() - 1)) { + seekToChunk(curChunk + 1); + } else if (readTimeout_ == TAIL_READ_TIMEOUT) { + // if tailing the file, wait until there is enough data to start + // the next chunk + while(curChunk == (getNumChunks() - 1)) { + usleep(DEFAULT_CORRUPTED_SLEEP_TIME_US); + } + seekToChunk(curChunk + 1); + } else { + // pretty hosed at this stage, rewind the file back to the last successful + // point and punt on the error + readState_.resetState(readState_.lastDispatchPtr_); + currentEvent_ = NULL; + char errorMsg[1024]; + sprintf(errorMsg, "TFileTransport: log file corrupted at offset: %lu", + offset_ + readState_.lastDispatchPtr_); + GlobalOutput(errorMsg); + throw TTransportException(errorMsg); + } + } + +} + +void TFileTransport::seekToChunk(int32_t chunk) { + if (fd_ <= 0) { + throw TTransportException("File not open"); + } + + int32_t numChunks = getNumChunks(); + + // file is empty, seeking to chunk is pointless + if (numChunks == 0) { + return; + } + + // negative indicates reverse seek (from the end) + if (chunk < 0) { + chunk += numChunks; + } + + // too large a value for reverse seek, just seek to beginning + if (chunk < 0) { + T_DEBUG("Incorrect value for reverse seek. Seeking to beginning...", chunk) + chunk = 0; + } + + // cannot seek past EOF + bool seekToEnd = false; + uint32_t minEndOffset = 0; + if (chunk >= numChunks) { + T_DEBUG("Trying to seek past EOF. Seeking to EOF instead..."); + seekToEnd = true; + chunk = numChunks - 1; + // this is the min offset to process events till + minEndOffset = lseek(fd_, 0, SEEK_END); + } + + off_t newOffset = off_t(chunk) * chunkSize_; + offset_ = lseek(fd_, newOffset, SEEK_SET); + readState_.resetAllValues(); + currentEvent_ = NULL; + if (offset_ == -1) { + GlobalOutput("TFileTransport: lseek error in seekToChunk"); + throw TTransportException("TFileTransport: lseek error in seekToChunk"); + } + + // seek to EOF if user wanted to go to last chunk + if (seekToEnd) { + uint32_t oldReadTimeout = getReadTimeout(); + setReadTimeout(NO_TAIL_READ_TIMEOUT); + // keep on reading unti the last event at point of seekChunk call + while (readEvent() && ((offset_ + readState_.bufferPtr_) < minEndOffset)) {}; + setReadTimeout(oldReadTimeout); + } + +} + +void TFileTransport::seekToEnd() { + seekToChunk(getNumChunks()); +} + +uint32_t TFileTransport::getNumChunks() { + if (fd_ <= 0) { + return 0; + } + + struct stat f_info; + int rv = fstat(fd_, &f_info); + + if (rv < 0) { + int errno_copy = errno; + throw TTransportException(TTransportException::UNKNOWN, + "TFileTransport::getNumChunks() (fstat)", + errno_copy); + } + + if (f_info.st_size > 0) { + return ((f_info.st_size)/chunkSize_) + 1; + } + + // empty file has no chunks + return 0; +} + +uint32_t TFileTransport::getCurChunk() { + return offset_/chunkSize_; +} + +// Utility Functions +void TFileTransport::openLogFile() { + mode_t mode = readOnly_ ? S_IRUSR | S_IRGRP | S_IROTH : S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH; + int flags = readOnly_ ? O_RDONLY : O_RDWR | O_CREAT | O_APPEND; + fd_ = ::open(filename_.c_str(), flags, mode); + offset_ = 0; + + // make sure open call was successful + if(fd_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TFileTransport: openLogFile() ::open() file: " + filename_, errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, filename_, errno_copy); + } + +} + +void TFileTransport::getNextFlushTime(struct timespec* ts_next_flush) { + clock_gettime(CLOCK_REALTIME, ts_next_flush); + ts_next_flush->tv_nsec += (flushMaxUs_ % 1000000) * 1000; + if (ts_next_flush->tv_nsec > 1000000000) { + ts_next_flush->tv_nsec -= 1000000000; + ts_next_flush->tv_sec += 1; + } + ts_next_flush->tv_sec += flushMaxUs_ / 1000000; +} + +TFileTransportBuffer::TFileTransportBuffer(uint32_t size) + : bufferMode_(WRITE) + , writePoint_(0) + , readPoint_(0) + , size_(size) +{ + buffer_ = new eventInfo*[size]; +} + +TFileTransportBuffer::~TFileTransportBuffer() { + if (buffer_) { + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + delete[] buffer_; + buffer_ = NULL; + } +} + +bool TFileTransportBuffer::addEvent(eventInfo *event) { + if (bufferMode_ == READ) { + GlobalOutput("Trying to write to a buffer in read mode"); + } + if (writePoint_ < size_) { + buffer_[writePoint_++] = event; + return true; + } else { + // buffer is full + return false; + } +} + +eventInfo* TFileTransportBuffer::getNext() { + if (bufferMode_ == WRITE) { + bufferMode_ = READ; + } + if (readPoint_ < writePoint_) { + return buffer_[readPoint_++]; + } else { + // no more entries + return NULL; + } +} + +void TFileTransportBuffer::reset() { + if (bufferMode_ == WRITE || writePoint_ > readPoint_) { + T_DEBUG("Resetting a buffer with unread entries"); + } + // Clean up the old entries + for (uint32_t i = 0; i < writePoint_; i++) { + delete buffer_[i]; + } + bufferMode_ = WRITE; + writePoint_ = 0; + readPoint_ = 0; +} + +bool TFileTransportBuffer::isFull() { + return writePoint_ == size_; +} + +bool TFileTransportBuffer::isEmpty() { + return writePoint_ == 0; +} + +TFileProcessor::TFileProcessor(shared_ptr processor, + shared_ptr protocolFactory, + shared_ptr inputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr processor, + shared_ptr inputProtocolFactory, + shared_ptr outputProtocolFactory, + shared_ptr inputTransport): + processor_(processor), + inputProtocolFactory_(inputProtocolFactory), + outputProtocolFactory_(outputProtocolFactory), + inputTransport_(inputTransport) { + + // default the output transport to a null transport (common case) + outputTransport_ = shared_ptr(new TNullTransport()); +} + +TFileProcessor::TFileProcessor(shared_ptr processor, + shared_ptr protocolFactory, + shared_ptr inputTransport, + shared_ptr outputTransport): + processor_(processor), + inputProtocolFactory_(protocolFactory), + outputProtocolFactory_(protocolFactory), + inputTransport_(inputTransport), + outputTransport_(outputTransport) {}; + +void TFileProcessor::process(uint32_t numEvents, bool tail) { + shared_ptr inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + // set the read timeout to 0 if tailing is required + int32_t oldReadTimeout = inputTransport_->getReadTimeout(); + if (tail) { + // save old read timeout so it can be restored + inputTransport_->setReadTimeout(TFileTransport::TAIL_READ_TIMEOUT); + } + + uint32_t numProcessed = 0; + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + numProcessed++; + if ( (numEvents > 0) && (numProcessed == numEvents)) { + return; + } + } catch (TEOFException& teof) { + if (!tail) { + break; + } + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } + + // restore old read timeout + if (tail) { + inputTransport_->setReadTimeout(oldReadTimeout); + } + +} + +void TFileProcessor::processChunk() { + shared_ptr inputProtocol = inputProtocolFactory_->getProtocol(inputTransport_); + shared_ptr outputProtocol = outputProtocolFactory_->getProtocol(outputTransport_); + + uint32_t curChunk = inputTransport_->getCurChunk(); + + while(1) { + // bad form to use exceptions for flow control but there is really + // no other way around it + try { + processor_->process(inputProtocol, outputProtocol); + if (curChunk != inputTransport_->getCurChunk()) { + break; + } + } catch (TEOFException& teof) { + break; + } catch (TException &te) { + cerr << te.what() << endl; + break; + } + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TFileTransport.h b/lib/cpp/src/transport/TFileTransport.h new file mode 100644 index 00000000..fbaf2cd0 --- /dev/null +++ b/lib/cpp/src/transport/TFileTransport.h @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TFILETRANSPORT_H_ 1 + +#include "TTransport.h" +#include "Thrift.h" +#include "TProcessor.h" + +#include +#include + +#include + +namespace apache { namespace thrift { namespace transport { + +using apache::thrift::TProcessor; +using apache::thrift::protocol::TProtocolFactory; + +// Data pertaining to a single event +typedef struct eventInfo { + uint8_t* eventBuff_; + uint32_t eventSize_; + uint32_t eventBuffPos_; + + eventInfo():eventBuff_(NULL), eventSize_(0), eventBuffPos_(0){}; + ~eventInfo() { + if (eventBuff_) { + delete[] eventBuff_; + } + } +} eventInfo; + +// information about current read state +typedef struct readState { + eventInfo* event_; + + // keep track of event size + uint8_t eventSizeBuff_[4]; + uint8_t eventSizeBuffPos_; + bool readingSize_; + + // read buffer variables + int32_t bufferPtr_; + int32_t bufferLen_; + + // last successful dispatch point + int32_t lastDispatchPtr_; + + void resetState(uint32_t lastDispatchPtr) { + readingSize_ = true; + eventSizeBuffPos_ = 0; + lastDispatchPtr_ = lastDispatchPtr; + } + + void resetAllValues() { + resetState(0); + bufferPtr_ = 0; + bufferLen_ = 0; + if (event_) { + delete(event_); + } + event_ = 0; + } + + readState() { + event_ = 0; + resetAllValues(); + } + + ~readState() { + if (event_) { + delete(event_); + } + } + +} readState; + +/** + * TFileTransportBuffer - buffer class used by TFileTransport for queueing up events + * to be written to disk. Should be used in the following way: + * 1) Buffer created + * 2) Buffer written to (addEvent) + * 3) Buffer read from (getNext) + * 4) Buffer reset (reset) + * 5) Go back to 2, or destroy buffer + * + * The buffer should never be written to after it is read from, unless it is reset first. + * Note: The above rules are enforced mainly for debugging its sole client TFileTransport + * which uses the buffer in this way. + * + */ +class TFileTransportBuffer { + public: + TFileTransportBuffer(uint32_t size); + ~TFileTransportBuffer(); + + bool addEvent(eventInfo *event); + eventInfo* getNext(); + void reset(); + bool isFull(); + bool isEmpty(); + + private: + TFileTransportBuffer(); // should not be used + + enum mode { + WRITE, + READ + }; + mode bufferMode_; + + uint32_t writePoint_; + uint32_t readPoint_; + uint32_t size_; + eventInfo** buffer_; +}; + +/** + * Abstract interface for transports used to read files + */ +class TFileReaderTransport : virtual public TTransport { + public: + virtual int32_t getReadTimeout() = 0; + virtual void setReadTimeout(int32_t readTimeout) = 0; + + virtual uint32_t getNumChunks() = 0; + virtual uint32_t getCurChunk() = 0; + virtual void seekToChunk(int32_t chunk) = 0; + virtual void seekToEnd() = 0; +}; + +/** + * Abstract interface for transports used to write files + */ +class TFileWriterTransport : virtual public TTransport { + public: + virtual uint32_t getChunkSize() = 0; + virtual void setChunkSize(uint32_t chunkSize) = 0; +}; + +/** + * File implementation of a transport. Reads and writes are done to a + * file on disk. + * + */ +class TFileTransport : public TFileReaderTransport, + public TFileWriterTransport { + public: + TFileTransport(std::string path, bool readOnly=false); + ~TFileTransport(); + + // TODO: what is the correct behaviour for this? + // the log file is generally always open + bool isOpen() { + return true; + } + + void write(const uint8_t* buf, uint32_t len); + void flush(); + + uint32_t readAll(uint8_t* buf, uint32_t len); + uint32_t read(uint8_t* buf, uint32_t len); + + // log-file specific functions + void seekToChunk(int32_t chunk); + void seekToEnd(); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + + // for changing the output file + void resetOutputFile(int fd, std::string filename, int64_t offset); + + // Setter/Getter functions for user-controllable options + void setReadBuffSize(uint32_t readBuffSize) { + if (readBuffSize) { + readBuffSize_ = readBuffSize; + } + } + uint32_t getReadBuffSize() { + return readBuffSize_; + } + + static const int32_t TAIL_READ_TIMEOUT = -1; + static const int32_t NO_TAIL_READ_TIMEOUT = 0; + void setReadTimeout(int32_t readTimeout) { + readTimeout_ = readTimeout; + } + int32_t getReadTimeout() { + return readTimeout_; + } + + void setChunkSize(uint32_t chunkSize) { + if (chunkSize) { + chunkSize_ = chunkSize; + } + } + uint32_t getChunkSize() { + return chunkSize_; + } + + void setEventBufferSize(uint32_t bufferSize) { + if (bufferAndThreadInitialized_) { + GlobalOutput("Cannot change the buffer size after writer thread started"); + return; + } + eventBufferSize_ = bufferSize; + } + + uint32_t getEventBufferSize() { + return eventBufferSize_; + } + + void setFlushMaxUs(uint32_t flushMaxUs) { + if (flushMaxUs) { + flushMaxUs_ = flushMaxUs; + } + } + uint32_t getFlushMaxUs() { + return flushMaxUs_; + } + + void setFlushMaxBytes(uint32_t flushMaxBytes) { + if (flushMaxBytes) { + flushMaxBytes_ = flushMaxBytes; + } + } + uint32_t getFlushMaxBytes() { + return flushMaxBytes_; + } + + void setMaxEventSize(uint32_t maxEventSize) { + maxEventSize_ = maxEventSize; + } + uint32_t getMaxEventSize() { + return maxEventSize_; + } + + void setMaxCorruptedEvents(uint32_t maxCorruptedEvents) { + maxCorruptedEvents_ = maxCorruptedEvents; + } + uint32_t getMaxCorruptedEvents() { + return maxCorruptedEvents_; + } + + void setEofSleepTimeUs(uint32_t eofSleepTime) { + if (eofSleepTime) { + eofSleepTime_ = eofSleepTime; + } + } + uint32_t getEofSleepTimeUs() { + return eofSleepTime_; + } + + private: + // helper functions for writing to a file + void enqueueEvent(const uint8_t* buf, uint32_t eventLen, bool blockUntilFlush); + bool swapEventBuffers(struct timespec* deadline); + bool initBufferAndWriteThread(); + + // control for writer thread + static void* startWriterThread(void* ptr) { + (((TFileTransport*)ptr)->writerThread()); + return 0; + } + void writerThread(); + + // helper functions for reading from a file + eventInfo* readEvent(); + + // event corruption-related functions + bool isEventCorrupted(); + void performRecovery(); + + // Utility functions + void openLogFile(); + void getNextFlushTime(struct timespec* ts_next_flush); + + // Class variables + readState readState_; + uint8_t* readBuff_; + eventInfo* currentEvent_; + + uint32_t readBuffSize_; + static const uint32_t DEFAULT_READ_BUFF_SIZE = 1 * 1024 * 1024; + + int32_t readTimeout_; + static const int32_t DEFAULT_READ_TIMEOUT_MS = 200; + + // size of chunks that file will be split up into + uint32_t chunkSize_; + static const uint32_t DEFAULT_CHUNK_SIZE = 16 * 1024 * 1024; + + // size of event buffers + uint32_t eventBufferSize_; + static const uint32_t DEFAULT_EVENT_BUFFER_SIZE = 10000; + + // max number of microseconds that can pass without flushing + uint32_t flushMaxUs_; + static const uint32_t DEFAULT_FLUSH_MAX_US = 3000000; + + // max number of bytes that can be written without flushing + uint32_t flushMaxBytes_; + static const uint32_t DEFAULT_FLUSH_MAX_BYTES = 1000 * 1024; + + // max event size + uint32_t maxEventSize_; + static const uint32_t DEFAULT_MAX_EVENT_SIZE = 0; + + // max number of corrupted events per chunk + uint32_t maxCorruptedEvents_; + static const uint32_t DEFAULT_MAX_CORRUPTED_EVENTS = 0; + + // sleep duration when EOF is hit + uint32_t eofSleepTime_; + static const uint32_t DEFAULT_EOF_SLEEP_TIME_US = 500 * 1000; + + // sleep duration when a corrupted event is encountered + uint32_t corruptedEventSleepTime_; + static const uint32_t DEFAULT_CORRUPTED_SLEEP_TIME_US = 1 * 1000 * 1000; + + // writer thread id + pthread_t writerThreadId_; + + // buffers to hold data before it is flushed. Each element of the buffer stores a msg that + // needs to be written to the file. The buffers are swapped by the writer thread. + TFileTransportBuffer *dequeueBuffer_; + TFileTransportBuffer *enqueueBuffer_; + + // conditions used to block when the buffer is full or empty + pthread_cond_t notFull_, notEmpty_; + volatile bool closing_; + + // To keep track of whether the buffer has been flushed + pthread_cond_t flushed_; + volatile bool forceFlush_; + + // Mutex that is grabbed when enqueueing and swapping the read/write buffers + pthread_mutex_t mutex_; + + // File information + std::string filename_; + int fd_; + + // Whether the writer thread and buffers have been initialized + bool bufferAndThreadInitialized_; + + // Offset within the file + off_t offset_; + + // event corruption information + uint32_t lastBadChunk_; + uint32_t numCorruptedEventsInChunk_; + + bool readOnly_; +}; + +// Exception thrown when EOF is hit +class TEOFException : public TTransportException { + public: + TEOFException(): + TTransportException(TTransportException::END_OF_FILE) {}; +}; + + +// wrapper class to process events from a file containing thrift events +class TFileProcessor { + public: + /** + * Constructor that defaults output transport to null transport + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport file transport + */ + TFileProcessor(boost::shared_ptr processor, + boost::shared_ptr protocolFactory, + boost::shared_ptr inputTransport); + + TFileProcessor(boost::shared_ptr processor, + boost::shared_ptr inputProtocolFactory, + boost::shared_ptr outputProtocolFactory, + boost::shared_ptr inputTransport); + + /** + * Constructor + * + * @param processor processes log-file events + * @param protocolFactory protocol factory + * @param inputTransport input file transport + * @param output output transport + */ + TFileProcessor(boost::shared_ptr processor, + boost::shared_ptr protocolFactory, + boost::shared_ptr inputTransport, + boost::shared_ptr outputTransport); + + /** + * processes events from the file + * + * @param numEvents number of events to process (0 for unlimited) + * @param tail tails the file if true + */ + void process(uint32_t numEvents, bool tail); + + /** + * process events until the end of the chunk + * + */ + void processChunk(); + + private: + boost::shared_ptr processor_; + boost::shared_ptr inputProtocolFactory_; + boost::shared_ptr outputProtocolFactory_; + boost::shared_ptr inputTransport_; + boost::shared_ptr outputTransport_; +}; + + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/THttpClient.cpp b/lib/cpp/src/transport/THttpClient.cpp new file mode 100644 index 00000000..59f23396 --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.cpp @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "THttpClient.h" +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +/** + * Http client implementation. + * + */ + +// Yeah, yeah, hacky to put these here, I know. +static const char* CRLF = "\r\n"; +static const int CRLF_LEN = 2; + +THttpClient::THttpClient(boost::shared_ptr transport, string host, string path) : + transport_(transport), + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + init(); +} + +THttpClient::THttpClient(string host, int port, string path) : + host_(host), + path_(path), + readHeaders_(true), + chunked_(false), + chunkedDone_(false), + chunkSize_(0), + contentLength_(0), + httpBuf_(NULL), + httpPos_(0), + httpBufLen_(0), + httpBufSize_(1024) { + transport_ = boost::shared_ptr(new TSocket(host, port)); + init(); +} + +void THttpClient::init() { + httpBuf_ = (char*)std::malloc(httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + httpBuf_[httpBufLen_] = '\0'; +} + +THttpClient::~THttpClient() { + if (httpBuf_ != NULL) { + std::free(httpBuf_); + } +} + +uint32_t THttpClient::read(uint8_t* buf, uint32_t len) { + if (readBuffer_.available_read() == 0) { + readBuffer_.resetBuffer(); + uint32_t got = readMoreData(); + if (got == 0) { + return 0; + } + } + return readBuffer_.read(buf, len); +} + +void THttpClient::readEnd() { + // Read any pending chunked data (footers etc.) + if (chunked_) { + while (!chunkedDone_) { + readChunked(); + } + } +} + +uint32_t THttpClient::readMoreData() { + // Get more data! + refill(); + + if (readHeaders_) { + readHeaders(); + } + + if (chunked_) { + return readChunked(); + } else { + return readContent(contentLength_); + } +} + +uint32_t THttpClient::readChunked() { + uint32_t length = 0; + + char* line = readLine(); + uint32_t chunkSize = parseChunkSize(line); + if (chunkSize == 0) { + readChunkedFooters(); + } else { + // Read data content + length += readContent(chunkSize); + // Read trailing CRLF after content + readLine(); + } + return length; +} + +void THttpClient::readChunkedFooters() { + // End of data, read footer lines until a blank one appears + while (true) { + char* line = readLine(); + if (strlen(line) == 0) { + chunkedDone_ = true; + break; + } + } +} + +uint32_t THttpClient::parseChunkSize(char* line) { + char* semi = strchr(line, ';'); + if (semi != NULL) { + *semi = '\0'; + } + int size = 0; + sscanf(line, "%x", &size); + return (uint32_t)size; +} + +uint32_t THttpClient::readContent(uint32_t size) { + uint32_t need = size; + while (need > 0) { + uint32_t avail = httpBufLen_ - httpPos_; + if (avail == 0) { + // We have given all the data, reset position to head of the buffer + httpPos_ = 0; + httpBufLen_ = 0; + refill(); + + // Now have available however much we read + avail = httpBufLen_; + } + uint32_t give = avail; + if (need < give) { + give = need; + } + readBuffer_.write((uint8_t*)(httpBuf_+httpPos_), give); + httpPos_ += give; + need -= give; + } + return size; +} + +char* THttpClient::readLine() { + while (true) { + char* eol = NULL; + + eol = strstr(httpBuf_+httpPos_, CRLF); + + // No CRLF yet? + if (eol == NULL) { + // Shift whatever we have now to front and refill + shift(); + refill(); + } else { + // Return pointer to next line + *eol = '\0'; + char* line = httpBuf_+httpPos_; + httpPos_ = (eol-httpBuf_) + CRLF_LEN; + return line; + } + } + +} + +void THttpClient::shift() { + if (httpBufLen_ > httpPos_) { + // Shift down remaining data and read more + uint32_t length = httpBufLen_ - httpPos_; + memmove(httpBuf_, httpBuf_+httpPos_, length); + httpBufLen_ = length; + } else { + httpBufLen_ = 0; + } + httpPos_ = 0; + httpBuf_[httpBufLen_] = '\0'; +} + +void THttpClient::refill() { + uint32_t avail = httpBufSize_ - httpBufLen_; + if (avail <= (httpBufSize_ / 4)) { + httpBufSize_ *= 2; + httpBuf_ = (char*)std::realloc(httpBuf_, httpBufSize_+1); + if (httpBuf_ == NULL) { + throw TTransportException("Out of memory."); + } + } + + // Read more data + uint32_t got = transport_->read((uint8_t*)(httpBuf_+httpBufLen_), httpBufSize_-httpBufLen_); + httpBufLen_ += got; + httpBuf_[httpBufLen_] = '\0'; + + if (got == 0) { + throw TTransportException("Could not refill buffer"); + } +} + +void THttpClient::readHeaders() { + // Initialize headers state variables + contentLength_ = 0; + chunked_ = false; + chunkedDone_ = false; + chunkSize_ = 0; + + // Control state flow + bool statusLine = true; + bool finished = false; + + // Loop until headers are finished + while (true) { + char* line = readLine(); + + if (strlen(line) == 0) { + if (finished) { + readHeaders_ = false; + return; + } else { + // Must have been an HTTP 100, keep going for another status line + statusLine = true; + } + } else { + if (statusLine) { + statusLine = false; + finished = parseStatusLine(line); + } else { + parseHeader(line); + } + } + } +} + +bool THttpClient::parseStatusLine(char* status) { + char* http = status; + + char* code = strchr(http, ' '); + if (code == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + + *code = '\0'; + while (*(code++) == ' '); + + char* msg = strchr(code, ' '); + if (msg == NULL) { + throw TTransportException(string("Bad Status: ") + status); + } + *msg = '\0'; + + if (strcmp(code, "200") == 0) { + // HTTP 200 = OK, we got the response + return true; + } else if (strcmp(code, "100") == 0) { + // HTTP 100 = continue, just keep reading + return false; + } else { + throw TTransportException(string("Bad Status: ") + status); + } +} + +void THttpClient::parseHeader(char* header) { + char* colon = strchr(header, ':'); + if (colon == NULL) { + return; + } + uint32_t sz = colon - header; + char* value = colon+1; + + if (strncmp(header, "Transfer-Encoding", sz) == 0) { + if (strstr(value, "chunked") != NULL) { + chunked_ = true; + } + } else if (strncmp(header, "Content-Length", sz) == 0) { + chunked_ = false; + contentLength_ = atoi(value); + } +} + +void THttpClient::write(const uint8_t* buf, uint32_t len) { + writeBuffer_.write(buf, len); +} + +void THttpClient::flush() { + // Fetch the contents of the write buffer + uint8_t* buf; + uint32_t len; + writeBuffer_.getBuffer(&buf, &len); + + // Construct the HTTP header + std::ostringstream h; + h << + "POST " << path_ << " HTTP/1.1" << CRLF << + "Host: " << host_ << CRLF << + "Content-Type: application/x-thrift" << CRLF << + "Content-Length: " << len << CRLF << + "Accept: application/x-thrift" << CRLF << + "User-Agent: C++/THttpClient" << CRLF << + CRLF; + string header = h.str(); + + // Write the header, then the data, then flush + transport_->write((const uint8_t*)header.c_str(), header.size()); + transport_->write(buf, len); + transport_->flush(); + + // Reset the buffer and header variables + writeBuffer_.resetBuffer(); + readHeaders_ = true; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/THttpClient.h b/lib/cpp/src/transport/THttpClient.h new file mode 100644 index 00000000..f4be4c1a --- /dev/null +++ b/lib/cpp/src/transport/THttpClient.h @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ +#define _THRIFT_TRANSPORT_THTTPCLIENT_H_ 1 + +#include + +namespace apache { namespace thrift { namespace transport { + +/** + * HTTP client implementation of the thrift transport. This was irritating + * to write, but the alternatives in C++ land are daunting. Linking CURL + * requires 23 dynamic libraries last time I checked (WTF?!?). All we have + * here is a VERY basic HTTP/1.1 client which supports HTTP 100 Continue, + * chunked transfer encoding, keepalive, etc. Tested against Apache. + * + */ +class THttpClient : public TTransport { + public: + THttpClient(boost::shared_ptr transport, std::string host, std::string path=""); + + THttpClient(std::string host, int port, std::string path=""); + + virtual ~THttpClient(); + + void open() { + transport_->open(); + } + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd(); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + private: + void init(); + + protected: + + boost::shared_ptr transport_; + + TMemoryBuffer writeBuffer_; + TMemoryBuffer readBuffer_; + + std::string host_; + std::string path_; + + bool readHeaders_; + bool chunked_; + bool chunkedDone_; + uint32_t chunkSize_; + uint32_t contentLength_; + + char* httpBuf_; + uint32_t httpPos_; + uint32_t httpBufLen_; + uint32_t httpBufSize_; + + uint32_t readMoreData(); + char* readLine(); + + void readHeaders(); + void parseHeader(char* header); + bool parseStatusLine(char* status); + + uint32_t readChunked(); + void readChunkedFooters(); + uint32_t parseChunkSize(char* line); + + uint32_t readContent(uint32_t size); + + void refill(); + void shift(); + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_THTTPCLIENT_H_ diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp new file mode 100644 index 00000000..9b47aa53 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.cpp @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "TSocket.h" +#include "TServerSocket.h" +#include + +namespace apache { namespace thrift { namespace transport { + +using namespace std; +using boost::shared_ptr; + +TServerSocket::TServerSocket(int port) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(0), + recvTimeout_(0), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout) : + port_(port), + serverSocket_(-1), + acceptBacklog_(1024), + sendTimeout_(sendTimeout), + recvTimeout_(recvTimeout), + retryLimit_(0), + retryDelay_(0), + tcpSendBuffer_(0), + tcpRecvBuffer_(0), + intSock1_(-1), + intSock2_(-1) {} + +TServerSocket::~TServerSocket() { + close(); +} + +void TServerSocket::setSendTimeout(int sendTimeout) { + sendTimeout_ = sendTimeout; +} + +void TServerSocket::setRecvTimeout(int recvTimeout) { + recvTimeout_ = recvTimeout; +} + +void TServerSocket::setRetryLimit(int retryLimit) { + retryLimit_ = retryLimit; +} + +void TServerSocket::setRetryDelay(int retryDelay) { + retryDelay_ = retryDelay; +} + +void TServerSocket::setTcpSendBuffer(int tcpSendBuffer) { + tcpSendBuffer_ = tcpSendBuffer; +} + +void TServerSocket::setTcpRecvBuffer(int tcpRecvBuffer) { + tcpRecvBuffer_ = tcpRecvBuffer; +} + +void TServerSocket::listen() { + int sv[2]; + if (-1 == socketpair(AF_LOCAL, SOCK_STREAM, 0, sv)) { + GlobalOutput.perror("TServerSocket::listen() socketpair() ", errno); + intSock1_ = -1; + intSock2_ = -1; + } else { + intSock1_ = sv[1]; + intSock2_ = sv[0]; + } + + struct addrinfo hints, *res, *res0; + int error; + char port[sizeof("65536") + 1]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + // Wildcard address + error = getaddrinfo(NULL, port, &hints, &res0); + if (error) { + GlobalOutput.printf("getaddrinfo %d: %s", error, gai_strerror(error)); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for server socket."); + } + + // Pick the ipv6 address first since ipv4 addresses can be mapped + // into ipv6 space. + for (res = res0; res; res = res->ai_next) { + if (res->ai_family == AF_INET6 || res->ai_next == NULL) + break; + } + + serverSocket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (serverSocket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not create server socket.", errno_copy); + } + + // Set reusaddress to prevent 2MSL delay on accept + int one = 1; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_REUSEADDR, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_REUSEADDR ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_REUSEADDR", errno_copy); + } + + // Set TCP buffer sizes + if (tcpSendBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_SNDBUF, + &tcpSendBuffer_, sizeof(tcpSendBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_SNDBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_SNDBUF", errno_copy); + } + } + + if (tcpRecvBuffer_ > 0) { + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_RCVBUF, + &tcpRecvBuffer_, sizeof(tcpRecvBuffer_))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_RCVBUF ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_RCVBUF", errno_copy); + } + } + + // Defer accept + #ifdef TCP_DEFER_ACCEPT + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, TCP_DEFER_ACCEPT, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_DEFER_ACCEPT ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_DEFER_ACCEPT", errno_copy); + } + #endif // #ifdef TCP_DEFER_ACCEPT + + #ifdef IPV6_V6ONLY + int zero = 0; + if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY, + &zero, sizeof(zero))) { + GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ", errno); + } + #endif // #ifdef IPV6_V6ONLY + + // Turn linger off, don't want to block on calls to close + struct linger ling = {0, 0}; + if (-1 == setsockopt(serverSocket_, SOL_SOCKET, SO_LINGER, + &ling, sizeof(ling))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() SO_LINGER ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set SO_LINGER", errno_copy); + } + + // TCP Nodelay, speed over bandwidth + if (-1 == setsockopt(serverSocket_, IPPROTO_TCP, TCP_NODELAY, + &one, sizeof(one))) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() setsockopt() TCP_NODELAY ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not set TCP_NODELAY", errno_copy); + } + + // Set NONBLOCK on the accept socket + int flags = fcntl(serverSocket_, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + if (-1 == fcntl(serverSocket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() fcntl() O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + + // prepare the port information + // we may want to try to bind more than once, since SO_REUSEADDR doesn't + // always seem to work. The client can configure the retry variables. + int retries = 0; + do { + if (0 == bind(serverSocket_, res->ai_addr, res->ai_addrlen)) { + break; + } + + // use short circuit evaluation here to only sleep if we need to + } while ((retries++ < retryLimit_) && (sleep(retryDelay_) == 0)); + + // free addrinfo + freeaddrinfo(res0); + + // throw an error if we failed to bind properly + if (retries > retryLimit_) { + char errbuf[1024]; + sprintf(errbuf, "TServerSocket::listen() BIND %d", port_); + GlobalOutput(errbuf); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not bind"); + } + + // Call listen + if (-1 == ::listen(serverSocket_, acceptBacklog_)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not listen", errno_copy); + } + + // The socket is now listening! +} + +shared_ptr TServerSocket::acceptImpl() { + if (serverSocket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "TServerSocket not listening"); + } + + struct pollfd fds[2]; + + int maxEintrs = 5; + int numEintrs = 0; + + while (true) { + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = serverSocket_; + fds[0].events = POLLIN; + if (intSock2_ >= 0) { + fds[1].fd = intSock2_; + fds[1].events = POLLIN; + } + int ret = poll(fds, 2, -1); + + if (ret < 0) { + // error cases + if (errno == EINTR && (numEintrs++ < maxEintrs)) { + // EINTR needs to be handled manually and we can tolerate + // a certain number + continue; + } + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() poll() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } else if (ret > 0) { + // Check for an interrupt signal + if (intSock2_ >= 0 && (fds[1].revents & POLLIN)) { + int8_t buf; + if (-1 == recv(intSock2_, &buf, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::acceptImpl() recv() interrupt ", errno); + } + throw TTransportException(TTransportException::INTERRUPTED); + } + + // Check for the actual server socket being ready + if (fds[0].revents & POLLIN) { + break; + } + } else { + GlobalOutput("TServerSocket::acceptImpl() poll 0"); + throw TTransportException(TTransportException::UNKNOWN); + } + } + + struct sockaddr_storage clientAddress; + int size = sizeof(clientAddress); + int clientSocket = ::accept(serverSocket_, + (struct sockaddr *) &clientAddress, + (socklen_t *) &size); + + if (clientSocket < 0) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() ::accept() ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "accept()", errno_copy); + } + + // Make sure client socket is blocking + int flags = fcntl(clientSocket, F_GETFL, 0); + if (flags == -1) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_GETFL ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_GETFL)", errno_copy); + } + + if (-1 == fcntl(clientSocket, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TServerSocket::acceptImpl() fcntl() F_SETFL ~O_NONBLOCK ", errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy); + } + + shared_ptr client(new TSocket(clientSocket)); + if (sendTimeout_ > 0) { + client->setSendTimeout(sendTimeout_); + } + if (recvTimeout_ > 0) { + client->setRecvTimeout(recvTimeout_); + } + + return client; +} + +void TServerSocket::interrupt() { + if (intSock1_ >= 0) { + int8_t byte = 0; + if (-1 == send(intSock1_, &byte, sizeof(int8_t), 0)) { + GlobalOutput.perror("TServerSocket::interrupt() send() ", errno); + } + } +} + +void TServerSocket::close() { + if (serverSocket_ >= 0) { + shutdown(serverSocket_, SHUT_RDWR); + ::close(serverSocket_); + } + if (intSock1_ >= 0) { + ::close(intSock1_); + } + if (intSock2_ >= 0) { + ::close(intSock2_); + } + serverSocket_ = -1; + intSock1_ = -1; + intSock2_ = -1; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h new file mode 100644 index 00000000..a6be0173 --- /dev/null +++ b/lib/cpp/src/transport/TServerSocket.h @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSERVERSOCKET_H_ 1 + +#include "TServerTransport.h" +#include + +namespace apache { namespace thrift { namespace transport { + +class TSocket; + +/** + * Server socket implementation of TServerTransport. Wrapper around a unix + * socket listen and accept calls. + * + */ +class TServerSocket : public TServerTransport { + public: + TServerSocket(int port); + TServerSocket(int port, int sendTimeout, int recvTimeout); + + ~TServerSocket(); + + void setSendTimeout(int sendTimeout); + void setRecvTimeout(int recvTimeout); + + void setRetryLimit(int retryLimit); + void setRetryDelay(int retryDelay); + + void setTcpSendBuffer(int tcpSendBuffer); + void setTcpRecvBuffer(int tcpRecvBuffer); + + void listen(); + void close(); + + void interrupt(); + + protected: + boost::shared_ptr acceptImpl(); + + private: + int port_; + int serverSocket_; + int acceptBacklog_; + int sendTimeout_; + int recvTimeout_; + int retryLimit_; + int retryDelay_; + int tcpSendBuffer_; + int tcpRecvBuffer_; + + int intSock1_; + int intSock2_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERSOCKET_H_ diff --git a/lib/cpp/src/transport/TServerTransport.h b/lib/cpp/src/transport/TServerTransport.h new file mode 100644 index 00000000..40bbc6c7 --- /dev/null +++ b/lib/cpp/src/transport/TServerTransport.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ 1 + +#include "TTransport.h" +#include "TTransportException.h" +#include + +namespace apache { namespace thrift { namespace transport { + +/** + * Server transport framework. A server needs to have some facility for + * creating base transports to read/write from. + * + */ +class TServerTransport { + public: + virtual ~TServerTransport() {} + + /** + * Starts the server transport listening for new connections. Prior to this + * call most transports will not return anything when accept is called. + * + * @throws TTransportException if we were unable to listen + */ + virtual void listen() {} + + /** + * Gets a new dynamically allocated transport object and passes it to the + * caller. Note that it is the explicit duty of the caller to free the + * allocated object. The returned TTransport object must always be in the + * opened state. NULL should never be returned, instead an Exception should + * always be thrown. + * + * @return A new TTransport object + * @throws TTransportException if there is an error + */ + boost::shared_ptr accept() { + boost::shared_ptr result = acceptImpl(); + if (result == NULL) { + throw TTransportException("accept() may not return NULL"); + } + return result; + } + + /** + * For "smart" TServerTransport implementations that work in a multi + * threaded context this can be used to break out of an accept() call. + * It is expected that the transport will throw a TTransportException + * with the interrupted error code. + */ + virtual void interrupt() {} + + /** + * Closes this transport such that future calls to accept will do nothing. + */ + virtual void close() = 0; + + protected: + TServerTransport() {} + + /** + * Subclasses should implement this function for accept. + * + * @return A newly allocated TTransport object + * @throw TTransportException If an error occurs + */ + virtual boost::shared_ptr acceptImpl() = 0; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSERVERTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TShortReadTransport.h b/lib/cpp/src/transport/TShortReadTransport.h new file mode 100644 index 00000000..3df8a57c --- /dev/null +++ b/lib/cpp/src/transport/TShortReadTransport.h @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ 1 + +#include + +#include + +namespace apache { namespace thrift { namespace transport { namespace test { + +/** + * This class is only meant for testing. It wraps another transport. + * Calls to read are passed through with some probability. Otherwise, + * the read amount is randomly reduced before being passed through. + * + */ +class TShortReadTransport : public TTransport { + public: + TShortReadTransport(boost::shared_ptr transport, double full_prob) + : transport_(transport) + , fullProb_(full_prob) + {} + + bool isOpen() { + return transport_->isOpen(); + } + + bool peek() { + return transport_->peek(); + } + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len) { + if (len == 0) { + return 0; + } + + if (rand()/(double)RAND_MAX >= fullProb_) { + len = 1 + rand()%len; + } + return transport_->read(buf, len); + } + + void write(const uint8_t* buf, uint32_t len) { + transport_->write(buf, len); + } + + void flush() { + transport_->flush(); + } + + const uint8_t* borrow(uint8_t* buf, uint32_t* len) { + return transport_->borrow(buf, len); + } + + void consume(uint32_t len) { + return transport_->consume(len); + } + + boost::shared_ptr getUnderlyingTransport() { + return transport_; + } + + protected: + boost::shared_ptr transport_; + double fullProb_; +}; + +}}}} // apache::thrift::transport::test + +#endif // #ifndef _THRIFT_TRANSPORT_TSHORTREADTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSimpleFileTransport.cpp b/lib/cpp/src/transport/TSimpleFileTransport.cpp new file mode 100644 index 00000000..e58a5743 --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.cpp @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "TSimpleFileTransport.h" + +#include +#include +#include + +namespace apache { namespace thrift { namespace transport { + +TSimpleFileTransport:: +TSimpleFileTransport(const std::string& path, bool read, bool write) + : TFDTransport(-1, TFDTransport::CLOSE_ON_DESTROY) { + int flags = 0; + if (read && write) { + flags = O_RDWR; + } else if (read) { + flags = O_RDONLY; + } else if (write) { + flags = O_WRONLY; + } else { + throw TTransportException("Neither READ nor WRITE specified"); + } + if (write) { + flags |= O_CREAT | O_APPEND; + } + int fd = ::open(path.c_str(), + flags, + S_IRUSR | S_IWUSR| S_IRGRP | S_IROTH); + if (fd < 0) { + throw TTransportException("failed to open file for writing: " + path); + } + setFD(fd); + open(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSimpleFileTransport.h b/lib/cpp/src/transport/TSimpleFileTransport.h new file mode 100644 index 00000000..6cc52ea1 --- /dev/null +++ b/lib/cpp/src/transport/TSimpleFileTransport.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ +#define _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ 1 + +#include "TFDTransport.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * Dead-simple wrapper around a file. + * + * Writeable files are opened with O_CREAT and O_APPEND + */ +class TSimpleFileTransport : public TFDTransport { + public: + TSimpleFileTransport(const std::string& path, + bool read = true, + bool write = false); +}; + +}}} // apache::thrift::transport + +#endif // _THRIFT_TRANSPORT_TSIMPLEFILETRANSPORT_H_ diff --git a/lib/cpp/src/transport/TSocket.cpp b/lib/cpp/src/transport/TSocket.cpp new file mode 100644 index 00000000..3395dabd --- /dev/null +++ b/lib/cpp/src/transport/TSocket.cpp @@ -0,0 +1,589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "concurrency/Monitor.h" +#include "TSocket.h" +#include "TTransportException.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +// Global var to track total socket sys calls +uint32_t g_socket_syscalls = 0; + +/** + * TSocket implementation. + * + */ + +TSocket::TSocket(string host, int port) : + host_(host), + port_(port), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket() : + host_(""), + port_(0), + socket_(-1), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::TSocket(int socket) : + host_(""), + port_(0), + socket_(socket), + connTimeout_(0), + sendTimeout_(0), + recvTimeout_(0), + lingerOn_(1), + lingerVal_(0), + noDelay_(1), + maxRecvRetries_(5) { + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); +} + +TSocket::~TSocket() { + close(); +} + +bool TSocket::isOpen() { + return (socket_ >= 0); +} + +bool TSocket::peek() { + if (!isOpen()) { + return false; + } + uint8_t buf; + int r = recv(socket_, &buf, 1, MSG_PEEK); + if (r == -1) { + int errno_copy = errno; + #ifdef __FreeBSD__ + /* shigin: + * freebsd returns -1 and ECONNRESET if socket was closed by + * the other side + */ + if (errno_copy == ECONNRESET) + { + close(); + return false; + } + #endif + GlobalOutput.perror("TSocket::peek() recv() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::UNKNOWN, "recv()", errno_copy); + } + return (r > 0); +} + +void TSocket::openConnection(struct addrinfo *res) { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol); + if (socket_ == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy); + } + + // Send timeout + if (sendTimeout_ > 0) { + setSendTimeout(sendTimeout_); + } + + // Recv timeout + if (recvTimeout_ > 0) { + setRecvTimeout(recvTimeout_); + } + + // Linger + setLinger(lingerOn_, lingerVal_); + + // No delay + setNoDelay(noDelay_); + + // Set the socket to be non blocking for connect if a timeout exists + int flags = fcntl(socket_, F_GETFL, 0); + if (connTimeout_ > 0) { + if (-1 == fcntl(socket_, F_SETFL, flags | O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } else { + if (-1 == fcntl(socket_, F_SETFL, flags & ~O_NONBLOCK)) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() fcntl " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "fcntl() failed", errno_copy); + } + } + + // Connect the socket + int ret = connect(socket_, res->ai_addr, res->ai_addrlen); + + // success case + if (ret == 0) { + goto done; + } + + if (errno != EINPROGRESS) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", errno_copy); + } + + + struct pollfd fds[1]; + std::memset(fds, 0 , sizeof(fds)); + fds[0].fd = socket_; + fds[0].events = POLLOUT; + ret = poll(fds, 1, connTimeout_); + + if (ret > 0) { + // Ensure the socket is connected and that there are no errors set + int val; + socklen_t lon; + lon = sizeof(int); + int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, (void *)&val, &lon); + if (ret2 == -1) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", errno_copy); + } + // no errors on socket, go to town + if (val == 0) { + goto done; + } + GlobalOutput.perror("TSocket::open() error on socket (after poll) " + getSocketInfo(), val); + throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", val); + } else if (ret == 0) { + // socket timed out + string errStr = "TSocket::open() timed out " + getSocketInfo(); + GlobalOutput(errStr.c_str()); + throw TTransportException(TTransportException::NOT_OPEN, "open() timed out"); + } else { + // error on poll() + int errno_copy = errno; + GlobalOutput.perror("TSocket::open() poll() " + getSocketInfo(), errno_copy); + throw TTransportException(TTransportException::NOT_OPEN, "poll() failed", errno_copy); + } + + done: + // Set socket back to normal mode (blocking) + fcntl(socket_, F_SETFL, flags); +} + +void TSocket::open() { + if (isOpen()) { + throw TTransportException(TTransportException::ALREADY_OPEN); + } + + // Validate port number + if (port_ < 0 || port_ > 65536) { + throw TTransportException(TTransportException::NOT_OPEN, "Specified port is invalid"); + } + + struct addrinfo hints, *res, *res0; + res = NULL; + res0 = NULL; + int error; + char port[sizeof("65536")]; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG; + sprintf(port, "%d", port_); + + error = getaddrinfo(host_.c_str(), port, &hints, &res0); + + if (error) { + string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() + string(gai_strerror(error)); + GlobalOutput(errStr.c_str()); + close(); + throw TTransportException(TTransportException::NOT_OPEN, "Could not resolve host for client socket."); + } + + // Cycle through all the returned addresses until one + // connects or push the exception up. + for (res = res0; res; res = res->ai_next) { + try { + openConnection(res); + break; + } catch (TTransportException& ttx) { + if (res->ai_next) { + close(); + } else { + close(); + freeaddrinfo(res0); // cleanup on failure + throw; + } + } + } + + // Free address structure memory + freeaddrinfo(res0); +} + +void TSocket::close() { + if (socket_ >= 0) { + shutdown(socket_, SHUT_RDWR); + ::close(socket_); + } + socket_ = -1; +} + +uint32_t TSocket::read(uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket"); + } + + int32_t retries = 0; + + // EAGAIN can be signalled both when a timeout has occurred and when + // the system is out of resources (an awesome undocumented feature). + // The following is an approximation of the time interval under which + // EAGAIN is taken to indicate an out of resources error. + uint32_t eagainThresholdMicros = 0; + if (recvTimeout_) { + // if a readTimeout is specified along with a max number of recv retries, then + // the threshold will ensure that the read timeout is not exceeded even in the + // case of resource errors + eagainThresholdMicros = (recvTimeout_*1000)/ ((maxRecvRetries_>0) ? maxRecvRetries_ : 2); + } + + try_again: + // Read from the socket + struct timeval begin; + gettimeofday(&begin, NULL); + int got = recv(socket_, buf, len, 0); + int errno_copy = errno; //gettimeofday can change errno + struct timeval end; + gettimeofday(&end, NULL); + uint32_t readElapsedMicros = (((end.tv_sec - begin.tv_sec) * 1000 * 1000) + + (((uint64_t)(end.tv_usec - begin.tv_usec)))); + ++g_socket_syscalls; + + // Check for error on read + if (got < 0) { + if (errno_copy == EAGAIN) { + // check if this is the lack of resources or timeout case + if (!eagainThresholdMicros || (readElapsedMicros < eagainThresholdMicros)) { + if (retries++ < maxRecvRetries_) { + usleep(50); + goto try_again; + } else { + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (unavailable resources)"); + } + } else { + // infer that timeout has been hit + throw TTransportException(TTransportException::TIMED_OUT, + "EAGAIN (timed out)"); + } + } + + // If interrupted, try again + if (errno_copy == EINTR && retries++ < maxRecvRetries_) { + goto try_again; + } + + // Now it's not a try again case, but a real probblez + GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy); + + // If we disconnect with no linger time + if (errno_copy == ECONNRESET) { + #ifdef __FreeBSD__ + /* shigin: freebsd doesn't follow POSIX semantic of recv and fails with + * ECONNRESET if peer performed shutdown + */ + close(); + return 0; + #else + throw TTransportException(TTransportException::NOT_OPEN, "ECONNRESET"); + #endif + } + + // This ish isn't open + if (errno_copy == ENOTCONN) { + throw TTransportException(TTransportException::NOT_OPEN, "ENOTCONN"); + } + + // Timed out! + if (errno_copy == ETIMEDOUT) { + throw TTransportException(TTransportException::TIMED_OUT, "ETIMEDOUT"); + } + + // Some other error, whatevz + throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy); + } + + // The remote host has closed the socket + if (got == 0) { + close(); + return 0; + } + + // Pack data into string + return got; +} + +void TSocket::write(const uint8_t* buf, uint32_t len) { + if (socket_ < 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket"); + } + + uint32_t sent = 0; + + while (sent < len) { + + int flags = 0; + #ifdef MSG_NOSIGNAL + // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we + // check for the EPIPE return condition and close the socket in that case + flags |= MSG_NOSIGNAL; + #endif // ifdef MSG_NOSIGNAL + + int b = send(socket_, buf + sent, len - sent, flags); + ++g_socket_syscalls; + + // Fail on a send error + if (b < 0) { + int errno_copy = errno; + GlobalOutput.perror("TSocket::write() send() " + getSocketInfo(), errno_copy); + + if (errno == EPIPE || errno == ECONNRESET || errno == ENOTCONN) { + close(); + throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy); + } + + throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy); + } + + // Fail on blocked send + if (b == 0) { + throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0."); + } + sent += b; + } +} + +std::string TSocket::getHost() { + return host_; +} + +int TSocket::getPort() { + return port_; +} + +void TSocket::setHost(string host) { + host_ = host; +} + +void TSocket::setPort(int port) { + port_ = port; +} + +void TSocket::setLinger(bool on, int linger) { + lingerOn_ = on; + lingerVal_ = linger; + if (socket_ < 0) { + return; + } + + struct linger l = {(lingerOn_ ? 1 : 0), lingerVal_}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_LINGER, &l, sizeof(l)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setLinger() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setNoDelay(bool noDelay) { + noDelay_ = noDelay; + if (socket_ < 0) { + return; + } + + // Set socket to NODELAY + int v = noDelay_ ? 1 : 0; + int ret = setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, &v, sizeof(v)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setNoDelay() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setConnTimeout(int ms) { + connTimeout_ = ms; +} + +void TSocket::setRecvTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setRecvTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + recvTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + recvTimeval_.tv_sec = (int)(recvTimeout_/1000); + recvTimeval_.tv_usec = (int)((recvTimeout_%1000)*1000); + + // Copy because poll may modify + struct timeval r = recvTimeval_; + int ret = setsockopt(socket_, SOL_SOCKET, SO_RCVTIMEO, &r, sizeof(r)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setRecvTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setSendTimeout(int ms) { + if (ms < 0) { + char errBuf[512]; + sprintf(errBuf, "TSocket::setSendTimeout with negative input: %d\n", ms); + GlobalOutput(errBuf); + return; + } + sendTimeout_ = ms; + + if (socket_ < 0) { + return; + } + + struct timeval s = {(int)(sendTimeout_/1000), + (int)((sendTimeout_%1000)*1000)}; + int ret = setsockopt(socket_, SOL_SOCKET, SO_SNDTIMEO, &s, sizeof(s)); + if (ret == -1) { + int errno_copy = errno; // Copy errno because we're allocating memory. + GlobalOutput.perror("TSocket::setSendTimeout() setsockopt() " + getSocketInfo(), errno_copy); + } +} + +void TSocket::setMaxRecvRetries(int maxRecvRetries) { + maxRecvRetries_ = maxRecvRetries; +} + +string TSocket::getSocketInfo() { + std::ostringstream oss; + oss << ""; + return oss.str(); +} + +std::string TSocket::getPeerHost() { + if (peerHost_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return host_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerHost_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), 0); + + peerHost_ = clienthost; + } + return peerHost_; +} + +std::string TSocket::getPeerAddress() { + if (peerAddress_.empty()) { + struct sockaddr_storage addr; + socklen_t addrLen = sizeof(addr); + + if (socket_ < 0) { + return peerAddress_; + } + + int rv = getpeername(socket_, (sockaddr*) &addr, &addrLen); + + if (rv != 0) { + return peerAddress_; + } + + char clienthost[NI_MAXHOST]; + char clientservice[NI_MAXSERV]; + + getnameinfo((sockaddr*) &addr, addrLen, + clienthost, sizeof(clienthost), + clientservice, sizeof(clientservice), + NI_NUMERICHOST|NI_NUMERICSERV); + + peerAddress_ = clienthost; + peerPort_ = std::atoi(clientservice); + } + return peerAddress_; +} + +int TSocket::getPeerPort() { + getPeerAddress(); + return peerPort_; +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h new file mode 100644 index 00000000..b0f445aa --- /dev/null +++ b/lib/cpp/src/transport/TSocket.h @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKET_H_ +#define _THRIFT_TRANSPORT_TSOCKET_H_ 1 + +#include +#include + +#include "TTransport.h" +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocket : public TTransport { + /** + * We allow the TServerSocket acceptImpl() method to access the private + * members of a socket so that it can access the TSocket(int socket) + * constructor which creates a socket object from the raw UNIX socket + * handle. + */ + friend class TServerSocket; + + public: + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + */ + TSocket(); + + /** + * Constructs a new socket. Note that this does NOT actually connect the + * socket. + * + * @param host An IP address or hostname to connect to + * @param port The port to connect on + */ + TSocket(std::string host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocket(); + + /** + * Whether the socket is alive. + * + * @return Is the socket alive? + */ + bool isOpen(); + + /** + * Calls select on the socket to see if there is more data available. + */ + bool peek(); + + /** + * Creates and opens the UNIX socket. + * + * @throws TTransportException If the socket could not connect + */ + virtual void open(); + + /** + * Shuts down communications on the socket. + */ + void close(); + + /** + * Reads from the underlying socket. + */ + uint32_t read(uint8_t* buf, uint32_t len); + + /** + * Writes to the underlying socket. + */ + void write(const uint8_t* buf, uint32_t len); + + /** + * Get the host that the socket is connected to + * + * @return string host identifier + */ + std::string getHost(); + + /** + * Get the port that the socket is connected to + * + * @return int port number + */ + int getPort(); + + /** + * Set the host that socket will connect to + * + * @param host host identifier + */ + void setHost(std::string host); + + /** + * Set the port that socket will connect to + * + * @param port port number + */ + void setPort(int port); + + /** + * Controls whether the linger option is set on the socket. + * + * @param on Whether SO_LINGER is on + * @param linger If linger is active, the number of seconds to linger for + */ + void setLinger(bool on, int linger); + + /** + * Whether to enable/disable Nagle's algorithm. + * + * @param noDelay Whether or not to disable the algorithm. + * @return + */ + void setNoDelay(bool noDelay); + + /** + * Set the connect timeout + */ + void setConnTimeout(int ms); + + /** + * Set the receive timeout + */ + void setRecvTimeout(int ms); + + /** + * Set the send timeout + */ + void setSendTimeout(int ms); + + /** + * Set the max number of recv retries in case of an EAGAIN + * error + */ + void setMaxRecvRetries(int maxRecvRetries); + + /** + * Get socket information formated as a string + */ + std::string getSocketInfo(); + + /** + * Returns the DNS name of the host to which the socket is connected + */ + std::string getPeerHost(); + + /** + * Returns the address of the host to which the socket is connected + */ + std::string getPeerAddress(); + + /** + * Returns the port of the host to which the socket is connected + **/ + int getPeerPort(); + + + protected: + /** + * Constructor to create socket from raw UNIX handle. Never called directly + * but used by the TServerSocket class. + */ + TSocket(int socket); + + /** connect, called by open */ + void openConnection(struct addrinfo *res); + + /** Host to connect to */ + std::string host_; + + /** Peer hostname */ + std::string peerHost_; + + /** Peer address */ + std::string peerAddress_; + + /** Peer port */ + int peerPort_; + + /** Port number to connect on */ + int port_; + + /** Underlying UNIX socket handle */ + int socket_; + + /** Connect timeout in ms */ + int connTimeout_; + + /** Send timeout in ms */ + int sendTimeout_; + + /** Recv timeout in ms */ + int recvTimeout_; + + /** Linger on */ + bool lingerOn_; + + /** Linger val */ + int lingerVal_; + + /** Nodelay */ + bool noDelay_; + + /** Recv EGAIN retries */ + int maxRecvRetries_; + + /** Recv timeout timeval */ + struct timeval recvTimeval_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKET_H_ + diff --git a/lib/cpp/src/transport/TSocketPool.cpp b/lib/cpp/src/transport/TSocketPool.cpp new file mode 100644 index 00000000..1150282b --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.cpp @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "TSocketPool.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace std; + +using boost::shared_ptr; + +/** + * TSocketPoolServer implementation + * + */ +TSocketPoolServer::TSocketPoolServer() + : host_(""), + port_(0), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * Constructor for TSocketPool server + */ +TSocketPoolServer::TSocketPoolServer(const string &host, int port) + : host_(host), + port_(port), + socket_(-1), + lastFailTime_(0), + consecutiveFailures_(0) {} + +/** + * TSocketPool implementation. + * + */ + +TSocketPool::TSocketPool() : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) { +} + +TSocketPool::TSocketPool(const vector &hosts, + const vector &ports) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + if (hosts.size() != ports.size()) { + GlobalOutput("TSocketPool::TSocketPool: hosts.size != ports.size"); + throw TTransportException(TTransportException::BAD_ARGS); + } + + for (unsigned int i = 0; i < hosts.size(); ++i) { + addServer(hosts[i], ports[i]); + } +} + +TSocketPool::TSocketPool(const vector >& servers) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + for (unsigned i = 0; i < servers.size(); ++i) { + addServer(servers[i].first, servers[i].second); + } +} + +TSocketPool::TSocketPool(const vector< shared_ptr >& servers) : TSocket(), + servers_(servers), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ +} + +TSocketPool::TSocketPool(const string& host, int port) : TSocket(), + numRetries_(1), + retryInterval_(60), + maxConsecutiveFailures_(1), + randomize_(true), + alwaysTryLast_(true) +{ + addServer(host, port); +} + +TSocketPool::~TSocketPool() { + vector< shared_ptr >::const_iterator iter = servers_.begin(); + vector< shared_ptr >::const_iterator iterEnd = servers_.end(); + for (; iter != iterEnd; ++iter) { + setCurrentServer(*iter); + TSocketPool::close(); + } +} + +void TSocketPool::addServer(const string& host, int port) { + servers_.push_back(shared_ptr(new TSocketPoolServer(host, port))); +} + +void TSocketPool::setServers(const vector< shared_ptr >& servers) { + servers_ = servers; +} + +void TSocketPool::getServers(vector< shared_ptr >& servers) { + servers = servers_; +} + +void TSocketPool::setNumRetries(int numRetries) { + numRetries_ = numRetries; +} + +void TSocketPool::setRetryInterval(int retryInterval) { + retryInterval_ = retryInterval; +} + + +void TSocketPool::setMaxConsecutiveFailures(int maxConsecutiveFailures) { + maxConsecutiveFailures_ = maxConsecutiveFailures; +} + +void TSocketPool::setRandomize(bool randomize) { + randomize_ = randomize; +} + +void TSocketPool::setAlwaysTryLast(bool alwaysTryLast) { + alwaysTryLast_ = alwaysTryLast; +} + +void TSocketPool::setCurrentServer(const shared_ptr &server) { + currentServer_ = server; + host_ = server->host_; + port_ = server->port_; + socket_ = server->socket_; +} + +/* TODO: without apc we ignore a lot of functionality from the php version */ +void TSocketPool::open() { + if (randomize_) { + random_shuffle(servers_.begin(), servers_.end()); + } + + unsigned int numServers = servers_.size(); + for (unsigned int i = 0; i < numServers; ++i) { + + shared_ptr &server = servers_[i]; + bool retryIntervalPassed = (server->lastFailTime_ == 0); + bool isLastServer = alwaysTryLast_ ? (i == (numServers - 1)) : false; + + // Impersonate the server socket + setCurrentServer(server); + + if (isOpen()) { + // already open means we're done + return; + } + + if (server->lastFailTime_ > 0) { + // The server was marked as down, so check if enough time has elapsed to retry + int elapsedTime = time(NULL) - server->lastFailTime_; + if (elapsedTime > retryInterval_) { + retryIntervalPassed = true; + } + } + + if (retryIntervalPassed || isLastServer) { + for (int j = 0; j < numRetries_; ++j) { + try { + TSocket::open(); + + // Copy over the opened socket so that we can keep it persistent + server->socket_ = socket_; + + // reset lastFailTime_ is required + if (server->lastFailTime_) { + server->lastFailTime_ = 0; + } + + // success + return; + } catch (TException e) { + string errStr = "TSocketPool::open failed "+getSocketInfo()+": "+e.what(); + GlobalOutput(errStr.c_str()); + // connection failed + } + } + + ++server->consecutiveFailures_; + if (server->consecutiveFailures_ > maxConsecutiveFailures_) { + // Mark server as down + server->consecutiveFailures_ = 0; + server->lastFailTime_ = time(NULL); + } + } + } + + GlobalOutput("TSocketPool::open: all connections failed"); + throw TTransportException(TTransportException::NOT_OPEN); +} + +void TSocketPool::close() { + if (isOpen()) { + TSocket::close(); + currentServer_->socket_ = -1; + } +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TSocketPool.h b/lib/cpp/src/transport/TSocketPool.h new file mode 100644 index 00000000..8c506695 --- /dev/null +++ b/lib/cpp/src/transport/TSocketPool.h @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ +#define _THRIFT_TRANSPORT_TSOCKETPOOL_H_ 1 + +#include +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + + /** + * Class to hold server information for TSocketPool + * + */ +class TSocketPoolServer { + + public: + /** + * Default constructor for server info + */ + TSocketPoolServer(); + + /** + * Constructor for TSocketPool server + */ + TSocketPoolServer(const std::string &host, int port); + + // Host name + std::string host_; + + // Port to connect on + int port_; + + // Socket for the server + int socket_; + + // Last time connecting to this server failed + int lastFailTime_; + + // Number of consecutive times connecting to this server failed + int consecutiveFailures_; +}; + +/** + * TCP Socket implementation of the TTransport interface. + * + */ +class TSocketPool : public TSocket { + + public: + + /** + * Socket pool constructor + */ + TSocketPool(); + + /** + * Socket pool constructor + * + * @param hosts list of host names + * @param ports list of port names + */ + TSocketPool(const std::vector &hosts, + const std::vector &ports); + + /** + * Socket pool constructor + * + * @param servers list of pairs of host name and port + */ + TSocketPool(const std::vector >& servers); + + /** + * Socket pool constructor + * + * @param servers list of TSocketPoolServers + */ + TSocketPool(const std::vector< boost::shared_ptr >& servers); + + /** + * Socket pool constructor + * + * @param host single host + * @param port single port + */ + TSocketPool(const std::string& host, int port); + + /** + * Destroyes the socket object, closing it if necessary. + */ + virtual ~TSocketPool(); + + /** + * Add a server to the pool + */ + void addServer(const std::string& host, int port); + + /** + * Set list of servers in this pool + */ + void setServers(const std::vector< boost::shared_ptr >& servers); + + /** + * Get list of servers in this pool + */ + void getServers(std::vector< boost::shared_ptr >& servers); + + /** + * Sets how many times to keep retrying a host in the connect function. + */ + void setNumRetries(int numRetries); + + /** + * Sets how long to wait until retrying a host if it was marked down + */ + void setRetryInterval(int retryInterval); + + /** + * Sets how many times to keep retrying a host before marking it as down. + */ + void setMaxConsecutiveFailures(int maxConsecutiveFailures); + + /** + * Turns randomization in connect order on or off. + */ + void setRandomize(bool randomize); + + /** + * Whether to always try the last server. + */ + void setAlwaysTryLast(bool alwaysTryLast); + + /** + * Creates and opens the UNIX socket. + */ + void open(); + + /* + * Closes the UNIX socket + */ + void close(); + + protected: + + void setCurrentServer(const boost::shared_ptr &server); + + /** List of servers to connect to */ + std::vector< boost::shared_ptr > servers_; + + /** Current server */ + boost::shared_ptr currentServer_; + + /** How many times to retry each host in connect */ + int numRetries_; + + /** Retry interval in seconds, how long to not try a host if it has been + * marked as down. + */ + int retryInterval_; + + /** Max consecutive failures before marking a host down. */ + int maxConsecutiveFailures_; + + /** Try hosts in order? or Randomized? */ + bool randomize_; + + /** Always try last host, even if marked down? */ + bool alwaysTryLast_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TSOCKETPOOL_H_ + diff --git a/lib/cpp/src/transport/TTransport.h b/lib/cpp/src/transport/TTransport.h new file mode 100644 index 00000000..eb0d5df8 --- /dev/null +++ b/lib/cpp/src/transport/TTransport.h @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TTRANSPORT_H_ 1 + +#include +#include +#include +#include + +namespace apache { namespace thrift { namespace transport { + +/** + * Generic interface for a method of transporting data. A TTransport may be + * capable of either reading or writing, but not necessarily both. + * + */ +class TTransport { + public: + /** + * Virtual deconstructor. + */ + virtual ~TTransport() {} + + /** + * Whether this transport is open. + */ + virtual bool isOpen() { + return false; + } + + /** + * Tests whether there is more data to read or if the remote side is + * still open. By default this is true whenever the transport is open, + * but implementations should add logic to test for this condition where + * possible (i.e. on a socket). + * This is used by a server to check if it should listen for another + * request. + */ + virtual bool peek() { + return isOpen(); + } + + /** + * Opens the transport for communications. + * + * @return bool Whether the transport was successfully opened + * @throws TTransportException if opening failed + */ + virtual void open() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot open base TTransport."); + } + + /** + * Closes the transport. + */ + virtual void close() { + throw TTransportException(TTransportException::NOT_OPEN, "Cannot close base TTransport."); + } + + /** + * Attempt to read up to the specified number of bytes into the string. + * + * @param buf Reference to the location to write the data + * @param len How many bytes to read + * @return How many bytes were actually read + * @throws TTransportException If an error occurs + */ + virtual uint32_t read(uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot read."); + } + + /** + * Reads the given amount of data in its entirety no matter what. + * + * @param s Reference to location for read data + * @param len How many bytes to read + * @return How many bytes read, which must be equal to size + * @throws TTransportException If insufficient data was read + */ + virtual uint32_t readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TTransportException("No more data to read."); + } + have += get; + } + + return have; + } + + /** + * Called when read is completed. + * This can be over-ridden to perform a transport-specific action + * e.g. logging the request to a file + * + */ + virtual void readEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Writes the string in its entirety to the buffer. + * + * @param buf The data to write out + * @throws TTransportException if an error occurs + */ + virtual void write(const uint8_t* /* buf */, uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot write."); + } + + /** + * Called when write is completed. + * This can be over-ridden to perform a transport-specific action + * at the end of a request. + * + */ + virtual void writeEnd() { + // default behaviour is to do nothing + return; + } + + /** + * Flushes any pending data to be written. Typically used with buffered + * transport mechanisms. + * + * @throws TTransportException if an error occurs + */ + virtual void flush() {} + + /** + * Attempts to return a pointer to \c len bytes, possibly copied into \c buf. + * Does not consume the bytes read (i.e.: a later read will return the same + * data). This method is meant to support protocols that need to read + * variable-length fields. They can attempt to borrow the maximum amount of + * data that they will need, then consume (see next method) what they + * actually use. Some transports will not support this method and others + * will fail occasionally, so protocols must be prepared to use read if + * borrow fails. + * + * @oaram buf A buffer where the data can be stored if needed. + * If borrow doesn't return buf, then the contents of + * buf after the call are undefined. + * @param len *len should initially contain the number of bytes to borrow. + * If borrow succeeds, *len will contain the number of bytes + * available in the returned pointer. This will be at least + * what was requested, but may be more if borrow returns + * a pointer to an internal buffer, rather than buf. + * If borrow fails, the contents of *len are undefined. + * @return If the borrow succeeds, return a pointer to the borrowed data. + * This might be equal to \c buf, or it might be a pointer into + * the transport's internal buffers. + * @throws TTransportException if an error occurs + */ + virtual const uint8_t* borrow(uint8_t* /* buf */, uint32_t* /* len */) { + return NULL; + } + + /** + * Remove len bytes from the transport. This should always follow a borrow + * of at least len bytes, and should always succeed. + * TODO(dreiss): Is there any transport that could borrow but fail to + * consume, or that would require a buffer to dump the consumed data? + * + * @param len How many bytes to consume + * @throws TTransportException If an error occurs + */ + virtual void consume(uint32_t /* len */) { + throw TTransportException(TTransportException::NOT_OPEN, "Base TTransport cannot consume."); + } + + protected: + /** + * Simple constructor. + */ + TTransport() {} +}; + +/** + * Generic factory class to make an input and output transport out of a + * source transport. Commonly used inside servers to make input and output + * streams out of raw clients. + * + */ +class TTransportFactory { + public: + TTransportFactory() {} + + virtual ~TTransportFactory() {} + + /** + * Default implementation does nothing, just returns the transport given. + */ + virtual boost::shared_ptr getTransport(boost::shared_ptr trans) { + return trans; + } + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORT_H_ diff --git a/lib/cpp/src/transport/TTransportException.cpp b/lib/cpp/src/transport/TTransportException.cpp new file mode 100644 index 00000000..f0aaedc2 --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.cpp @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +using std::string; +using boost::lexical_cast; + +namespace apache { namespace thrift { namespace transport { + +}}} // apache::thrift::transport + diff --git a/lib/cpp/src/transport/TTransportException.h b/lib/cpp/src/transport/TTransportException.h new file mode 100644 index 00000000..330785ce --- /dev/null +++ b/lib/cpp/src/transport/TTransportException.h @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ 1 + +#include +#include + +namespace apache { namespace thrift { namespace transport { + +/** + * Class to encapsulate all the possible types of transport errors that may + * occur in various transport systems. This provides a sort of generic + * wrapper around the shitty UNIX E_ error codes that lets a common code + * base of error handling to be used for various types of transports, i.e. + * pipes etc. + * + */ +class TTransportException : public apache::thrift::TException { + public: + /** + * Error codes for the various types of exceptions. + */ + enum TTransportExceptionType + { UNKNOWN = 0 + , NOT_OPEN = 1 + , ALREADY_OPEN = 2 + , TIMED_OUT = 3 + , END_OF_FILE = 4 + , INTERRUPTED = 5 + , BAD_ARGS = 6 + , CORRUPTED_DATA = 7 + , INTERNAL_ERROR = 8 + }; + + TTransportException() : + apache::thrift::TException(), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type) : + apache::thrift::TException(), + type_(type) {} + + TTransportException(const std::string& message) : + apache::thrift::TException(message), + type_(UNKNOWN) {} + + TTransportException(TTransportExceptionType type, const std::string& message) : + apache::thrift::TException(message), + type_(type) {} + + TTransportException(TTransportExceptionType type, + const std::string& message, + int errno_copy) : + apache::thrift::TException(message + ": " + TOutput::strerror_s(errno_copy)), + type_(type) {} + + virtual ~TTransportException() throw() {} + + /** + * Returns an error code that provides information about the type of error + * that has occurred. + * + * @return Error code + */ + TTransportExceptionType getType() const throw() { + return type_; + } + + virtual const char* what() const throw() { + if (message_.empty()) { + switch (type_) { + case UNKNOWN : return "TTransportException: Unknown transport exception"; + case NOT_OPEN : return "TTransportException: Transport not open"; + case ALREADY_OPEN : return "TTransportException: Transport already open"; + case TIMED_OUT : return "TTransportException: Timed out"; + case END_OF_FILE : return "TTransportException: End of file"; + case INTERRUPTED : return "TTransportException: Interrupted"; + case BAD_ARGS : return "TTransportException: Invalid arguments"; + case CORRUPTED_DATA : return "TTransportException: Corrupted Data"; + case INTERNAL_ERROR : return "TTransportException: Internal error"; + default : return "TTransportException: (Invalid exception type)"; + } + } else { + return message_.c_str(); + } + } + + protected: + /** Just like strerror_r but returns a C++ string object. */ + std::string strerror_s(int errno_copy); + + /** Error code */ + TTransportExceptionType type_; + +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTEXCEPTION_H_ diff --git a/lib/cpp/src/transport/TTransportUtils.cpp b/lib/cpp/src/transport/TTransportUtils.cpp new file mode 100644 index 00000000..a840fa6c --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.cpp @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +uint32_t TPipedTransport::read(uint8_t* buf, uint32_t len) { + uint32_t need = len; + + // We don't have enough data yet + if (rLen_-rPos_ < need) { + // Copy out whatever we have + if (rLen_-rPos_ > 0) { + memcpy(buf, rBuf_+rPos_, rLen_-rPos_); + need -= rLen_-rPos_; + buf += rLen_-rPos_; + rPos_ = rLen_; + } + + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + + + // Hand over whatever we have + uint32_t give = need; + if (rLen_-rPos_ < give) { + give = rLen_-rPos_; + } + if (give > 0) { + memcpy(buf, rBuf_+rPos_, give); + rPos_ += give; + need -= give; + } + + return (len - need); +} + +void TPipedTransport::write(const uint8_t* buf, uint32_t len) { + if (len == 0) { + return; + } + + // Make the buffer as big as it needs to be + if ((len + wLen_) >= wBufSize_) { + uint32_t newBufSize = wBufSize_*2; + while ((len + wLen_) >= newBufSize) { + newBufSize *= 2; + } + wBuf_ = (uint8_t *)std::realloc(wBuf_, sizeof(uint8_t) * newBufSize); + wBufSize_ = newBufSize; + } + + // Copy into the buffer + memcpy(wBuf_ + wLen_, buf, len); + wLen_ += len; +} + +void TPipedTransport::flush() { + // Write out any data waiting in the write buffer + if (wLen_ > 0) { + srcTrans_->write(wBuf_, wLen_); + wLen_ = 0; + } + + // Flush the underlying transport + srcTrans_->flush(); +} + +TPipedFileReaderTransport::TPipedFileReaderTransport(boost::shared_ptr srcTrans, boost::shared_ptr dstTrans) + : TPipedTransport(srcTrans, dstTrans), + srcTrans_(srcTrans) { +} + +TPipedFileReaderTransport::~TPipedFileReaderTransport() { +} + +bool TPipedFileReaderTransport::isOpen() { + return TPipedTransport::isOpen(); +} + +bool TPipedFileReaderTransport::peek() { + return TPipedTransport::peek(); +} + +void TPipedFileReaderTransport::open() { + TPipedTransport::open(); +} + +void TPipedFileReaderTransport::close() { + TPipedTransport::close(); +} + +uint32_t TPipedFileReaderTransport::read(uint8_t* buf, uint32_t len) { + return TPipedTransport::read(buf, len); +} + +uint32_t TPipedFileReaderTransport::readAll(uint8_t* buf, uint32_t len) { + uint32_t have = 0; + uint32_t get = 0; + + while (have < len) { + get = read(buf+have, len-have); + if (get <= 0) { + throw TEOFException(); + } + have += get; + } + + return have; +} + +void TPipedFileReaderTransport::readEnd() { + TPipedTransport::readEnd(); +} + +void TPipedFileReaderTransport::write(const uint8_t* buf, uint32_t len) { + TPipedTransport::write(buf, len); +} + +void TPipedFileReaderTransport::writeEnd() { + TPipedTransport::writeEnd(); +} + +void TPipedFileReaderTransport::flush() { + TPipedTransport::flush(); +} + +int32_t TPipedFileReaderTransport::getReadTimeout() { + return srcTrans_->getReadTimeout(); +} + +void TPipedFileReaderTransport::setReadTimeout(int32_t readTimeout) { + srcTrans_->setReadTimeout(readTimeout); +} + +uint32_t TPipedFileReaderTransport::getNumChunks() { + return srcTrans_->getNumChunks(); +} + +uint32_t TPipedFileReaderTransport::getCurChunk() { + return srcTrans_->getCurChunk(); +} + +void TPipedFileReaderTransport::seekToChunk(int32_t chunk) { + srcTrans_->seekToChunk(chunk); +} + +void TPipedFileReaderTransport::seekToEnd() { + srcTrans_->seekToEnd(); +} + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TTransportUtils.h b/lib/cpp/src/transport/TTransportUtils.h new file mode 100644 index 00000000..d65c9167 --- /dev/null +++ b/lib/cpp/src/transport/TTransportUtils.h @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ +#define _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ 1 + +#include +#include +#include +#include +#include +// Include the buffered transports that used to be defined here. +#include +#include + +namespace apache { namespace thrift { namespace transport { + +/** + * The null transport is a dummy transport that doesn't actually do anything. + * It's sort of an analogy to /dev/null, you can never read anything from it + * and it will let you write anything you want to it, though it won't actually + * go anywhere. + * + */ +class TNullTransport : public TTransport { + public: + TNullTransport() {} + + ~TNullTransport() {} + + bool isOpen() { + return true; + } + + void open() {} + + void write(const uint8_t* /* buf */, uint32_t /* len */) { + return; + } + +}; + + +/** + * TPipedTransport. This transport allows piping of a request from one + * transport to another either when readEnd() or writeEnd(). The typical + * use case for this is to log a request or a reply to disk. + * The underlying buffer expands to a keep a copy of the entire + * request/response. + * + */ +class TPipedTransport : virtual public TTransport { + public: + TPipedTransport(boost::shared_ptr srcTrans, + boost::shared_ptr dstTrans) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(512), wLen_(0) { + + // default is to to pipe the request when readEnd() is called + pipeOnRead_ = true; + pipeOnWrite_ = false; + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + TPipedTransport(boost::shared_ptr srcTrans, + boost::shared_ptr dstTrans, + uint32_t sz) : + srcTrans_(srcTrans), + dstTrans_(dstTrans), + rBufSize_(512), rPos_(0), rLen_(0), + wBufSize_(sz), wLen_(0) { + + rBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * rBufSize_); + wBuf_ = (uint8_t*) std::malloc(sizeof(uint8_t) * wBufSize_); + } + + ~TPipedTransport() { + std::free(rBuf_); + std::free(wBuf_); + } + + bool isOpen() { + return srcTrans_->isOpen(); + } + + bool peek() { + if (rPos_ >= rLen_) { + // Double the size of the underlying buffer if it is full + if (rLen_ == rBufSize_) { + rBufSize_ *=2; + rBuf_ = (uint8_t *)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_); + } + + // try to fill up the buffer + rLen_ += srcTrans_->read(rBuf_+rPos_, rBufSize_ - rPos_); + } + return (rLen_ > rPos_); + } + + + void open() { + srcTrans_->open(); + } + + void close() { + srcTrans_->close(); + } + + void setPipeOnRead(bool pipeVal) { + pipeOnRead_ = pipeVal; + } + + void setPipeOnWrite(bool pipeVal) { + pipeOnWrite_ = pipeVal; + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void readEnd() { + + if (pipeOnRead_) { + dstTrans_->write(rBuf_, rPos_); + dstTrans_->flush(); + } + + srcTrans_->readEnd(); + + // If requests are being pipelined, copy down our read-ahead data, + // then reset our state. + int read_ahead = rLen_ - rPos_; + memcpy(rBuf_, rBuf_ + rPos_, read_ahead); + rPos_ = 0; + rLen_ = read_ahead; + } + + void write(const uint8_t* buf, uint32_t len); + + void writeEnd() { + if (pipeOnWrite_) { + dstTrans_->write(wBuf_, wLen_); + dstTrans_->flush(); + } + } + + void flush(); + + boost::shared_ptr getTargetTransport() { + return dstTrans_; + } + + protected: + boost::shared_ptr srcTrans_; + boost::shared_ptr dstTrans_; + + uint8_t* rBuf_; + uint32_t rBufSize_; + uint32_t rPos_; + uint32_t rLen_; + + uint8_t* wBuf_; + uint32_t wBufSize_; + uint32_t wLen_; + + bool pipeOnRead_; + bool pipeOnWrite_; +}; + + +/** + * Wraps a transport into a pipedTransport instance. + * + */ +class TPipedTransportFactory : public TTransportFactory { + public: + TPipedTransportFactory() {} + TPipedTransportFactory(boost::shared_ptr dstTrans) { + initializeTargetTransport(dstTrans); + } + virtual ~TPipedTransportFactory() {} + + /** + * Wraps the base transport into a piped transport. + */ + virtual boost::shared_ptr getTransport(boost::shared_ptr srcTrans) { + return boost::shared_ptr(new TPipedTransport(srcTrans, dstTrans_)); + } + + virtual void initializeTargetTransport(boost::shared_ptr dstTrans) { + if (dstTrans_.get() == NULL) { + dstTrans_ = dstTrans; + } else { + throw TException("Target transport already initialized"); + } + } + + protected: + boost::shared_ptr dstTrans_; +}; + +/** + * TPipedFileTransport. This is just like a TTransport, except that + * it is a templatized class, so that clients who rely on a specific + * TTransport can still access the original transport. + * + */ +class TPipedFileReaderTransport : public TPipedTransport, + public TFileReaderTransport { + public: + TPipedFileReaderTransport(boost::shared_ptr srcTrans, boost::shared_ptr dstTrans); + + ~TPipedFileReaderTransport(); + + // TTransport functions + bool isOpen(); + bool peek(); + void open(); + void close(); + uint32_t read(uint8_t* buf, uint32_t len); + uint32_t readAll(uint8_t* buf, uint32_t len); + void readEnd(); + void write(const uint8_t* buf, uint32_t len); + void writeEnd(); + void flush(); + + // TFileReaderTransport functions + int32_t getReadTimeout(); + void setReadTimeout(int32_t readTimeout); + uint32_t getNumChunks(); + uint32_t getCurChunk(); + void seekToChunk(int32_t chunk); + void seekToEnd(); + + protected: + // shouldn't be used + TPipedFileReaderTransport(); + boost::shared_ptr srcTrans_; +}; + +/** + * Creates a TPipedFileReaderTransport from a filepath and a destination transport + * + */ +class TPipedFileReaderTransportFactory : public TPipedTransportFactory { + public: + TPipedFileReaderTransportFactory() {} + TPipedFileReaderTransportFactory(boost::shared_ptr dstTrans) + : TPipedTransportFactory(dstTrans) + {} + virtual ~TPipedFileReaderTransportFactory() {} + + boost::shared_ptr getTransport(boost::shared_ptr srcTrans) { + boost::shared_ptr pFileReaderTransport = boost::dynamic_pointer_cast(srcTrans); + if (pFileReaderTransport.get() != NULL) { + return getFileReaderTransport(pFileReaderTransport); + } else { + return boost::shared_ptr(); + } + } + + boost::shared_ptr getFileReaderTransport(boost::shared_ptr srcTrans) { + return boost::shared_ptr(new TPipedFileReaderTransport(srcTrans, dstTrans_)); + } +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ diff --git a/lib/cpp/src/transport/TZlibTransport.cpp b/lib/cpp/src/transport/TZlibTransport.cpp new file mode 100644 index 00000000..2f14e906 --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.cpp @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +using std::string; + +namespace apache { namespace thrift { namespace transport { + +// Don't call this outside of the constructor. +void TZlibTransport::initZlib() { + int rv; + bool r_init = false; + try { + rstream_ = new z_stream; + wstream_ = new z_stream; + + rstream_->zalloc = Z_NULL; + wstream_->zalloc = Z_NULL; + rstream_->zfree = Z_NULL; + wstream_->zfree = Z_NULL; + rstream_->opaque = Z_NULL; + wstream_->opaque = Z_NULL; + + rstream_->next_in = crbuf_; + wstream_->next_in = uwbuf_; + rstream_->next_out = urbuf_; + wstream_->next_out = cwbuf_; + rstream_->avail_in = 0; + wstream_->avail_in = 0; + rstream_->avail_out = urbuf_size_; + wstream_->avail_out = cwbuf_size_; + + rv = inflateInit(rstream_); + checkZlibRv(rv, rstream_->msg); + + // Have to set this flag so we know whether to de-initialize. + r_init = true; + + rv = deflateInit(wstream_, Z_DEFAULT_COMPRESSION); + checkZlibRv(rv, wstream_->msg); + } + + catch (...) { + if (r_init) { + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + } + // There is no way we can get here if wstream_ was initialized. + + throw; + } +} + +inline void TZlibTransport::checkZlibRv(int status, const char* message) { + if (status != Z_OK) { + throw TZlibTransportException(status, message); + } +} + +inline void TZlibTransport::checkZlibRvNothrow(int status, const char* message) { + if (status != Z_OK) { + string output = "TZlibTransport: zlib failure in destructor: " + + TZlibTransportException::errorMessage(status, message); + GlobalOutput(output.c_str()); + } +} + +TZlibTransport::~TZlibTransport() { + int rv; + rv = inflateEnd(rstream_); + checkZlibRvNothrow(rv, rstream_->msg); + rv = deflateEnd(wstream_); + checkZlibRvNothrow(rv, wstream_->msg); + + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + delete rstream_; + delete wstream_; +} + +bool TZlibTransport::isOpen() { + return (readAvail() > 0) || transport_->isOpen(); +} + +// READING STRATEGY +// +// We have two buffers for reading: one containing the compressed data (crbuf_) +// and one containing the uncompressed data (urbuf_). When read is called, +// we repeat the following steps until we have satisfied the request: +// - Copy data from urbuf_ into the caller's buffer. +// - If we had enough, return. +// - If urbuf_ is empty, read some data into it from the underlying transport. +// - Inflate data from crbuf_ into urbuf_. +// +// In standalone objects, we set input_ended_ to true when inflate returns +// Z_STREAM_END. This allows to make sure that a checksum was verified. + +inline int TZlibTransport::readAvail() { + return urbuf_size_ - rstream_->avail_out - urpos_; +} + +uint32_t TZlibTransport::read(uint8_t* buf, uint32_t len) { + int need = len; + + // TODO(dreiss): Skip urbuf on big reads. + + while (true) { + // Copy out whatever we have available, then give them the min of + // what we have and what they want, then advance indices. + int give = std::min(readAvail(), need); + memcpy(buf, urbuf_ + urpos_, give); + need -= give; + buf += give; + urpos_ += give; + + // If they were satisfied, we are done. + if (need == 0) { + return len; + } + + // If we get to this point, we need to get some more data. + + // If zlib has reported the end of a stream, we can't really do any more. + if (input_ended_) { + return len - need; + } + + // The uncompressed read buffer is empty, so reset the stream fields. + rstream_->next_out = urbuf_; + rstream_->avail_out = urbuf_size_; + urpos_ = 0; + + // If we don't have any more compressed data available, + // read some from the underlying transport. + if (rstream_->avail_in == 0) { + uint32_t got = transport_->read(crbuf_, crbuf_size_); + if (got == 0) { + return len - need; + } + rstream_->next_in = crbuf_; + rstream_->avail_in = got; + } + + // We have some compressed data now. Uncompress it. + int zlib_rv = inflate(rstream_, Z_SYNC_FLUSH); + + if (zlib_rv == Z_STREAM_END) { + if (standalone_) { + input_ended_ = true; + } + } else { + checkZlibRv(zlib_rv, rstream_->msg); + } + + // Okay. The read buffer should have whatever we can give it now. + // Loop back to the start and try to give some more. + } +} + + +// WRITING STRATEGY +// +// We buffer up small writes before sending them to zlib, so our logic is: +// - Is the write big? +// - Send the buffer to zlib. +// - Send this data to zlib. +// - Is the write small? +// - Is there insufficient space in the buffer for it? +// - Send the buffer to zlib. +// - Copy the data to the buffer. +// +// We have two buffers for writing also: the uncompressed buffer (mentioned +// above) and the compressed buffer. When sending data to zlib we loop over +// the following until the source (uncompressed buffer or big write) is empty: +// - Is there no more space in the compressed buffer? +// - Write the compressed buffer to the underlying transport. +// - Deflate from the source into the compressed buffer. + +void TZlibTransport::write(const uint8_t* buf, uint32_t len) { + // zlib's "deflate" function has enough logic in it that I think + // we're better off (performance-wise) buffering up small writes. + if ((int)len > MIN_DIRECT_DEFLATE_SIZE) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + flushToZlib(buf, len); + } else if (len > 0) { + if (uwbuf_size_ - uwpos_ < (int)len) { + flushToZlib(uwbuf_, uwpos_); + uwpos_ = 0; + } + memcpy(uwbuf_ + uwpos_, buf, len); + uwpos_ += len; + } +} + +void TZlibTransport::flush() { + flushToZlib(uwbuf_, uwpos_, true); + assert((int)wstream_->avail_out != cwbuf_size_); + transport_->write(cwbuf_, cwbuf_size_ - wstream_->avail_out); + transport_->flush(); +} + +void TZlibTransport::flushToZlib(const uint8_t* buf, int len, bool finish) { + int flush = (finish ? Z_FINISH : Z_NO_FLUSH); + + wstream_->next_in = const_cast(buf); + wstream_->avail_in = len; + + while (wstream_->avail_in > 0 || finish) { + // If our ouput buffer is full, flush to the underlying transport. + if (wstream_->avail_out == 0) { + transport_->write(cwbuf_, cwbuf_size_); + wstream_->next_out = cwbuf_; + wstream_->avail_out = cwbuf_size_; + } + + int zlib_rv = deflate(wstream_, flush); + + if (finish && zlib_rv == Z_STREAM_END) { + assert(wstream_->avail_in == 0); + break; + } + + checkZlibRv(zlib_rv, wstream_->msg); + } +} + +const uint8_t* TZlibTransport::borrow(uint8_t* buf, uint32_t* len) { + // Don't try to be clever with shifting buffers. + // If we have enough data, give a pointer to it, + // otherwise let the protcol use its slow path. + if (readAvail() >= (int)*len) { + *len = (uint32_t)readAvail(); + return urbuf_ + urpos_; + } + return NULL; +} + +void TZlibTransport::consume(uint32_t len) { + if (readAvail() >= (int)len) { + urpos_ += len; + } else { + throw TTransportException(TTransportException::BAD_ARGS, + "consume did not follow a borrow."); + } +} + +void TZlibTransport::verifyChecksum() { + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport can only verify checksums for standalone objects."); + } + + if (!input_ended_) { + // This should only be called when reading is complete, + // but it's possible that the whole checksum has not been fed to zlib yet. + // We try to read an extra byte here to force zlib to finish the stream. + // It might not always be easy to "unread" this byte, + // but we throw an exception if we get it, which is not really + // a recoverable error, so it doesn't matter. + uint8_t buf[1]; + uint32_t got = this->read(buf, sizeof(buf)); + if (got || !input_ended_) { + throw TTransportException( + TTransportException::CORRUPTED_DATA, + "Zlib stream not complete."); + } + } + + // If the checksum had been bad, we would have gotten an error while + // inflating. +} + + +}}} // apache::thrift::transport diff --git a/lib/cpp/src/transport/TZlibTransport.h b/lib/cpp/src/transport/TZlibTransport.h new file mode 100644 index 00000000..1439d9de --- /dev/null +++ b/lib/cpp/src/transport/TZlibTransport.h @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ +#define _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ 1 + +#include +#include + +struct z_stream_s; + +namespace apache { namespace thrift { namespace transport { + +class TZlibTransportException : public TTransportException { + public: + TZlibTransportException(int status, const char* msg) : + TTransportException(TTransportException::INTERNAL_ERROR, + errorMessage(status, msg)), + zlib_status_(status), + zlib_msg_(msg == NULL ? "(null)" : msg) {} + + virtual ~TZlibTransportException() throw() {} + + int getZlibStatus() { return zlib_status_; } + std::string getZlibMessage() { return zlib_msg_; } + + static std::string errorMessage(int status, const char* msg) { + std::string rv = "zlib error: "; + if (msg) { + rv += msg; + } else { + rv += "(no message)"; + } + rv += " (status = "; + rv += boost::lexical_cast(status); + rv += ")"; + return rv; + } + + int zlib_status_; + std::string zlib_msg_; +}; + +/** + * This transport uses zlib's compressed format on the "far" side. + * + * There are two kinds of TZlibTransport objects: + * - Standalone objects are used to encode self-contained chunks of data + * (like structures). They include checksums. + * - Non-standalone transports are used for RPC. They are not implemented yet. + * + * TODO(dreiss): Don't do an extra copy of the compressed data if + * the underlying transport is TBuffered or TMemory. + * + */ +class TZlibTransport : public TTransport { + public: + + /** + * @param transport The transport to read compressed data from + * and write compressed data to. + * @param use_for_rpc True if this object will be used for RPC, + * false if this is a standalone object. + * @param urbuf_size Uncompressed buffer size for reading. + * @param crbuf_size Compressed buffer size for reading. + * @param uwbuf_size Uncompressed buffer size for writing. + * @param cwbuf_size Compressed buffer size for writing. + * + * TODO(dreiss): Write a constructor that isn't a pain. + */ + TZlibTransport(boost::shared_ptr transport, + bool use_for_rpc, + int urbuf_size = DEFAULT_URBUF_SIZE, + int crbuf_size = DEFAULT_CRBUF_SIZE, + int uwbuf_size = DEFAULT_UWBUF_SIZE, + int cwbuf_size = DEFAULT_CWBUF_SIZE) : + transport_(transport), + standalone_(!use_for_rpc), + urpos_(0), + uwpos_(0), + input_ended_(false), + output_flushed_(false), + urbuf_size_(urbuf_size), + crbuf_size_(crbuf_size), + uwbuf_size_(uwbuf_size), + cwbuf_size_(cwbuf_size), + urbuf_(NULL), + crbuf_(NULL), + uwbuf_(NULL), + cwbuf_(NULL), + rstream_(NULL), + wstream_(NULL) + { + + if (!standalone_) { + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport has not been tested for RPC."); + } + + if (uwbuf_size_ < MIN_DIRECT_DEFLATE_SIZE) { + // Have to copy this into a local because of a linking issue. + int minimum = MIN_DIRECT_DEFLATE_SIZE; + throw TTransportException( + TTransportException::BAD_ARGS, + "TZLibTransport: uncompressed write buffer must be at least" + + boost::lexical_cast(minimum) + "."); + } + + try { + urbuf_ = new uint8_t[urbuf_size]; + crbuf_ = new uint8_t[crbuf_size]; + uwbuf_ = new uint8_t[uwbuf_size]; + cwbuf_ = new uint8_t[cwbuf_size]; + + // Don't call this outside of the constructor. + initZlib(); + + } catch (...) { + delete[] urbuf_; + delete[] crbuf_; + delete[] uwbuf_; + delete[] cwbuf_; + throw; + } + } + + // Don't call this outside of the constructor. + void initZlib(); + + ~TZlibTransport(); + + bool isOpen(); + + void open() { + transport_->open(); + } + + void close() { + transport_->close(); + } + + uint32_t read(uint8_t* buf, uint32_t len); + + void write(const uint8_t* buf, uint32_t len); + + void flush(); + + const uint8_t* borrow(uint8_t* buf, uint32_t* len); + + void consume(uint32_t len); + + void verifyChecksum(); + + /** + * TODO(someone_smart): Choose smart defaults. + */ + static const int DEFAULT_URBUF_SIZE = 128; + static const int DEFAULT_CRBUF_SIZE = 1024; + static const int DEFAULT_UWBUF_SIZE = 128; + static const int DEFAULT_CWBUF_SIZE = 1024; + + protected: + + inline void checkZlibRv(int status, const char* msg); + inline void checkZlibRvNothrow(int status, const char* msg); + inline int readAvail(); + void flushToZlib(const uint8_t* buf, int len, bool finish = false); + + // Writes smaller than this are buffered up. + // Larger (or equal) writes are dumped straight to zlib. + static const int MIN_DIRECT_DEFLATE_SIZE = 32; + + boost::shared_ptr transport_; + bool standalone_; + + int urpos_; + int uwpos_; + + /// True iff zlib has reached the end of a stream. + /// This is only ever true in standalone protcol objects. + bool input_ended_; + /// True iff we have flushed the output stream. + /// This is only ever true in standalone protcol objects. + bool output_flushed_; + + int urbuf_size_; + int crbuf_size_; + int uwbuf_size_; + int cwbuf_size_; + + uint8_t* urbuf_; + uint8_t* crbuf_; + uint8_t* uwbuf_; + uint8_t* cwbuf_; + + struct z_stream_s* rstream_; + struct z_stream_s* wstream_; +}; + +}}} // apache::thrift::transport + +#endif // #ifndef _THRIFT_TRANSPORT_TZLIBTRANSPORT_H_ diff --git a/lib/cpp/thrift-nb.pc.in b/lib/cpp/thrift-nb.pc.in new file mode 100644 index 00000000..ae051887 --- /dev/null +++ b/lib/cpp/thrift-nb.pc.in @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift Nonblocking API +Version: @VERSION@ +Requires: thrift = @VERSION@ +Libs: -L${libdir} -lthriftnb +Cflags: -I${includedir}/thrift diff --git a/lib/cpp/thrift-z.pc.in b/lib/cpp/thrift-z.pc.in new file mode 100644 index 00000000..72f46bf9 --- /dev/null +++ b/lib/cpp/thrift-z.pc.in @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift Zlib API +Version: @VERSION@ +Requires: thrift = @VERSION@ +Libs: -L${libdir} -lthriftz +Cflags: -I${includedir}/thrift diff --git a/lib/cpp/thrift.pc.in b/lib/cpp/thrift.pc.in new file mode 100644 index 00000000..7aec09f1 --- /dev/null +++ b/lib/cpp/thrift.pc.in @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +prefix=@prefix@ +exec_prefix=@exec_prefix@ +libdir=@libdir@ +includedir=@includedir@ + +Name: Thrift +Description: Thrift C++ API +Version: @VERSION@ +Libs: -L${libdir} -lthrift +Cflags: -I${includedir}/thrift diff --git a/lib/csharp/Makefile.am b/lib/csharp/Makefile.am new file mode 100644 index 00000000..4047011c --- /dev/null +++ b/lib/csharp/Makefile.am @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFTCODE= \ + src/Collections/THashSet.cs \ + src/Protocol/TBase.cs \ + src/Protocol/TProtocolException.cs \ + src/Protocol/TProtocolFactory.cs \ + src/Protocol/TList.cs \ + src/Protocol/TSet.cs \ + src/Protocol/TMap.cs \ + src/Protocol/TProtocolUtil.cs \ + src/Protocol/TMessageType.cs \ + src/Protocol/TProtocol.cs \ + src/Protocol/TType.cs \ + src/Protocol/TField.cs \ + src/Protocol/TMessage.cs \ + src/Protocol/TStruct.cs \ + src/Protocol/TBinaryProtocol.cs \ + src/Server/TThreadedServer.cs \ + src/Server/TThreadPoolServer.cs \ + src/Server/TSimpleServer.cs \ + src/Server/TServer.cs \ + src/Transport/TBufferedTransport.cs \ + src/Transport/TTransport.cs \ + src/Transport/TSocket.cs \ + src/Transport/TTransportException.cs \ + src/Transport/TStreamTransport.cs \ + src/Transport/TServerTransport.cs \ + src/Transport/TServerSocket.cs \ + src/Transport/TTransportFactory.cs \ + src/TProcessor.cs \ + src/TApplicationException.cs + + +CSC=gmcs + +if NET_2_0 +MONO_DEFINES=/d:NET_2_0 +endif + +all-local: Thrift.dll + +Thrift.dll: $(THRIFTCODE) + $(CSC) $(THRIFTCODE) /out:Thrift.dll /target:library $(MONO_DEFINES) + +clean-local: + $(RM) Thrift.dll + +EXTRA_DIST = \ + $(THRIFTCODE) \ + ThriftMSBuildTask \ + src/Thrift.csproj \ + src/Thrift.sln diff --git a/lib/csharp/README b/lib/csharp/README new file mode 100644 index 00000000..b7dc5de3 --- /dev/null +++ b/lib/csharp/README @@ -0,0 +1,26 @@ +Thrift C# Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with C# +==================== + +Thrift requires Mono >= 1.2.6 or .NET framework >= 3.5 diff --git a/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs b/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..d79c2039 --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/Properties/AssemblyInfo.cs @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("ThriftMSBuildTask")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("ThriftMSBuildTask")] +[assembly: AssemblyCopyright("Copyright © 2009 The Apache Software Foundation")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("5095e09d-7b95-4be1-b250-e1c1db1c485e")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs b/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs new file mode 100644 index 00000000..4389e0a6 --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/ThriftBuild.cs @@ -0,0 +1,242 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.Build.Framework; +using Microsoft.Build.Utilities; +using Microsoft.Build.Tasks; +using System.IO; +using System.Diagnostics; + +namespace ThriftMSBuildTask +{ + /// + /// MSBuild Task to generate csharp from .thrift files, and compile the code into a library: ThriftImpl.dll + /// + public class ThriftBuild : Task + { + /// + /// The full path to the thrift.exe compiler + /// + [Required] + public ITaskItem ThriftExecutable + { + get; + set; + } + + /// + /// The full path to a thrift.dll C# library + /// + [Required] + public ITaskItem ThriftLibrary + { + get; + set; + } + + /// + /// A direcotry containing .thrift files + /// + [Required] + public ITaskItem ThriftDefinitionDir + { + get; + set; + } + + /// + /// The name of the auto-gen and compiled thrift library. It will placed in + /// the same directory as ThriftLibrary + /// + [Required] + public ITaskItem OutputName + { + get; + set; + } + + /// + /// The full path to the compiled ThriftLibrary. This allows msbuild tasks to use this + /// output as a variable for use elsewhere. + /// + [Output] + public ITaskItem ThriftImplementation + { + get { return thriftImpl; } + } + + private ITaskItem thriftImpl; + private const string lastCompilationName = "LAST_COMP_TIMESTAMP"; + + //use the Message Build Task to write something to build log + private void LogMessage(string text, MessageImportance importance) + { + Message m = new Message(); + m.Text = text; + m.Importance = importance.ToString(); + m.BuildEngine = this.BuildEngine; + m.Execute(); + } + + //recursively find .cs files in srcDir, paths should initially be non-null and empty + private void FindSourcesHelper(string srcDir, List paths) + { + string[] files = Directory.GetFiles(srcDir, "*.cs"); + foreach (string f in files) + { + paths.Add(f); + } + string[] dirs = Directory.GetDirectories(srcDir); + foreach (string dir in dirs) + { + FindSourcesHelper(dir, paths); + } + } + + /// + /// Quote paths with spaces + /// + private string SafePath(string path) + { + if (path.Contains(' ') && !path.StartsWith("\"")) + { + return "\"" + path + "\""; + } + return path; + } + + private ITaskItem[] FindSources(string srcDir) + { + List files = new List(); + FindSourcesHelper(srcDir, files); + ITaskItem[] items = new ITaskItem[files.Count]; + for (int i = 0; i < items.Length; i++) + { + items[i] = new TaskItem(files[i]); + } + return items; + } + + private string LastWriteTime(string defDir) + { + string[] files = Directory.GetFiles(defDir, "*.thrift"); + DateTime d = (new DirectoryInfo(defDir)).LastWriteTime; + foreach(string file in files) + { + FileInfo f = new FileInfo(file); + DateTime curr = f.LastWriteTime; + if (DateTime.Compare(curr, d) > 0) + { + d = curr; + } + } + return d.ToFileTimeUtc().ToString(); + } + + public override bool Execute() + { + string defDir = SafePath(ThriftDefinitionDir.ItemSpec); + //look for last compilation timestamp + string lastBuildPath = Path.Combine(defDir, lastCompilationName); + DirectoryInfo defDirInfo = new DirectoryInfo(defDir); + string lastWrite = LastWriteTime(defDir); + if (File.Exists(lastBuildPath)) + { + string lastComp = File.ReadAllText(lastBuildPath); + //don't recompile if the thrift library has been updated since lastComp + FileInfo f = new FileInfo(ThriftLibrary.ItemSpec); + string thriftLibTime = f.LastWriteTimeUtc.ToFileTimeUtc().ToString(); + if (lastComp.CompareTo(thriftLibTime) < 0) + { + //new thrift library, do a compile + lastWrite = thriftLibTime; + } + else if (lastComp == lastWrite || (lastComp == thriftLibTime && lastComp.CompareTo(lastWrite) > 0)) + { + //the .thrift dir hasn't been written to since last compilation, don't need to do anything + LogMessage("ThriftImpl up-to-date", MessageImportance.High); + return true; + } + } + + //find the directory of the thriftlibrary (that's where output will go) + FileInfo thriftLibInfo = new FileInfo(SafePath(ThriftLibrary.ItemSpec)); + string thriftDir = thriftLibInfo.Directory.FullName; + + string genDir = Path.Combine(thriftDir, "gen-csharp"); + if (Directory.Exists(genDir)) + { + try + { + Directory.Delete(genDir, true); + } + catch { /*eh i tried, just over-write now*/} + } + + //run the thrift executable to generate C# + foreach (string thriftFile in Directory.GetFiles(defDir, "*.thrift")) + { + LogMessage("Generating code for: " + thriftFile, MessageImportance.Normal); + Process p = new Process(); + p.StartInfo.FileName = SafePath(ThriftExecutable.ItemSpec); + p.StartInfo.Arguments = "--gen csharp -o " + SafePath(thriftDir) + " -r " + thriftFile; + p.StartInfo.UseShellExecute = false; + p.StartInfo.CreateNoWindow = true; + p.StartInfo.RedirectStandardOutput = false; + p.Start(); + p.WaitForExit(); + if (p.ExitCode != 0) + { + LogMessage("thrift.exe failed to compile " + thriftFile, MessageImportance.High); + return false; + } + if (p.ExitCode != 0) + { + LogMessage("thrift.exe failed to compile " + thriftFile, MessageImportance.High); + return false; + } + } + + Csc csc = new Csc(); + csc.TargetType = "library"; + csc.References = new ITaskItem[] { new TaskItem(ThriftLibrary.ItemSpec) }; + csc.EmitDebugInformation = true; + string outputPath = Path.Combine(thriftDir, OutputName.ItemSpec); + csc.OutputAssembly = new TaskItem(outputPath); + csc.Sources = FindSources(Path.Combine(thriftDir, "gen-csharp")); + csc.BuildEngine = this.BuildEngine; + LogMessage("Compiling generated cs...", MessageImportance.Normal); + if (!csc.Execute()) + { + return false; + } + + //write file to defDir to indicate a build was successfully completed + File.WriteAllText(lastBuildPath, lastWrite); + + thriftImpl = new TaskItem(outputPath); + + return true; + } + } +} diff --git a/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj b/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj new file mode 100644 index 00000000..02110eae --- /dev/null +++ b/lib/csharp/ThriftMSBuildTask/ThriftMSBuildTask.csproj @@ -0,0 +1,66 @@ + + + + Debug + AnyCPU + 9.0.21022 + 2.0 + {EC0A0231-66EA-4593-A792-C6CA3BB8668E} + Library + Properties + ThriftMSBuildTask + ThriftMSBuildTask + v3.5 + 512 + SAK + SAK + SAK + SAK + + + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + + + + + 3.5 + + + 3.5 + + + 3.5 + + + + + + + + + + + diff --git a/lib/csharp/src/Collections/THashSet.cs b/lib/csharp/src/Collections/THashSet.cs new file mode 100644 index 00000000..a9957693 --- /dev/null +++ b/lib/csharp/src/Collections/THashSet.cs @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Thrift.Collections +{ + public class THashSet : ICollection + { +#if NET_2_0 + TDictSet set = new TDictSet(); +#else + HashSet set = new HashSet(); +#endif + public int Count + { + get { return set.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public void Add(T item) + { + set.Add(item); + } + + public void Clear() + { + set.Clear(); + } + + public bool Contains(T item) + { + return set.Contains(item); + } + + public void CopyTo(T[] array, int arrayIndex) + { + set.CopyTo(array, arrayIndex); + } + + public IEnumerator GetEnumerator() + { + return set.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)set).GetEnumerator(); + } + + public bool Remove(T item) + { + return set.Remove(item); + } + +#if NET_2_0 + private class TDictSet : ICollection + { + Dictionary> dict = new Dictionary>(); + + public int Count + { + get { return dict.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public IEnumerator GetEnumerator() + { + return ((IEnumerable)dict.Keys).GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return dict.Keys.GetEnumerator(); + } + + public bool Add(V item) + { + if (!dict.ContainsKey(item)) + { + dict[item] = this; + return true; + } + + return false; + } + + void ICollection.Add(V item) + { + Add(item); + } + + public void Clear() + { + dict.Clear(); + } + + public bool Contains(V item) + { + return dict.ContainsKey(item); + } + + public void CopyTo(V[] array, int arrayIndex) + { + dict.Keys.CopyTo(array, arrayIndex); + } + + public bool Remove(V item) + { + return dict.Remove(item); + } + } +#endif + } + +} diff --git a/lib/csharp/src/Protocol/TBase.cs b/lib/csharp/src/Protocol/TBase.cs new file mode 100644 index 00000000..1969bb3d --- /dev/null +++ b/lib/csharp/src/Protocol/TBase.cs @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace Thrift.Protocol +{ + public interface TBase + { + /// + /// Reads the TObject from the given input protocol. + /// + void Read(TProtocol tProtocol); + + /// + /// Writes the objects out to the protocol + /// + void Write(TProtocol tProtocol); + } +} diff --git a/lib/csharp/src/Protocol/TBinaryProtocol.cs b/lib/csharp/src/Protocol/TBinaryProtocol.cs new file mode 100644 index 00000000..14ca43b7 --- /dev/null +++ b/lib/csharp/src/Protocol/TBinaryProtocol.cs @@ -0,0 +1,392 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Text; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public class TBinaryProtocol : TProtocol + { + protected const uint VERSION_MASK = 0xffff0000; + protected const uint VERSION_1 = 0x80010000; + + protected bool strictRead_ = false; + protected bool strictWrite_ = true; + + protected int readLength_; + protected bool checkReadLength_ = false; + + + #region BinaryProtocol Factory + /** + * Factory + */ + public class Factory : TProtocolFactory { + + protected bool strictRead_ = false; + protected bool strictWrite_ = true; + + public Factory() + :this(false, true) + { + } + + public Factory(bool strictRead, bool strictWrite) + { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public TProtocol GetProtocol(TTransport trans) { + return new TBinaryProtocol(trans, strictRead_, strictWrite_); + } + } + + #endregion + + public TBinaryProtocol(TTransport trans) + : this(trans, false, true) + { + } + + public TBinaryProtocol(TTransport trans, bool strictRead, bool strictWrite) + :base(trans) + { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + #region Write Methods + + public override void WriteMessageBegin(TMessage message) + { + if (strictWrite_) + { + uint version = VERSION_1 | (uint)(message.Type); + WriteI32((int)version); + WriteString(message.Name); + WriteI32(message.SeqID); + } + else + { + WriteString(message.Name); + WriteByte((byte)message.Type); + WriteI32(message.SeqID); + } + } + + public override void WriteMessageEnd() + { + } + + public override void WriteStructBegin(TStruct struc) + { + } + + public override void WriteStructEnd() + { + } + + public override void WriteFieldBegin(TField field) + { + WriteByte((byte)field.Type); + WriteI16(field.ID); + } + + public override void WriteFieldEnd() + { + } + + public override void WriteFieldStop() + { + WriteByte((byte)TType.Stop); + } + + public override void WriteMapBegin(TMap map) + { + WriteByte((byte)map.KeyType); + WriteByte((byte)map.ValueType); + WriteI32(map.Count); + } + + public override void WriteMapEnd() + { + } + + public override void WriteListBegin(TList list) + { + WriteByte((byte)list.ElementType); + WriteI32(list.Count); + } + + public override void WriteListEnd() + { + } + + public override void WriteSetBegin(TSet set) + { + WriteByte((byte)set.ElementType); + WriteI32(set.Count); + } + + public override void WriteSetEnd() + { + } + + public override void WriteBool(bool b) + { + WriteByte(b ? (byte)1 : (byte)0); + } + + private byte[] bout = new byte[1]; + public override void WriteByte(byte b) + { + bout[0] = b; + trans.Write(bout, 0, 1); + } + + private byte[] i16out = new byte[2]; + public override void WriteI16(short s) + { + i16out[0] = (byte)(0xff & (s >> 8)); + i16out[1] = (byte)(0xff & s); + trans.Write(i16out, 0, 2); + } + + private byte[] i32out = new byte[4]; + public override void WriteI32(int i32) + { + i32out[0] = (byte)(0xff & (i32 >> 24)); + i32out[1] = (byte)(0xff & (i32 >> 16)); + i32out[2] = (byte)(0xff & (i32 >> 8)); + i32out[3] = (byte)(0xff & i32); + trans.Write(i32out, 0, 4); + } + + private byte[] i64out = new byte[8]; + public override void WriteI64(long i64) + { + i64out[0] = (byte)(0xff & (i64 >> 56)); + i64out[1] = (byte)(0xff & (i64 >> 48)); + i64out[2] = (byte)(0xff & (i64 >> 40)); + i64out[3] = (byte)(0xff & (i64 >> 32)); + i64out[4] = (byte)(0xff & (i64 >> 24)); + i64out[5] = (byte)(0xff & (i64 >> 16)); + i64out[6] = (byte)(0xff & (i64 >> 8)); + i64out[7] = (byte)(0xff & i64); + trans.Write(i64out, 0, 8); + } + + public override void WriteDouble(double d) + { + WriteI64(BitConverter.DoubleToInt64Bits(d)); + } + + public override void WriteBinary(byte[] b) + { + WriteI32(b.Length); + trans.Write(b, 0, b.Length); + } + + #endregion + + #region ReadMethods + + public override TMessage ReadMessageBegin() + { + TMessage message = new TMessage(); + int size = ReadI32(); + if (size < 0) + { + uint version = (uint)size & VERSION_MASK; + if (version != VERSION_1) + { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in ReadMessageBegin: " + version); + } + message.Type = (TMessageType)(size & 0x000000ff); + message.Name = ReadString(); + message.SeqID = ReadI32(); + } + else + { + if (strictRead_) + { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?"); + } + message.Name = ReadStringBody(size); + message.Type = (TMessageType)ReadByte(); + message.SeqID = ReadI32(); + } + return message; + } + + public override void ReadMessageEnd() + { + } + + public override TStruct ReadStructBegin() + { + return new TStruct(); + } + + public override void ReadStructEnd() + { + } + + public override TField ReadFieldBegin() + { + TField field = new TField(); + field.Type = (TType)ReadByte(); + + if (field.Type != TType.Stop) + { + field.ID = ReadI16(); + } + + return field; + } + + public override void ReadFieldEnd() + { + } + + public override TMap ReadMapBegin() + { + TMap map = new TMap(); + map.KeyType = (TType)ReadByte(); + map.ValueType = (TType)ReadByte(); + map.Count = ReadI32(); + + return map; + } + + public override void ReadMapEnd() + { + } + + public override TList ReadListBegin() + { + TList list = new TList(); + list.ElementType = (TType)ReadByte(); + list.Count = ReadI32(); + + return list; + } + + public override void ReadListEnd() + { + } + + public override TSet ReadSetBegin() + { + TSet set = new TSet(); + set.ElementType = (TType)ReadByte(); + set.Count = ReadI32(); + + return set; + } + + public override void ReadSetEnd() + { + } + + public override bool ReadBool() + { + return ReadByte() == 1; + } + + private byte[] bin = new byte[1]; + public override byte ReadByte() + { + ReadAll(bin, 0, 1); + return bin[0]; + } + + private byte[] i16in = new byte[2]; + public override short ReadI16() + { + ReadAll(i16in, 0, 2); + return (short)(((i16in[0] & 0xff) << 8) | ((i16in[1] & 0xff))); + } + + private byte[] i32in = new byte[4]; + public override int ReadI32() + { + ReadAll(i32in, 0, 4); + return (int)(((i32in[0] & 0xff) << 24) | ((i32in[1] & 0xff) << 16) | ((i32in[2] & 0xff) << 8) | ((i32in[3] & 0xff))); + } + + private byte[] i64in = new byte[8]; + public override long ReadI64() + { + ReadAll(i64in, 0, 8); + return (long)(((long)(i64in[0] & 0xff) << 56) | ((long)(i64in[1] & 0xff) << 48) | ((long)(i64in[2] & 0xff) << 40) | ((long)(i64in[3] & 0xff) << 32) | + ((long)(i64in[4] & 0xff) << 24) | ((long)(i64in[5] & 0xff) << 16) | ((long)(i64in[6] & 0xff) << 8) | ((long)(i64in[7] & 0xff))); + } + + public override double ReadDouble() + { + return BitConverter.Int64BitsToDouble(ReadI64()); + } + + public void SetReadLength(int readLength) + { + readLength_ = readLength; + checkReadLength_ = true; + } + + protected void CheckReadLength(int length) + { + if (checkReadLength_) + { + readLength_ -= length; + if (readLength_ < 0) + { + throw new Exception("Message length exceeded: " + length); + } + } + } + + public override byte[] ReadBinary() + { + int size = ReadI32(); + CheckReadLength(size); + byte[] buf = new byte[size]; + trans.ReadAll(buf, 0, size); + return buf; + } + private string ReadStringBody(int size) + { + CheckReadLength(size); + byte[] buf = new byte[size]; + trans.ReadAll(buf, 0, size); + return Encoding.UTF8.GetString(buf); + } + + private int ReadAll(byte[] buf, int off, int len) + { + CheckReadLength(len); + return trans.ReadAll(buf, off, len); + } + + #endregion + } +} diff --git a/lib/csharp/src/Protocol/TField.cs b/lib/csharp/src/Protocol/TField.cs new file mode 100644 index 00000000..485c994b --- /dev/null +++ b/lib/csharp/src/Protocol/TField.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TField + { + private string name; + private TType type; + private short id; + + public TField(string name, TType type, short id) + :this() + { + this.name = name; + this.type = type; + this.id = id; + } + + public string Name + { + get { return name; } + set { name = value; } + } + + public TType Type + { + get { return type; } + set { type = value; } + } + + public short ID + { + get { return id; } + set { id = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TList.cs b/lib/csharp/src/Protocol/TList.cs new file mode 100644 index 00000000..dbc5c40e --- /dev/null +++ b/lib/csharp/src/Protocol/TList.cs @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TList + { + private TType elementType; + private int count; + + public TList(TType elementType, int count) + :this() + { + this.elementType = elementType; + this.count = count; + } + + public TType ElementType + { + get { return elementType; } + set { elementType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMap.cs b/lib/csharp/src/Protocol/TMap.cs new file mode 100644 index 00000000..8b53f899 --- /dev/null +++ b/lib/csharp/src/Protocol/TMap.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TMap + { + private TType keyType; + private TType valueType; + private int count; + + public TMap(TType keyType, TType valueType, int count) + :this() + { + this.keyType = keyType; + this.valueType = valueType; + this.count = count; + } + + public TType KeyType + { + get { return keyType; } + set { keyType = value; } + } + + public TType ValueType + { + get { return valueType; } + set { valueType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMessage.cs b/lib/csharp/src/Protocol/TMessage.cs new file mode 100644 index 00000000..8cb6e0b1 --- /dev/null +++ b/lib/csharp/src/Protocol/TMessage.cs @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TMessage + { + private string name; + private TMessageType type; + private int seqID; + + public TMessage(string name, TMessageType type, int seqid) + :this() + { + this.name = name; + this.type = type; + this.seqID = seqid; + } + + public string Name + { + get { return name; } + set { name = value; } + } + + public TMessageType Type + { + get { return type; } + set { type = value; } + } + + public int SeqID + { + get { return seqID; } + set { seqID = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TMessageType.cs b/lib/csharp/src/Protocol/TMessageType.cs new file mode 100644 index 00000000..ab07cf6c --- /dev/null +++ b/lib/csharp/src/Protocol/TMessageType.cs @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public enum TMessageType + { + Call = 1, + Reply = 2, + Exception = 3, + Oneway = 4 + } +} diff --git a/lib/csharp/src/Protocol/TProtocol.cs b/lib/csharp/src/Protocol/TProtocol.cs new file mode 100644 index 00000000..acf9c1b3 --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocol.cs @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Text; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public abstract class TProtocol + { + protected TTransport trans; + + protected TProtocol(TTransport trans) + { + this.trans = trans; + } + + public TTransport Transport + { + get { return trans; } + } + + public abstract void WriteMessageBegin(TMessage message); + public abstract void WriteMessageEnd(); + public abstract void WriteStructBegin(TStruct struc); + public abstract void WriteStructEnd(); + public abstract void WriteFieldBegin(TField field); + public abstract void WriteFieldEnd(); + public abstract void WriteFieldStop(); + public abstract void WriteMapBegin(TMap map); + public abstract void WriteMapEnd(); + public abstract void WriteListBegin(TList list); + public abstract void WriteListEnd(); + public abstract void WriteSetBegin(TSet set); + public abstract void WriteSetEnd(); + public abstract void WriteBool(bool b); + public abstract void WriteByte(byte b); + public abstract void WriteI16(short i16); + public abstract void WriteI32(int i32); + public abstract void WriteI64(long i64); + public abstract void WriteDouble(double d); + public void WriteString(string s) { + WriteBinary(Encoding.UTF8.GetBytes(s)); + } + public abstract void WriteBinary(byte[] b); + + public abstract TMessage ReadMessageBegin(); + public abstract void ReadMessageEnd(); + public abstract TStruct ReadStructBegin(); + public abstract void ReadStructEnd(); + public abstract TField ReadFieldBegin(); + public abstract void ReadFieldEnd(); + public abstract TMap ReadMapBegin(); + public abstract void ReadMapEnd(); + public abstract TList ReadListBegin(); + public abstract void ReadListEnd(); + public abstract TSet ReadSetBegin(); + public abstract void ReadSetEnd(); + public abstract bool ReadBool(); + public abstract byte ReadByte(); + public abstract short ReadI16(); + public abstract int ReadI32(); + public abstract long ReadI64(); + public abstract double ReadDouble(); + public string ReadString() { + return Encoding.UTF8.GetString(ReadBinary()); + } + public abstract byte[] ReadBinary(); + } +} diff --git a/lib/csharp/src/Protocol/TProtocolException.cs b/lib/csharp/src/Protocol/TProtocolException.cs new file mode 100644 index 00000000..9c250476 --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolException.cs @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + class TProtocolException : Exception + { + public const int UNKNOWN = 0; + public const int INVALID_DATA = 1; + public const int NEGATIVE_SIZE = 2; + public const int SIZE_LIMIT = 3; + public const int BAD_VERSION = 4; + + protected int type_ = UNKNOWN; + + public TProtocolException() + : base() + { + } + + public TProtocolException(int type) + : base() + { + type_ = type; + } + + public TProtocolException(int type, String message) + : base(message) + { + type_ = type; + } + + public TProtocolException(String message) + : base(message) + { + } + + public int getType() + { + return type_; + } + } +} diff --git a/lib/csharp/src/Protocol/TProtocolFactory.cs b/lib/csharp/src/Protocol/TProtocolFactory.cs new file mode 100644 index 00000000..ae976acd --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolFactory.cs @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Transport; + +namespace Thrift.Protocol +{ + public interface TProtocolFactory + { + TProtocol GetProtocol(TTransport trans); + } +} diff --git a/lib/csharp/src/Protocol/TProtocolUtil.cs b/lib/csharp/src/Protocol/TProtocolUtil.cs new file mode 100644 index 00000000..57cef0ef --- /dev/null +++ b/lib/csharp/src/Protocol/TProtocolUtil.cs @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public static class TProtocolUtil + { + public static void Skip(TProtocol prot, TType type) + { + switch (type) + { + case TType.Bool: + prot.ReadBool(); + break; + case TType.Byte: + prot.ReadByte(); + break; + case TType.I16: + prot.ReadI16(); + break; + case TType.I32: + prot.ReadI32(); + break; + case TType.I64: + prot.ReadI64(); + break; + case TType.Double: + prot.ReadDouble(); + break; + case TType.String: + // Don't try to decode the string, just skip it. + prot.ReadBinary(); + break; + case TType.Struct: + prot.ReadStructBegin(); + while (true) + { + TField field = prot.ReadFieldBegin(); + if (field.Type == TType.Stop) + { + break; + } + Skip(prot, field.Type); + prot.ReadFieldEnd(); + } + prot.ReadStructEnd(); + break; + case TType.Map: + TMap map = prot.ReadMapBegin(); + for (int i = 0; i < map.Count; i++) + { + Skip(prot, map.KeyType); + Skip(prot, map.ValueType); + } + prot.ReadMapEnd(); + break; + case TType.Set: + TSet set = prot.ReadSetBegin(); + for (int i = 0; i < set.Count; i++) + { + Skip(prot, set.ElementType); + } + prot.ReadSetEnd(); + break; + case TType.List: + TList list = prot.ReadListBegin(); + for (int i = 0; i < list.Count; i++) + { + Skip(prot, list.ElementType); + } + prot.ReadListEnd(); + break; + } + } + } +} diff --git a/lib/csharp/src/Protocol/TSet.cs b/lib/csharp/src/Protocol/TSet.cs new file mode 100644 index 00000000..ac73992d --- /dev/null +++ b/lib/csharp/src/Protocol/TSet.cs @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TSet + { + private TType elementType; + private int count; + + public TSet(TType elementType, int count) + :this() + { + this.elementType = elementType; + this.count = count; + } + + public TType ElementType + { + get { return elementType; } + set { elementType = value; } + } + + public int Count + { + get { return count; } + set { count = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TStruct.cs b/lib/csharp/src/Protocol/TStruct.cs new file mode 100644 index 00000000..0cac2733 --- /dev/null +++ b/lib/csharp/src/Protocol/TStruct.cs @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Thrift.Protocol +{ + public struct TStruct + { + private string name; + + public TStruct(string name) + :this() + { + this.name = name; + } + + public string Name + { + get { return name; } + set { name = value; } + } + } +} diff --git a/lib/csharp/src/Protocol/TType.cs b/lib/csharp/src/Protocol/TType.cs new file mode 100644 index 00000000..c2d78edc --- /dev/null +++ b/lib/csharp/src/Protocol/TType.cs @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Protocol +{ + public enum TType : byte + { + Stop = 0, + Void = 1, + Bool = 2, + Byte = 3, + Double = 4, + I16 = 6, + I32 = 8, + I64 = 10, + String = 11, + Struct = 12, + Map = 13, + Set = 14, + List = 15 + } +} diff --git a/lib/csharp/src/Server/TServer.cs b/lib/csharp/src/Server/TServer.cs new file mode 100644 index 00000000..61a9416f --- /dev/null +++ b/lib/csharp/src/Server/TServer.cs @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; +using Thrift.Transport; +using System.IO; + +namespace Thrift.Server +{ + public abstract class TServer + { + /** + * Core processor + */ + protected TProcessor processor; + + /** + * Server transport + */ + protected TServerTransport serverTransport; + + /** + * Input Transport Factory + */ + protected TTransportFactory inputTransportFactory; + + /** + * Output Transport Factory + */ + protected TTransportFactory outputTransportFactory; + + /** + * Input Protocol Factory + */ + protected TProtocolFactory inputProtocolFactory; + + /** + * Output Protocol Factory + */ + protected TProtocolFactory outputProtocolFactory; + public delegate void LogDelegate(string str); + protected LogDelegate logDelegate; + + /** + * Default constructors. + */ + + public TServer(TProcessor processor, + TServerTransport serverTransport) + :this(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + LogDelegate logDelegate) + : this(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory) + :this(processor, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory(), + DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :this(processor, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + DefaultLogDelegate) + { + } + + public TServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + LogDelegate logDelegate) + { + this.processor = processor; + this.serverTransport = serverTransport; + this.inputTransportFactory = inputTransportFactory; + this.outputTransportFactory = outputTransportFactory; + this.inputProtocolFactory = inputProtocolFactory; + this.outputProtocolFactory = outputProtocolFactory; + this.logDelegate = logDelegate; + } + + /** + * The run method fires up the server and gets things going. + */ + public abstract void Serve(); + + public abstract void Stop(); + + protected static void DefaultLogDelegate(string s) + { + Console.Error.WriteLine(s); + } + } +} + diff --git a/lib/csharp/src/Server/TSimpleServer.cs b/lib/csharp/src/Server/TSimpleServer.cs new file mode 100644 index 00000000..34a51de4 --- /dev/null +++ b/lib/csharp/src/Server/TSimpleServer.cs @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Transport; +using Thrift.Protocol; + +namespace Thrift.Server +{ + /// + /// Simple single-threaded server for testing + /// + public class TSimpleServer : TServer + { + private bool stop = false; + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport) + :base(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), DefaultLogDelegate) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + LogDelegate logDel) + : base(processor, serverTransport, new TTransportFactory(), new TTransportFactory(), new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), logDel) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory) + :base(processor, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory(), + DefaultLogDelegate) + { + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :base(processor, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory, + DefaultLogDelegate) + { + } + + public override void Serve() + { + try + { + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate(ttx.ToString()); + return; + } + + while (!stop) + { + TTransport client = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + client = serverTransport.Accept(); + if (client != null) + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) { } + } + } + catch (TTransportException ttx) + { + // Client died, just move on + if (stop) + { + logDelegate("TSimpleServer was shutting down, caught " + ttx.GetType().Name); + } + } + catch (Exception x) + { + logDelegate(x.ToString()); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + + if (outputTransport != null) + { + outputTransport.Close(); + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServerTranport failed on close: " + ttx.Message); + } + stop = false; + } + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + } + } +} diff --git a/lib/csharp/src/Server/TThreadPoolServer.cs b/lib/csharp/src/Server/TThreadPoolServer.cs new file mode 100644 index 00000000..efc71f01 --- /dev/null +++ b/lib/csharp/src/Server/TThreadPoolServer.cs @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Threading; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Thrift.Server +{ + /// + /// Server that uses C# built-in ThreadPool to spawn threads when handling requests + /// + public class TThreadPoolServer : TServer + { + private const int DEFAULT_MIN_THREADS = 10; + private const int DEFAULT_MAX_THREADS = 100; + private volatile bool stop = false; + + public TThreadPoolServer(TProcessor processor, TServerTransport serverTransport) + :this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadPoolServer(TProcessor processor, TServerTransport serverTransport, LogDelegate logDelegate) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, logDelegate) + { + } + + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + :this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + DEFAULT_MIN_THREADS, DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + int minThreadPoolThreads, int maxThreadPoolThreads, LogDelegate logDel) + :base(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, logDel) + { + if (!ThreadPool.SetMinThreads(minThreadPoolThreads, minThreadPoolThreads)) + { + throw new Exception("Error: could not SetMinThreads in ThreadPool"); + } + if (!ThreadPool.SetMaxThreads(maxThreadPoolThreads, maxThreadPoolThreads)) + { + throw new Exception("Error: could not SetMaxThreads in ThreadPool"); + } + } + + /// + /// Use new ThreadPool thread for each new client connection + /// + public override void Serve() + { + try + { + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate("Error, could not listen on ServerTransport: " + ttx); + return; + } + + while (!stop) + { + int failureCount = 0; + try + { + TTransport client = serverTransport.Accept(); + ThreadPool.QueueUserWorkItem(this.Execute, client); + } + catch (TTransportException ttx) + { + if (stop) + { + logDelegate("TThreadPoolServer was shutting down, caught " + ttx.GetType().Name); + } + else + { + ++failureCount; + logDelegate(ttx.ToString()); + } + + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServerTransport failed on close: " + ttx.Message); + } + stop = false; + } + } + + /// + /// Loops on processing a client forever + /// threadContext will be a TTransport instance + /// + /// + private void Execute(Object threadContext) + { + TTransport client = (TTransport)threadContext; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) + { + //keep processing requests until client disconnects + } + } + catch (TTransportException) + { + // Assume the client died and continue silently + //Console.WriteLine(ttx); + } + + catch (Exception x) + { + logDelegate("Error: " + x); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + if (outputTransport != null) + { + outputTransport.Close(); + } + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + } + } +} diff --git a/lib/csharp/src/Server/TThreadedServer.cs b/lib/csharp/src/Server/TThreadedServer.cs new file mode 100644 index 00000000..75206f15 --- /dev/null +++ b/lib/csharp/src/Server/TThreadedServer.cs @@ -0,0 +1,234 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Threading; +using Thrift.Collections; +using Thrift.Protocol; +using Thrift.Transport; + +namespace Thrift.Server +{ + /// + /// Server that uses C# threads (as opposed to the ThreadPool) when handling requests + /// + public class TThreadedServer : TServer + { + private const int DEFAULT_MAX_THREADS = 100; + private volatile bool stop = false; + private readonly int maxThreads; + + private Queue clientQueue; + private THashSet clientThreads; + private object clientLock; + private Thread workerThread; + + public TThreadedServer(TProcessor processor, TServerTransport serverTransport) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadedServer(TProcessor processor, TServerTransport serverTransport, LogDelegate logDelegate) + : this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory(), + DEFAULT_MAX_THREADS, logDelegate) + { + } + + + public TThreadedServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) + : this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + DEFAULT_MAX_THREADS, DefaultLogDelegate) + { + } + + public TThreadedServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + int maxThreads, LogDelegate logDel) + : base(processor, serverTransport, inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, logDel) + { + this.maxThreads = maxThreads; + clientQueue = new Queue(); + clientLock = new object(); + clientThreads = new THashSet(); + } + + /// + /// Use new Thread for each new client connection. block until numConnections < maxTHreads + /// + public override void Serve() + { + try + { + //start worker thread + workerThread = new Thread(new ThreadStart(Execute)); + workerThread.Start(); + serverTransport.Listen(); + } + catch (TTransportException ttx) + { + logDelegate("Error, could not listen on ServerTransport: " + ttx); + return; + } + + while (!stop) + { + int failureCount = 0; + try + { + TTransport client = serverTransport.Accept(); + lock (clientLock) + { + clientQueue.Enqueue(client); + Monitor.Pulse(clientLock); + } + } + catch (TTransportException ttx) + { + if (stop) + { + logDelegate("TThreadPoolServer was shutting down, caught " + ttx); + } + else + { + ++failureCount; + logDelegate(ttx.ToString()); + } + + } + } + + if (stop) + { + try + { + serverTransport.Close(); + } + catch (TTransportException ttx) + { + logDelegate("TServeTransport failed on close: " + ttx.Message); + } + stop = false; + } + } + + /// + /// Loops on processing a client forever + /// threadContext will be a TTransport instance + /// + /// + private void Execute() + { + while (!stop) + { + TTransport client; + Thread t; + lock (clientLock) + { + //don't dequeue if too many connections + while (clientThreads.Count >= maxThreads) + { + Monitor.Wait(clientLock); + } + + while (clientQueue.Count == 0) + { + Monitor.Wait(clientLock); + } + + client = clientQueue.Dequeue(); + t = new Thread(new ParameterizedThreadStart(ClientWorker)); + clientThreads.Add(t); + } + //start processing requests from client on new thread + t.Start(client); + } + } + + private void ClientWorker(Object context) + { + TTransport client = (TTransport)context; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try + { + inputTransport = inputTransportFactory.GetTransport(client); + outputTransport = outputTransportFactory.GetTransport(client); + inputProtocol = inputProtocolFactory.GetProtocol(inputTransport); + outputProtocol = outputProtocolFactory.GetProtocol(outputTransport); + while (processor.Process(inputProtocol, outputProtocol)) + { + //keep processing requests until client disconnects + } + } + catch (TTransportException) + { + } + catch (Exception x) + { + logDelegate("Error: " + x); + } + + if (inputTransport != null) + { + inputTransport.Close(); + } + if (outputTransport != null) + { + outputTransport.Close(); + } + + lock (clientLock) + { + clientThreads.Remove(Thread.CurrentThread); + Monitor.Pulse(clientLock); + } + return; + } + + public override void Stop() + { + stop = true; + serverTransport.Close(); + //clean up all the threads myself + workerThread.Abort(); + foreach (Thread t in clientThreads) + { + t.Abort(); + } + } + } +} diff --git a/lib/csharp/src/TApplicationException.cs b/lib/csharp/src/TApplicationException.cs new file mode 100644 index 00000000..12719686 --- /dev/null +++ b/lib/csharp/src/TApplicationException.cs @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; + +namespace Thrift +{ + public class TApplicationException : Exception + { + protected ExceptionType type; + + public TApplicationException() + { + } + + public TApplicationException(ExceptionType type) + { + this.type = type; + } + + public TApplicationException(ExceptionType type, string message) + : base(message) + { + this.type = type; + } + + public static TApplicationException Read(TProtocol iprot) + { + TField field; + + string message = null; + ExceptionType type = ExceptionType.Unknown; + + while (true) + { + field = iprot.ReadFieldBegin(); + if (field.Type == TType.Stop) + { + break; + } + + switch (field.ID) + { + case 1: + if (field.Type == TType.String) + { + message = iprot.ReadString(); + } + else + { + TProtocolUtil.Skip(iprot, field.Type); + } + break; + case 2: + if (field.Type == TType.I32) + { + type = (ExceptionType)iprot.ReadI32(); + } + else + { + TProtocolUtil.Skip(iprot, field.Type); + } + break; + default: + TProtocolUtil.Skip(iprot, field.Type); + break; + } + + iprot.ReadFieldEnd(); + } + + iprot.ReadStructEnd(); + + return new TApplicationException(type, message); + } + + public void Write(TProtocol oprot) + { + TStruct struc = new TStruct("TApplicationException"); + TField field = new TField(); + + oprot.WriteStructBegin(struc); + + if (!String.IsNullOrEmpty(Message)) + { + field.Name = "message"; + field.Type = TType.String; + field.ID = 1; + oprot.WriteFieldBegin(field); + oprot.WriteString(Message); + oprot.WriteFieldEnd(); + } + + field.Name = "type"; + field.Type = TType.I32; + field.ID = 2; + oprot.WriteFieldBegin(field); + oprot.WriteI32((int)type); + oprot.WriteFieldEnd(); + oprot.WriteFieldStop(); + oprot.WriteStructEnd(); + } + + public enum ExceptionType + { + Unknown, + UnknownMethod, + InvalidMessageType, + WrongMethodName, + BadSequenceID, + MissingResult + } + } +} diff --git a/lib/csharp/src/TProcessor.cs b/lib/csharp/src/TProcessor.cs new file mode 100644 index 00000000..cbb55b79 --- /dev/null +++ b/lib/csharp/src/TProcessor.cs @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift.Protocol; + +namespace Thrift +{ + public interface TProcessor + { + bool Process(TProtocol iprot, TProtocol oprot); + } +} diff --git a/lib/csharp/src/Thrift.csproj b/lib/csharp/src/Thrift.csproj new file mode 100644 index 00000000..1eb4355d --- /dev/null +++ b/lib/csharp/src/Thrift.csproj @@ -0,0 +1,77 @@ + + + Debug + AnyCPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D} + 9.0.21022 + 2.0 + Library + false + Thrift + v3.5 + 512 + Thrift + SAK + SAK + SAK + SAK + + + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + + 3.5 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/lib/csharp/src/Thrift.sln b/lib/csharp/src/Thrift.sln new file mode 100644 index 00000000..cb7342cd --- /dev/null +++ b/lib/csharp/src/Thrift.sln @@ -0,0 +1,51 @@ + +Microsoft Visual Studio Solution File, Format Version 10.00 +# Visual Studio 2008 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Thrift", "Thrift.csproj", "{499EB63C-D74C-47E8-AE48-A2FC94538E9D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ThriftTest", "..\..\..\test\csharp\ThriftTest\ThriftTest.csproj", "{48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}" + ProjectSection(ProjectDependencies) = postProject + {499EB63C-D74C-47E8-AE48-A2FC94538E9D} = {499EB63C-D74C-47E8-AE48-A2FC94538E9D} + EndProjectSection +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ThriftMSBuildTask", "..\ThriftMSBuildTask\ThriftMSBuildTask.csproj", "{EC0A0231-66EA-4593-A792-C6CA3BB8668E}" +EndProject +Global + GlobalSection(SourceCodeControl) = preSolution + SccNumberOfProjects = 4 + SccProjectName0 = Perforce\u0020Project + SccLocalPath0 = ..\\..\\.. + SccProvider0 = MSSCCI:Perforce\u0020SCM + SccProjectFilePathRelativizedFromConnection0 = lib\\csharp\\src\\ + SccProjectUniqueName1 = Thrift.csproj + SccLocalPath1 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection1 = lib\\csharp\\src\\ + SccProjectUniqueName2 = ..\\..\\..\\test\\csharp\\ThriftTest\\ThriftTest.csproj + SccLocalPath2 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection2 = test\\csharp\\ThriftTest\\ + SccProjectUniqueName3 = ..\\ThriftMSBuildTask\\ThriftMSBuildTask.csproj + SccLocalPath3 = ..\\..\\.. + SccProjectFilePathRelativizedFromConnection3 = lib\\csharp\\ThriftMSBuildTask\\ + EndGlobalSection + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {499EB63C-D74C-47E8-AE48-A2FC94538E9D}.Release|Any CPU.Build.0 = Release|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C}.Release|Any CPU.Build.0 = Release|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EC0A0231-66EA-4593-A792-C6CA3BB8668E}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection +EndGlobal diff --git a/lib/csharp/src/Transport/TBufferedTransport.cs b/lib/csharp/src/Transport/TBufferedTransport.cs new file mode 100644 index 00000000..28a855a5 --- /dev/null +++ b/lib/csharp/src/Transport/TBufferedTransport.cs @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.IO; + +namespace Thrift.Transport +{ + public class TBufferedTransport : TTransport + { + private BufferedStream inputBuffer; + private BufferedStream outputBuffer; + private int bufSize; + private TStreamTransport transport; + + public TBufferedTransport(TStreamTransport transport) + :this(transport, 1024) + { + + } + + public TBufferedTransport(TStreamTransport transport, int bufSize) + { + this.bufSize = bufSize; + this.transport = transport; + InitBuffers(); + } + + private void InitBuffers() + { + if (transport.InputStream != null) + { + inputBuffer = new BufferedStream(transport.InputStream, bufSize); + } + if (transport.OutputStream != null) + { + outputBuffer = new BufferedStream(transport.OutputStream, bufSize); + } + } + + public TTransport UnderlyingTransport + { + get { return transport; } + } + + public override bool IsOpen + { + get { return transport.IsOpen; } + } + + public override void Open() + { + transport.Open(); + InitBuffers(); + } + + public override void Close() + { + if (inputBuffer != null && inputBuffer.CanRead) + { + inputBuffer.Close(); + } + if (outputBuffer != null && outputBuffer.CanWrite) + { + outputBuffer.Close(); + } + } + + public override int Read(byte[] buf, int off, int len) + { + return inputBuffer.Read(buf, off, len); + } + + public override void Write(byte[] buf, int off, int len) + { + outputBuffer.Write(buf, off, len); + } + + public override void Flush() + { + outputBuffer.Flush(); + } + } +} diff --git a/lib/csharp/src/Transport/TServerSocket.cs b/lib/csharp/src/Transport/TServerSocket.cs new file mode 100644 index 00000000..2658fce8 --- /dev/null +++ b/lib/csharp/src/Transport/TServerSocket.cs @@ -0,0 +1,157 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Net.Sockets; + + +namespace Thrift.Transport +{ + public class TServerSocket : TServerTransport + { + /** + * Underlying server with socket + */ + private TcpListener server = null; + + /** + * Port to listen on + */ + private int port = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout = 0; + + /** + * Whether or not to wrap new TSocket connections in buffers + */ + private bool useBufferedSockets = false; + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(TcpListener listener) + :this(listener, 0) + { + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(TcpListener listener, int clientTimeout) + { + this.server = listener; + this.clientTimeout = clientTimeout; + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port) + : this(port, 0) + { + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port, int clientTimeout) + :this(port, clientTimeout, false) + { + } + + public TServerSocket(int port, int clientTimeout, bool useBufferedSockets) + { + this.port = port; + this.clientTimeout = clientTimeout; + this.useBufferedSockets = useBufferedSockets; + try + { + // Make server socket + server = new TcpListener(System.Net.IPAddress.Any, this.port); + } + catch (Exception) + { + server = null; + throw new TTransportException("Could not create ServerSocket on port " + port + "."); + } + } + + public override void Listen() + { + // Make sure not to block on accept + if (server != null) + { + try + { + server.Start(); + } + catch (SocketException sx) + { + throw new TTransportException("Could not accept on listening socket: " + sx.Message); + } + } + } + + protected override TTransport AcceptImpl() + { + if (server == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No underlying server socket."); + } + try + { + TcpClient result = server.AcceptTcpClient(); + TSocket result2 = new TSocket(result); + result2.Timeout = clientTimeout; + if (useBufferedSockets) + { + TBufferedTransport result3 = new TBufferedTransport(result2); + return result3; + } + else + { + return result2; + } + } + catch (Exception ex) + { + throw new TTransportException(ex.ToString()); + } + } + + public override void Close() + { + if (server != null) + { + try + { + server.Stop(); + } + catch (Exception ex) + { + throw new TTransportException("WARNING: Could not close server socket: " + ex); + } + server = null; + } + } + } +} diff --git a/lib/csharp/src/Transport/TServerTransport.cs b/lib/csharp/src/Transport/TServerTransport.cs new file mode 100644 index 00000000..9cb52e5c --- /dev/null +++ b/lib/csharp/src/Transport/TServerTransport.cs @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public abstract class TServerTransport + { + public abstract void Listen(); + public abstract void Close(); + protected abstract TTransport AcceptImpl(); + + public TTransport Accept() + { + TTransport transport = AcceptImpl(); + if (transport == null) { + throw new TTransportException("accept() may not return NULL"); + } + return transport; + } + } +} diff --git a/lib/csharp/src/Transport/TSocket.cs b/lib/csharp/src/Transport/TSocket.cs new file mode 100644 index 00000000..18cf1547 --- /dev/null +++ b/lib/csharp/src/Transport/TSocket.cs @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Net.Sockets; + +namespace Thrift.Transport +{ + public class TSocket : TStreamTransport + { + private TcpClient client = null; + private string host = null; + private int port = 0; + private int timeout = 0; + + public TSocket(TcpClient client) + { + this.client = client; + + if (IsOpen) + { + inputStream = client.GetStream(); + outputStream = client.GetStream(); + } + } + + public TSocket(string host, int port) : this(host, port, 0) + { + } + + public TSocket(string host, int port, int timeout) + { + this.host = host; + this.port = port; + this.timeout = timeout; + + InitSocket(); + } + + private void InitSocket() + { + client = new TcpClient(); + client.ReceiveTimeout = client.SendTimeout = timeout; + } + + public int Timeout + { + set + { + client.ReceiveTimeout = client.SendTimeout = timeout = value; + } + } + + public TcpClient TcpClient + { + get + { + return client; + } + } + + public string Host + { + get + { + return host; + } + } + + public int Port + { + get + { + return port; + } + } + + public override bool IsOpen + { + get + { + if (client == null) + { + return false; + } + + return client.Connected; + } + } + + public override void Open() + { + if (IsOpen) + { + throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected"); + } + + if (String.IsNullOrEmpty(host)) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host"); + } + + if (port <= 0) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port"); + } + + if (client == null) + { + InitSocket(); + } + + client.Connect(host, port); + inputStream = client.GetStream(); + outputStream = client.GetStream(); + } + + public override void Close() + { + base.Close(); + if (client != null) + { + client.Close(); + client = null; + } + } + } +} diff --git a/lib/csharp/src/Transport/TStreamTransport.cs b/lib/csharp/src/Transport/TStreamTransport.cs new file mode 100644 index 00000000..7681e0d9 --- /dev/null +++ b/lib/csharp/src/Transport/TStreamTransport.cs @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.IO; + +namespace Thrift.Transport +{ + public class TStreamTransport : TTransport + { + protected Stream inputStream; + protected Stream outputStream; + + public TStreamTransport() + { + } + + public TStreamTransport(Stream inputStream, Stream outputStream) + { + this.inputStream = inputStream; + this.outputStream = outputStream; + } + + public Stream OutputStream + { + get { return outputStream; } + } + + public Stream InputStream + { + get { return inputStream; } + } + + public override bool IsOpen + { + get { return true; } + } + + public override void Open() + { + } + + public override void Close() + { + if (inputStream != null) + { + inputStream.Close(); + inputStream = null; + } + if (outputStream != null) + { + outputStream.Close(); + outputStream = null; + } + } + + public override int Read(byte[] buf, int off, int len) + { + if (inputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot read from null inputstream"); + } + + return inputStream.Read(buf, off, len); + } + + public override void Write(byte[] buf, int off, int len) + { + if (outputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot write to null outputstream"); + } + + outputStream.Write(buf, off, len); + } + + public override void Flush() + { + if (outputStream == null) + { + throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot flush null outputstream"); + } + + outputStream.Flush(); + } + } +} diff --git a/lib/csharp/src/Transport/TTransport.cs b/lib/csharp/src/Transport/TTransport.cs new file mode 100644 index 00000000..83f6776c --- /dev/null +++ b/lib/csharp/src/Transport/TTransport.cs @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public abstract class TTransport + { + public abstract bool IsOpen + { + get; + } + + public bool Peek() + { + return IsOpen; + } + + public abstract void Open(); + + public abstract void Close(); + + public abstract int Read(byte[] buf, int off, int len); + + public int ReadAll(byte[] buf, int off, int len) + { + int got = 0; + int ret = 0; + + while (got < len) + { + ret = Read(buf, off + got, len - got); + if (ret <= 0) + { + throw new TTransportException("Cannot read, Remote side has closed"); + } + got += ret; + } + + return got; + } + + public abstract void Write(byte[] buf, int off, int len); + + public virtual void Flush() + { + } + } +} diff --git a/lib/csharp/src/Transport/TTransportException.cs b/lib/csharp/src/Transport/TTransportException.cs new file mode 100644 index 00000000..fe10faa5 --- /dev/null +++ b/lib/csharp/src/Transport/TTransportException.cs @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + public class TTransportException : Exception + { + protected ExceptionType type; + + public TTransportException() + : base() + { + } + + public TTransportException(ExceptionType type) + : this() + { + this.type = type; + } + + public TTransportException(ExceptionType type, string message) + : base(message) + { + this.type = type; + } + + public TTransportException(string message) + : base(message) + { + } + + public ExceptionType Type + { + get { return type; } + } + + public enum ExceptionType + { + Unknown, + NotOpen, + AlreadyOpen, + TimedOut, + EndOfFile + } + } +} diff --git a/lib/csharp/src/Transport/TTransportFactory.cs b/lib/csharp/src/Transport/TTransportFactory.cs new file mode 100644 index 00000000..3d3694db --- /dev/null +++ b/lib/csharp/src/Transport/TTransportFactory.cs @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; + +namespace Thrift.Transport +{ + /// + /// From Mark Slee & Aditya Agarwal of Facebook: + /// Factory class used to create wrapped instance of Transports. + /// This is used primarily in servers, which get Transports from + /// a ServerTransport and then may want to mutate them (i.e. create + /// a BufferedTransport from the underlying base transport) + /// + public class TTransportFactory + { + public virtual TTransport GetTransport(TTransport trans) + { + return trans; + } + } +} diff --git a/lib/erl/Makefile b/lib/erl/Makefile new file mode 100644 index 00000000..77fe8b6a --- /dev/null +++ b/lib/erl/Makefile @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +MODULES = \ + src + +all clean docs: + for dir in $(MODULES); do \ + (cd $$dir; ${MAKE} $@); \ + done + +install: all + echo 'No install target, sorry.' + +check: all + +distclean: clean + +# Hack to make "make dist" work. +# This should not work, but it appears to. +distdir: diff --git a/lib/erl/README b/lib/erl/README new file mode 100644 index 00000000..ddb6946f --- /dev/null +++ b/lib/erl/README @@ -0,0 +1,56 @@ +Thrift Erlang Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Example +======= + +Example session using thrift_client: + +118> f(), {ok, C} = thrift_client:start_link("localhost", 9090, thriftTest_thrif +t). +{ok,<0.271.0>} +119> thrift_client:call(C, testVoid, []). +{ok,ok} +120> thrift_client:call(C, testVoid, [asdf]). +{error,{bad_args,testVoid,[asdf]}} +121> thrift_client:call(C, testI32, [123]). +{ok,123} +122> thrift_client:call(C, testOneway, [1]). +{ok,ok} +123> catch thrift_client:call(C, testXception, ["foo"]). +{error,{no_function,testXception}} +124> catch thrift_client:call(C, testException, ["foo"]). +{ok,ok} +125> catch thrift_client:call(C, testException, ["Xception"]). +{xception,1001,"This is an Xception"} +126> thrift_client:call(C, testException, ["Xception"]). + +=ERROR REPORT==== 24-Feb-2008::23:00:23 === +Error in process <0.269.0> with exit value: {{nocatch,{xception,1001,"This is an + Xception"}},[{thrift_client,call,3},{erl_eval,do_apply,5},{shell,exprs,6},{shel +l,eval_loop,3}]} + +** exited: {{nocatch,{xception,1001,"This is an Xception"}}, + [{thrift_client,call,3}, + {erl_eval,do_apply,5}, + {shell,exprs,6}, + {shell,eval_loop,3}]} ** diff --git a/lib/erl/build/beamver b/lib/erl/build/beamver new file mode 100644 index 00000000..2b5f77b3 --- /dev/null +++ b/lib/erl/build/beamver @@ -0,0 +1,59 @@ +#!/bin/sh + +# erlwareSys: otp/build/beamver,v 1.1 2002/02/14 11:45:20 hal Exp $ + +# usage: beamver +# +# if there's a usable -vsn() attribute, print it and exit with status 0 +# otherwise, print nothing and exit with status 1 + +# From the Erlang shell: +# +# 5> code:which(acca_inets). +# "/home/martin/work/otp/releases//../../acca/ebin/.beam" +# +# 8> beam_lib:version(code:which()). +# {ok,{,['$Id: beamver,v 1.1.1.1 2003/06/13 21:43:21 mlogan Exp $ ']}} + +# TMPFILE looks like this: +# +# io:format("hello ~p~n", +# beam_lib:version("/home/hal/work/otp/acca/ebin/acca_inets.beam")]). + +TMPFILE=/tmp/beamver.$$ + +# exit with failure if we can't read the file +test -f "$1" || exit 1 +BEAMFILE=\"$1\" + +cat > $TMPFILE <<_EOF +io:format("~p~n", + [beam_lib:version($BEAMFILE)]). +_EOF + +# beam_result is {ok,{Module_name, Beam_version} or {error,beam_lib,{Reason}} +beam_result=`erl -noshell \ + -s file eval $TMPFILE \ + -s erlang halt` + +rm -f $TMPFILE + +# sed regexes: +# remove brackets and anything outside them +# remove quotes and anything outside them +# remove apostrophes and anything outside them +# remove leading and trailing spaces + +case $beam_result in +\{ok*) + echo $beam_result | sed -e 's/.*\[\(.*\)].*/\1/' \ + -e 's/.*\"\(.*\)\".*/\1/' \ + -e "s/.*\'\(.*\)\'.*/\1/" \ + -e 's/ *$//' -e 's/^ *//' + exit 0 + ;; +*) + exit 1 + ;; +esac + diff --git a/lib/erl/build/buildtargets.mk b/lib/erl/build/buildtargets.mk new file mode 100644 index 00000000..db52b785 --- /dev/null +++ b/lib/erl/build/buildtargets.mk @@ -0,0 +1,15 @@ +EBIN ?= ../ebin +ESRC ?= . +EMULATOR = beam + +ERLC_WFLAGS = -W +ERLC = erlc $(ERLC_WFLAGS) $(ERLC_FLAGS) +ERL = erl -boot start_clean + +$(EBIN)/%.beam: $(ESRC)/%.erl + @echo " ERLC $<" + @$(ERLC) $(ERL_FLAGS) $(ERL_COMPILE_FLAGS) -o$(EBIN) $< + +.erl.beam: + $(ERLC) $(ERL_FLAGS) $(ERL_COMPILE_FLAGS) -o$(dir $@) $< + diff --git a/lib/erl/build/colors.mk b/lib/erl/build/colors.mk new file mode 100644 index 00000000..4d69c41d --- /dev/null +++ b/lib/erl/build/colors.mk @@ -0,0 +1,24 @@ +# Colors to assist visual inspection of make output. + +# Colors +LGRAY=$$'\e[0;37m' +DGRAY=$$'\e[1;30m' +LGREEN=$$'\e[1;32m' +LBLUE=$$'\e[1;34m' +LCYAN=$$'\e[1;36m' +LPURPLE=$$'\e[1;35m' +LRED=$$'\e[1;31m' +NO_COLOR=$$'\e[0m' +DEFAULT=$$'\e[0m' +BLACK=$$'\e[0;30m' +BLUE=$$'\e[0;34m' +GREEN=$$'\e[0;32m' +CYAN=$$'\e[0;36m' +RED=$$'\e[0;31m' +PURPLE=$$'\e[0;35m' +BROWN=$$'\e[0;33m' +YELLOW=$$'\e[1;33m' +WHITE=$$'\e[1;37m' + +BOLD=$$'\e[1;37m' +OFF=$$'\e[0m' diff --git a/lib/erl/build/docs.mk b/lib/erl/build/docs.mk new file mode 100644 index 00000000..b0b7377f --- /dev/null +++ b/lib/erl/build/docs.mk @@ -0,0 +1,12 @@ +EDOC_PATH=../../../tools/utilities + +#single place to include docs from. +docs: + @mkdir -p ../doc + @echo -n $${MY_BLUE:-$(BLUE)}; \ + $(EDOC_PATH)/edoc $(APP_NAME); \ + if [ $$? -eq 0 ]; then \ + echo $${MY_LRED:-$(LRED)}"$$d Doc Failed"; \ + fi; \ + echo -n $(OFF)$(NO_COLOR) + diff --git a/lib/erl/build/mime.types b/lib/erl/build/mime.types new file mode 100644 index 00000000..d6e3c0d0 --- /dev/null +++ b/lib/erl/build/mime.types @@ -0,0 +1,98 @@ + +application/activemessage +application/andrew-inset +application/applefile +application/atomicmail +application/dca-rft +application/dec-dx +application/mac-binhex40 hqx +application/mac-compactpro cpt +application/macwriteii +application/msword doc +application/news-message-id +application/news-transmission +application/octet-stream bin dms lha lzh exe class +application/oda oda +application/pdf pdf +application/postscript ai eps ps +application/powerpoint ppt +application/remote-printing +application/rtf rtf +application/slate +application/wita +application/wordperfect5.1 +application/x-bcpio bcpio +application/x-cdlink vcd +application/x-compress Z +application/x-cpio cpio +application/x-csh csh +application/x-director dcr dir dxr +application/x-dvi dvi +application/x-gtar gtar +application/x-gzip gz +application/x-hdf hdf +application/x-httpd-cgi cgi +application/x-koan skp skd skt skm +application/x-latex latex +application/x-mif mif +application/x-netcdf nc cdf +application/x-sh sh +application/x-shar shar +application/x-stuffit sit +application/x-sv4cpio sv4cpio +application/x-sv4crc sv4crc +application/x-tar tar +application/x-tcl tcl +application/x-tex tex +application/x-texinfo texinfo texi +application/x-troff t tr roff +application/x-troff-man man +application/x-troff-me me +application/x-troff-ms ms +application/x-ustar ustar +application/x-wais-source src +application/zip zip +audio/basic au snd +audio/mpeg mpga mp2 +audio/x-aiff aif aiff aifc +audio/x-pn-realaudio ram +audio/x-pn-realaudio-plugin rpm +audio/x-realaudio ra +audio/x-wav wav +chemical/x-pdb pdb xyz +image/gif gif +image/ief ief +image/jpeg jpeg jpg jpe +image/png png +image/tiff tiff tif +image/x-cmu-raster ras +image/x-portable-anymap pnm +image/x-portable-bitmap pbm +image/x-portable-graymap pgm +image/x-portable-pixmap ppm +image/x-rgb rgb +image/x-xbitmap xbm +image/x-xpixmap xpm +image/x-xwindowdump xwd +message/external-body +message/news +message/partial +message/rfc822 +multipart/alternative +multipart/appledouble +multipart/digest +multipart/mixed +multipart/parallel +text/html html htm +text/x-server-parsed-html shtml +text/plain txt +text/richtext rtx +text/tab-separated-values tsv +text/x-setext etx +text/x-sgml sgml sgm +video/mpeg mpeg mpg mpe +video/quicktime qt mov +video/x-msvideo avi +video/x-sgi-movie movie +x-conference/x-cooltalk ice +x-world/x-vrml wrl vrml diff --git a/lib/erl/build/otp.mk b/lib/erl/build/otp.mk new file mode 100644 index 00000000..1d16e2c8 --- /dev/null +++ b/lib/erl/build/otp.mk @@ -0,0 +1,146 @@ +# +----------------------------------------------------------------------+ +# $Id: otp.mk,v 1.4 2004/07/01 14:57:10 tfee Exp $ +# +----------------------------------------------------------------------+ + +# otp.mk +# - to be included in all OTP Makefiles +# installed to /usr/local/include/erlang/otp.mk + +# gmake looks in /usr/local/include - that's hard-coded +# users of this file will use +# include erlang/top.mk + +# most interface files will be installed to $ERL_RUN_TOP/app-vsn/include/*.hrl + +# group owner for library/include directories +ERLANGDEV_GROUP=erlangdev + +# ERL_TOP is root of Erlang source tree +# ERL_RUN_TOP is root of Erlang target tree (some Ericsson Makefiles use $ROOT) +# ERLANG_OTP is target root for Erlang code +# - see sasl/systools reference manual page; grep "TEST" + +# OS_TYPE is FreeBSD, NetBSD, OpenBSD, Linux, SCO_SV, SunOS. +OS_TYPE=${shell uname} + +# MHOST is the host where this Makefile runs. +MHOST=${shell hostname -s} +ERL_COMPILE_FLAGS+=-W0 + +# The location of the erlang runtime system. +ifndef ERL_RUN_TOP +ERL_RUN_TOP=/usr/local/lib/erlang +endif + + +# Edit to reflect local environment. +# ifeq (${OS_TYPE},Linux) +# ERL_RUN_TOP=/usr/local/lib/erlang +# Note* ERL_RUN_TOP can be determined by starting an +# erlang shell and typing code:root_dir(). +# ERL_TOP=a symbolic link to the actual source top, which changes from version to version +# Note* ERL_TOP is the directory where the erlang +# source files reside. Make sure to run ./configure there. +# TARGET=i686-pc-linux-gnu +# Note* Target can be found in $ERL_TOP/erts +# endif + +# See above for directions. +ifeq (${OS_TYPE},Linux) +ERL_TOP=/opt/OTP_SRC +TARGET=i686-pc-linux-gnu +endif + +ERLANG_OTP=/usr/local/erlang/otp +VAR_OTP=/var/otp + + +# Aliases for common binaries +# Note - CFLAGS is modified in erlang.conf + + +################################ +# SunOS +################################ +ifeq (${OS_TYPE},SunOS) + + CC=gcc + CXX=g++ + AR=/usr/ccs/bin/ar + ARFLAGS=-rv + CXXFLAGS+=${CFLAGS} -I/usr/include/g++ + LD=/usr/ccs/bin/ld + RANLIB=/usr/ccs/bin/ranlib + +CFLAGS+=-Wall -pedantic -ansi -O +CORE=*.core +endif + + +################################ +# FreeBSD +################################ +ifeq (${OS_TYPE},FreeBSD) + + ifdef LINUXBIN + COMPAT_LINUX=/compat/linux + CC=${COMPAT_LINUX}/usr/bin/gcc + CXX=${COMPAT_LINUX}/usr/bin/g++ + AR=${COMPAT_LINUX}/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=-fhandle-exceptions ${CFLAGS} -I${COMPAT_LINUX}/usr/include/g++ + LD=${COMPAT_LINUX}/usr/bin/ld + RANLIB=${COMPAT_LINUX}/usr/bin/ranlib + BRANDELF=brandelf -t Linux + else + CC=gcc + CXX=g++ + AR=/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=-fhandle-exceptions ${CFLAGS} -I/usr/include/g++ + LD=/usr/bin/ld + RANLIB=/usr/bin/ranlib + BRANDELF=@true + + ifdef USES_PTHREADS + CFLAGS+=-D_THREAD_SAFE + LDFLAGS+=-lc_r + + # -pthread flag for 3.0+ + ifneq (${shell uname -r | cut -d. -f1},2) + CFLAGS+=-pthread + endif + endif + endif + +CFLAGS+=-Wall -pedantic -ansi -O -DFREEBSD +CORE=*.core +endif + +################################ +# OpenBSD +################################ +ifeq (${OS_TYPE},OpenBSD) + + CC=gcc + CXX=g++ + AR=/usr/bin/ar + ARFLAGS=-rv + CXXFLAGS+=${CFLAGS} -I/usr/include/g++ + LD=/usr/bin/ld + RANLIB=/usr/bin/ranlib + + ifdef USES_PTHREADS + CFLAGS+=-D_THREAD_SAFE + LDFLAGS+=-lc_r + + # -pthread flag for 3.0+ + ifneq (${shell uname -r | cut -d. -f1},2) + CFLAGS+=-pthread + endif + endif + +CFLAGS+=-Wall -pedantic -ansi -O -DOPENBSD +CORE=*.core +endif + diff --git a/lib/erl/build/otp_subdir.mk b/lib/erl/build/otp_subdir.mk new file mode 100644 index 00000000..2a36c658 --- /dev/null +++ b/lib/erl/build/otp_subdir.mk @@ -0,0 +1,85 @@ +# Comment by tfee 2004-07-01 +# ========================== +# This file is a mod of the stock OTP one. +# The change allows make to stop when a compile error occurs. +# This file needs to go into two places: +# /usr/local/include/erlang +# /opt/OTP_SRC/make +# +# where OTP_SRC is a symbolic link to a peer directory containing +# the otp source, e.g. otp_src_R9C-2. +# +# After installing OTP, running sudo make install in otp/build +# will push this file out to the two places listed above. +# +# The mod involves setting the shell variable $short_circuit, which we +# introduce - ie it is not in the stock file. This variable is tested +# to affect execution flow and is also returned to affect the flow in +# the calling script (this one). The latter step is necessary because +# of the recursion involved. +# ===================================================================== + + +# ``The contents of this file are subject to the Erlang Public License, +# Version 1.1, (the "License"); you may not use this file except in +# compliance with the License. You should have received a copy of the +# Erlang Public License along with this software. If not, it can be +# retrieved via the world wide web at http://www.erlang.org/. +# +# Software distributed under the License is distributed on an "AS IS" +# basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +# the License for the specific language governing rights and limitations +# under the License. +# +# The Initial Developer of the Original Code is Ericsson Utvecklings AB. +# Portions created by Ericsson are Copyright 1999, Ericsson Utvecklings +# AB. All Rights Reserved.'' +# +# $Id: otp_subdir.mk,v 1.5 2004/07/12 15:12:23 jeinhorn Exp $ +# +# +# Make include file for otp + +.PHONY: debug opt release docs release_docs tests release_tests \ + clean depend + +# +# Targets that don't affect documentation directories +# +debug opt release docs release_docs tests release_tests clean depend: prepare + @set -e ; \ + app_pwd=`pwd` ; \ + if test -f vsn.mk; then \ + echo "=== Entering application" `basename $$app_pwd` ; \ + fi ; \ + case "$(MAKE)" in *clearmake*) tflag="-T";; *) tflag="";; esac; \ + short_circuit=0 ; \ + for d in $(SUB_DIRECTORIES); do \ + if [[ $$short_circuit = 0 ]]; then \ + if test -f $$d/SKIP ; then \ + echo "=== Skipping subdir $$d, reason:" ; \ + cat $$d/SKIP ; \ + echo "===" ; \ + else \ + if test ! -d $$d ; then \ + echo "=== Skipping subdir $$d, it is missing" ; \ + else \ + xflag="" ; \ + if test -f $$d/ignore_config_record.inf; then \ + xflag=$$tflag ; \ + fi ; \ + (cd $$d && $(MAKE) $$xflag $@) ; \ + if [[ $$? != 0 ]]; then \ + short_circuit=1 ; \ + fi ; \ + fi ; \ + fi ; \ + fi ; \ + done ; \ + if test -f vsn.mk; then \ + echo "=== Leaving application" `basename $$app_pwd` ; \ + fi ; \ + exit $$short_circuit + +prepare: + echo diff --git a/lib/erl/build/raw_test.mk b/lib/erl/build/raw_test.mk new file mode 100644 index 00000000..bf8535d1 --- /dev/null +++ b/lib/erl/build/raw_test.mk @@ -0,0 +1,29 @@ +# for testing erlang files directly. The set up for a +# this type of test would be +# files to test reside in lib//src and the test files which are +# just plain erlang code reside in lib//test +# +# This color codes emitted while the tests run assume that you are using +# a white-on-black display schema. If not, e.g. if you use a white +# background, you will not be able to read the "WHITE" text. +# You can override this by supplying your own "white" color, +# which may in fact be black! You do this by defining an environment +# variable named "MY_WHITE" and setting it to $'\e[0;30m' (which is +# simply bash's way of specifying "Escape [ 0 ; 3 0 m"). +# Similarly, you can set your versions of the standard colors +# found in colors.mk. + +test: + @TEST_MODULES=`ls *_test.erl`; \ + trap "echo $(OFF)$(NO_COLOR); exit 1;" 1 2 3 6; \ + for d in $$TEST_MODULES; do \ + echo $${MY_GREEN:-$(GREEN)}"Testing File $$d" $${MY_WHITE:-$(WHITE)}; \ + echo -n $${MY_BLUE:-$(BLUE)}; \ + erl -name $(APP_NAME) $(TEST_LIBS) \ + -s `basename $$d .erl` all -s init stop -noshell; \ + if [ $$? -ne 0 ]; then \ + echo $${MY_LRED:-$(LRED)}"$$d Test Failed"; \ + fi; \ + echo -n $(OFF)$(NO_COLOR); \ + done + diff --git a/lib/erl/include/thrift_constants.hrl b/lib/erl/include/thrift_constants.hrl new file mode 100644 index 00000000..36eb49bf --- /dev/null +++ b/lib/erl/include/thrift_constants.hrl @@ -0,0 +1,54 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +%% TType +-define(tType_STOP, 0). +-define(tType_VOID, 1). +-define(tType_BOOL, 2). +-define(tType_BYTE, 3). +-define(tType_DOUBLE, 4). +-define(tType_I16, 6). +-define(tType_I32, 8). +-define(tType_I64, 10). +-define(tType_STRING, 11). +-define(tType_STRUCT, 12). +-define(tType_MAP, 13). +-define(tType_SET, 14). +-define(tType_LIST, 15). + +% TMessageType +-define(tMessageType_CALL, 1). +-define(tMessageType_REPLY, 2). +-define(tMessageType_EXCEPTION, 3). +-define(tMessageType_ONEWAY, 4). + +% TApplicationException +-define(TApplicationException_Structure, + {struct, [{1, string}, + {2, i32}]}). + +-record('TApplicationException', {message, type}). + +-define(TApplicationException_UNKNOWN, 0). +-define(TApplicationException_UNKNOWN_METHOD, 1). +-define(TApplicationException_INVALID_MESSAGE_TYPE, 2). +-define(TApplicationException_WRONG_METHOD_NAME, 3). +-define(TApplicationException_BAD_SEQUENCE_ID, 4). +-define(TApplicationException_MISSING_RESULT, 5). + diff --git a/lib/erl/include/thrift_protocol.hrl b/lib/erl/include/thrift_protocol.hrl new file mode 100644 index 00000000..f4e1901f --- /dev/null +++ b/lib/erl/include/thrift_protocol.hrl @@ -0,0 +1,31 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-ifndef(THRIFT_PROTOCOL_INCLUDED). +-define(THRIFT_PROTOCOL_INCLUDED, yea). + +-record(protocol_message_begin, {name, type, seqid}). +-record(protocol_struct_begin, {name}). +-record(protocol_field_begin, {name, type, id}). +-record(protocol_map_begin, {ktype, vtype, size}). +-record(protocol_list_begin, {etype, size}). +-record(protocol_set_begin, {etype, size}). + + +-endif. diff --git a/lib/erl/src/Makefile b/lib/erl/src/Makefile new file mode 100644 index 00000000..980af812 --- /dev/null +++ b/lib/erl/src/Makefile @@ -0,0 +1,116 @@ +# $Id: Makefile,v 1.3 2004/08/13 16:35:59 mlogan Exp $ +# +include ../build/otp.mk +include ../build/colors.mk +include ../build/buildtargets.mk + +# ---------------------------------------------------- +# Application version +# ---------------------------------------------------- + +include ../vsn.mk +APP_NAME=thrift +PFX=thrift +VSN=$(THRIFT_VSN) + +# ---------------------------------------------------- +# Install directory specification +# WARNING: INSTALL_DIR the command to install a directory. +# INSTALL_DST is the target directory +# ---------------------------------------------------- +INSTALL_DST = $(ERLANG_OTP)/lib/$(APP_NAME)-$(VSN) + +# ---------------------------------------------------- +# Target Specs +# ---------------------------------------------------- + + +MODULES = $(shell find . -name \*.erl | sed 's:^\./::' | sed 's/\.erl//') +MODULES_STRING_LIST = $(shell find . -name \*.erl | sed 's:^\./:":' | sed 's/\.erl/",/') + +HRL_FILES= +INTERNAL_HRL_FILES= $(APP_NAME).hrl +ERL_FILES= $(MODULES:%=%.erl) +DOC_FILES=$(ERL_FILES) + +APP_FILE= $(APP_NAME).app +APPUP_FILE= $(APP_NAME).appup + +APP_SRC= $(APP_FILE).src +APPUP_SRC= $(APPUP_FILE).src + +APP_TARGET= $(EBIN)/$(APP_FILE) +APPUP_TARGET= $(EBIN)/$(APPUP_FILE) + +BEAMS= $(MODULES:%=$(EBIN)/%.$(EMULATOR)) +TARGET_FILES= $(BEAMS) $(APP_TARGET) $(APPUP_TARGET) + +WEB_TARGET=/var/yaws/www/$(APP_NAME) + +# ---------------------------------------------------- +# FLAGS +# ---------------------------------------------------- + +ERL_FLAGS += +ERL_INCLUDE = -I../include -I../../fslib/include -I../../system_status/include +ERL_COMPILE_FLAGS += $(ERL_INCLUDE) + +# ---------------------------------------------------- +# Targets +# ---------------------------------------------------- + +all debug opt: $(EBIN) $(TARGET_FILES) + +#$(EBIN)/rm_logger.beam: $(APP_NAME).hrl +include ../build/docs.mk + +# Note: In the open-source build clean must not destroy the preloaded +# beam files. +clean: + rm -f $(TARGET_FILES) + rm -f *~ + rm -f core + rm -rf $(EBIN) + rm -rf *html + +$(EBIN): + mkdir $(EBIN) + +dialyzer: $(TARGET_FILES) + dialyzer --src -r . $(ERL_INCLUDE) + +# ---------------------------------------------------- +# Special Build Targets +# ---------------------------------------------------- + +$(APP_TARGET): $(APP_SRC) ../vsn.mk $(BEAMS) + sed -e 's;%VSN%;$(VSN);' \ + -e 's;%PFX%;$(PFX);' \ + -e 's;%APP_NAME%;$(APP_NAME);' \ + -e 's;%MODULES%;%MODULES%$(MODULES_STRING_LIST);' \ + $< > $<".tmp" + sed -e 's/%MODULES%\(.*\),/\1/' \ + $<".tmp" > $@ + rm $<".tmp" + +$(APPUP_TARGET): $(APPUP_SRC) ../vsn.mk + sed -e 's;%VSN%;$(VSN);' $< > $@ + +$(WEB_TARGET): ../markup/* + rm -rf $(WEB_TARGET) + mkdir $(WEB_TARGET) + cp -r ../markup/ $(WEB_TARGET) + cp -r ../skins/ $(WEB_TARGET) + +# ---------------------------------------------------- +# Install Target +# ---------------------------------------------------- + +install: all $(WEB_TARGET) +# $(INSTALL_DIR) $(INSTALL_DST)/src +# $(INSTALL_DATA) $(ERL_FILES) $(INSTALL_DST)/src +# $(INSTALL_DATA) $(INTERNAL_HRL_FILES) $(INSTALL_DST)/src +# $(INSTALL_DIR) $(INSTALL_DST)/include +# $(INSTALL_DATA) $(HRL_FILES) $(INSTALL_DST)/include +# $(INSTALL_DIR) $(INSTALL_DST)/ebin +# $(INSTALL_DATA) $(TARGET_FILES) $(INSTALL_DST)/ebin diff --git a/lib/erl/src/test_handler.erl b/lib/erl/src/test_handler.erl new file mode 100644 index 00000000..28a3acd3 --- /dev/null +++ b/lib/erl/src/test_handler.erl @@ -0,0 +1,26 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_handler). + +-export([handle_function/2]). + +handle_function(add, Params = {A, B}) -> + io:format("Got params: ~p~n", [Params]), + {reply, A + B}. diff --git a/lib/erl/src/test_service.erl b/lib/erl/src/test_service.erl new file mode 100644 index 00000000..7aa4827f --- /dev/null +++ b/lib/erl/src/test_service.erl @@ -0,0 +1,29 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_service). +% +% Test service definition + +-export([function_info/2]). + +function_info(add, params_type) -> + {struct, [{1, i32}, + {2, i32}]}; +function_info(add, reply_type) -> i32. diff --git a/lib/erl/src/thrift.app.src b/lib/erl/src/thrift.app.src new file mode 100644 index 00000000..681b3eb3 --- /dev/null +++ b/lib/erl/src/thrift.app.src @@ -0,0 +1,44 @@ +%%% -*- mode:erlang -*- +{application, %APP_NAME%, + [ + % A quick description of the application. + {description, "Thrift bindings"}, + + % The version of the applicaton + {vsn, "%VSN%"}, + + % All modules used by the application. + {modules, [ + %MODULES% + ]}, + + % All of the registered names the application uses. This can be ignored. + {registered, []}, + + % Applications that are to be started prior to this one. This can be ignored + % leave it alone unless you understand it well and let the .rel files in + % your release handle this. + {applications, + [ + kernel, + stdlib + ]}, + + % OTP application loader will load, but not start, included apps. Again + % this can be ignored as well. To load but not start an application it + % is easier to include it in the .rel file followed by the atom 'none' + {included_applications, []}, + + % configuration parameters similar to those in the config file specified + % on the command line. can be fetched with gas:get_env + {env, [ + % If an error/crash occurs during processing of a function, + % should the TApplicationException serialized back to the client + % include the erlang backtrace? + {exceptions_include_traces, true} + ]}, + + % The Module and Args used to start this application. + {mod, {thrift_app, []}} + ] +}. diff --git a/lib/erl/src/thrift.appup.src b/lib/erl/src/thrift.appup.src new file mode 100644 index 00000000..54a63833 --- /dev/null +++ b/lib/erl/src/thrift.appup.src @@ -0,0 +1 @@ +{"%VSN%",[],[]}. diff --git a/lib/erl/src/thrift_base64_transport.erl b/lib/erl/src/thrift_base64_transport.erl new file mode 100644 index 00000000..9d13151c --- /dev/null +++ b/lib/erl/src/thrift_base64_transport.erl @@ -0,0 +1,64 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_base64_transport). + +-behaviour(thrift_transport). + +%% API +-export([new/1, new_transport_factory/1]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +%% State +-record(b64_transport, {wrapped}). + +new(Wrapped) -> + State = #b64_transport{wrapped = Wrapped}, + thrift_transport:new(?MODULE, State). + + +write(#b64_transport{wrapped = Wrapped}, Data) -> + thrift_transport:write(Wrapped, base64:encode(iolist_to_binary(Data))). + + +%% base64 doesn't support reading quite yet since it would involve +%% nasty buffering and such +read(#b64_transport{wrapped = Wrapped}, Data) -> + {error, no_reads_allowed}. + + +flush(#b64_transport{wrapped = Wrapped}) -> + thrift_transport:write(Wrapped, <<"\n">>), + thrift_transport:flush(Wrapped). + + +close(Me = #b64_transport{wrapped = Wrapped}) -> + flush(Me), + thrift_transport:close(Wrapped). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +new_transport_factory(WrapFactory) -> + F = fun() -> + {ok, Wrapped} = WrapFactory(), + new(Wrapped) + end, + {ok, F}. diff --git a/lib/erl/src/thrift_binary_protocol.erl b/lib/erl/src/thrift_binary_protocol.erl new file mode 100644 index 00000000..ad533842 --- /dev/null +++ b/lib/erl/src/thrift_binary_protocol.erl @@ -0,0 +1,325 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_binary_protocol). + +-behavior(thrift_protocol). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-export([new/1, new/2, + read/2, + write/2, + flush_transport/1, + close_transport/1, + + new_protocol_factory/2 + ]). + +-record(binary_protocol, {transport, + strict_read=true, + strict_write=true + }). + +-define(VERSION_MASK, 16#FFFF0000). +-define(VERSION_1, 16#80010000). +-define(TYPE_MASK, 16#000000ff). + +new(Transport) -> + new(Transport, _Options = []). + +new(Transport, Options) -> + State = #binary_protocol{transport = Transport}, + State1 = parse_options(Options, State), + thrift_protocol:new(?MODULE, State1). + +parse_options([], State) -> + State; +parse_options([{strict_read, Bool} | Rest], State) when is_boolean(Bool) -> + parse_options(Rest, State#binary_protocol{strict_read=Bool}); +parse_options([{strict_write, Bool} | Rest], State) when is_boolean(Bool) -> + parse_options(Rest, State#binary_protocol{strict_write=Bool}). + + +flush_transport(#binary_protocol{transport = Transport}) -> + thrift_transport:flush(Transport). + +close_transport(#binary_protocol{transport = Transport}) -> + thrift_transport:close(Transport). + +%%% +%%% instance methods +%%% + +write(This, #protocol_message_begin{ + name = Name, + type = Type, + seqid = Seqid}) -> + case This#binary_protocol.strict_write of + true -> + write(This, {i32, ?VERSION_1 bor Type}), + write(This, {string, Name}), + write(This, {i32, Seqid}); + false -> + write(This, {string, Name}), + write(This, {byte, Type}), + write(This, {i32, Seqid}) + end, + ok; + +write(This, message_end) -> ok; + +write(This, #protocol_field_begin{ + name = _Name, + type = Type, + id = Id}) -> + write(This, {byte, Type}), + write(This, {i16, Id}), + ok; + +write(This, field_stop) -> + write(This, {byte, ?tType_STOP}), + ok; + +write(This, field_end) -> ok; + +write(This, #protocol_map_begin{ + ktype = Ktype, + vtype = Vtype, + size = Size}) -> + write(This, {byte, Ktype}), + write(This, {byte, Vtype}), + write(This, {i32, Size}), + ok; + +write(This, map_end) -> ok; + +write(This, #protocol_list_begin{ + etype = Etype, + size = Size}) -> + write(This, {byte, Etype}), + write(This, {i32, Size}), + ok; + +write(This, list_end) -> ok; + +write(This, #protocol_set_begin{ + etype = Etype, + size = Size}) -> + write(This, {byte, Etype}), + write(This, {i32, Size}), + ok; + +write(This, set_end) -> ok; + +write(This, #protocol_struct_begin{}) -> ok; +write(This, struct_end) -> ok; + +write(This, {bool, true}) -> write(This, {byte, 1}); +write(This, {bool, false}) -> write(This, {byte, 0}); + +write(This, {byte, Byte}) -> + write(This, <>); + +write(This, {i16, I16}) -> + write(This, <>); + +write(This, {i32, I32}) -> + write(This, <>); + +write(This, {i64, I64}) -> + write(This, <>); + +write(This, {double, Double}) -> + write(This, <>); + +write(This, {string, Str}) when is_list(Str) -> + write(This, {i32, length(Str)}), + write(This, list_to_binary(Str)); + +write(This, {string, Bin}) when is_binary(Bin) -> + write(This, {i32, size(Bin)}), + write(This, Bin); + +%% Data :: iolist() +write(This, Data) -> + thrift_transport:write(This#binary_protocol.transport, Data). + +%% + +read(This, message_begin) -> + case read(This, ui32) of + {ok, Sz} when Sz band ?VERSION_MASK =:= ?VERSION_1 -> + %% we're at version 1 + {ok, Name} = read(This, string), + Type = Sz band ?TYPE_MASK, + {ok, SeqId} = read(This, i32), + #protocol_message_begin{name = binary_to_list(Name), + type = Type, + seqid = SeqId}; + + {ok, Sz} when Sz < 0 -> + %% there's a version number but it's unexpected + {error, {bad_binary_protocol_version, Sz}}; + + {ok, Sz} when This#binary_protocol.strict_read =:= true -> + %% strict_read is true and there's no version header; that's an error + {error, no_binary_protocol_version}; + + {ok, Sz} when This#binary_protocol.strict_read =:= false -> + %% strict_read is false, so just read the old way + {ok, Name} = read(This, Sz), + {ok, Type} = read(This, byte), + {ok, SeqId} = read(This, i32), + #protocol_message_begin{name = binary_to_list(Name), + type = Type, + seqid = SeqId}; + + Err = {error, closed} -> Err; + Err = {error, timeout}-> Err; + Err = {error, ebadf} -> Err + end; + +read(This, message_end) -> ok; + +read(This, struct_begin) -> ok; +read(This, struct_end) -> ok; + +read(This, field_begin) -> + case read(This, byte) of + {ok, Type = ?tType_STOP} -> + #protocol_field_begin{type = Type}; + {ok, Type} -> + {ok, Id} = read(This, i16), + #protocol_field_begin{type = Type, + id = Id} + end; + +read(This, field_end) -> ok; + +read(This, map_begin) -> + {ok, Ktype} = read(This, byte), + {ok, Vtype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_map_begin{ktype = Ktype, + vtype = Vtype, + size = Size}; +read(This, map_end) -> ok; + +read(This, list_begin) -> + {ok, Etype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_list_begin{etype = Etype, + size = Size}; +read(This, list_end) -> ok; + +read(This, set_begin) -> + {ok, Etype} = read(This, byte), + {ok, Size} = read(This, i32), + #protocol_set_begin{etype = Etype, + size = Size}; +read(This, set_end) -> ok; + +read(This, field_stop) -> + {ok, ?tType_STOP} = read(This, byte), + ok; + +%% + +read(This, bool) -> + case read(This, byte) of + {ok, Byte} -> {ok, Byte /= 0}; + Else -> Else + end; + +read(This, byte) -> + case read(This, 1) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +read(This, i16) -> + case read(This, 2) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +read(This, i32) -> + case read(This, 4) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +%% unsigned ints aren't used by thrift itself, but it's used for the parsing +%% of the packet version header. Without this special function BEAM works fine +%% but hipe thinks it received a bad version header. +read(This, ui32) -> + case read(This, 4) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +read(This, i64) -> + case read(This, 8) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +read(This, double) -> + case read(This, 8) of + {ok, <>} -> {ok, Val}; + Else -> Else + end; + +% returns a binary directly, call binary_to_list if necessary +read(This, string) -> + {ok, Sz} = read(This, i32), + {ok, Bin} = read(This, Sz); + +read(This, 0) -> {ok, <<>>}; +read(This, Len) when is_integer(Len), Len >= 0 -> + thrift_transport:read(This#binary_protocol.transport, Len). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +-record(tbp_opts, {strict_read = true, + strict_write = true}). + +parse_factory_options([], Opts) -> + Opts; +parse_factory_options([{strict_read, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#tbp_opts{strict_read=Bool}); +parse_factory_options([{strict_write, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#tbp_opts{strict_write=Bool}). + + +%% returns a (fun() -> thrift_protocol()) +new_protocol_factory(TransportFactory, Options) -> + ParsedOpts = parse_factory_options(Options, #tbp_opts{}), + F = fun() -> + {ok, Transport} = TransportFactory(), + thrift_binary_protocol:new( + Transport, + [{strict_read, ParsedOpts#tbp_opts.strict_read}, + {strict_write, ParsedOpts#tbp_opts.strict_write}]) + end, + {ok, F}. + diff --git a/lib/erl/src/thrift_buffered_transport.erl b/lib/erl/src/thrift_buffered_transport.erl new file mode 100644 index 00000000..ebc16bd6 --- /dev/null +++ b/lib/erl/src/thrift_buffered_transport.erl @@ -0,0 +1,180 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_buffered_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/1, new_transport_factory/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(buffered_transport, {wrapped, % a thrift_transport + write_buffer % iolist() + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(WrappedTransport) -> + case gen_server:start_link(?MODULE, [WrappedTransport], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + + + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}, _Timeout=10000). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Wrapped]) -> + {ok, #buffered_transport{wrapped = Wrapped, + write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #buffered_transport{write_buffer = WBuf}) -> + {reply, ok, State#buffered_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #buffered_transport{wrapped = Wrapped}) -> + Response = thrift_transport:read(Wrapped, Len), + {reply, Response, State}; + +handle_call(flush, _From, State = #buffered_transport{write_buffer = WBuf, + wrapped = Wrapped}) -> + Response = thrift_transport:write(Wrapped, WBuf), + thrift_transport:flush(Wrapped), + {reply, Response, State#buffered_transport{write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State = #buffered_transport{write_buffer = WBuf, + wrapped = Wrapped}) -> + thrift_transport:write(Wrapped, WBuf), + %% Wrapped is closed by terminate/2 + %% error_logger:info_msg("thrift_buffered_transport ~p: closing", [self()]), + {stop, normal, State}; +handle_cast(Msg, State=#buffered_transport{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, State = #buffered_transport{wrapped=Wrapped}) -> + thrift_transport:close(Wrapped), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +new_transport_factory(WrapFactory) -> + F = fun() -> + {ok, Wrapped} = WrapFactory(), + new(Wrapped) + end, + {ok, F}. diff --git a/lib/erl/src/thrift_client.erl b/lib/erl/src/thrift_client.erl new file mode 100644 index 00000000..5ba8aee6 --- /dev/null +++ b/lib/erl/src/thrift_client.erl @@ -0,0 +1,330 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_client). + +-behaviour(gen_server). + +%% API +-export([start_link/2, start_link/3, start_link/4, call/3, send_call/3, close/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(state, {service, protocol, seqid}). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +start_link(Host, Port, Service) when is_integer(Port), is_atom(Service) -> + start_link(Host, Port, Service, []). + + +%% +%% Splits client options into protocol options and transport options +%% +%% split_options([Options...]) -> {ProtocolOptions, TransportOptions} +%% +split_options(Options) -> + split_options(Options, [], []). + +split_options([], ProtoIn, TransIn) -> + {ProtoIn, TransIn}; + +split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn) + when OptKey =:= strict_read; + OptKey =:= strict_write -> + split_options(Rest, [Opt | ProtoIn], TransIn); + +split_options([Opt = {OptKey, _} | Rest], ProtoIn, TransIn) + when OptKey =:= framed; + OptKey =:= connect_timeout; + OptKey =:= sockopts -> + split_options(Rest, ProtoIn, [Opt | TransIn]). + + +%% Backwards-compatible starter for the common-case of socket transports +start_link(Host, Port, Service, Options) + when is_integer(Port), is_atom(Service), is_list(Options) -> + {ProtoOpts, TransOpts} = split_options(Options), + + {ok, TransportFactory} = + thrift_socket_transport:new_transport_factory(Host, Port, TransOpts), + + {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory( + TransportFactory, ProtoOpts), + + start_link(ProtocolFactory, Service). + + +%% ProtocolFactory :: fun() -> thrift_protocol() +start_link(ProtocolFactory, Service) + when is_function(ProtocolFactory), is_atom(Service) -> + case gen_server:start_link(?MODULE, [Service], []) of + {ok, Pid} -> + case gen_server:call(Pid, {connect, ProtocolFactory}) of + ok -> + {ok, Pid}; + Error -> + Error + end; + Else -> + Else + end. + +call(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + case gen_server:call(Client, {call, Function, Args}) of + R = {ok, _} -> R; + R = {error, _} -> R; + {exception, Exception} -> throw(Exception) + end. + +cast(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + gen_server:cast(Client, {call, Function, Args}). + +%% Sends a function call but does not read the result. This is useful +%% if you're trying to log non-oneway function calls to write-only +%% transports like thrift_disk_log_transport. +send_call(Client, Function, Args) + when is_pid(Client), is_atom(Function), is_list(Args) -> + gen_server:call(Client, {send_call, Function, Args}). + +close(Client) when is_pid(Client) -> + gen_server:cast(Client, close). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Service]) -> + {ok, #state{service = Service}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({connect, ProtocolFactory}, _From, + State = #state{service = Service}) -> + case ProtocolFactory() of + {ok, Protocol} -> + {reply, ok, State#state{protocol = Protocol, + seqid = 0}}; + Error -> + {stop, normal, Error, State} + end; + +handle_call({call, Function, Args}, _From, State = #state{service = Service}) -> + Result = catch_function_exceptions( + fun() -> + ok = send_function_call(State, Function, Args), + receive_function_result(State, Function) + end, + Service), + {reply, Result, State}; + + +handle_call({send_call, Function, Args}, _From, State = #state{service = Service}) -> + Result = catch_function_exceptions( + fun() -> + send_function_call(State, Function, Args) + end, + Service), + {reply, Result, State}. + + +%% Helper function that catches exceptions thrown by sending or receiving +%% a function and returns the correct response for call or send_only above. +catch_function_exceptions(Fun, Service) -> + try + Fun() + catch + throw:{return, Return} -> + Return; + error:function_clause -> + ST = erlang:get_stacktrace(), + case hd(ST) of + {Service, function_info, [Function, _]} -> + {error, {no_function, Function}}; + _ -> throw({error, {function_clause, ST}}) + end + end. + + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast({call, Function, Args}, State = #state{service = Service, + protocol = Protocol, + seqid = SeqId}) -> + _Result = + try + ok = send_function_call(State, Function, Args), + receive_function_result(State, Function) + catch + Class:Reason -> + error_logger:error_msg("error ignored in handle_cast({cast,...},...): ~p:~p~n", [Class, Reason]) + end, + + {noreply, State}; + +handle_cast(close, State=#state{protocol = Protocol}) -> +%% error_logger:info_msg("thrift_client ~p received close", [self()]), + {stop,normal,State}; +handle_cast(_Msg, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(Reason, State = #state{protocol=undefined}) -> + ok; +terminate(Reason, State = #state{protocol=Protocol}) -> + thrift_protocol:close_transport(Protocol), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +send_function_call(#state{protocol = Proto, + service = Service, + seqid = SeqId}, + Function, + Args) -> + Params = Service:function_info(Function, params_type), + {struct, PList} = Params, + if + length(PList) =/= length(Args) -> + throw({return, {error, {bad_args, Function, Args}}}); + true -> ok + end, + + Begin = #protocol_message_begin{name = atom_to_list(Function), + type = ?tMessageType_CALL, + seqid = SeqId}, + ok = thrift_protocol:write(Proto, Begin), + ok = thrift_protocol:write(Proto, {Params, list_to_tuple([Function | Args])}), + ok = thrift_protocol:write(Proto, message_end), + thrift_protocol:flush_transport(Proto), + ok. + +receive_function_result(State = #state{protocol = Proto, + service = Service}, + Function) -> + ResultType = Service:function_info(Function, reply_type), + read_result(State, Function, ResultType). + +read_result(_State, + _Function, + oneway_void) -> + {ok, ok}; + +read_result(State = #state{protocol = Proto, + seqid = SeqId}, + Function, + ReplyType) -> + case thrift_protocol:read(Proto, message_begin) of + #protocol_message_begin{seqid = RetSeqId} when RetSeqId =/= SeqId -> + {error, {bad_seq_id, SeqId}}; + + #protocol_message_begin{type = ?tMessageType_EXCEPTION} -> + handle_application_exception(State); + + #protocol_message_begin{type = ?tMessageType_REPLY} -> + handle_reply(State, Function, ReplyType) + end. + +handle_reply(State = #state{protocol = Proto, + service = Service}, + Function, + ReplyType) -> + {struct, ExceptionFields} = Service:function_info(Function, exceptions), + ReplyStructDef = {struct, [{0, ReplyType}] ++ ExceptionFields}, + {ok, Reply} = thrift_protocol:read(Proto, ReplyStructDef), + ReplyList = tuple_to_list(Reply), + true = length(ReplyList) == length(ExceptionFields) + 1, + ExceptionVals = tl(ReplyList), + Thrown = [X || X <- ExceptionVals, + X =/= undefined], + Result = + case Thrown of + [] when ReplyType == {struct, []} -> + {ok, ok}; + [] -> + {ok, hd(ReplyList)}; + [Exception] -> + {exception, Exception} + end, + ok = thrift_protocol:read(Proto, message_end), + Result. + +handle_application_exception(State = #state{protocol = Proto}) -> + {ok, Exception} = thrift_protocol:read(Proto, + ?TApplicationException_Structure), + ok = thrift_protocol:read(Proto, message_end), + XRecord = list_to_tuple( + ['TApplicationException' | tuple_to_list(Exception)]), + error_logger:error_msg("X: ~p~n", [XRecord]), + true = is_record(XRecord, 'TApplicationException'), + {exception, XRecord}. diff --git a/lib/erl/src/thrift_disk_log_transport.erl b/lib/erl/src/thrift_disk_log_transport.erl new file mode 100644 index 00000000..761fa309 --- /dev/null +++ b/lib/erl/src/thrift_disk_log_transport.erl @@ -0,0 +1,118 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +%%% Todo: this might be better off as a gen_server type of transport +%%% that handles stuff like group commit, similar to TFileTransport +%%% in cpp land +-module(thrift_disk_log_transport). + +-behaviour(thrift_transport). + +%% API +-export([new/2, new_transport_factory/2, new_transport_factory/3]). + +%% thrift_transport callbacks +-export([read/2, write/2, force_flush/1, flush/1, close/1]). + +%% state +-record(dl_transport, {log, + close_on_close = false, + sync_every = infinity, + sync_tref}). + + +%% Create a transport attached to an already open log. +%% If you'd like this transport to close the disk_log using disk_log:lclose() +%% when the transport is closed, pass a {close_on_close, true} tuple in the +%% Opts list. +new(LogName, Opts) when is_atom(LogName), is_list(Opts) -> + State = parse_opts(Opts, #dl_transport{log = LogName}), + + State2 = + case State#dl_transport.sync_every of + N when is_integer(N), N > 0 -> + {ok, TRef} = timer:apply_interval(N, ?MODULE, force_flush, State), + State#dl_transport{sync_tref = TRef}; + _ -> State + end, + + thrift_transport:new(?MODULE, State2). + + +parse_opts([], State) -> + State; +parse_opts([{close_on_close, Bool} | Rest], State) when is_boolean(Bool) -> + State#dl_transport{close_on_close = Bool}; +parse_opts([{sync_every, Int} | Rest], State) when is_integer(Int), Int > 0 -> + State#dl_transport{sync_every = Int}. + + +%%%% TRANSPORT IMPLENTATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +%% disk_log_transport is write-only +read(_State, Len) -> + {error, no_read_from_disk_log}. + +write(#dl_transport{log = Log}, Data) -> + disk_log:balog(Log, erlang:iolist_to_binary(Data)). + +force_flush(#dl_transport{log = Log}) -> + error_logger:info_msg("~p syncing~n", [?MODULE]), + disk_log:sync(Log). + +flush(#dl_transport{log = Log, sync_every = SE}) -> + case SE of + undefined -> % no time-based sync + disk_log:sync(Log); + _Else -> % sync will happen automagically + ok + end. + + +%% On close, close the underlying log if we're configured to do so. +close(#dl_transport{close_on_close = false}) -> + ok; +close(#dl_transport{log = Log}) -> + disk_log:lclose(Log). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +new_transport_factory(Name, ExtraLogOpts) -> + new_transport_factory(Name, ExtraLogOpts, [{close_on_close, true}, + {sync_every, 500}]). + +new_transport_factory(Name, ExtraLogOpts, TransportOpts) -> + F = fun() -> factory_impl(Name, ExtraLogOpts, TransportOpts) end, + {ok, F}. + +factory_impl(Name, ExtraLogOpts, TransportOpts) -> + LogOpts = [{name, Name}, + {format, external}, + {type, wrap} | + ExtraLogOpts], + Log = + case disk_log:open(LogOpts) of + {ok, Log} -> + Log; + {repaired, Log, Info1, Info2} -> + error_logger:info_msg("Disk log ~p repaired: ~p, ~p~n", [Log, Info1, Info2]), + Log + end, + new(Log, TransportOpts). diff --git a/lib/erl/src/thrift_file_transport.erl b/lib/erl/src/thrift_file_transport.erl new file mode 100644 index 00000000..5ac2dbe1 --- /dev/null +++ b/lib/erl/src/thrift_file_transport.erl @@ -0,0 +1,87 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_file_transport). + +-behaviour(thrift_transport). + +-export([new_reader/1, + new/1, + new/2, + write/2, read/2, flush/1, close/1]). + +-record(t_file_transport, {device, + should_close = true, + mode = write}). + +%%%% CONSTRUCTION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +new_reader(Filename) -> + case file:open(Filename, [read, binary, {read_ahead, 1024*1024}]) of + {ok, IODevice} -> + new(IODevice, [{should_close, true}, {mode, read}]); + Error -> Error + end. + +new(Device) -> + new(Device, []). + +%% Device :: io_device() +%% +%% Device should be opened in raw and binary mode. +new(Device, Opts) when is_list(Opts) -> + State = parse_opts(Opts, #t_file_transport{device = Device}), + thrift_transport:new(?MODULE, State). + + +%% Parse options +parse_opts([{should_close, Bool} | Rest], State) when is_boolean(Bool) -> + parse_opts(Rest, State#t_file_transport{should_close = Bool}); +parse_opts([{mode, Mode} | Rest], State) + when Mode =:= write; + Mode =:= read -> + parse_opts(Rest, State#t_file_transport{mode = Mode}); +parse_opts([], State) -> + State. + + +%%%% TRANSPORT IMPL %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +write(#t_file_transport{device = Device, mode = write}, Data) -> + file:write(Device, Data); +write(_T, _D) -> + {error, read_mode}. + + +read(#t_file_transport{device = Device, mode = read}, Len) + when is_integer(Len), Len >= 0 -> + file:read(Device, Len); +read(_T, _D) -> + {error, read_mode}. + +flush(#t_file_transport{device = Device, mode = write}) -> + file:sync(Device). + +close(#t_file_transport{device = Device, should_close = SC}) -> + case SC of + true -> + file:close(Device); + false -> + ok + end. diff --git a/lib/erl/src/thrift_framed_transport.erl b/lib/erl/src/thrift_framed_transport.erl new file mode 100644 index 00000000..01bab70b --- /dev/null +++ b/lib/erl/src/thrift_framed_transport.erl @@ -0,0 +1,208 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_framed_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(framed_transport, {wrapped, % a thrift_transport + read_buffer, % iolist() + write_buffer % iolist() + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(WrappedTransport) -> + case gen_server:start_link(?MODULE, [WrappedTransport], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([Wrapped]) -> + {ok, #framed_transport{wrapped = Wrapped, + read_buffer = [], + write_buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #framed_transport{write_buffer = WBuf}) -> + {reply, ok, State#framed_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #framed_transport{wrapped = Wrapped, + read_buffer = RBuf}) -> + {RBuf1, RBuf1Size} = + %% if the read buffer is empty, read another frame + %% otherwise, just read from what's left in the buffer + case iolist_size(RBuf) of + 0 -> + %% read the frame length + {ok, <>} = + thrift_transport:read(Wrapped, 4), + %% then read the data + {ok, Bin} = + thrift_transport:read(Wrapped, FrameLen), + {Bin, erlang:byte_size(Bin)}; + Sz -> + {RBuf, Sz} + end, + + %% pull off Give bytes, return them to the user, leave the rest in the buffer + Give = min(RBuf1Size, Len), + <> = iolist_to_binary(RBuf1), + + Response = {ok, Data}, + State1 = State#framed_transport{read_buffer=RBuf2}, + + {reply, Response, State1}; + +handle_call(flush, _From, State) -> + {Response, State1} = do_flush(State), + {reply, Response, State1}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State) -> + {_, State1} = do_flush(State), + %% Wrapped is closed by terminate/2 + %% error_logger:info_msg("thrift_framed_transport ~p: closing", [self()]), + {stop, normal, State}; +handle_cast(Msg, State=#framed_transport{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, State = #framed_transport{wrapped=Wrapped}) -> + thrift_transport:close(Wrapped), + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +do_flush(State = #framed_transport{write_buffer = Buffer, + wrapped = Wrapped}) -> + FrameLen = iolist_size(Buffer), + Data = [<>, Buffer], + + Response = thrift_transport:write(Wrapped, Data), + + thrift_transport:flush(Wrapped), + + State1 = State#framed_transport{write_buffer = []}, + {Response, State1}. + +min(A,B) when A A; +min(_,B) -> B. + diff --git a/lib/erl/src/thrift_http_transport.erl b/lib/erl/src/thrift_http_transport.erl new file mode 100644 index 00000000..f8c18277 --- /dev/null +++ b/lib/erl/src/thrift_http_transport.erl @@ -0,0 +1,199 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_http_transport). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/2, new/3]). + +%% gen_server callbacks +-export([init/1, + handle_call/3, + handle_cast/2, + handle_info/2, + terminate/2, + code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(http_transport, {host, % string() + path, % string() + read_buffer, % iolist() + write_buffer, % iolist() + http_options, % see http(3) + extra_headers % [{str(), str()}, ...] + }). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: new() -> {ok, Transport} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +new(Host, Path) -> + new(Host, Path, _Options = []). + +%%-------------------------------------------------------------------- +%% Options include: +%% {http_options, HttpOptions} = See http(3) +%% {extra_headers, ExtraHeaders} = List of extra HTTP headers +%%-------------------------------------------------------------------- +new(Host, Path, Options) -> + case gen_server:start_link(?MODULE, {Host, Path, Options}, []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer, making a request +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +init({Host, Path, Options}) -> + State1 = #http_transport{host = Host, + path = Path, + read_buffer = [], + write_buffer = [], + http_options = [], + extra_headers = []}, + ApplyOption = + fun + ({http_options, HttpOpts}, State = #http_transport{}) -> + State#http_transport{http_options = HttpOpts}; + ({extra_headers, ExtraHeaders}, State = #http_transport{}) -> + State#http_transport{extra_headers = ExtraHeaders}; + (Other, #http_transport{}) -> + {invalid_option, Other}; + (_, Error) -> + Error + end, + case lists:foldl(ApplyOption, State1, Options) of + State2 = #http_transport{} -> + {ok, State2}; + Else -> + {stop, Else} + end. + +handle_call({write, Data}, _From, State = #http_transport{write_buffer = WBuf}) -> + {reply, ok, State#http_transport{write_buffer = [WBuf, Data]}}; + +handle_call({read, Len}, _From, State = #http_transport{read_buffer = RBuf}) -> + %% Pull off Give bytes, return them to the user, leave the rest in the buffer. + Give = min(iolist_size(RBuf), Len), + case iolist_to_binary(RBuf) of + <> -> + Response = {ok, Data}, + State1 = State#http_transport{read_buffer=RBuf1}, + {reply, Response, State1}; + _ -> + {reply, {error, 'EOF'}, State} + end; + +handle_call(flush, _From, State) -> + {Response, State1} = do_flush(State), + {reply, Response, State1}. + +handle_cast(close, State) -> + {_, State1} = do_flush(State), + {stop, normal, State1}; + +handle_cast(_Msg, State=#http_transport{}) -> + {noreply, State}. + +handle_info(_Info, State) -> + {noreply, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +do_flush(State = #http_transport{host = Host, + path = Path, + read_buffer = Rbuf, + write_buffer = Wbuf, + http_options = HttpOptions, + extra_headers = ExtraHeaders}) -> + case iolist_to_binary(Wbuf) of + <<>> -> + %% Don't bother flushing empty buffers. + {ok, State}; + WBinary -> + {ok, {{_Version, 200, _ReasonPhrase}, _Headers, Body}} = + http:request(post, + {"http://" ++ Host ++ Path, + [{"User-Agent", "Erlang/thrift_http_transport"} | ExtraHeaders], + "application/x-thrift", + WBinary}, + HttpOptions, + [{body_format, binary}]), + + State1 = State#http_transport{read_buffer = [Rbuf, Body], + write_buffer = []}, + {ok, State1} + end. + +min(A,B) when A A; +min(_,B) -> B. diff --git a/lib/erl/src/thrift_memory_buffer.erl b/lib/erl/src/thrift_memory_buffer.erl new file mode 100644 index 00000000..b4f607a9 --- /dev/null +++ b/lib/erl/src/thrift_memory_buffer.erl @@ -0,0 +1,164 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_memory_buffer). + +-behaviour(gen_server). +-behaviour(thrift_transport). + +%% API +-export([new/0, new_transport_factory/0]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%% thrift_transport callbacks +-export([write/2, read/2, flush/1, close/1]). + +-record(memory_buffer, {buffer}). + +%%==================================================================== +%% API +%%==================================================================== +new() -> + case gen_server:start_link(?MODULE, [], []) of + {ok, Pid} -> + thrift_transport:new(?MODULE, Pid); + Else -> + Else + end. + +new_transport_factory() -> + {ok, fun() -> new() end}. + +%%-------------------------------------------------------------------- +%% Function: write(Transport, Data) -> ok +%% +%% Data = iolist() +%% +%% Description: Writes data into the buffer +%%-------------------------------------------------------------------- +write(Transport, Data) -> + gen_server:call(Transport, {write, Data}). + +%%-------------------------------------------------------------------- +%% Function: flush(Transport) -> ok +%% +%% Description: Flushes the buffer through to the wrapped transport +%%-------------------------------------------------------------------- +flush(Transport) -> + gen_server:call(Transport, flush). + +%%-------------------------------------------------------------------- +%% Function: close(Transport) -> ok +%% +%% Description: Closes the transport and the wrapped transport +%%-------------------------------------------------------------------- +close(Transport) -> + gen_server:cast(Transport, close). + +%%-------------------------------------------------------------------- +%% Function: Read(Transport, Len) -> {ok, Data} +%% +%% Data = binary() +%% +%% Description: Reads data through from the wrapped transoprt +%%-------------------------------------------------------------------- +read(Transport, Len) when is_integer(Len) -> + gen_server:call(Transport, {read, Len}). + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init([]) -> + {ok, #memory_buffer{buffer = []}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call({write, Data}, _From, State = #memory_buffer{buffer = Buf}) -> + {reply, ok, State#memory_buffer{buffer = [Buf, Data]}}; + +handle_call({read, Len}, _From, State = #memory_buffer{buffer = Buf}) -> + Binary = iolist_to_binary(Buf), + Give = min(iolist_size(Binary), Len), + {Result, Remaining} = split_binary(Binary, Give), + {reply, {ok, Result}, State#memory_buffer{buffer = Remaining}}; + +handle_call(flush, _From, State) -> + {reply, ok, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(close, State) -> + {stop, normal, State}; +handle_cast(Msg, State=#memory_buffer{}) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, _State) -> + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +min(A,B) when A A; +min(_,B) -> B. + diff --git a/lib/erl/src/thrift_processor.erl b/lib/erl/src/thrift_processor.erl new file mode 100644 index 00000000..e26fb330 --- /dev/null +++ b/lib/erl/src/thrift_processor.erl @@ -0,0 +1,188 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_processor). + +-export([init/1]). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(thrift_processor, {handler, in_protocol, out_protocol, service}). + +init({Server, ProtoGen, Service, Handler}) when is_function(ProtoGen, 0) -> + {ok, IProt, OProt} = ProtoGen(), + loop(#thrift_processor{in_protocol = IProt, + out_protocol = OProt, + service = Service, + handler = Handler}). + +loop(State = #thrift_processor{in_protocol = IProto, + out_protocol = OProto}) -> + case thrift_protocol:read(IProto, message_begin) of + #protocol_message_begin{name = Function, + type = ?tMessageType_CALL} -> + ok = handle_function(State, list_to_atom(Function)), + loop(State); + #protocol_message_begin{name = Function, + type = ?tMessageType_ONEWAY} -> + ok = handle_function(State, list_to_atom(Function)), + loop(State); + {error, timeout} -> + thrift_protocol:close_transport(OProto), + ok; + {error, closed} -> + %% error_logger:info_msg("Client disconnected~n"), + thrift_protocol:close_transport(OProto), + exit(shutdown) + end. + +handle_function(State=#thrift_processor{in_protocol = IProto, + out_protocol = OProto, + handler = Handler, + service = Service}, + Function) -> + InParams = Service:function_info(Function, params_type), + + {ok, Params} = thrift_protocol:read(IProto, InParams), + + try + Result = Handler:handle_function(Function, Params), + %% {Micro, Result} = better_timer(Handler, handle_function, [Function, Params]), + %% error_logger:info_msg("Processed ~p(~p) in ~.4fms~n", + %% [Function, Params, Micro/1000.0]), + handle_success(State, Function, Result) + catch + Type:Data -> + handle_function_catch(State, Function, Type, Data) + end, + after_reply(OProto). + +handle_function_catch(State = #thrift_processor{service = Service}, + Function, ErrType, ErrData) -> + IsOneway = Service:function_info(Function, reply_type) =:= oneway_void, + + case {ErrType, ErrData} of + _ when IsOneway -> + Stack = erlang:get_stacktrace(), + error_logger:warning_msg( + "oneway void ~p threw error which must be ignored: ~p", + [Function, {ErrType, ErrData, Stack}]), + ok; + + {throw, Exception} when is_tuple(Exception), size(Exception) > 0 -> + error_logger:warning_msg("~p threw exception: ~p~n", [Function, Exception]), + handle_exception(State, Function, Exception), + ok; % we still want to accept more requests from this client + + {error, Error} -> + ok = handle_error(State, Function, Error) + end. + +handle_success(State = #thrift_processor{out_protocol = OProto, + service = Service}, + Function, + Result) -> + ReplyType = Service:function_info(Function, reply_type), + StructName = atom_to_list(Function) ++ "_result", + + ok = case Result of + {reply, ReplyData} -> + Reply = {{struct, [{0, ReplyType}]}, {StructName, ReplyData}}, + send_reply(OProto, Function, ?tMessageType_REPLY, Reply); + + ok when ReplyType == {struct, []} -> + send_reply(OProto, Function, ?tMessageType_REPLY, {ReplyType, {StructName}}); + + ok when ReplyType == oneway_void -> + %% no reply for oneway void + ok + end. + +handle_exception(State = #thrift_processor{out_protocol = OProto, + service = Service}, + Function, + Exception) -> + ExceptionType = element(1, Exception), + %% Fetch a structure like {struct, [{-2, {struct, {Module, Type}}}, + %% {-3, {struct, {Module, Type}}}]} + + ReplySpec = Service:function_info(Function, exceptions), + {struct, XInfo} = ReplySpec, + + true = is_list(XInfo), + + %% Assuming we had a type1 exception, we'd get: [undefined, Exception, undefined] + %% e.g.: [{-1, type0}, {-2, type1}, {-3, type2}] + ExceptionList = [case Type of + ExceptionType -> Exception; + _ -> undefined + end + || {_Fid, {struct, {_Module, Type}}} <- XInfo], + + ExceptionTuple = list_to_tuple([Function | ExceptionList]), + + % Make sure we got at least one defined + case lists:all(fun(X) -> X =:= undefined end, ExceptionList) of + true -> + ok = handle_unknown_exception(State, Function, Exception); + false -> + ok = send_reply(OProto, Function, ?tMessageType_REPLY, {ReplySpec, ExceptionTuple}) + end. + +%% +%% Called when an exception has been explicitly thrown by the service, but it was +%% not one of the exceptions that was defined for the function. +%% +handle_unknown_exception(State, Function, Exception) -> + handle_error(State, Function, {exception_not_declared_as_thrown, + Exception}). + +handle_error(#thrift_processor{out_protocol = OProto}, Function, Error) -> + Stack = erlang:get_stacktrace(), + error_logger:error_msg("~p had an error: ~p~n", [Function, {Error, Stack}]), + + Message = + case application:get_env(thrift, exceptions_include_traces) of + {ok, true} -> + lists:flatten(io_lib:format("An error occurred: ~p~n", + [{Error, Stack}])); + _ -> + "An unknown handler error occurred." + end, + Reply = {?TApplicationException_Structure, + #'TApplicationException'{ + message = Message, + type = ?TApplicationException_UNKNOWN}}, + send_reply(OProto, Function, ?tMessageType_EXCEPTION, Reply). + +send_reply(OProto, Function, ReplyMessageType, Reply) -> + ok = thrift_protocol:write(OProto, #protocol_message_begin{ + name = atom_to_list(Function), + type = ReplyMessageType, + seqid = 0}), + ok = thrift_protocol:write(OProto, Reply), + ok = thrift_protocol:write(OProto, message_end), + ok = thrift_protocol:flush_transport(OProto), + ok. + +after_reply(OProto) -> + ok = thrift_protocol:flush_transport(OProto) + %% ok = thrift_protocol:close_transport(OProto) + . diff --git a/lib/erl/src/thrift_protocol.erl b/lib/erl/src/thrift_protocol.erl new file mode 100644 index 00000000..1bfb0a42 --- /dev/null +++ b/lib/erl/src/thrift_protocol.erl @@ -0,0 +1,356 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_protocol). + +-export([new/2, + write/2, + read/2, + read/3, + skip/2, + flush_transport/1, + close_transport/1, + typeid_to_atom/1 + ]). + +-export([behaviour_info/1]). + +-include("thrift_constants.hrl"). +-include("thrift_protocol.hrl"). + +-record(protocol, {module, data}). + +behaviour_info(callbacks) -> + [ + {read, 2}, + {write, 2}, + {flush_transport, 1}, + {close_transport, 1} + ]; +behaviour_info(_Else) -> undefined. + +new(Module, Data) when is_atom(Module) -> + {ok, #protocol{module = Module, + data = Data}}. + +flush_transport(#protocol{module = Module, + data = Data}) -> + Module:flush_transport(Data). + +close_transport(#protocol{module = Module, + data = Data}) -> + Module:close_transport(Data). + +typeid_to_atom(?tType_STOP) -> field_stop; +typeid_to_atom(?tType_VOID) -> void; +typeid_to_atom(?tType_BOOL) -> bool; +typeid_to_atom(?tType_BYTE) -> byte; +typeid_to_atom(?tType_DOUBLE) -> double; +typeid_to_atom(?tType_I16) -> i16; +typeid_to_atom(?tType_I32) -> i32; +typeid_to_atom(?tType_I64) -> i64; +typeid_to_atom(?tType_STRING) -> string; +typeid_to_atom(?tType_STRUCT) -> struct; +typeid_to_atom(?tType_MAP) -> map; +typeid_to_atom(?tType_SET) -> set; +typeid_to_atom(?tType_LIST) -> list. + +term_to_typeid(void) -> ?tType_VOID; +term_to_typeid(bool) -> ?tType_BOOL; +term_to_typeid(byte) -> ?tType_BYTE; +term_to_typeid(double) -> ?tType_DOUBLE; +term_to_typeid(i16) -> ?tType_I16; +term_to_typeid(i32) -> ?tType_I32; +term_to_typeid(i64) -> ?tType_I64; +term_to_typeid(string) -> ?tType_STRING; +term_to_typeid({struct, _}) -> ?tType_STRUCT; +term_to_typeid({map, _, _}) -> ?tType_MAP; +term_to_typeid({set, _}) -> ?tType_SET; +term_to_typeid({list, _}) -> ?tType_LIST. + +%% Structure is like: +%% [{Fid, Type}, ...] +read(IProto, {struct, Structure}, Tag) + when is_list(Structure), is_atom(Tag) -> + + % If we want a tagged tuple, we need to offset all the tuple indices + % by 1 to avoid overwriting the tag. + Offset = if Tag =/= undefined -> 1; true -> 0 end, + IndexList = case length(Structure) of + N when N > 0 -> lists:seq(1 + Offset, N + Offset); + _ -> [] + end, + + SWithIndices = [{Fid, {Type, Index}} || + {{Fid, Type}, Index} <- + lists:zip(Structure, IndexList)], + % Fid -> {Type, Index} + SDict = dict:from_list(SWithIndices), + + ok = read(IProto, struct_begin), + RTuple0 = erlang:make_tuple(length(Structure) + Offset, undefined), + RTuple1 = if Tag =/= undefined -> setelement(1, RTuple0, Tag); + true -> RTuple0 + end, + + RTuple2 = read_struct_loop(IProto, SDict, RTuple1), + {ok, RTuple2}. + +read(IProto, {struct, {Module, StructureName}}) when is_atom(Module), + is_atom(StructureName) -> + read(IProto, Module:struct_info(StructureName), StructureName); + +read(IProto, S={struct, Structure}) when is_list(Structure) -> + read(IProto, S, undefined); + +read(IProto, {list, Type}) -> + #protocol_list_begin{etype = EType, size = Size} = + read(IProto, list_begin), + List = [Result || {ok, Result} <- + [read(IProto, Type) || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, list_end), + {ok, List}; + +read(IProto, {map, KeyType, ValType}) -> + #protocol_map_begin{size = Size} = + read(IProto, map_begin), + + List = [{Key, Val} || {{ok, Key}, {ok, Val}} <- + [{read(IProto, KeyType), + read(IProto, ValType)} || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, map_end), + {ok, dict:from_list(List)}; + +read(IProto, {set, Type}) -> + #protocol_set_begin{etype = _EType, + size = Size} = + read(IProto, set_begin), + List = [Result || {ok, Result} <- + [read(IProto, Type) || _X <- lists:duplicate(Size, 0)]], + ok = read(IProto, set_end), + {ok, sets:from_list(List)}; + +read(#protocol{module = Module, + data = ModuleData}, ProtocolType) -> + Module:read(ModuleData, ProtocolType). + +read_struct_loop(IProto, SDict, RTuple) -> + #protocol_field_begin{type = FType, id = Fid, name = Name} = + thrift_protocol:read(IProto, field_begin), + case {FType, Fid} of + {?tType_STOP, _} -> + RTuple; + _Else -> + case dict:find(Fid, SDict) of + {ok, {Type, Index}} -> + case term_to_typeid(Type) of + FType -> + {ok, Val} = read(IProto, Type), + thrift_protocol:read(IProto, field_end), + NewRTuple = setelement(Index, RTuple, Val), + read_struct_loop(IProto, SDict, NewRTuple); + Expected -> + error_logger:info_msg( + "Skipping field ~p with wrong type (~p != ~p)~n", + [Fid, FType, Expected]), + skip_field(FType, IProto, SDict, RTuple) + end; + _Else2 -> + error_logger:info_msg("Skipping field ~p with unknown fid~n", [Fid]), + skip_field(FType, IProto, SDict, RTuple) + end + end. + +skip_field(FType, IProto, SDict, RTuple) -> + FTypeAtom = thrift_protocol:typeid_to_atom(FType), + thrift_protocol:skip(IProto, FTypeAtom), + read(IProto, field_end), + read_struct_loop(IProto, SDict, RTuple). + + +skip(Proto, struct) -> + ok = read(Proto, struct_begin), + ok = skip_struct_loop(Proto), + ok = read(Proto, struct_end); + +skip(Proto, map) -> + Map = read(Proto, map_begin), + ok = skip_map_loop(Proto, Map), + ok = read(Proto, map_end); + +skip(Proto, set) -> + Set = read(Proto, set_begin), + ok = skip_set_loop(Proto, Set), + ok = read(Proto, set_end); + +skip(Proto, list) -> + List = read(Proto, list_begin), + ok = skip_list_loop(Proto, List), + ok = read(Proto, list_end); + +skip(Proto, Type) when is_atom(Type) -> + _Ignore = read(Proto, Type), + ok. + + +skip_struct_loop(Proto) -> + #protocol_field_begin{type = Type} = read(Proto, field_begin), + case Type of + ?tType_STOP -> + ok; + _Else -> + skip(Proto, Type), + ok = read(Proto, field_end), + skip_struct_loop(Proto) + end. + +skip_map_loop(Proto, Map = #protocol_map_begin{ktype = Ktype, + vtype = Vtype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Ktype), + skip(Proto, Vtype), + skip_map_loop(Proto, + Map#protocol_map_begin{size = Size - 1}); + 0 -> ok + end. + +skip_set_loop(Proto, Map = #protocol_set_begin{etype = Etype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Etype), + skip_set_loop(Proto, + Map#protocol_set_begin{size = Size - 1}); + 0 -> ok + end. + +skip_list_loop(Proto, Map = #protocol_list_begin{etype = Etype, + size = Size}) -> + case Size of + N when N > 0 -> + skip(Proto, Etype), + skip_list_loop(Proto, + Map#protocol_list_begin{size = Size - 1}); + 0 -> ok + end. + + +%%-------------------------------------------------------------------- +%% Function: write(OProto, {Type, Data}) -> ok +%% +%% Type = {struct, StructDef} | +%% {list, Type} | +%% {map, KeyType, ValType} | +%% {set, Type} | +%% BaseType +%% +%% Data = +%% tuple() -- for struct +%% | list() -- for list +%% | dictionary() -- for map +%% | set() -- for set +%% | term() -- for base types +%% +%% Description: +%%-------------------------------------------------------------------- +write(Proto, {{struct, StructDef}, Data}) + when is_list(StructDef), is_tuple(Data), length(StructDef) == size(Data) - 1 -> + + [StructName | Elems] = tuple_to_list(Data), + ok = write(Proto, #protocol_struct_begin{name = StructName}), + ok = struct_write_loop(Proto, StructDef, Elems), + ok = write(Proto, struct_end), + ok; + +write(Proto, {{struct, {Module, StructureName}}, Data}) + when is_atom(Module), + is_atom(StructureName), + element(1, Data) =:= StructureName -> + StructType = Module:struct_info(StructureName), + write(Proto, {Module:struct_info(StructureName), Data}); + +write(Proto, {{list, Type}, Data}) + when is_list(Data) -> + ok = write(Proto, + #protocol_list_begin{ + etype = term_to_typeid(Type), + size = length(Data) + }), + lists:foreach(fun(Elem) -> + ok = write(Proto, {Type, Elem}) + end, + Data), + ok = write(Proto, list_end), + ok; + +write(Proto, {{map, KeyType, ValType}, Data}) -> + ok = write(Proto, + #protocol_map_begin{ + ktype = term_to_typeid(KeyType), + vtype = term_to_typeid(ValType), + size = dict:size(Data) + }), + dict:fold(fun(KeyData, ValData, _Acc) -> + ok = write(Proto, {KeyType, KeyData}), + ok = write(Proto, {ValType, ValData}) + end, + _AccO = ok, + Data), + ok = write(Proto, map_end), + ok; + +write(Proto, {{set, Type}, Data}) -> + true = sets:is_set(Data), + ok = write(Proto, + #protocol_set_begin{ + etype = term_to_typeid(Type), + size = sets:size(Data) + }), + sets:fold(fun(Elem, _Acc) -> + ok = write(Proto, {Type, Elem}) + end, + _Acc0 = ok, + Data), + ok = write(Proto, set_end), + ok; + +write(#protocol{module = Module, + data = ModuleData}, Data) -> + Module:write(ModuleData, Data). + +struct_write_loop(Proto, [{Fid, Type} | RestStructDef], [Data | RestData]) -> + case Data of + undefined -> + % null fields are skipped in response + skip; + _ -> + ok = write(Proto, + #protocol_field_begin{ + type = term_to_typeid(Type), + id = Fid + }), + ok = write(Proto, {Type, Data}), + ok = write(Proto, field_end) + end, + struct_write_loop(Proto, RestStructDef, RestData); +struct_write_loop(Proto, [], []) -> + ok = write(Proto, field_stop), + ok. diff --git a/lib/erl/src/thrift_server.erl b/lib/erl/src/thrift_server.erl new file mode 100644 index 00000000..5d0012ba --- /dev/null +++ b/lib/erl/src/thrift_server.erl @@ -0,0 +1,183 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_server). + +-behaviour(gen_server). + +%% API +-export([start_link/3, stop/1, take_socket/2]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +-define(SERVER, ?MODULE). + +-record(state, {listen_socket, acceptor_ref, service, handler}). + +%%==================================================================== +%% API +%%==================================================================== +%%-------------------------------------------------------------------- +%% Function: start_link() -> {ok,Pid} | ignore | {error,Error} +%% Description: Starts the server +%%-------------------------------------------------------------------- +start_link(Port, Service, HandlerModule) when is_integer(Port), is_atom(HandlerModule) -> + gen_server:start_link({local, ?SERVER}, ?MODULE, {Port, Service, HandlerModule}, []). + +%%-------------------------------------------------------------------- +%% Function: stop(Pid) -> ok, {error, Reason} +%% Description: Stops the server. +%%-------------------------------------------------------------------- +stop(Pid) when is_pid(Pid) -> + gen_server:call(Pid, stop). + + +take_socket(Server, Socket) -> + gen_server:call(Server, {take_socket, Socket}). + + +%%==================================================================== +%% gen_server callbacks +%%==================================================================== + +%%-------------------------------------------------------------------- +%% Function: init(Args) -> {ok, State} | +%% {ok, State, Timeout} | +%% ignore | +%% {stop, Reason} +%% Description: Initiates the server +%%-------------------------------------------------------------------- +init({Port, Service, Handler}) -> + {ok, Socket} = gen_tcp:listen(Port, + [binary, + {packet, 0}, + {active, false}, + {nodelay, true}, + {reuseaddr, true}]), + {ok, Ref} = prim_inet:async_accept(Socket, -1), + {ok, #state{listen_socket = Socket, + acceptor_ref = Ref, + service = Service, + handler = Handler}}. + +%%-------------------------------------------------------------------- +%% Function: %% handle_call(Request, From, State) -> {reply, Reply, State} | +%% {reply, Reply, State, Timeout} | +%% {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, Reply, State} | +%% {stop, Reason, State} +%% Description: Handling call messages +%%-------------------------------------------------------------------- +handle_call(stop, _From, State) -> + {stop, stopped, ok, State}; + +handle_call({take_socket, Socket}, {FromPid, _Tag}, State) -> + Result = gen_tcp:controlling_process(Socket, FromPid), + {reply, Result, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_cast(Msg, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling cast messages +%%-------------------------------------------------------------------- +handle_cast(_Msg, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: handle_info(Info, State) -> {noreply, State} | +%% {noreply, State, Timeout} | +%% {stop, Reason, State} +%% Description: Handling all non call/cast messages +%%-------------------------------------------------------------------- +handle_info({inet_async, ListenSocket, Ref, {ok, ClientSocket}}, + State = #state{listen_socket = ListenSocket, + acceptor_ref = Ref, + service = Service, + handler = Handler}) -> + case set_sockopt(ListenSocket, ClientSocket) of + ok -> + %% New client connected - start processor + start_processor(ClientSocket, Service, Handler), + {ok, NewRef} = prim_inet:async_accept(ListenSocket, -1), + {noreply, State#state{acceptor_ref = NewRef}}; + {error, Reason} -> + error_logger:error_msg("Couldn't set socket opts: ~p~n", + [Reason]), + {stop, Reason, State} + end; + +handle_info({inet_async, ListenSocket, Ref, Error}, State) -> + error_logger:error_msg("Error in acceptor: ~p~n", [Error]), + {stop, Error, State}; + +handle_info(_Info, State) -> + {noreply, State}. + +%%-------------------------------------------------------------------- +%% Function: terminate(Reason, State) -> void() +%% Description: This function is called by a gen_server when it is about to +%% terminate. It should be the opposite of Module:init/1 and do any necessary +%% cleaning up. When it returns, the gen_server terminates with Reason. +%% The return value is ignored. +%%-------------------------------------------------------------------- +terminate(_Reason, _State) -> + ok. + +%%-------------------------------------------------------------------- +%% Func: code_change(OldVsn, State, Extra) -> {ok, NewState} +%% Description: Convert process state when code is changed +%%-------------------------------------------------------------------- +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%-------------------------------------------------------------------- +%%% Internal functions +%%-------------------------------------------------------------------- +set_sockopt(ListenSocket, ClientSocket) -> + true = inet_db:register_socket(ClientSocket, inet_tcp), + case prim_inet:getopts(ListenSocket, + [active, nodelay, keepalive, delay_send, priority, tos]) of + {ok, Opts} -> + case prim_inet:setopts(ClientSocket, Opts) of + ok -> ok; + Error -> gen_tcp:close(ClientSocket), + Error + end; + Error -> + gen_tcp:close(ClientSocket), + Error + end. + +start_processor(Socket, Service, Handler) -> + Server = self(), + + ProtoGen = fun() -> + % Become the controlling process + ok = take_socket(Server, Socket), + {ok, SocketTransport} = thrift_socket_transport:new(Socket), + {ok, BufferedTransport} = thrift_buffered_transport:new(SocketTransport), + {ok, Protocol} = thrift_binary_protocol:new(BufferedTransport), + {ok, Protocol, Protocol} + end, + + spawn(thrift_processor, init, [{Server, ProtoGen, Service, Handler}]). diff --git a/lib/erl/src/thrift_service.erl b/lib/erl/src/thrift_service.erl new file mode 100644 index 00000000..2ed7b57b --- /dev/null +++ b/lib/erl/src/thrift_service.erl @@ -0,0 +1,25 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_service). + +-export([behaviour_info/1]). + +behaviour_info(callbacks) -> + [{function_info, 2}]. diff --git a/lib/erl/src/thrift_socket_server.erl b/lib/erl/src/thrift_socket_server.erl new file mode 100644 index 00000000..62bdfdaf --- /dev/null +++ b/lib/erl/src/thrift_socket_server.erl @@ -0,0 +1,249 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_socket_server). + +-behaviour(gen_server). + +-export([start/1, stop/1]). + +-export([init/1, handle_call/3, handle_cast/2, terminate/2, code_change/3, + handle_info/2]). + +-export([acceptor_loop/1]). + +-record(thrift_socket_server, + {port, + service, + handler, + name, + max=2048, + ip=any, + listen=null, + acceptor=null, + socket_opts=[{recv_timeout, 500}] + }). + +start(State=#thrift_socket_server{}) -> + start_server(State); +start(Options) -> + start(parse_options(Options)). + +stop(Name) when is_atom(Name) -> + gen_server:cast(Name, stop); +stop(Pid) when is_pid(Pid) -> + gen_server:cast(Pid, stop); +stop({local, Name}) -> + stop(Name); +stop({global, Name}) -> + stop(Name); +stop(Options) -> + State = parse_options(Options), + stop(State#thrift_socket_server.name). + +%% Internal API + +parse_options(Options) -> + parse_options(Options, #thrift_socket_server{}). + +parse_options([], State) -> + State; +parse_options([{name, L} | Rest], State) when is_list(L) -> + Name = {local, list_to_atom(L)}, + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{name, A} | Rest], State) when is_atom(A) -> + Name = {local, A}, + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{name, Name} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{name=Name}); +parse_options([{port, L} | Rest], State) when is_list(L) -> + Port = list_to_integer(L), + parse_options(Rest, State#thrift_socket_server{port=Port}); +parse_options([{port, Port} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{port=Port}); +parse_options([{ip, Ip} | Rest], State) -> + ParsedIp = case Ip of + any -> + any; + Ip when is_tuple(Ip) -> + Ip; + Ip when is_list(Ip) -> + {ok, IpTuple} = inet_parse:address(Ip), + IpTuple + end, + parse_options(Rest, State#thrift_socket_server{ip=ParsedIp}); +parse_options([{socket_opts, L} | Rest], State) when is_list(L), length(L) > 0 -> + parse_options(Rest, State#thrift_socket_server{socket_opts=L}); +parse_options([{handler, Handler} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{handler=Handler}); +parse_options([{service, Service} | Rest], State) -> + parse_options(Rest, State#thrift_socket_server{service=Service}); +parse_options([{max, Max} | Rest], State) -> + MaxInt = case Max of + Max when is_list(Max) -> + list_to_integer(Max); + Max when is_integer(Max) -> + Max + end, + parse_options(Rest, State#thrift_socket_server{max=MaxInt}). + +start_server(State=#thrift_socket_server{name=Name}) -> + case Name of + undefined -> + gen_server:start_link(?MODULE, State, []); + _ -> + gen_server:start_link(Name, ?MODULE, State, []) + end. + +init(State=#thrift_socket_server{ip=Ip, port=Port}) -> + process_flag(trap_exit, true), + BaseOpts = [binary, + {reuseaddr, true}, + {packet, 0}, + {backlog, 4096}, + {recbuf, 8192}, + {active, false}], + Opts = case Ip of + any -> + BaseOpts; + Ip -> + [{ip, Ip} | BaseOpts] + end, + case gen_tcp_listen(Port, Opts, State) of + {stop, eacces} -> + %% fdsrv module allows another shot to bind + %% ports which require root access + case Port < 1024 of + true -> + case fdsrv:start() of + {ok, _} -> + case fdsrv:bind_socket(tcp, Port) of + {ok, Fd} -> + gen_tcp_listen(Port, [{fd, Fd} | Opts], State); + _ -> + {stop, fdsrv_bind_failed} + end; + _ -> + {stop, fdsrv_start_failed} + end; + false -> + {stop, eacces} + end; + Other -> + error_logger:info_msg("thrift service listening on port ~p", [Port]), + Other + end. + +gen_tcp_listen(Port, Opts, State) -> + case gen_tcp:listen(Port, Opts) of + {ok, Listen} -> + {ok, ListenPort} = inet:port(Listen), + {ok, new_acceptor(State#thrift_socket_server{listen=Listen, + port=ListenPort})}; + {error, Reason} -> + {stop, Reason} + end. + +new_acceptor(State=#thrift_socket_server{max=0}) -> + error_logger:error_msg("Not accepting new connections"), + State#thrift_socket_server{acceptor=null}; +new_acceptor(State=#thrift_socket_server{acceptor=OldPid, listen=Listen, + service=Service, handler=Handler, + socket_opts=Opts + }) -> + Pid = proc_lib:spawn_link(?MODULE, acceptor_loop, + [{self(), Listen, Service, Handler, Opts}]), +%% error_logger:info_msg("Spawning new acceptor: ~p => ~p", [OldPid, Pid]), + State#thrift_socket_server{acceptor=Pid}. + +acceptor_loop({Server, Listen, Service, Handler, SocketOpts}) + when is_pid(Server), is_list(SocketOpts) -> + case catch gen_tcp:accept(Listen) of % infinite timeout + {ok, Socket} -> + gen_server:cast(Server, {accepted, self()}), + ProtoGen = fun() -> + {ok, SocketTransport} = thrift_socket_transport:new(Socket, SocketOpts), + {ok, BufferedTransport} = thrift_buffered_transport:new(SocketTransport), + {ok, Protocol} = thrift_binary_protocol:new(BufferedTransport), + {ok, IProt=Protocol, OProt=Protocol} + end, + thrift_processor:init({Server, ProtoGen, Service, Handler}); + {error, closed} -> + exit({error, closed}); + Other -> + error_logger:error_report( + [{application, thrift}, + "Accept failed error", + lists:flatten(io_lib:format("~p", [Other]))]), + exit({error, accept_failed}) + end. + +handle_call({get, port}, _From, State=#thrift_socket_server{port=Port}) -> + {reply, Port, State}; +handle_call(_Message, _From, State) -> + Res = error, + {reply, Res, State}. + +handle_cast({accepted, Pid}, + State=#thrift_socket_server{acceptor=Pid, max=Max}) -> + % io:format("accepted ~p~n", [Pid]), + State1 = State#thrift_socket_server{max=Max - 1}, + {noreply, new_acceptor(State1)}; +handle_cast(stop, State) -> + {stop, normal, State}. + +terminate(_Reason, #thrift_socket_server{listen=Listen, port=Port}) -> + gen_tcp:close(Listen), + case Port < 1024 of + true -> + catch fdsrv:stop(), + ok; + false -> + ok + end. + +code_change(_OldVsn, State, _Extra) -> + State. + +handle_info({'EXIT', Pid, normal}, + State=#thrift_socket_server{acceptor=Pid}) -> + {noreply, new_acceptor(State)}; +handle_info({'EXIT', Pid, Reason}, + State=#thrift_socket_server{acceptor=Pid}) -> + error_logger:error_report({?MODULE, ?LINE, + {acceptor_error, Reason}}), + timer:sleep(100), + {noreply, new_acceptor(State)}; +handle_info({'EXIT', _LoopPid, Reason}, + State=#thrift_socket_server{acceptor=Pid, max=Max}) -> + case Reason of + normal -> ok; + shutdown -> ok; + _ -> error_logger:error_report({?MODULE, ?LINE, + {child_error, Reason, erlang:get_stacktrace()}}) + end, + State1 = State#thrift_socket_server{max=Max + 1}, + State2 = case Pid of + null -> new_acceptor(State1); + _ -> State1 + end, + {noreply, State2}; +handle_info(Info, State) -> + error_logger:info_report([{'INFO', Info}, {'State', State}]), + {noreply, State}. diff --git a/lib/erl/src/thrift_socket_transport.erl b/lib/erl/src/thrift_socket_transport.erl new file mode 100644 index 00000000..fcd69449 --- /dev/null +++ b/lib/erl/src/thrift_socket_transport.erl @@ -0,0 +1,119 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_socket_transport). + +-behaviour(thrift_transport). + +-export([new/1, + new/2, + write/2, read/2, flush/1, close/1, + + new_transport_factory/3]). + +-record(data, {socket, + recv_timeout=infinity}). + +new(Socket) -> + new(Socket, []). + +new(Socket, Opts) when is_list(Opts) -> + State = + case lists:keysearch(recv_timeout, 1, Opts) of + {value, {recv_timeout, Timeout}} + when is_integer(Timeout), Timeout > 0 -> + #data{socket=Socket, recv_timeout=Timeout}; + _ -> + #data{socket=Socket} + end, + thrift_transport:new(?MODULE, State). + +%% Data :: iolist() +write(#data{socket = Socket}, Data) -> + gen_tcp:send(Socket, Data). + +read(#data{socket=Socket, recv_timeout=Timeout}, Len) + when is_integer(Len), Len >= 0 -> + case gen_tcp:recv(Socket, Len, Timeout) of + Err = {error, timeout} -> + error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]), + gen_tcp:close(Socket), + Err; + Data -> Data + end. + +%% We can't really flush - everything is flushed when we write +flush(_) -> + ok. + +close(#data{socket = Socket}) -> + gen_tcp:close(Socket). + + +%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + + +%% The following "local" record is filled in by parse_factory_options/2 +%% below. These options can be passed to new_protocol_factory/3 in a +%% proplists-style option list. They're parsed like this so it is an O(n) +%% operation instead of O(n^2) +-record(factory_opts, {connect_timeout = infinity, + sockopts = [], + framed = false}). + +parse_factory_options([], Opts) -> + Opts; +parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) -> + parse_factory_options(Rest, Opts#factory_opts{framed=Bool}); +parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) -> + parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList}); +parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) -> + parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO}). + + +%% +%% Generates a "transport factory" function - a fun which returns a thrift_transport() +%% instance. +%% This can be passed into a protocol factory to generate a connection to a +%% thrift server over a socket. +%% +new_transport_factory(Host, Port, Options) -> + ParsedOpts = parse_factory_options(Options, #factory_opts{}), + + F = fun() -> + SockOpts = [binary, + {packet, 0}, + {active, false}, + {nodelay, true} | + ParsedOpts#factory_opts.sockopts], + case catch gen_tcp:connect(Host, Port, SockOpts, + ParsedOpts#factory_opts.connect_timeout) of + {ok, Sock} -> + {ok, Transport} = thrift_socket_transport:new(Sock), + {ok, BufTransport} = + case ParsedOpts#factory_opts.framed of + true -> thrift_framed_transport:new(Transport); + false -> thrift_buffered_transport:new(Transport) + end, + {ok, BufTransport}; + Error -> + Error + end + end, + {ok, F}. diff --git a/lib/erl/src/thrift_transport.erl b/lib/erl/src/thrift_transport.erl new file mode 100644 index 00000000..20c4b5dc --- /dev/null +++ b/lib/erl/src/thrift_transport.erl @@ -0,0 +1,57 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(thrift_transport). + +-export([behaviour_info/1]). + +-export([new/2, + write/2, + read/2, + flush/1, + close/1 + ]). + +behaviour_info(callbacks) -> + [{read, 2}, + {write, 2}, + {flush, 1}, + {close, 1} + ]. + +-record(transport, {module, data}). + +new(Module, Data) when is_atom(Module) -> + {ok, #transport{module = Module, + data = Data}}. + +%% Data :: iolist() +write(Transport, Data) -> + Module = Transport#transport.module, + Module:write(Transport#transport.data, Data). + +read(Transport, Len) when is_integer(Len) -> + Module = Transport#transport.module, + Module:read(Transport#transport.data, Len). + +flush(#transport{module = Module, data = Data}) -> + Module:flush(Data). + +close(#transport{module = Module, data = Data}) -> + Module:close(Data). diff --git a/lib/erl/vsn.mk b/lib/erl/vsn.mk new file mode 100644 index 00000000..d9b40014 --- /dev/null +++ b/lib/erl/vsn.mk @@ -0,0 +1 @@ +THRIFT_VSN=0.1 diff --git a/lib/hs/.gitignore b/lib/hs/.gitignore new file mode 100644 index 00000000..849ddff3 --- /dev/null +++ b/lib/hs/.gitignore @@ -0,0 +1 @@ +dist/ diff --git a/lib/hs/README b/lib/hs/README new file mode 100644 index 00000000..e58c8c93 --- /dev/null +++ b/lib/hs/README @@ -0,0 +1,82 @@ +Haskell Thrift Bindings + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Running +======= + +You need -fglasgow-exts. Use Cabal to compile and install. If you're trying to +manually compile or load via ghci, and you're using ghc 6.10 (or really if your +default base package has major version number 4), you must specify a version of +the base package with major version number 3. Furthermore if you have the syb +package installed you need to hide that package to avoid import conflicts. +Here's an example of what I'm talking about: + + ghci -fglasgow-exts -package base-3.0.3.0 -hide-package syb -isrc Thrift.hs + +To determine which versions of the base package you have installed use the +following command: + + ghc-pkg list base + +All of this is taken care of for you if you use Cabal. + + +Enums +===== + +become haskell data types. Use fromEnum to get out the int value. + +Structs +======= + +become records. Field labels are ugly, of the form f_STRUCTNAME_FIELDNAME. All +fields are Maybe types. + +Exceptions +========== + +identical to structs. Throw them with throwDyn. Catch them with catchDyn. + +Client +====== + +just a bunch of functions. You may have to import a bunch of client files to +deal with inheritance. + +Interface +========= + +You should only have to import the last one in the chain of inheritors. To make +an interface, declare a label: + + data MyIface = MyIface + +and then declare it an instance of each iface class, starting with the superest +class and proceding down (all the while defining the methods). Then pass your +label to process as the handler. + +Processor +========= + +Just a function that takes a handler label, protocols. It calls the +superclasses process if there is a superclass. + diff --git a/lib/hs/Setup.lhs b/lib/hs/Setup.lhs new file mode 100644 index 00000000..c9e6d970 --- /dev/null +++ b/lib/hs/Setup.lhs @@ -0,0 +1,23 @@ +#!/usr/bin/env runhaskell + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +> import Distribution.Simple +> main = defaultMain diff --git a/lib/hs/TODO b/lib/hs/TODO new file mode 100644 index 00000000..13681732 --- /dev/null +++ b/lib/hs/TODO @@ -0,0 +1,2 @@ +The library could stand to be built up more. +Many modules need export lists. diff --git a/lib/hs/Thrift.cabal b/lib/hs/Thrift.cabal new file mode 100644 index 00000000..4cef4de6 --- /dev/null +++ b/lib/hs/Thrift.cabal @@ -0,0 +1,20 @@ +Name: Thrift +Version: 0.1.0 +Cabal-Version: >= 1.2 +License: Apache2 +Category: Foreign +Build-Type: Simple +Synopsis: Thrift library package + +Library + Hs-Source-Dirs: + src + Build-Depends: + base >=4, network, ghc-prim + ghc-options: + -fglasgow-exts + Extensions: + DeriveDataTypeable + Exposed-Modules: + Thrift, Thrift.Protocol, Thrift.Transport, Thrift.Protocol.Binary + Thrift.Transport.Handle, Thrift.Server diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs new file mode 100644 index 00000000..291bcae5 --- /dev/null +++ b/lib/hs/src/Thrift.hs @@ -0,0 +1,111 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift + ( module Thrift.Transport + , module Thrift.Protocol + , AppExnType(..) + , AppExn(..) + , readAppExn + , writeAppExn + , ThriftException(..) + ) where + +import Control.Monad ( when ) +import Control.Exception + +import Data.Typeable ( Typeable ) + +import Thrift.Transport +import Thrift.Protocol + + +data ThriftException = ThriftException + deriving ( Show, Typeable ) +instance Exception ThriftException + +data AppExnType + = AE_UNKNOWN + | AE_UNKNOWN_METHOD + | AE_INVALID_MESSAGE_TYPE + | AE_WRONG_METHOD_NAME + | AE_BAD_SEQUENCE_ID + | AE_MISSING_RESULT + deriving ( Eq, Show, Typeable ) + +instance Enum AppExnType where + toEnum 0 = AE_UNKNOWN + toEnum 1 = AE_UNKNOWN_METHOD + toEnum 2 = AE_INVALID_MESSAGE_TYPE + toEnum 3 = AE_WRONG_METHOD_NAME + toEnum 4 = AE_BAD_SEQUENCE_ID + toEnum 5 = AE_MISSING_RESULT + + fromEnum AE_UNKNOWN = 0 + fromEnum AE_UNKNOWN_METHOD = 1 + fromEnum AE_INVALID_MESSAGE_TYPE = 2 + fromEnum AE_WRONG_METHOD_NAME = 3 + fromEnum AE_BAD_SEQUENCE_ID = 4 + fromEnum AE_MISSING_RESULT = 5 + +data AppExn = AppExn { ae_type :: AppExnType, ae_message :: String } + deriving ( Show, Typeable ) +instance Exception AppExn + +writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO () +writeAppExn pt ae = do + writeStructBegin pt "TApplicationException" + + when (ae_message ae /= "") $ do + writeFieldBegin pt ("message", T_STRING , 1) + writeString pt (ae_message ae) + writeFieldEnd pt + + writeFieldBegin pt ("type", T_I32, 2); + writeI32 pt (fromEnum (ae_type ae)) + writeFieldEnd pt + writeFieldStop pt + writeStructEnd pt + +readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn +readAppExn pt = do + readStructBegin pt + rec <- readAppExnFields pt (AppExn {ae_type = undefined, ae_message = undefined}) + readStructEnd pt + return rec + +readAppExnFields pt rec = do + (n, ft, id) <- readFieldBegin pt + if ft == T_STOP + then return rec + else case id of + 1 -> if ft == T_STRING then + do s <- readString pt + readAppExnFields pt rec{ae_message = s} + else do skip pt ft + readAppExnFields pt rec + 2 -> if ft == T_I32 then + do i <- readI32 pt + readAppExnFields pt rec{ae_type = (toEnum i)} + else do skip pt ft + readAppExnFields pt rec + _ -> do skip pt ft + readFieldEnd pt + readAppExnFields pt rec + diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs new file mode 100644 index 00000000..8fa060ea --- /dev/null +++ b/lib/hs/src/Thrift/Protocol.hs @@ -0,0 +1,191 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Protocol + ( Protocol(..) + , skip + , MessageType(..) + , ThriftType(..) + , ProtocolExn(..) + , ProtocolExnType(..) + ) where + +import Control.Monad ( replicateM_, unless ) +import Control.Exception + +import Data.Typeable ( Typeable ) +import Data.Int + +import Thrift.Transport + + +data ThriftType + = T_STOP + | T_VOID + | T_BOOL + | T_BYTE + | T_DOUBLE + | T_I16 + | T_I32 + | T_I64 + | T_STRING + | T_STRUCT + | T_MAP + | T_SET + | T_LIST + deriving ( Eq ) + +instance Enum ThriftType where + fromEnum T_STOP = 0 + fromEnum T_VOID = 1 + fromEnum T_BOOL = 2 + fromEnum T_BYTE = 3 + fromEnum T_DOUBLE = 4 + fromEnum T_I16 = 6 + fromEnum T_I32 = 8 + fromEnum T_I64 = 10 + fromEnum T_STRING = 11 + fromEnum T_STRUCT = 12 + fromEnum T_MAP = 13 + fromEnum T_SET = 14 + fromEnum T_LIST = 15 + + toEnum 0 = T_STOP + toEnum 1 = T_VOID + toEnum 2 = T_BOOL + toEnum 3 = T_BYTE + toEnum 4 = T_DOUBLE + toEnum 6 = T_I16 + toEnum 8 = T_I32 + toEnum 10 = T_I64 + toEnum 11 = T_STRING + toEnum 12 = T_STRUCT + toEnum 13 = T_MAP + toEnum 14 = T_SET + toEnum 15 = T_LIST + +data MessageType + = M_CALL + | M_REPLY + | M_EXCEPTION + deriving ( Eq ) + +instance Enum MessageType where + fromEnum M_CALL = 1 + fromEnum M_REPLY = 2 + fromEnum M_EXCEPTION = 3 + + toEnum 1 = M_CALL + toEnum 2 = M_REPLY + toEnum 3 = M_EXCEPTION + + +class Protocol a where + getTransport :: Transport t => a t -> t + + writeMessageBegin :: Transport t => a t -> (String, MessageType, Int) -> IO () + writeMessageEnd :: Transport t => a t -> IO () + + writeStructBegin :: Transport t => a t -> String -> IO () + writeStructEnd :: Transport t => a t -> IO () + writeFieldBegin :: Transport t => a t -> (String, ThriftType, Int) -> IO () + writeFieldEnd :: Transport t => a t -> IO () + writeFieldStop :: Transport t => a t -> IO () + writeMapBegin :: Transport t => a t -> (ThriftType, ThriftType, Int) -> IO () + writeMapEnd :: Transport t => a t -> IO () + writeListBegin :: Transport t => a t -> (ThriftType, Int) -> IO () + writeListEnd :: Transport t => a t -> IO () + writeSetBegin :: Transport t => a t -> (ThriftType, Int) -> IO () + writeSetEnd :: Transport t => a t -> IO () + + writeBool :: Transport t => a t -> Bool -> IO () + writeByte :: Transport t => a t -> Int -> IO () + writeI16 :: Transport t => a t -> Int -> IO () + writeI32 :: Transport t => a t -> Int -> IO () + writeI64 :: Transport t => a t -> Int64 -> IO () + writeDouble :: Transport t => a t -> Double -> IO () + writeString :: Transport t => a t -> String -> IO () + writeBinary :: Transport t => a t -> String -> IO () + + + readMessageBegin :: Transport t => a t -> IO (String, MessageType, Int) + readMessageEnd :: Transport t => a t -> IO () + + readStructBegin :: Transport t => a t -> IO String + readStructEnd :: Transport t => a t -> IO () + readFieldBegin :: Transport t => a t -> IO (String, ThriftType, Int) + readFieldEnd :: Transport t => a t -> IO () + readMapBegin :: Transport t => a t -> IO (ThriftType, ThriftType, Int) + readMapEnd :: Transport t => a t -> IO () + readListBegin :: Transport t => a t -> IO (ThriftType, Int) + readListEnd :: Transport t => a t -> IO () + readSetBegin :: Transport t => a t -> IO (ThriftType, Int) + readSetEnd :: Transport t => a t -> IO () + + readBool :: Transport t => a t -> IO Bool + readByte :: Transport t => a t -> IO Int + readI16 :: Transport t => a t -> IO Int + readI32 :: Transport t => a t -> IO Int + readI64 :: Transport t => a t -> IO Int64 + readDouble :: Transport t => a t -> IO Double + readString :: Transport t => a t -> IO String + readBinary :: Transport t => a t -> IO String + + +skip :: (Protocol p, Transport t) => p t -> ThriftType -> IO () +skip p T_STOP = return () +skip p T_VOID = return () +skip p T_BOOL = readBool p >> return () +skip p T_BYTE = readByte p >> return () +skip p T_I16 = readI16 p >> return () +skip p T_I32 = readI32 p >> return () +skip p T_I64 = readI64 p >> return () +skip p T_DOUBLE = readDouble p >> return () +skip p T_STRING = readString p >> return () +skip p T_STRUCT = do readStructBegin p + skipFields p + readStructEnd p +skip p T_MAP = do (k, v, s) <- readMapBegin p + replicateM_ s (skip p k >> skip p v) + readMapEnd p +skip p T_SET = do (t, n) <- readSetBegin p + replicateM_ n (skip p t) + readSetEnd p +skip p T_LIST = do (t, n) <- readListBegin p + replicateM_ n (skip p t) + readListEnd p + + +skipFields :: (Protocol p, Transport t) => p t -> IO () +skipFields p = do + (_, t, _) <- readFieldBegin p + unless (t == T_STOP) (skip p t >> readFieldEnd p >> skipFields p) + + +data ProtocolExnType + = PE_UNKNOWN + | PE_INVALID_DATA + | PE_NEGATIVE_SIZE + | PE_SIZE_LIMIT + | PE_BAD_VERSION + deriving ( Eq, Show, Typeable ) + +data ProtocolExn = ProtocolExn ProtocolExnType String + deriving ( Show, Typeable ) +instance Exception ProtocolExn diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs new file mode 100644 index 00000000..3f798cee --- /dev/null +++ b/lib/hs/src/Thrift/Protocol/Binary.hs @@ -0,0 +1,147 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Protocol.Binary + ( module Thrift.Protocol + , BinaryProtocol(..) + ) where + +import Control.Exception ( throw ) + +import Data.Bits +import Data.Int +import Data.List ( foldl' ) + +import GHC.Exts +import GHC.Word + +import Thrift.Protocol +import Thrift.Transport + + +version_mask = 0xffff0000 +version_1 = 0x80010000 + +data BinaryProtocol a = Transport a => BinaryProtocol a + + +instance Protocol BinaryProtocol where + getTransport (BinaryProtocol t) = t + + writeMessageBegin p (n, t, s) = do + writeI32 p (version_1 .|. (fromEnum t)) + writeString p n + writeI32 p s + writeMessageEnd _ = return () + + writeStructBegin _ _ = return () + writeStructEnd _ = return () + writeFieldBegin p (_, t, i) = writeType p t >> writeI16 p i + writeFieldEnd _ = return () + writeFieldStop p = writeType p T_STOP + writeMapBegin p (k, v, n) = writeType p k >> writeType p v >> writeI32 p n + writeMapEnd p = return () + writeListBegin p (t, n) = writeType p t >> writeI32 p n + writeListEnd _ = return () + writeSetBegin p (t, n) = writeType p t >> writeI32 p n + writeSetEnd _ = return () + + writeBool p b = tWrite (getTransport p) [toEnum $ if b then 1 else 0] + writeByte p b = tWrite (getTransport p) (getBytes b 1) + writeI16 p b = tWrite (getTransport p) (getBytes b 2) + writeI32 p b = tWrite (getTransport p) (getBytes b 4) + writeI64 p b = tWrite (getTransport p) (getBytes b 8) + writeDouble p d = writeI64 p (fromIntegral $ floatBits d) + writeString p s = writeI32 p (length s) >> tWrite (getTransport p) s + writeBinary = writeString + + readMessageBegin p = do + ver <- readI32 p + if (ver .&. version_mask /= version_1) + then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier" + else do + s <- readString p + sz <- readI32 p + return (s, toEnum $ ver .&. 0xFF, sz) + readMessageEnd _ = return () + readStructBegin _ = return "" + readStructEnd _ = return () + readFieldBegin p = do + t <- readType p + n <- if t /= T_STOP then readI16 p else return 0 + return ("", t, n) + readFieldEnd _ = return () + readMapBegin p = do + kt <- readType p + vt <- readType p + n <- readI32 p + return (kt, vt, n) + readMapEnd _ = return () + readListBegin p = do + t <- readType p + n <- readI32 p + return (t, n) + readListEnd _ = return () + readSetBegin p = do + t <- readType p + n <- readI32 p + return (t, n) + readSetEnd _ = return () + + readBool p = (== 1) `fmap` readByte p + readByte p = do + bs <- tReadAll (getTransport p) 1 + return $ fromIntegral (composeBytes bs :: Int8) + readI16 p = do + bs <- tReadAll (getTransport p) 2 + return $ fromIntegral (composeBytes bs :: Int16) + readI32 p = composeBytes `fmap` tReadAll (getTransport p) 4 + readI64 p = composeBytes `fmap` tReadAll (getTransport p) 8 + readDouble p = do + bs <- readI64 p + return $ floatOfBits $ fromIntegral bs + readString p = readI32 p >>= tReadAll (getTransport p) + readBinary = readString + + +-- | Write a type as a byte +writeType :: (Protocol p, Transport t) => p t -> ThriftType -> IO () +writeType p t = writeByte p (fromEnum t) + +-- | Read a byte as though it were a ThriftType +readType :: (Protocol p, Transport t) => p t -> IO ThriftType +readType p = toEnum `fmap` readByte p + +composeBytes :: (Bits b, Enum t) => [t] -> b +composeBytes = (foldl' fn 0) . (map $ fromIntegral . fromEnum) + where fn acc b = (acc `shiftL` 8) .|. b + +getByte :: Bits a => a -> Int -> a +getByte i n = 255 .&. (i `shiftR` (8 * n)) + +getBytes :: (Bits a, Integral a) => a -> Int -> String +getBytes i 0 = [] +getBytes i n = (toEnum $ fromIntegral $ getByte i (n-1)):(getBytes i (n-1)) + +floatBits :: Double -> Word64 +floatBits (D# d#) = W64# (unsafeCoerce# d#) + +floatOfBits :: Word64 -> Double +floatOfBits (W64# b#) = D# (unsafeCoerce# b#) + diff --git a/lib/hs/src/Thrift/Server.hs b/lib/hs/src/Thrift/Server.hs new file mode 100644 index 00000000..770965f1 --- /dev/null +++ b/lib/hs/src/Thrift/Server.hs @@ -0,0 +1,65 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Server + ( runBasicServer + , runThreadedServer + ) where + +import Control.Concurrent ( forkIO ) +import Control.Exception +import Control.Monad ( forever, when ) + +import Network + +import System.IO + +import Thrift +import Thrift.Transport.Handle +import Thrift.Protocol.Binary + + +-- | A threaded sever that is capable of using any Transport or Protocol +-- instances. +runThreadedServer :: (Transport t, Protocol i, Protocol o) + => (Socket -> IO (i t, o t)) + -> h + -> (h -> (i t, o t) -> IO Bool) + -> PortID + -> IO a +runThreadedServer accepter hand proc port = do + socket <- listenOn port + acceptLoop (accepter socket) (proc hand) + +-- | A basic threaded binary protocol socket server. +runBasicServer :: h + -> (h -> (BinaryProtocol Handle, BinaryProtocol Handle) -> IO Bool) + -> PortNumber + -> IO a +runBasicServer hand proc port = runThreadedServer binaryAccept hand proc (PortNumber port) + where binaryAccept s = do + (h, _, _) <- accept s + return (BinaryProtocol h, BinaryProtocol h) + +acceptLoop :: IO t -> (t -> IO Bool) -> IO a +acceptLoop accepter proc = forever $ + do ps <- accepter + forkIO $ handle (\(e :: SomeException) -> return ()) + (loop $ proc ps) + where loop m = do { continue <- m; when continue (loop m) } diff --git a/lib/hs/src/Thrift/Transport.hs b/lib/hs/src/Thrift/Transport.hs new file mode 100644 index 00000000..29f50d07 --- /dev/null +++ b/lib/hs/src/Thrift/Transport.hs @@ -0,0 +1,60 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Transport + ( Transport(..) + , TransportExn(..) + , TransportExnType(..) + ) where + +import Control.Monad ( when ) +import Control.Exception ( Exception, throw ) + +import Data.Typeable ( Typeable ) + + +class Transport a where + tIsOpen :: a -> IO Bool + tClose :: a -> IO () + tRead :: a -> Int -> IO String + tWrite :: a -> String ->IO () + tFlush :: a -> IO () + tReadAll :: a -> Int -> IO String + + tReadAll a 0 = return [] + tReadAll a len = do + result <- tRead a len + let rlen = length result + when (rlen == 0) (throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN) + if len <= rlen + then return result + else (result ++) `fmap` (tReadAll a (len - rlen)) + +data TransportExn = TransportExn String TransportExnType + deriving ( Show, Typeable ) +instance Exception TransportExn + +data TransportExnType + = TE_UNKNOWN + | TE_NOT_OPEN + | TE_ALREADY_OPEN + | TE_TIMED_OUT + | TE_END_OF_FILE + deriving ( Eq, Show, Typeable ) + diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs new file mode 100644 index 00000000..e49456b5 --- /dev/null +++ b/lib/hs/src/Thrift/Transport/Handle.hs @@ -0,0 +1,58 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Transport.Handle + ( module Thrift.Transport + , HandleSource(..) + ) where + +import Control.Exception ( throw ) +import Control.Monad ( replicateM ) + +import Network + +import System.IO +import System.IO.Error ( isEOFError ) + +import Thrift.Transport + + +instance Transport Handle where + tIsOpen = hIsOpen + tClose h = hClose h + tRead h n = replicateM n (hGetChar h) `catch` handleEOF + tWrite h s = mapM_ (hPutChar h) s + tFlush = hFlush + + +-- | Type class for all types that can open a Handle. This class is used to +-- replace tOpen in the Transport type class. +class HandleSource s where + hOpen :: s -> IO Handle + +instance HandleSource FilePath where + hOpen s = openFile s ReadWriteMode + +instance HandleSource (HostName, PortID) where + hOpen = uncurry connectTo + + +handleEOF e = if isEOFError e + then return [] + else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN diff --git a/lib/java/Makefile.am b/lib/java/Makefile.am new file mode 100644 index 00000000..0a40496d --- /dev/null +++ b/lib/java/Makefile.am @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +EXTRA_DIST = build.xml ivy.xml src test + +all-local: + $(ANT) + +install-exec-hook: + $(ANT) install -Dinstall.path=$(DESTDIR)$(JAVA_PREFIX) \ + -Dinstall.javadoc.path=$(DESTDIR)$(docdir)/java + +# Make sure this doesn't fail if ant is not configured. +clean-local: + ANT=$(ANT) ; if test -z "$$ANT" ; then ANT=: ; fi ; \ + $$ANT clean + +check-local: all + $(ANT) test + diff --git a/lib/java/README b/lib/java/README new file mode 100644 index 00000000..6b8d351b --- /dev/null +++ b/lib/java/README @@ -0,0 +1,43 @@ +Thrift Java Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Java +====================== + +The Thrift Java source is not build using the GNU tools, but rather uses +the Apache Ant build system, which tends to be predominant amongst Java +developers. + +To compile the Java Thrift libraries, simply do the following: + +ant + +Yep, that's easy. Look for libthrift.jar in the base directory. + +To include Thrift in your applications simply add libthrift.jar to your +classpath, or install if in your default system classpath of choice. + +Dependencies +============ + +Apache Ant +http://ant.apache.org/ diff --git a/lib/java/build.xml b/lib/java/build.xml new file mode 100644 index 00000000..0a7c8944 --- /dev/null +++ b/lib/java/build.xml @@ -0,0 +1,192 @@ + + + + + Thrift Build File + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + You need Apache Ivy 2.0 or later from http://ant.apache.org/ + It could not be loaded from ${ivy_repo_url} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lib/java/ivy.xml b/lib/java/ivy.xml new file mode 100644 index 00000000..0b1be5d8 --- /dev/null +++ b/lib/java/ivy.xml @@ -0,0 +1,7 @@ + + + + + + + diff --git a/lib/java/src/org/apache/thrift/IntRangeSet.java b/lib/java/src/org/apache/thrift/IntRangeSet.java new file mode 100644 index 00000000..5430134d --- /dev/null +++ b/lib/java/src/org/apache/thrift/IntRangeSet.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +/** + * IntRangeSet is a specialized Set implementation designed + * specifically to make the generated validate() method calls faster. It groups + * the set values into ranges, and in the contains() call, it does + * num ranges * 2 comparisons max. For the common case, which is a single, + * contiguous range, this approach is about 60% faster than using a HashSet. If + * you had a very ragged value set, like all the odd numbers, for instance, + * then you would end up with pretty poor running time. + */ +public class IntRangeSet implements Set { + /** + * This array keeps the bounds of each extent in alternating cells, always + * increasing. Example: [0,5,10,15], which corresponds to 0-5, 10-15. + */ + private int[] extents; + + /** + * We'll keep a duplicate, real HashSet around internally to satisfy some of + * the other set operations. + */ + private Set realSet = new HashSet(); + + public IntRangeSet(int... values) { + Arrays.sort(values); + + List extent_list = new ArrayList(); + + int ext_start = values[0]; + int ext_end_so_far = values[0]; + for (int i = 1; i < values.length; i++) { + realSet.add(values[i]); + + if (values[i] == ext_end_so_far + 1) { + // advance the end so far + ext_end_so_far = values[i]; + } else { + // create an extent for everything we saw so far, move on to the next one + extent_list.add(ext_start); + extent_list.add(ext_end_so_far); + ext_start = values[i]; + ext_end_so_far = values[i]; + } + } + extent_list.add(ext_start); + extent_list.add(ext_end_so_far); + + extents = new int[extent_list.size()]; + for (int i = 0; i < extent_list.size(); i++) { + extents[i] = extent_list.get(i); + } + } + + public boolean add(Integer i) { + throw new UnsupportedOperationException(); + } + + public void clear() { + throw new UnsupportedOperationException(); + } + + public boolean addAll(Collection arg0) { + throw new UnsupportedOperationException(); + } + + /** + * While this method is here for Set interface compatibility, you should avoid + * using it. It incurs boxing overhead! Use the int method directly, instead. + */ + public boolean contains(Object arg0) { + return contains(((Integer)arg0).intValue()); + } + + /** + * This is much faster, since it doesn't stop at Integer on the way through. + * @param val the value you want to check set membership for + * @return true if val was found, false otherwise + */ + public boolean contains(int val) { + for (int i = 0; i < extents.length / 2; i++) { + if (val < extents[i*2]) { + return false; + } else if (val <= extents[i*2+1]) { + return true; + } + } + + return false; + } + + public boolean containsAll(Collection arg0) { + for (Object o : arg0) { + if (!contains(o)) { + return false; + } + } + return true; + } + + public boolean isEmpty() { + return realSet.isEmpty(); + } + + public Iterator iterator() { + return realSet.iterator(); + } + + public boolean remove(Object arg0) { + throw new UnsupportedOperationException(); + } + + public boolean removeAll(Collection arg0) { + throw new UnsupportedOperationException(); + } + + public boolean retainAll(Collection arg0) { + throw new UnsupportedOperationException(); + } + + public int size() { + return realSet.size(); + } + + public Object[] toArray() { + return realSet.toArray(); + } + + public T[] toArray(T[] arg0) { + return realSet.toArray(arg0); + } + + @Override + public String toString() { + String buf = ""; + for (int i = 0; i < extents.length / 2; i++) { + if (i != 0) { + buf += ", "; + } + buf += "[" + extents[i*2] + "," + extents[i*2+1] + "]"; + } + return buf; + } +} diff --git a/lib/java/src/org/apache/thrift/TApplicationException.java b/lib/java/src/org/apache/thrift/TApplicationException.java new file mode 100644 index 00000000..a85e3705 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TApplicationException.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolUtil; +import org.apache.thrift.protocol.TStruct; +import org.apache.thrift.protocol.TType; + +/** + * Application level exception + * + */ +public class TApplicationException extends TException { + + private static final TStruct TAPPLICATION_EXCEPTION_STRUCT = new TStruct("TApplicationException"); + private static final TField MESSAGE_FIELD = new TField("message", TType.STRING, (short)1); + private static final TField TYPE_FIELD = new TField("type", TType.I32, (short)2); + + private static final long serialVersionUID = 1L; + + public static final int UNKNOWN = 0; + public static final int UNKNOWN_METHOD = 1; + public static final int INVALID_MESSAGE_TYPE = 2; + public static final int WRONG_METHOD_NAME = 3; + public static final int BAD_SEQUENCE_ID = 4; + public static final int MISSING_RESULT = 5; + + protected int type_ = UNKNOWN; + + public TApplicationException() { + super(); + } + + public TApplicationException(int type) { + super(); + type_ = type; + } + + public TApplicationException(int type, String message) { + super(message); + type_ = type; + } + + public TApplicationException(String message) { + super(message); + } + + public int getType() { + return type_; + } + + public static TApplicationException read(TProtocol iprot) throws TException { + TField field; + iprot.readStructBegin(); + + String message = null; + int type = UNKNOWN; + + while (true) { + field = iprot.readFieldBegin(); + if (field.type == TType.STOP) { + break; + } + switch (field.id) { + case 1: + if (field.type == TType.STRING) { + message = iprot.readString(); + } else { + TProtocolUtil.skip(iprot, field.type); + } + break; + case 2: + if (field.type == TType.I32) { + type = iprot.readI32(); + } else { + TProtocolUtil.skip(iprot, field.type); + } + break; + default: + TProtocolUtil.skip(iprot, field.type); + break; + } + iprot.readFieldEnd(); + } + iprot.readStructEnd(); + + return new TApplicationException(type, message); + } + + public void write(TProtocol oprot) throws TException { + oprot.writeStructBegin(TAPPLICATION_EXCEPTION_STRUCT); + if (getMessage() != null) { + oprot.writeFieldBegin(MESSAGE_FIELD); + oprot.writeString(getMessage()); + oprot.writeFieldEnd(); + } + oprot.writeFieldBegin(TYPE_FIELD); + oprot.writeI32(type_); + oprot.writeFieldEnd(); + oprot.writeFieldStop(); + oprot.writeStructEnd(); + } +} diff --git a/lib/java/src/org/apache/thrift/TBase.java b/lib/java/src/org/apache/thrift/TBase.java new file mode 100644 index 00000000..7c8978a2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TBase.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TProtocol; + +/** + * Generic base interface for generated Thrift objects. + * + */ +public interface TBase extends Cloneable { + + /** + * Reads the TObject from the given input protocol. + * + * @param iprot Input protocol + */ + public void read(TProtocol iprot) throws TException; + + /** + * Writes the objects out to the protocol + * + * @param oprot Output protocol + */ + public void write(TProtocol oprot) throws TException; + + /** + * Check if a field is currently set or unset. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public boolean isSet(int fieldId); + + /** + * Get a field's value by id. Primitive types will be wrapped in the + * appropriate "boxed" types. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public Object getFieldValue(int fieldId); + + /** + * Set a field's value by id. Primitive types must be "boxed" in the + * appropriate object wrapper type. + * + * @param fieldId The field's id tag as found in the IDL. + */ + public void setFieldValue(int fieldId, Object value); +} diff --git a/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java new file mode 100644 index 00000000..e35fbcb7 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TByteArrayOutputStream.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayOutputStream; + +/** + * Class that allows access to the underlying buf without doing deep + * copies on it. + * + */ +public class TByteArrayOutputStream extends ByteArrayOutputStream { + public TByteArrayOutputStream(int size) { + super(size); + } + + public TByteArrayOutputStream() { + super(); + } + + + public byte[] get() { + return buf; + } + + public int len() { + return count; + } +} diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java new file mode 100644 index 00000000..d6dd5d4b --- /dev/null +++ b/lib/java/src/org/apache/thrift/TDeserializer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayInputStream; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; + +/** + * Generic utility for easily deserializing objects from a byte array or Java + * String. + * + */ +public class TDeserializer { + private final TProtocolFactory protocolFactory_; + + /** + * Create a new TDeserializer that uses the TBinaryProtocol by default. + */ + public TDeserializer() { + this(new TBinaryProtocol.Factory()); + } + + /** + * Create a new TDeserializer. It will use the TProtocol specified by the + * factory that is passed in. + * + * @param protocolFactory Factory to create a protocol + */ + public TDeserializer(TProtocolFactory protocolFactory) { + protocolFactory_ = protocolFactory; + } + + /** + * Deserialize the Thrift object from a byte array. + * + * @param base The object to read into + * @param bytes The array to read from + */ + public void deserialize(TBase base, byte[] bytes) throws TException { + base.read( + protocolFactory_.getProtocol( + new TIOStreamTransport( + new ByteArrayInputStream(bytes)))); + } + + /** + * Deserialize the Thrift object from a Java string, using a specified + * character set for decoding. + * + * @param base The object to read into + * @param data The string to read from + * @param charset Valid JVM charset + */ + public void deserialize(TBase base, String data, String charset) throws TException { + try { + deserialize(base, data.getBytes(charset)); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT ENCODING: " + charset); + } + } + + /** + * Deserialize the Thrift object from a Java string, using the default JVM + * charset encoding. + * + * @param base The object to read into + * @param data The string to read from + */ + public void toString(TBase base, String data) throws TException { + deserialize(base, data.getBytes()); + } +} + diff --git a/lib/java/src/org/apache/thrift/TException.java b/lib/java/src/org/apache/thrift/TException.java new file mode 100644 index 00000000..f84f4812 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +/** + * Generic exception class for Thrift. + * + */ +public class TException extends Exception { + + private static final long serialVersionUID = 1L; + + public TException() { + super(); + } + + public TException(String message) { + super(message); + } + + public TException(Throwable cause) { + super(cause); + } + + public TException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/lib/java/src/org/apache/thrift/TFieldRequirementType.java b/lib/java/src/org/apache/thrift/TFieldRequirementType.java new file mode 100644 index 00000000..74bac4ef --- /dev/null +++ b/lib/java/src/org/apache/thrift/TFieldRequirementType.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +/** + * Requirement type constants. + * + */ +public final class TFieldRequirementType { + public static final byte REQUIRED = 1; + public static final byte OPTIONAL = 2; + public static final byte DEFAULT = 3; +} diff --git a/lib/java/src/org/apache/thrift/TProcessor.java b/lib/java/src/org/apache/thrift/TProcessor.java new file mode 100644 index 00000000..d79522c3 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TProcessor.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.protocol.TProtocol; + +/** + * A processor is a generic object which operates upon an input stream and + * writes to some output stream. + * + */ +public interface TProcessor { + public boolean process(TProtocol in, TProtocol out) + throws TException; +} diff --git a/lib/java/src/org/apache/thrift/TProcessorFactory.java b/lib/java/src/org/apache/thrift/TProcessorFactory.java new file mode 100644 index 00000000..bcd8a38f --- /dev/null +++ b/lib/java/src/org/apache/thrift/TProcessorFactory.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import org.apache.thrift.transport.TTransport; + +/** + * The default processor factory just returns a singleton + * instance. + */ +public class TProcessorFactory { + + private final TProcessor processor_; + + public TProcessorFactory(TProcessor processor) { + processor_ = processor; + } + + public TProcessor getProcessor(TTransport trans) { + return processor_; + } +} diff --git a/lib/java/src/org/apache/thrift/TSerializer.java b/lib/java/src/org/apache/thrift/TSerializer.java new file mode 100644 index 00000000..4e1ce612 --- /dev/null +++ b/lib/java/src/org/apache/thrift/TSerializer.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +import java.io.ByteArrayOutputStream; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; + +/** + * Generic utility for easily serializing objects into a byte array or Java + * String. + * + */ +public class TSerializer { + + /** + * This is the byte array that data is actually serialized into + */ + private final ByteArrayOutputStream baos_ = new ByteArrayOutputStream(); + + /** + * This transport wraps that byte array + */ + private final TIOStreamTransport transport_ = new TIOStreamTransport(baos_); + + /** + * Internal protocol used for serializing objects. + */ + private TProtocol protocol_; + + /** + * Create a new TSerializer that uses the TBinaryProtocol by default. + */ + public TSerializer() { + this(new TBinaryProtocol.Factory()); + } + + /** + * Create a new TSerializer. It will use the TProtocol specified by the + * factory that is passed in. + * + * @param protocolFactory Factory to create a protocol + */ + public TSerializer(TProtocolFactory protocolFactory) { + protocol_ = protocolFactory.getProtocol(transport_); + } + + /** + * Serialize the Thrift object into a byte array. The process is simple, + * just clear the byte array output, write the object into it, and grab the + * raw bytes. + * + * @param base The object to serialize + * @return Serialized object in byte[] format + */ + public byte[] serialize(TBase base) throws TException { + baos_.reset(); + base.write(protocol_); + return baos_.toByteArray(); + } + + /** + * Serialize the Thrift object into a Java string, using a specified + * character set for encoding. + * + * @param base The object to serialize + * @param charset Valid JVM charset + * @return Serialized object as a String + */ + public String toString(TBase base, String charset) throws TException { + try { + return new String(serialize(base), charset); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT ENCODING: " + charset); + } + } + + /** + * Serialize the Thrift object into a Java string, using the default JVM + * charset encoding. + * + * @param base The object to serialize + * @return Serialized object as a String + */ + public String toString(TBase base) throws TException { + return new String(serialize(base)); + } +} + diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java new file mode 100644 index 00000000..3e90a8b9 --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/FieldMetaData.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import java.util.HashMap; +import java.util.Map; +import org.apache.thrift.TBase; + +/** + * This class is used to store meta data about thrift fields. Every field in a + * a struct should have a corresponding instance of this class describing it. + * + */ +public class FieldMetaData implements java.io.Serializable { + public final String fieldName; + public final byte requirementType; + public final FieldValueMetaData valueMetaData; + private static Map, Map> structMap; + + static { + structMap = new HashMap, Map>(); + } + + public FieldMetaData(String name, byte req, FieldValueMetaData vMetaData){ + this.fieldName = name; + this.requirementType = req; + this.valueMetaData = vMetaData; + } + + public static void addStructMetaDataMap(Class sClass, Map map){ + structMap.put(sClass, map); + } + + /** + * Returns a map with metadata (i.e. instances of FieldMetaData) that + * describe the fields of the given class. + * + * @param sClass The TBase class for which the metadata map is requested + */ + public static Map getStructMetaDataMap(Class sClass){ + if (!structMap.containsKey(sClass)){ // Load class if it hasn't been loaded + try{ + sClass.newInstance(); + } catch (InstantiationException e){ + throw new RuntimeException("InstantiationException for TBase class: " + sClass.getName() + ", message: " + e.getMessage()); + } catch (IllegalAccessException e){ + throw new RuntimeException("IllegalAccessException for TBase class: " + sClass.getName() + ", message: " + e.getMessage()); + } + } + return structMap.get(sClass); + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java b/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java new file mode 100644 index 00000000..f72da0cd --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/FieldValueMetaData.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import org.apache.thrift.protocol.TType; + +/** + * FieldValueMetaData and collection of subclasses to store metadata about + * the value(s) of a field + */ +public class FieldValueMetaData implements java.io.Serializable { + public final byte type; + + public FieldValueMetaData(byte type){ + this.type = type; + } + + public boolean isStruct() { + return type == TType.STRUCT; + } + + public boolean isContainer() { + return type == TType.LIST || type == TType.MAP || type == TType.SET; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java b/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java new file mode 100644 index 00000000..8e7073bf --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/ListMetaData.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class ListMetaData extends FieldValueMetaData { + public final FieldValueMetaData elemMetaData; + + public ListMetaData(byte type, FieldValueMetaData eMetaData){ + super(type); + this.elemMetaData = eMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java b/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java new file mode 100644 index 00000000..e7c408c7 --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/MapMetaData.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class MapMetaData extends FieldValueMetaData { + public final FieldValueMetaData keyMetaData; + public final FieldValueMetaData valueMetaData; + + public MapMetaData(byte type, FieldValueMetaData kMetaData, FieldValueMetaData vMetaData){ + super(type); + this.keyMetaData = kMetaData; + this.valueMetaData = vMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java b/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java new file mode 100644 index 00000000..cf4b96aa --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/SetMetaData.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +public class SetMetaData extends FieldValueMetaData { + public final FieldValueMetaData elemMetaData; + + public SetMetaData(byte type, FieldValueMetaData eMetaData){ + super(type); + this.elemMetaData = eMetaData; + } +} diff --git a/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java b/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java new file mode 100644 index 00000000..b37d21da --- /dev/null +++ b/lib/java/src/org/apache/thrift/meta_data/StructMetaData.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.meta_data; + +import org.apache.thrift.TBase; + +public class StructMetaData extends FieldValueMetaData { + public final Class structClass; + + public StructMetaData(byte type, Class sClass){ + super(type); + this.structClass = sClass; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java b/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java new file mode 100644 index 00000000..37a9fd9f --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TBase64Utils.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Class for encoding and decoding Base64 data. + * + * This class is kept at package level because the interface does no input + * validation and is therefore too low-level for generalized reuse. + * + * Note also that the encoding does not pad with equal signs , as discussed in + * section 2.2 of the RFC (http://www.faqs.org/rfcs/rfc3548.html). Furthermore, + * bad data encountered when decoding is neither rejected or ignored but simply + * results in bad decoded data -- this is not in compliance with the RFC but is + * done in the interest of performance. + * + */ +class TBase64Utils { + + private static final String ENCODE_TABLE = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + /** + * Encode len bytes of data in src at offset srcOff, storing the result into + * dst at offset dstOff. len must be 1, 2, or 3. dst must have at least len+1 + * bytes of space at dstOff. src and dst should not be the same object. This + * method does no validation of the input values in the interest of + * performance. + * + * @param src the source of bytes to encode + * @param srcOff the offset into the source to read the unencoded bytes + * @param len the number of bytes to encode (must be 1, 2, or 3). + * @param dst the destination for the encoding + * @param dstOff the offset into the destination to place the encoded bytes + */ + static final void encode(byte[] src, int srcOff, int len, byte[] dst, + int dstOff) { + dst[dstOff] = (byte)ENCODE_TABLE.charAt((src[srcOff] >> 2) & 0x3F); + if (len == 3) { + dst[dstOff + 1] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff] << 4) + (src[srcOff+1] >> 4)) & 0x3F); + dst[dstOff + 2] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff+1] << 2) + (src[srcOff+2] >> 6)) & 0x3F); + dst[dstOff + 3] = + (byte)ENCODE_TABLE.charAt(src[srcOff+2] & 0x3F); + } + else if (len == 2) { + dst[dstOff+1] = + (byte)ENCODE_TABLE.charAt( + ((src[srcOff] << 4) + (src[srcOff+1] >> 4)) & 0x3F); + dst[dstOff + 2] = + (byte)ENCODE_TABLE.charAt((src[srcOff+1] << 2) & 0x3F); + + } + else { // len == 1) { + dst[dstOff + 1] = + (byte)ENCODE_TABLE.charAt((src[srcOff] << 4) & 0x3F); + } + } + + private static final byte[] DECODE_TABLE = { + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,62,-1,-1,-1,63, + 52,53,54,55,56,57,58,59,60,61,-1,-1,-1,-1,-1,-1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14, + 15,16,17,18,19,20,21,22,23,24,25,-1,-1,-1,-1,-1, + -1,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40, + 41,42,43,44,45,46,47,48,49,50,51,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, + }; + + /** + * Decode len bytes of data in src at offset srcOff, storing the result into + * dst at offset dstOff. len must be 2, 3, or 4. dst must have at least len-1 + * bytes of space at dstOff. src and dst may be the same object as long as + * dstoff <= srcOff. This method does no validation of the input values in + * the interest of performance. + * + * @param src the source of bytes to decode + * @param srcOff the offset into the source to read the encoded bytes + * @param len the number of bytes to decode (must be 2, 3, or 4) + * @param dst the destination for the decoding + * @param dstOff the offset into the destination to place the decoded bytes + */ + static final void decode(byte[] src, int srcOff, int len, byte[] dst, + int dstOff) { + dst[dstOff] = (byte) + ((DECODE_TABLE[src[srcOff] & 0x0FF] << 2) | + (DECODE_TABLE[src[srcOff+1] & 0x0FF] >> 4)); + if (len > 2) { + dst[dstOff+1] = (byte) + (((DECODE_TABLE[src[srcOff+1] & 0x0FF] << 4) & 0xF0) | + (DECODE_TABLE[src[srcOff+2] & 0x0FF] >> 2)); + if (len > 3) { + dst[dstOff+2] = (byte) + (((DECODE_TABLE[src[srcOff+2] & 0x0FF] << 6) & 0xC0) | + DECODE_TABLE[src[srcOff+3] & 0x0FF]); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java new file mode 100644 index 00000000..e9bd8b79 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * Binary protocol implementation for thrift. + * + */ +public class TBinaryProtocol extends TProtocol { + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + + protected static final int VERSION_MASK = 0xffff0000; + protected static final int VERSION_1 = 0x80010000; + + protected boolean strictRead_ = false; + protected boolean strictWrite_ = true; + + protected int readLength_; + protected boolean checkReadLength_ = false; + + /** + * Factory + */ + public static class Factory implements TProtocolFactory { + protected boolean strictRead_ = false; + protected boolean strictWrite_ = true; + + public Factory() { + this(false, true); + } + + public Factory(boolean strictRead, boolean strictWrite) { + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public TProtocol getProtocol(TTransport trans) { + return new TBinaryProtocol(trans, strictRead_, strictWrite_); + } + } + + /** + * Constructor + */ + public TBinaryProtocol(TTransport trans) { + this(trans, false, true); + } + + public TBinaryProtocol(TTransport trans, boolean strictRead, boolean strictWrite) { + super(trans); + strictRead_ = strictRead; + strictWrite_ = strictWrite; + } + + public void writeMessageBegin(TMessage message) throws TException { + if (strictWrite_) { + int version = VERSION_1 | message.type; + writeI32(version); + writeString(message.name); + writeI32(message.seqid); + } else { + writeString(message.name); + writeByte(message.type); + writeI32(message.seqid); + } + } + + public void writeMessageEnd() {} + + public void writeStructBegin(TStruct struct) {} + + public void writeStructEnd() {} + + public void writeFieldBegin(TField field) throws TException { + writeByte(field.type); + writeI16(field.id); + } + + public void writeFieldEnd() {} + + public void writeFieldStop() throws TException { + writeByte(TType.STOP); + } + + public void writeMapBegin(TMap map) throws TException { + writeByte(map.keyType); + writeByte(map.valueType); + writeI32(map.size); + } + + public void writeMapEnd() {} + + public void writeListBegin(TList list) throws TException { + writeByte(list.elemType); + writeI32(list.size); + } + + public void writeListEnd() {} + + public void writeSetBegin(TSet set) throws TException { + writeByte(set.elemType); + writeI32(set.size); + } + + public void writeSetEnd() {} + + public void writeBool(boolean b) throws TException { + writeByte(b ? (byte)1 : (byte)0); + } + + private byte [] bout = new byte[1]; + public void writeByte(byte b) throws TException { + bout[0] = b; + trans_.write(bout, 0, 1); + } + + private byte[] i16out = new byte[2]; + public void writeI16(short i16) throws TException { + i16out[0] = (byte)(0xff & (i16 >> 8)); + i16out[1] = (byte)(0xff & (i16)); + trans_.write(i16out, 0, 2); + } + + private byte[] i32out = new byte[4]; + public void writeI32(int i32) throws TException { + i32out[0] = (byte)(0xff & (i32 >> 24)); + i32out[1] = (byte)(0xff & (i32 >> 16)); + i32out[2] = (byte)(0xff & (i32 >> 8)); + i32out[3] = (byte)(0xff & (i32)); + trans_.write(i32out, 0, 4); + } + + private byte[] i64out = new byte[8]; + public void writeI64(long i64) throws TException { + i64out[0] = (byte)(0xff & (i64 >> 56)); + i64out[1] = (byte)(0xff & (i64 >> 48)); + i64out[2] = (byte)(0xff & (i64 >> 40)); + i64out[3] = (byte)(0xff & (i64 >> 32)); + i64out[4] = (byte)(0xff & (i64 >> 24)); + i64out[5] = (byte)(0xff & (i64 >> 16)); + i64out[6] = (byte)(0xff & (i64 >> 8)); + i64out[7] = (byte)(0xff & (i64)); + trans_.write(i64out, 0, 8); + } + + public void writeDouble(double dub) throws TException { + writeI64(Double.doubleToLongBits(dub)); + } + + public void writeString(String str) throws TException { + try { + byte[] dat = str.getBytes("UTF-8"); + writeI32(dat.length); + trans_.write(dat, 0, dat.length); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public void writeBinary(byte[] bin) throws TException { + writeI32(bin.length); + trans_.write(bin, 0, bin.length); + } + + /** + * Reading methods. + */ + + public TMessage readMessageBegin() throws TException { + int size = readI32(); + if (size < 0) { + int version = size & VERSION_MASK; + if (version != VERSION_1) { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Bad version in readMessageBegin"); + } + return new TMessage(readString(), (byte)(size & 0x000000ff), readI32()); + } else { + if (strictRead_) { + throw new TProtocolException(TProtocolException.BAD_VERSION, "Missing version in readMessageBegin, old client?"); + } + return new TMessage(readStringBody(size), readByte(), readI32()); + } + } + + public void readMessageEnd() {} + + public TStruct readStructBegin() { + return ANONYMOUS_STRUCT; + } + + public void readStructEnd() {} + + public TField readFieldBegin() throws TException { + byte type = readByte(); + short id = type == TType.STOP ? 0 : readI16(); + return new TField("", type, id); + } + + public void readFieldEnd() {} + + public TMap readMapBegin() throws TException { + return new TMap(readByte(), readByte(), readI32()); + } + + public void readMapEnd() {} + + public TList readListBegin() throws TException { + return new TList(readByte(), readI32()); + } + + public void readListEnd() {} + + public TSet readSetBegin() throws TException { + return new TSet(readByte(), readI32()); + } + + public void readSetEnd() {} + + public boolean readBool() throws TException { + return (readByte() == 1); + } + + private byte[] bin = new byte[1]; + public byte readByte() throws TException { + readAll(bin, 0, 1); + return bin[0]; + } + + private byte[] i16rd = new byte[2]; + public short readI16() throws TException { + readAll(i16rd, 0, 2); + return + (short) + (((i16rd[0] & 0xff) << 8) | + ((i16rd[1] & 0xff))); + } + + private byte[] i32rd = new byte[4]; + public int readI32() throws TException { + readAll(i32rd, 0, 4); + return + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); + } + + private byte[] i64rd = new byte[8]; + public long readI64() throws TException { + readAll(i64rd, 0, 8); + return + ((long)(i64rd[0] & 0xff) << 56) | + ((long)(i64rd[1] & 0xff) << 48) | + ((long)(i64rd[2] & 0xff) << 40) | + ((long)(i64rd[3] & 0xff) << 32) | + ((long)(i64rd[4] & 0xff) << 24) | + ((long)(i64rd[5] & 0xff) << 16) | + ((long)(i64rd[6] & 0xff) << 8) | + ((long)(i64rd[7] & 0xff)); + } + + public double readDouble() throws TException { + return Double.longBitsToDouble(readI64()); + } + + public String readString() throws TException { + int size = readI32(); + return readStringBody(size); + } + + public String readStringBody(int size) throws TException { + try { + checkReadLength(size); + byte[] buf = new byte[size]; + trans_.readAll(buf, 0, size); + return new String(buf, "UTF-8"); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public byte[] readBinary() throws TException { + int size = readI32(); + checkReadLength(size); + byte[] buf = new byte[size]; + trans_.readAll(buf, 0, size); + return buf; + } + + private int readAll(byte[] buf, int off, int len) throws TException { + checkReadLength(len); + return trans_.readAll(buf, off, len); + } + + public void setReadLength(int readLength) { + readLength_ = readLength; + checkReadLength_ = true; + } + + protected void checkReadLength(int length) throws TException { + if (checkReadLength_) { + readLength_ -= length; + if (readLength_ < 0) { + throw new TException("Message length exceeded: " + length); + } + } + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java new file mode 100755 index 00000000..e2d0bfdc --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.protocol; + +import java.util.Stack; +import java.io.UnsupportedEncodingException; + +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.TException; + +/** + * TCompactProtocol2 is the Java implementation of the compact protocol specified + * in THRIFT-110. The fundamental approach to reducing the overhead of + * structures is a) use variable-length integers all over the place and b) make + * use of unused bits wherever possible. Your savings will obviously vary + * based on the specific makeup of your structs, but in general, the more + * fields, nested structures, short strings and collections, and low-value i32 + * and i64 fields you have, the more benefit you'll see. + */ +public final class TCompactProtocol extends TProtocol { + + private final static TStruct ANONYMOUS_STRUCT = new TStruct(""); + private final static TField TSTOP = new TField("", TType.STOP, (short)0); + + private final static byte[] ttypeToCompactType = new byte[16]; + + static { + ttypeToCompactType[TType.STOP] = TType.STOP; + ttypeToCompactType[TType.BOOL] = Types.BOOLEAN_TRUE; + ttypeToCompactType[TType.BYTE] = Types.BYTE; + ttypeToCompactType[TType.I16] = Types.I16; + ttypeToCompactType[TType.I32] = Types.I32; + ttypeToCompactType[TType.I64] = Types.I64; + ttypeToCompactType[TType.DOUBLE] = Types.DOUBLE; + ttypeToCompactType[TType.STRING] = Types.BINARY; + ttypeToCompactType[TType.LIST] = Types.LIST; + ttypeToCompactType[TType.SET] = Types.SET; + ttypeToCompactType[TType.MAP] = Types.MAP; + ttypeToCompactType[TType.STRUCT] = Types.STRUCT; + } + + /** + * TProtocolFactory that produces TCompactProtocols. + */ + public static class Factory implements TProtocolFactory { + public Factory() {} + + public TProtocol getProtocol(TTransport trans) { + return new TCompactProtocol(trans); + } + } + + private static final byte PROTOCOL_ID = (byte)0x82; + private static final byte VERSION = 1; + private static final byte VERSION_MASK = 0x1f; // 0001 1111 + private static final byte TYPE_MASK = (byte)0xE0; // 1110 0000 + private static final int TYPE_SHIFT_AMOUNT = 5; + + /** + * All of the on-wire type codes. + */ + private static class Types { + public static final byte BOOLEAN_TRUE = 0x01; + public static final byte BOOLEAN_FALSE = 0x02; + public static final byte BYTE = 0x03; + public static final byte I16 = 0x04; + public static final byte I32 = 0x05; + public static final byte I64 = 0x06; + public static final byte DOUBLE = 0x07; + public static final byte BINARY = 0x08; + public static final byte LIST = 0x09; + public static final byte SET = 0x0A; + public static final byte MAP = 0x0B; + public static final byte STRUCT = 0x0C; + } + + /** + * Used to keep track of the last field for the current and previous structs, + * so we can do the delta stuff. + */ + private Stack lastField_ = new Stack(); + + private short lastFieldId_ = 0; + + /** + * If we encounter a boolean field begin, save the TField here so it can + * have the value incorporated. + */ + private TField booleanField_ = null; + + /** + * If we read a field header, and it's a boolean field, save the boolean + * value here so that readBool can use it. + */ + private Boolean boolValue_ = null; + + /** + * Create a TCompactProtocol. + * + * @param transport the TTransport object to read from or write to. + */ + public TCompactProtocol(TTransport transport) { + super(transport); + } + + + // + // Public Writing methods. + // + + /** + * Write a message header to the wire. Compact Protocol messages contain the + * protocol version so we can migrate forwards in the future if need be. + */ + public void writeMessageBegin(TMessage message) throws TException { + writeByteDirect(PROTOCOL_ID); + writeByteDirect((VERSION & VERSION_MASK) | ((message.type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + writeVarint32(message.seqid); + writeString(message.name); + } + + /** + * Write a struct begin. This doesn't actually put anything on the wire. We + * use it as an opportunity to put special placeholder markers on the field + * stack so we can get the field id deltas correct. + */ + public void writeStructBegin(TStruct struct) throws TException { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + } + + /** + * Write a struct end. This doesn't actually put anything on the wire. We use + * this as an opportunity to pop the last field from the current struct off + * of the field stack. + */ + public void writeStructEnd() throws TException { + lastFieldId_ = lastField_.pop(); + } + + /** + * Write a field header containing the field id and field type. If the + * difference between the current field id and the last one is small (< 15), + * then the field id will be encoded in the 4 MSB as a delta. Otherwise, the + * field id will follow the type header as a zigzag varint. + */ + public void writeFieldBegin(TField field) throws TException { + if (field.type == TType.BOOL) { + // we want to possibly include the value, so we'll wait. + booleanField_ = field; + } else { + writeFieldBeginInternal(field, (byte)-1); + } + } + + /** + * The workhorse of writeFieldBegin. It has the option of doing a + * 'type override' of the type header. This is used specifically in the + * boolean field case. + */ + private void writeFieldBeginInternal(TField field, byte typeOverride) throws TException { + // short lastField = lastField_.pop(); + + // if there's a type override, use that. + byte typeToWrite = typeOverride == -1 ? getCompactType(field.type) : typeOverride; + + // check if we can use delta encoding for the field id + if (field.id > lastFieldId_ && field.id - lastFieldId_ <= 15) { + // write them together + writeByteDirect((field.id - lastFieldId_) << 4 | typeToWrite); + } else { + // write them separate + writeByteDirect(typeToWrite); + writeI16(field.id); + } + + lastFieldId_ = field.id; + // lastField_.push(field.id); + } + + /** + * Write the STOP symbol so we know there are no more fields in this struct. + */ + public void writeFieldStop() throws TException { + writeByteDirect(TType.STOP); + } + + /** + * Write a map header. If the map is empty, omit the key and value type + * headers, as we don't need any additional information to skip it. + */ + public void writeMapBegin(TMap map) throws TException { + if (map.size == 0) { + writeByteDirect(0); + } else { + writeVarint32(map.size); + writeByteDirect(getCompactType(map.keyType) << 4 | getCompactType(map.valueType)); + } + } + + /** + * Write a list header. + */ + public void writeListBegin(TList list) throws TException { + writeCollectionBegin(list.elemType, list.size); + } + + /** + * Write a set header. + */ + public void writeSetBegin(TSet set) throws TException { + writeCollectionBegin(set.elemType, set.size); + } + + /** + * Write a boolean value. Potentially, this could be a boolean field, in + * which case the field header info isn't written yet. If so, decide what the + * right type header is for the value and then write the field header. + * Otherwise, write a single byte. + */ + public void writeBool(boolean b) throws TException { + if (booleanField_ != null) { + // we haven't written the field header yet + writeFieldBeginInternal(booleanField_, b ? Types.BOOLEAN_TRUE : Types.BOOLEAN_FALSE); + booleanField_ = null; + } else { + // we're not part of a field, so just write the value. + writeByteDirect(b ? Types.BOOLEAN_TRUE : Types.BOOLEAN_FALSE); + } + } + + /** + * Write a byte. Nothing to see here! + */ + public void writeByte(byte b) throws TException { + writeByteDirect(b); + } + + /** + * Write an I16 as a zigzag varint. + */ + public void writeI16(short i16) throws TException { + writeVarint32(intToZigZag(i16)); + } + + /** + * Write an i32 as a zigzag varint. + */ + public void writeI32(int i32) throws TException { + writeVarint32(intToZigZag(i32)); + } + + /** + * Write an i64 as a zigzag varint. + */ + public void writeI64(long i64) throws TException { + writeVarint64(longToZigzag(i64)); + } + + /** + * Write a double to the wire as 8 bytes. + */ + public void writeDouble(double dub) throws TException { + byte[] data = new byte[]{0, 0, 0, 0, 0, 0, 0, 0}; + fixedLongToBytes(Double.doubleToLongBits(dub), data, 0); + trans_.write(data); + } + + /** + * Write a string to the wire with a varint size preceeding. + */ + public void writeString(String str) throws TException { + try { + writeBinary(str.getBytes("UTF-8")); + } catch (UnsupportedEncodingException e) { + throw new TException("UTF-8 not supported!"); + } + } + + /** + * Write a byte array, using a varint for the size. + */ + public void writeBinary(byte[] bin) throws TException { + writeVarint32(bin.length); + trans_.write(bin); + } + + // + // These methods are called by structs, but don't actually have any wire + // output or purpose. + // + + public void writeMessageEnd() throws TException {} + public void writeMapEnd() throws TException {} + public void writeListEnd() throws TException {} + public void writeSetEnd() throws TException {} + public void writeFieldEnd() throws TException {} + + // + // Internal writing methods + // + + /** + * Abstract method for writing the start of lists and sets. List and sets on + * the wire differ only by the type indicator. + */ + protected void writeCollectionBegin(byte elemType, int size) throws TException { + if (size <= 14) { + writeByteDirect(size << 4 | getCompactType(elemType)); + } else { + writeByteDirect(0xf0 | getCompactType(elemType)); + writeVarint32(size); + } + } + + /** + * Write an i32 as a varint. Results in 1-5 bytes on the wire. + * TODO: make a permanent buffer like writeVarint64? + */ + byte[] i32buf = new byte[5]; + private void writeVarint32(int n) throws TException { + int idx = 0; + while (true) { + if ((n & ~0x7F) == 0) { + i32buf[idx++] = (byte)n; + // writeByteDirect((byte)n); + break; + // return; + } else { + i32buf[idx++] = (byte)((n & 0x7F) | 0x80); + // writeByteDirect((byte)((n & 0x7F) | 0x80)); + n >>>= 7; + } + } + trans_.write(i32buf, 0, idx); + } + + /** + * Write an i64 as a varint. Results in 1-10 bytes on the wire. + */ + byte[] varint64out = new byte[10]; + private void writeVarint64(long n) throws TException { + int idx = 0; + while (true) { + if ((n & ~0x7FL) == 0) { + varint64out[idx++] = (byte)n; + break; + } else { + varint64out[idx++] = ((byte)((n & 0x7F) | 0x80)); + n >>>= 7; + } + } + trans_.write(varint64out, 0, idx); + } + + /** + * Convert l into a zigzag long. This allows negative numbers to be + * represented compactly as a varint. + */ + private long longToZigzag(long l) { + return (l << 1) ^ (l >> 63); + } + + /** + * Convert n into a zigzag int. This allows negative numbers to be + * represented compactly as a varint. + */ + private int intToZigZag(int n) { + return (n << 1) ^ (n >> 31); + } + + /** + * Convert a long into little-endian bytes in buf starting at off and going + * until off+7. + */ + private void fixedLongToBytes(long n, byte[] buf, int off) { + buf[off+0] = (byte)( n & 0xff); + buf[off+1] = (byte)((n >> 8 ) & 0xff); + buf[off+2] = (byte)((n >> 16) & 0xff); + buf[off+3] = (byte)((n >> 24) & 0xff); + buf[off+4] = (byte)((n >> 32) & 0xff); + buf[off+5] = (byte)((n >> 40) & 0xff); + buf[off+6] = (byte)((n >> 48) & 0xff); + buf[off+7] = (byte)((n >> 56) & 0xff); + } + + /** + * Writes a byte without any possiblity of all that field header nonsense. + * Used internally by other writing methods that know they need to write a byte. + */ + private byte[] byteDirectBuffer = new byte[1]; + private void writeByteDirect(byte b) throws TException { + byteDirectBuffer[0] = b; + trans_.write(byteDirectBuffer); + } + + /** + * Writes a byte without any possiblity of all that field header nonsense. + */ + private void writeByteDirect(int n) throws TException { + writeByteDirect((byte)n); + } + + + // + // Reading methods. + // + + /** + * Read a message header. + */ + public TMessage readMessageBegin() throws TException { + byte protocolId = readByte(); + if (protocolId != PROTOCOL_ID) { + throw new TProtocolException("Expected protocol id " + Integer.toHexString(PROTOCOL_ID) + " but got " + Integer.toHexString(protocolId)); + } + byte versionAndType = readByte(); + byte version = (byte)(versionAndType & VERSION_MASK); + if (version != VERSION) { + throw new TProtocolException("Expected version " + VERSION + " but got " + version); + } + byte type = (byte)((versionAndType >> TYPE_SHIFT_AMOUNT) & 0x03); + int seqid = readVarint32(); + String messageName = readString(); + return new TMessage(messageName, type, seqid); + } + + /** + * Read a struct begin. There's nothing on the wire for this, but it is our + * opportunity to push a new struct begin marker onto the field stack. + */ + public TStruct readStructBegin() throws TException { + lastField_.push(lastFieldId_); + lastFieldId_ = 0; + return ANONYMOUS_STRUCT; + } + + /** + * Doesn't actually consume any wire data, just removes the last field for + * this struct from the field stack. + */ + public void readStructEnd() throws TException { + // consume the last field we read off the wire. + lastFieldId_ = lastField_.pop(); + } + + /** + * Read a field header off the wire. + */ + public TField readFieldBegin() throws TException { + byte type = readByte(); + + // if it's a stop, then we can return immediately, as the struct is over. + if ((type & 0x0f) == TType.STOP) { + return TSTOP; + } + + short fieldId; + + // mask off the 4 MSB of the type header. it could contain a field id delta. + short modifier = (short)((type & 0xf0) >> 4); + if (modifier == 0) { + // not a delta. look ahead for the zigzag varint field id. + fieldId = readI16(); + } else { + // has a delta. add the delta to the last read field id. + fieldId = (short)(lastFieldId_ + modifier); + } + + TField field = new TField("", getTType((byte)(type & 0x0f)), fieldId); + + // if this happens to be a boolean field, the value is encoded in the type + if (isBoolType(type)) { + // save the boolean value in a special instance variable. + boolValue_ = (byte)(type & 0x0f) == Types.BOOLEAN_TRUE ? Boolean.TRUE : Boolean.FALSE; + } + + // push the new field onto the field stack so we can keep the deltas going. + lastFieldId_ = field.id; + return field; + } + + /** + * Read a map header off the wire. If the size is zero, skip reading the key + * and value type. This means that 0-length maps will yield TMaps without the + * "correct" types. + */ + public TMap readMapBegin() throws TException { + int size = readVarint32(); + byte keyAndValueType = size == 0 ? 0 : readByte(); + return new TMap(getTType((byte)(keyAndValueType >> 4)), getTType((byte)(keyAndValueType & 0xf)), size); + } + + /** + * Read a list header off the wire. If the list size is 0-14, the size will + * be packed into the element type header. If it's a longer list, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ + public TList readListBegin() throws TException { + byte size_and_type = readByte(); + int size = (size_and_type >> 4) & 0x0f; + if (size == 15) { + size = readVarint32(); + } + byte type = getTType(size_and_type); + return new TList(type, size); + } + + /** + * Read a set header off the wire. If the set size is 0-14, the size will + * be packed into the element type header. If it's a longer set, the 4 MSB + * of the element type header will be 0xF, and a varint will follow with the + * true size. + */ + public TSet readSetBegin() throws TException { + return new TSet(readListBegin()); + } + + /** + * Read a boolean off the wire. If this is a boolean field, the value should + * already have been read during readFieldBegin, so we'll just consume the + * pre-stored value. Otherwise, read a byte. + */ + public boolean readBool() throws TException { + if (boolValue_ != null) { + boolean result = boolValue_.booleanValue(); + boolValue_ = null; + return result; + } + return readByte() == Types.BOOLEAN_TRUE; + } + + byte[] byteRawBuf = new byte[1]; + /** + * Read a single byte off the wire. Nothing interesting here. + */ + public byte readByte() throws TException { + trans_.read(byteRawBuf, 0, 1); + return byteRawBuf[0]; + } + + /** + * Read an i16 from the wire as a zigzag varint. + */ + public short readI16() throws TException { + return (short)zigzagToInt(readVarint32()); + } + + /** + * Read an i32 from the wire as a zigzag varint. + */ + public int readI32() throws TException { + return zigzagToInt(readVarint32()); + } + + /** + * Read an i64 from the wire as a zigzag varint. + */ + public long readI64() throws TException { + return zigzagToLong(readVarint64()); + } + + /** + * No magic here - just read a double off the wire. + */ + public double readDouble() throws TException { + byte[] longBits = new byte[8]; + trans_.read(longBits, 0, 8); + return Double.longBitsToDouble(bytesToLong(longBits)); + } + + /** + * Reads a byte[] (via readBinary), and then UTF-8 decodes it. + */ + public String readString() throws TException { + try { + return new String(readBinary(), "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new TException("UTF-8 not supported!"); + } + } + + /** + * Read a byte[] from the wire. + */ + public byte[] readBinary() throws TException { + int length = readVarint32(); + if (length == 0) return new byte[0]; + + byte[] buf = new byte[length]; + trans_.read(buf, 0, length); + return buf; + } + + + // + // These methods are here for the struct to call, but don't have any wire + // encoding. + // + public void readMessageEnd() throws TException {} + public void readFieldEnd() throws TException {} + public void readMapEnd() throws TException {} + public void readListEnd() throws TException {} + public void readSetEnd() throws TException {} + + // + // Internal reading methods + // + + /** + * Read an i32 from the wire as a varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 5 bytes. + */ + private int readVarint32() throws TException { + // if the wire contains the right stuff, this will just truncate the i64 we + // read and get us the right sign. + return (int)readVarint64(); + } + + /** + * Read an i64 from the wire as a proper varint. The MSB of each byte is set + * if there is another byte to follow. This can read up to 10 bytes. + */ + private long readVarint64() throws TException { + int shift = 0; + long result = 0; + while (true) { + byte b = readByte(); + result |= (long) (b & 0x7f) << shift; + if ((b & 0x80) != 0x80) break; + shift +=7; + } + return result; + } + + // + // encoding helpers + // + + /** + * Convert from zigzag int to int. + */ + private int zigzagToInt(int n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Convert from zigzag long to long. + */ + private long zigzagToLong(long n) { + return (n >>> 1) ^ -(n & 1); + } + + /** + * Note that it's important that the mask bytes are long literals, + * otherwise they'll default to ints, and when you shift an int left 56 bits, + * you just get a messed up int. + */ + private long bytesToLong(byte[] bytes) { + return + ((bytes[7] & 0xffL) << 56) | + ((bytes[6] & 0xffL) << 48) | + ((bytes[5] & 0xffL) << 40) | + ((bytes[4] & 0xffL) << 32) | + ((bytes[3] & 0xffL) << 24) | + ((bytes[2] & 0xffL) << 16) | + ((bytes[1] & 0xffL) << 8) | + ((bytes[0] & 0xffL)); + } + + // + // type testing and converting + // + + private boolean isBoolType(byte b) { + return (b & 0x0f) == Types.BOOLEAN_TRUE || (b & 0x0f) == Types.BOOLEAN_FALSE; + } + + /** + * Given a TCompactProtocol.Types constant, convert it to its corresponding + * TType value. + */ + private byte getTType(byte type) { + switch ((byte)(type & 0x0f)) { + case TType.STOP: + return TType.STOP; + case Types.BOOLEAN_FALSE: + case Types.BOOLEAN_TRUE: + return TType.BOOL; + case Types.BYTE: + return TType.BYTE; + case Types.I16: + return TType.I16; + case Types.I32: + return TType.I32; + case Types.I64: + return TType.I64; + case Types.DOUBLE: + return TType.DOUBLE; + case Types.BINARY: + return TType.STRING; + case Types.LIST: + return TType.LIST; + case Types.SET: + return TType.SET; + case Types.MAP: + return TType.MAP; + case Types.STRUCT: + return TType.STRUCT; + default: + throw new RuntimeException("don't know what type: " + (byte)(type & 0x0f)); + } + } + + /** + * Given a TType value, find the appropriate TCompactProtocol.Types constant. + */ + private byte getCompactType(byte ttype) { + return ttypeToCompactType[ttype]; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TField.java b/lib/java/src/org/apache/thrift/protocol/TField.java new file mode 100644 index 00000000..03affdaa --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TField.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates field metadata. + * + */ +public class TField { + public TField() { + this("", TType.STOP, (short)0); + } + + public TField(String n, byte t, short i) { + name = n; + type = t; + id = i; + } + + public final String name; + public final byte type; + public final short id; + + public String toString() { + return ""; + } + + public boolean equals(TField otherField) { + return type == otherField.type && id == otherField.id; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java new file mode 100644 index 00000000..631c6a5b --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TJSONProtocol.java @@ -0,0 +1,927 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.transport.TTransport; +import java.io.UnsupportedEncodingException; +import java.util.Stack; + +/** + * JSON protocol implementation for thrift. + * + * This is a full-featured protocol supporting write and read. + * + * Please see the C++ class header for a detailed description of the + * protocol's wire format. + * + */ +public class TJSONProtocol extends TProtocol { + + /** + * Factory for JSON protocol objects + */ + public static class Factory implements TProtocolFactory { + + public TProtocol getProtocol(TTransport trans) { + return new TJSONProtocol(trans); + } + + } + + private static final byte[] COMMA = new byte[] {','}; + private static final byte[] COLON = new byte[] {':'}; + private static final byte[] LBRACE = new byte[] {'{'}; + private static final byte[] RBRACE = new byte[] {'}'}; + private static final byte[] LBRACKET = new byte[] {'['}; + private static final byte[] RBRACKET = new byte[] {']'}; + private static final byte[] QUOTE = new byte[] {'"'}; + private static final byte[] BACKSLASH = new byte[] {'\\'}; + private static final byte[] ZERO = new byte[] {'0'}; + + private static final byte[] ESCSEQ = new byte[] {'\\','u','0','0'}; + + private static final long VERSION = 1; + + private static final byte[] JSON_CHAR_TABLE = { + /* 0 1 2 3 4 5 6 7 8 9 A B C D E F */ + 0, 0, 0, 0, 0, 0, 0, 0,'b','t','n', 0,'f','r', 0, 0, // 0 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1 + 1, 1,'"', 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 2 + }; + + private static final String ESCAPE_CHARS = "\"\\bfnrt"; + + private static final byte[] ESCAPE_CHAR_VALS = { + '"', '\\', '\b', '\f', '\n', '\r', '\t', + }; + + private static final int DEF_STRING_SIZE = 16; + + private static final byte[] NAME_BOOL = new byte[] {'t', 'f'}; + private static final byte[] NAME_BYTE = new byte[] {'i','8'}; + private static final byte[] NAME_I16 = new byte[] {'i','1','6'}; + private static final byte[] NAME_I32 = new byte[] {'i','3','2'}; + private static final byte[] NAME_I64 = new byte[] {'i','6','4'}; + private static final byte[] NAME_DOUBLE = new byte[] {'d','b','l'}; + private static final byte[] NAME_STRUCT = new byte[] {'r','e','c'}; + private static final byte[] NAME_STRING = new byte[] {'s','t','r'}; + private static final byte[] NAME_MAP = new byte[] {'m','a','p'}; + private static final byte[] NAME_LIST = new byte[] {'l','s','t'}; + private static final byte[] NAME_SET = new byte[] {'s','e','t'}; + + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + + private static final byte[] getTypeNameForTypeID(byte typeID) + throws TException { + switch (typeID) { + case TType.BOOL: + return NAME_BOOL; + case TType.BYTE: + return NAME_BYTE; + case TType.I16: + return NAME_I16; + case TType.I32: + return NAME_I32; + case TType.I64: + return NAME_I64; + case TType.DOUBLE: + return NAME_DOUBLE; + case TType.STRING: + return NAME_STRING; + case TType.STRUCT: + return NAME_STRUCT; + case TType.MAP: + return NAME_MAP; + case TType.SET: + return NAME_SET; + case TType.LIST: + return NAME_LIST; + default: + throw new TProtocolException(TProtocolException.NOT_IMPLEMENTED, + "Unrecognized type"); + } + } + + private static final byte getTypeIDForTypeName(byte[] name) + throws TException { + byte result = TType.STOP; + if (name.length > 1) { + switch (name[0]) { + case 'd': + result = TType.DOUBLE; + break; + case 'i': + switch (name[1]) { + case '8': + result = TType.BYTE; + break; + case '1': + result = TType.I16; + break; + case '3': + result = TType.I32; + break; + case '6': + result = TType.I64; + break; + } + break; + case 'l': + result = TType.LIST; + break; + case 'm': + result = TType.MAP; + break; + case 'r': + result = TType.STRUCT; + break; + case 's': + if (name[1] == 't') { + result = TType.STRING; + } + else if (name[1] == 'e') { + result = TType.SET; + } + break; + case 't': + result = TType.BOOL; + break; + } + } + if (result == TType.STOP) { + throw new TProtocolException(TProtocolException.NOT_IMPLEMENTED, + "Unrecognized type"); + } + return result; + } + + // Base class for tracking JSON contexts that may require inserting/reading + // additional JSON syntax characters + // This base context does nothing. + protected class JSONBaseContext { + protected void write() throws TException {} + + protected void read() throws TException {} + + protected boolean escapeNum() { return false; } + } + + // Context for JSON lists. Will insert/read commas before each item except + // for the first one + protected class JSONListContext extends JSONBaseContext { + private boolean first_ = true; + + @Override + protected void write() throws TException { + if (first_) { + first_ = false; + } else { + trans_.write(COMMA); + } + } + + @Override + protected void read() throws TException { + if (first_) { + first_ = false; + } else { + readJSONSyntaxChar(COMMA); + } + } + } + + // Context for JSON records. Will insert/read colons before the value portion + // of each record pair, and commas before each key except the first. In + // addition, will indicate that numbers in the key position need to be + // escaped in quotes (since JSON keys must be strings). + protected class JSONPairContext extends JSONBaseContext { + private boolean first_ = true; + private boolean colon_ = true; + + @Override + protected void write() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + trans_.write(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + + @Override + protected void read() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + readJSONSyntaxChar(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + + @Override + protected boolean escapeNum() { + return colon_; + } + } + + // Holds up to one byte from the transport + protected class LookaheadReader { + + private boolean hasData_; + private byte[] data_ = new byte[1]; + + // Return and consume the next byte to be read, either taking it from the + // data buffer if present or getting it from the transport otherwise. + protected byte read() throws TException { + if (hasData_) { + hasData_ = false; + } + else { + trans_.readAll(data_, 0, 1); + } + return data_[0]; + } + + // Return the next byte to be read without consuming, filling the data + // buffer if it has not been filled already. + protected byte peek() throws TException { + if (!hasData_) { + trans_.readAll(data_, 0, 1); + } + hasData_ = true; + return data_[0]; + } + } + + // Stack of nested contexts that we may be in + private Stack contextStack_ = new Stack(); + + // Current context that we are in + private JSONBaseContext context_ = new JSONBaseContext(); + + // Reader that manages a 1-byte buffer + private LookaheadReader reader_ = new LookaheadReader(); + + // Push a new JSON context onto the stack. + private void pushContext(JSONBaseContext c) { + contextStack_.push(context_); + context_ = c; + } + + // Pop the last JSON context off the stack + private void popContext() { + context_ = contextStack_.pop(); + } + + /** + * Constructor + */ + public TJSONProtocol(TTransport trans) { + super(trans); + } + + // Temporary buffer used by several methods + private byte[] tmpbuf_ = new byte[4]; + + // Read a byte that must match b[0]; otherwise an excpetion is thrown. + // Marked protected to avoid synthetic accessor in JSONListContext.read + // and JSONPairContext.read + protected void readJSONSyntaxChar(byte[] b) throws TException { + byte ch = reader_.read(); + if (ch != b[0]) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Unexpected character:" + (char)ch); + } + } + + // Convert a byte containing a hex char ('0'-'9' or 'a'-'f') into its + // corresponding hex value + private static final byte hexVal(byte ch) throws TException { + if ((ch >= '0') && (ch <= '9')) { + return (byte)((char)ch - '0'); + } + else if ((ch >= 'a') && (ch <= 'f')) { + return (byte)((char)ch - 'a'); + } + else { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Expected hex character"); + } + } + + // Convert a byte containing a hex value to its corresponding hex character + private static final byte hexChar(byte val) { + val &= 0x0F; + if (val < 10) { + return (byte)((char)val + '0'); + } + else { + return (byte)((char)val + 'a'); + } + } + + // Write the bytes in array buf as a JSON characters, escaping as needed + private void writeJSONString(byte[] b) throws TException { + context_.write(); + trans_.write(QUOTE); + int len = b.length; + for (int i = 0; i < len; i++) { + if ((b[i] & 0x00FF) >= 0x30) { + if (b[i] == BACKSLASH[0]) { + trans_.write(BACKSLASH); + trans_.write(BACKSLASH); + } + else { + trans_.write(b, i, 1); + } + } + else { + tmpbuf_[0] = JSON_CHAR_TABLE[b[i]]; + if (tmpbuf_[0] == 1) { + trans_.write(b, i, 1); + } + else if (tmpbuf_[0] > 1) { + trans_.write(BACKSLASH); + trans_.write(tmpbuf_, 0, 1); + } + else { + trans_.write(ESCSEQ); + tmpbuf_[0] = hexChar((byte)(b[i] >> 4)); + tmpbuf_[1] = hexChar(b[i]); + trans_.write(tmpbuf_, 0, 2); + } + } + } + trans_.write(QUOTE); + } + + // Write out number as a JSON value. If the context dictates so, it will be + // wrapped in quotes to output as a JSON string. + private void writeJSONInteger(long num) throws TException { + context_.write(); + String str = Long.toString(num); + boolean escapeNum = context_.escapeNum(); + if (escapeNum) { + trans_.write(QUOTE); + } + try { + byte[] buf = str.getBytes("UTF-8"); + trans_.write(buf); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + if (escapeNum) { + trans_.write(QUOTE); + } + } + + // Write out a double as a JSON value. If it is NaN or infinity or if the + // context dictates escaping, write out as JSON string. + private void writeJSONDouble(double num) throws TException { + context_.write(); + String str = Double.toString(num); + boolean special = false; + switch (str.charAt(0)) { + case 'N': // NaN + case 'I': // Infinity + special = true; + break; + case '-': + if (str.charAt(1) == 'I') { // -Infinity + special = true; + } + break; + } + + boolean escapeNum = special || context_.escapeNum(); + if (escapeNum) { + trans_.write(QUOTE); + } + try { + byte[] b = str.getBytes("UTF-8"); + trans_.write(b, 0, b.length); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + if (escapeNum) { + trans_.write(QUOTE); + } + } + + // Write out contents of byte array b as a JSON string with base-64 encoded + // data + private void writeJSONBase64(byte[] b) throws TException { + context_.write(); + trans_.write(QUOTE); + int len = b.length; + int off = 0; + while (len >= 3) { + // Encode 3 bytes at a time + TBase64Utils.encode(b, off, 3, tmpbuf_, 0); + trans_.write(tmpbuf_, 0, 4); + off += 3; + len -= 3; + } + if (len > 0) { + // Encode remainder + TBase64Utils.encode(b, off, len, tmpbuf_, 0); + trans_.write(tmpbuf_, 0, len + 1); + } + trans_.write(QUOTE); + } + + private void writeJSONObjectStart() throws TException { + context_.write(); + trans_.write(LBRACE); + pushContext(new JSONPairContext()); + } + + private void writeJSONObjectEnd() throws TException { + popContext(); + trans_.write(RBRACE); + } + + private void writeJSONArrayStart() throws TException { + context_.write(); + trans_.write(LBRACKET); + pushContext(new JSONListContext()); + } + + private void writeJSONArrayEnd() throws TException { + popContext(); + trans_.write(RBRACKET); + } + + @Override + public void writeMessageBegin(TMessage message) throws TException { + writeJSONArrayStart(); + writeJSONInteger(VERSION); + try { + byte[] b = message.name.getBytes("UTF-8"); + writeJSONString(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + writeJSONInteger(message.type); + writeJSONInteger(message.seqid); + } + + @Override + public void writeMessageEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeStructBegin(TStruct struct) throws TException { + writeJSONObjectStart(); + } + + @Override + public void writeStructEnd() throws TException { + writeJSONObjectEnd(); + } + + @Override + public void writeFieldBegin(TField field) throws TException { + writeJSONInteger(field.id); + writeJSONObjectStart(); + writeJSONString(getTypeNameForTypeID(field.type)); + } + + @Override + public void writeFieldEnd() throws TException { + writeJSONObjectEnd(); + } + + @Override + public void writeFieldStop() {} + + @Override + public void writeMapBegin(TMap map) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(map.keyType)); + writeJSONString(getTypeNameForTypeID(map.valueType)); + writeJSONInteger(map.size); + writeJSONObjectStart(); + } + + @Override + public void writeMapEnd() throws TException { + writeJSONObjectEnd(); + writeJSONArrayEnd(); + } + + @Override + public void writeListBegin(TList list) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(list.elemType)); + writeJSONInteger(list.size); + } + + @Override + public void writeListEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeSetBegin(TSet set) throws TException { + writeJSONArrayStart(); + writeJSONString(getTypeNameForTypeID(set.elemType)); + writeJSONInteger(set.size); + } + + @Override + public void writeSetEnd() throws TException { + writeJSONArrayEnd(); + } + + @Override + public void writeBool(boolean b) throws TException { + writeJSONInteger(b ? (long)1 : (long)0); + } + + @Override + public void writeByte(byte b) throws TException { + writeJSONInteger((long)b); + } + + @Override + public void writeI16(short i16) throws TException { + writeJSONInteger((long)i16); + } + + @Override + public void writeI32(int i32) throws TException { + writeJSONInteger((long)i32); + } + + @Override + public void writeI64(long i64) throws TException { + writeJSONInteger(i64); + } + + @Override + public void writeDouble(double dub) throws TException { + writeJSONDouble(dub); + } + + @Override + public void writeString(String str) throws TException { + try { + byte[] b = str.getBytes("UTF-8"); + writeJSONString(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + @Override + public void writeBinary(byte[] bin) throws TException { + writeJSONBase64(bin); + } + + /** + * Reading methods. + */ + + // Read in a JSON string, unescaping as appropriate.. Skip reading from the + // context if skipContext is true. + private TByteArrayOutputStream readJSONString(boolean skipContext) + throws TException { + TByteArrayOutputStream arr = new TByteArrayOutputStream(DEF_STRING_SIZE); + if (!skipContext) { + context_.read(); + } + readJSONSyntaxChar(QUOTE); + while (true) { + byte ch = reader_.read(); + if (ch == QUOTE[0]) { + break; + } + if (ch == ESCSEQ[0]) { + ch = reader_.read(); + if (ch == ESCSEQ[1]) { + readJSONSyntaxChar(ZERO); + readJSONSyntaxChar(ZERO); + trans_.readAll(tmpbuf_, 0, 2); + ch = (byte)((hexVal((byte)tmpbuf_[0]) << 4) + hexVal(tmpbuf_[1])); + } + else { + int off = ESCAPE_CHARS.indexOf(ch); + if (off == -1) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Expected control char"); + } + ch = ESCAPE_CHAR_VALS[off]; + } + } + arr.write(ch); + } + return arr; + } + + // Return true if the given byte could be a valid part of a JSON number. + private boolean isJSONNumeric(byte b) { + switch (b) { + case '+': + case '-': + case '.': + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'E': + case 'e': + return true; + } + return false; + } + + // Read in a sequence of characters that are all valid in JSON numbers. Does + // not do a complete regex check to validate that this is actually a number. + private String readJSONNumericChars() throws TException { + StringBuilder strbld = new StringBuilder(); + while (true) { + byte ch = reader_.peek(); + if (!isJSONNumeric(ch)) { + break; + } + strbld.append((char)reader_.read()); + } + return strbld.toString(); + } + + // Read in a JSON number. If the context dictates, read in enclosing quotes. + private long readJSONInteger() throws TException { + context_.read(); + if (context_.escapeNum()) { + readJSONSyntaxChar(QUOTE); + } + String str = readJSONNumericChars(); + if (context_.escapeNum()) { + readJSONSyntaxChar(QUOTE); + } + try { + return Long.valueOf(str); + } + catch (NumberFormatException ex) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data"); + } + } + + // Read in a JSON double value. Throw if the value is not wrapped in quotes + // when expected or if wrapped in quotes when not expected. + private double readJSONDouble() throws TException { + context_.read(); + if (reader_.peek() == QUOTE[0]) { + TByteArrayOutputStream arr = readJSONString(true); + try { + double dub = Double.valueOf(arr.toString("UTF-8")); + if (!context_.escapeNum() && !Double.isNaN(dub) && + !Double.isInfinite(dub)) { + // Throw exception -- we should not be in a string in this case + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Numeric data unexpectedly quoted"); + } + return dub; + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + else { + if (context_.escapeNum()) { + // This will throw - we should have had a quote if escapeNum == true + readJSONSyntaxChar(QUOTE); + } + try { + return Double.valueOf(readJSONNumericChars()); + } + catch (NumberFormatException ex) { + throw new TProtocolException(TProtocolException.INVALID_DATA, + "Bad data encounted in numeric data"); + } + } + } + + // Read in a JSON string containing base-64 encoded data and decode it. + private byte[] readJSONBase64() throws TException { + TByteArrayOutputStream arr = readJSONString(false); + byte[] b = arr.get(); + int len = arr.len(); + int off = 0; + int size = 0; + while (len >= 4) { + // Decode 4 bytes at a time + TBase64Utils.decode(b, off, 4, b, size); // NB: decoded in place + off += 4; + len -= 4; + size += 3; + } + // Don't decode if we hit the end or got a single leftover byte (invalid + // base64 but legal for skip of regular string type) + if (len > 1) { + // Decode remainder + TBase64Utils.decode(b, off, len, b, size); // NB: decoded in place + size += len - 1; + } + // Sadly we must copy the byte[] (any way around this?) + byte [] result = new byte[size]; + System.arraycopy(b, 0, result, 0, size); + return result; + } + + private void readJSONObjectStart() throws TException { + context_.read(); + readJSONSyntaxChar(LBRACE); + pushContext(new JSONPairContext()); + } + + private void readJSONObjectEnd() throws TException { + readJSONSyntaxChar(RBRACE); + popContext(); + } + + private void readJSONArrayStart() throws TException { + context_.read(); + readJSONSyntaxChar(LBRACKET); + pushContext(new JSONListContext()); + } + + private void readJSONArrayEnd() throws TException { + readJSONSyntaxChar(RBRACKET); + popContext(); + } + + @Override + public TMessage readMessageBegin() throws TException { + readJSONArrayStart(); + if (readJSONInteger() != VERSION) { + throw new TProtocolException(TProtocolException.BAD_VERSION, + "Message contained bad version."); + } + String name; + try { + name = readJSONString(false).toString("UTF-8"); + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + byte type = (byte) readJSONInteger(); + int seqid = (int) readJSONInteger(); + return new TMessage(name, type, seqid); + } + + @Override + public void readMessageEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public TStruct readStructBegin() throws TException { + readJSONObjectStart(); + return ANONYMOUS_STRUCT; + } + + @Override + public void readStructEnd() throws TException { + readJSONObjectEnd(); + } + + @Override + public TField readFieldBegin() throws TException { + byte ch = reader_.peek(); + byte type; + short id = 0; + if (ch == RBRACE[0]) { + type = TType.STOP; + } + else { + id = (short) readJSONInteger(); + readJSONObjectStart(); + type = getTypeIDForTypeName(readJSONString(false).get()); + } + return new TField("", type, id); + } + + @Override + public void readFieldEnd() throws TException { + readJSONObjectEnd(); + } + + @Override + public TMap readMapBegin() throws TException { + readJSONArrayStart(); + byte keyType = getTypeIDForTypeName(readJSONString(false).get()); + byte valueType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + readJSONObjectStart(); + return new TMap(keyType, valueType, size); + } + + @Override + public void readMapEnd() throws TException { + readJSONObjectEnd(); + readJSONArrayEnd(); + } + + @Override + public TList readListBegin() throws TException { + readJSONArrayStart(); + byte elemType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + return new TList(elemType, size); + } + + @Override + public void readListEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public TSet readSetBegin() throws TException { + readJSONArrayStart(); + byte elemType = getTypeIDForTypeName(readJSONString(false).get()); + int size = (int)readJSONInteger(); + return new TSet(elemType, size); + } + + @Override + public void readSetEnd() throws TException { + readJSONArrayEnd(); + } + + @Override + public boolean readBool() throws TException { + return (readJSONInteger() == 0 ? false : true); + } + + @Override + public byte readByte() throws TException { + return (byte) readJSONInteger(); + } + + @Override + public short readI16() throws TException { + return (short) readJSONInteger(); + } + + @Override + public int readI32() throws TException { + return (int) readJSONInteger(); + } + + @Override + public long readI64() throws TException { + return (long) readJSONInteger(); + } + + @Override + public double readDouble() throws TException { + return readJSONDouble(); + } + + @Override + public String readString() throws TException { + try { + return readJSONString(false).toString("UTF-8"); + } + catch (UnsupportedEncodingException ex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + @Override + public byte[] readBinary() throws TException { + return readJSONBase64(); + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TList.java b/lib/java/src/org/apache/thrift/protocol/TList.java new file mode 100644 index 00000000..0d36e83d --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TList.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates list metadata. + * + */ +public final class TList { + public TList() { + this(TType.STOP, 0); + } + + public TList(byte t, int s) { + elemType = t; + size = s; + } + + public final byte elemType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMap.java b/lib/java/src/org/apache/thrift/protocol/TMap.java new file mode 100644 index 00000000..20881f7a --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMap.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates map metadata. + * + */ +public final class TMap { + public TMap() { + this(TType.STOP, TType.STOP, 0); + } + + public TMap(byte k, byte v, int s) { + keyType = k; + valueType = v; + size = s; + } + + public final byte keyType; + public final byte valueType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMessage.java b/lib/java/src/org/apache/thrift/protocol/TMessage.java new file mode 100644 index 00000000..cd56964d --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMessage.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates struct metadata. + * + */ +public final class TMessage { + public TMessage() { + this("", TType.STOP, 0); + } + + public TMessage(String n, byte t, int s) { + name = n; + type = t; + seqid = s; + } + + public final String name; + public final byte type; + public final int seqid; + + public String toString() { + return ""; + } + + public boolean equals(TMessage other) { + return name.equals(other.name) && type == other.type && seqid == other.seqid; + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TMessageType.java b/lib/java/src/org/apache/thrift/protocol/TMessageType.java new file mode 100644 index 00000000..aa3f9317 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TMessageType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Message type constants in the Thrift protocol. + * + */ +public final class TMessageType { + public static final byte CALL = 1; + public static final byte REPLY = 2; + public static final byte EXCEPTION = 3; + public static final byte ONEWAY = 4; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/org/apache/thrift/protocol/TProtocol.java new file mode 100644 index 00000000..50d6683d --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocol.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * Protocol interface definition. + * + */ +public abstract class TProtocol { + + /** + * Prevent direct instantiation + */ + @SuppressWarnings("unused") + private TProtocol() {} + + /** + * Transport + */ + protected TTransport trans_; + + /** + * Constructor + */ + protected TProtocol(TTransport trans) { + trans_ = trans; + } + + /** + * Transport accessor + */ + public TTransport getTransport() { + return trans_; + } + + /** + * Writing methods. + */ + + public abstract void writeMessageBegin(TMessage message) throws TException; + + public abstract void writeMessageEnd() throws TException; + + public abstract void writeStructBegin(TStruct struct) throws TException; + + public abstract void writeStructEnd() throws TException; + + public abstract void writeFieldBegin(TField field) throws TException; + + public abstract void writeFieldEnd() throws TException; + + public abstract void writeFieldStop() throws TException; + + public abstract void writeMapBegin(TMap map) throws TException; + + public abstract void writeMapEnd() throws TException; + + public abstract void writeListBegin(TList list) throws TException; + + public abstract void writeListEnd() throws TException; + + public abstract void writeSetBegin(TSet set) throws TException; + + public abstract void writeSetEnd() throws TException; + + public abstract void writeBool(boolean b) throws TException; + + public abstract void writeByte(byte b) throws TException; + + public abstract void writeI16(short i16) throws TException; + + public abstract void writeI32(int i32) throws TException; + + public abstract void writeI64(long i64) throws TException; + + public abstract void writeDouble(double dub) throws TException; + + public abstract void writeString(String str) throws TException; + + public abstract void writeBinary(byte[] bin) throws TException; + + /** + * Reading methods. + */ + + public abstract TMessage readMessageBegin() throws TException; + + public abstract void readMessageEnd() throws TException; + + public abstract TStruct readStructBegin() throws TException; + + public abstract void readStructEnd() throws TException; + + public abstract TField readFieldBegin() throws TException; + + public abstract void readFieldEnd() throws TException; + + public abstract TMap readMapBegin() throws TException; + + public abstract void readMapEnd() throws TException; + + public abstract TList readListBegin() throws TException; + + public abstract void readListEnd() throws TException; + + public abstract TSet readSetBegin() throws TException; + + public abstract void readSetEnd() throws TException; + + public abstract boolean readBool() throws TException; + + public abstract byte readByte() throws TException; + + public abstract short readI16() throws TException; + + public abstract int readI32() throws TException; + + public abstract long readI64() throws TException; + + public abstract double readDouble() throws TException; + + public abstract String readString() throws TException; + + public abstract byte[] readBinary() throws TException; + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolException.java b/lib/java/src/org/apache/thrift/protocol/TProtocolException.java new file mode 100644 index 00000000..248815be --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolException.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; + +/** + * Protocol exceptions. + * + */ +public class TProtocolException extends TException { + + + private static final long serialVersionUID = 1L; + public static final int UNKNOWN = 0; + public static final int INVALID_DATA = 1; + public static final int NEGATIVE_SIZE = 2; + public static final int SIZE_LIMIT = 3; + public static final int BAD_VERSION = 4; + public static final int NOT_IMPLEMENTED = 5; + + protected int type_ = UNKNOWN; + + public TProtocolException() { + super(); + } + + public TProtocolException(int type) { + super(); + type_ = type; + } + + public TProtocolException(int type, String message) { + super(message); + type_ = type; + } + + public TProtocolException(String message) { + super(message); + } + + public TProtocolException(int type, Throwable cause) { + super(cause); + type_ = type; + } + + public TProtocolException(Throwable cause) { + super(cause); + } + + public TProtocolException(String message, Throwable cause) { + super(message, cause); + } + + public TProtocolException(int type, String message, Throwable cause) { + super(message, cause); + type_ = type; + } + + public int getType() { + return type_; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java b/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java new file mode 100644 index 00000000..afa502b7 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.transport.TTransport; + +/** + * Factory interface for constructing protocol instances. + * + */ +public interface TProtocolFactory { + public TProtocol getProtocol(TTransport trans); +} diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java new file mode 100644 index 00000000..9bf10f67 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TProtocolUtil.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import org.apache.thrift.TException; + +/** + * Utility class with static methods for interacting with protocol data + * streams. + * + */ +public class TProtocolUtil { + + /** + * The maximum recursive depth the skip() function will traverse before + * throwing a TException. + */ + private static int maxSkipDepth = Integer.MAX_VALUE; + + /** + * Specifies the maximum recursive depth that the skip function will + * traverse before throwing a TException. This is a global setting, so + * any call to skip in this JVM will enforce this value. + * + * @param depth the maximum recursive depth. A value of 2 would allow + * the skip function to skip a structure or collection with basic children, + * but it would not permit skipping a struct that had a field containing + * a child struct. A value of 1 would only allow skipping of simple + * types and empty structs/collections. + */ + public static void setMaxSkipDepth(int depth) { + maxSkipDepth = depth; + } + + /** + * Skips over the next data element from the provided input TProtocol object. + * + * @param prot the protocol object to read from + * @param type the next value will be intepreted as this TType value. + */ + public static void skip(TProtocol prot, byte type) + throws TException { + skip(prot, type, maxSkipDepth); + } + + /** + * Skips over the next data element from the provided input TProtocol object. + * + * @param prot the protocol object to read from + * @param type the next value will be intepreted as this TType value. + * @param maxDepth this function will only skip complex objects to this + * recursive depth, to prevent Java stack overflow. + */ + public static void skip(TProtocol prot, byte type, int maxDepth) + throws TException { + if (maxDepth <= 0) { + throw new TException("Maximum skip depth exceeded"); + } + switch (type) { + case TType.BOOL: + { + prot.readBool(); + break; + } + case TType.BYTE: + { + prot.readByte(); + break; + } + case TType.I16: + { + prot.readI16(); + break; + } + case TType.I32: + { + prot.readI32(); + break; + } + case TType.I64: + { + prot.readI64(); + break; + } + case TType.DOUBLE: + { + prot.readDouble(); + break; + } + case TType.STRING: + { + prot.readBinary(); + break; + } + case TType.STRUCT: + { + prot.readStructBegin(); + while (true) { + TField field = prot.readFieldBegin(); + if (field.type == TType.STOP) { + break; + } + skip(prot, field.type, maxDepth - 1); + prot.readFieldEnd(); + } + prot.readStructEnd(); + break; + } + case TType.MAP: + { + TMap map = prot.readMapBegin(); + for (int i = 0; i < map.size; i++) { + skip(prot, map.keyType, maxDepth - 1); + skip(prot, map.valueType, maxDepth - 1); + } + prot.readMapEnd(); + break; + } + case TType.SET: + { + TSet set = prot.readSetBegin(); + for (int i = 0; i < set.size; i++) { + skip(prot, set.elemType, maxDepth - 1); + } + prot.readSetEnd(); + break; + } + case TType.LIST: + { + TList list = prot.readListBegin(); + for (int i = 0; i < list.size; i++) { + skip(prot, list.elemType, maxDepth - 1); + } + prot.readListEnd(); + break; + } + default: + break; + } + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TSet.java b/lib/java/src/org/apache/thrift/protocol/TSet.java new file mode 100644 index 00000000..38be9a99 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TSet.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates set metadata. + * + */ +public final class TSet { + public TSet() { + this(TType.STOP, 0); + } + + public TSet(byte t, int s) { + elemType = t; + size = s; + } + + public TSet(TList list) { + this(list.elemType, list.size); + } + + public final byte elemType; + public final int size; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java b/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java new file mode 100644 index 00000000..a60bdf40 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TSimpleJSONProtocol.java @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +import java.io.UnsupportedEncodingException; +import java.util.Stack; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; + +/** + * JSON protocol implementation for thrift. + * + * This protocol is write-only and produces a simple output format + * suitable for parsing by scripting languages. It should not be + * confused with the full-featured TJSONProtocol. + * + */ +public class TSimpleJSONProtocol extends TProtocol { + + /** + * Factory + */ + public static class Factory implements TProtocolFactory { + public TProtocol getProtocol(TTransport trans) { + return new TSimpleJSONProtocol(trans); + } + } + + public static final byte[] COMMA = new byte[] {','}; + public static final byte[] COLON = new byte[] {':'}; + public static final byte[] LBRACE = new byte[] {'{'}; + public static final byte[] RBRACE = new byte[] {'}'}; + public static final byte[] LBRACKET = new byte[] {'['}; + public static final byte[] RBRACKET = new byte[] {']'}; + public static final char QUOTE = '"'; + + private static final TStruct ANONYMOUS_STRUCT = new TStruct(); + private static final TField ANONYMOUS_FIELD = new TField(); + private static final TMessage EMPTY_MESSAGE = new TMessage(); + private static final TSet EMPTY_SET = new TSet(); + private static final TList EMPTY_LIST = new TList(); + private static final TMap EMPTY_MAP = new TMap(); + + protected class Context { + protected void write() throws TException {} + } + + protected class ListContext extends Context { + protected boolean first_ = true; + + protected void write() throws TException { + if (first_) { + first_ = false; + } else { + trans_.write(COMMA); + } + } + } + + protected class StructContext extends Context { + protected boolean first_ = true; + protected boolean colon_ = true; + + protected void write() throws TException { + if (first_) { + first_ = false; + colon_ = true; + } else { + trans_.write(colon_ ? COLON : COMMA); + colon_ = !colon_; + } + } + } + + protected final Context BASE_CONTEXT = new Context(); + + /** + * Stack of nested contexts that we may be in. + */ + protected Stack writeContextStack_ = new Stack(); + + /** + * Current context that we are in + */ + protected Context writeContext_ = BASE_CONTEXT; + + /** + * Push a new write context onto the stack. + */ + protected void pushWriteContext(Context c) { + writeContextStack_.push(writeContext_); + writeContext_ = c; + } + + /** + * Pop the last write context off the stack + */ + protected void popWriteContext() { + writeContext_ = writeContextStack_.pop(); + } + + /** + * Constructor + */ + public TSimpleJSONProtocol(TTransport trans) { + super(trans); + } + + public void writeMessageBegin(TMessage message) throws TException { + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + writeString(message.name); + writeByte(message.type); + writeI32(message.seqid); + } + + public void writeMessageEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeStructBegin(TStruct struct) throws TException { + writeContext_.write(); + trans_.write(LBRACE); + pushWriteContext(new StructContext()); + } + + public void writeStructEnd() throws TException { + popWriteContext(); + trans_.write(RBRACE); + } + + public void writeFieldBegin(TField field) throws TException { + // Note that extra type information is omitted in JSON! + writeString(field.name); + } + + public void writeFieldEnd() {} + + public void writeFieldStop() {} + + public void writeMapBegin(TMap map) throws TException { + writeContext_.write(); + trans_.write(LBRACE); + pushWriteContext(new StructContext()); + // No metadata! + } + + public void writeMapEnd() throws TException { + popWriteContext(); + trans_.write(RBRACE); + } + + public void writeListBegin(TList list) throws TException { + writeContext_.write(); + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + // No metadata! + } + + public void writeListEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeSetBegin(TSet set) throws TException { + writeContext_.write(); + trans_.write(LBRACKET); + pushWriteContext(new ListContext()); + // No metadata! + } + + public void writeSetEnd() throws TException { + popWriteContext(); + trans_.write(RBRACKET); + } + + public void writeBool(boolean b) throws TException { + writeByte(b ? (byte)1 : (byte)0); + } + + public void writeByte(byte b) throws TException { + writeI32(b); + } + + public void writeI16(short i16) throws TException { + writeI32(i16); + } + + public void writeI32(int i32) throws TException { + writeContext_.write(); + _writeStringData(Integer.toString(i32)); + } + + public void _writeStringData(String s) throws TException { + try { + byte[] b = s.getBytes("UTF-8"); + trans_.write(b); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + public void writeI64(long i64) throws TException { + writeContext_.write(); + _writeStringData(Long.toString(i64)); + } + + public void writeDouble(double dub) throws TException { + writeContext_.write(); + _writeStringData(Double.toString(dub)); + } + + public void writeString(String str) throws TException { + writeContext_.write(); + int length = str.length(); + StringBuffer escape = new StringBuffer(length + 16); + escape.append(QUOTE); + for (int i = 0; i < length; ++i) { + char c = str.charAt(i); + switch (c) { + case '"': + case '\\': + escape.append('\\'); + escape.append(c); + break; + case '\b': + escape.append('\\'); + escape.append('b'); + break; + case '\f': + escape.append('\\'); + escape.append('f'); + break; + case '\n': + escape.append('\\'); + escape.append('n'); + break; + case '\r': + escape.append('\\'); + escape.append('r'); + break; + case '\t': + escape.append('\\'); + escape.append('t'); + break; + default: + // Control characeters! According to JSON RFC u0020 (space) + if (c < ' ') { + String hex = Integer.toHexString(c); + escape.append('\\'); + escape.append('u'); + for (int j = 4; j > hex.length(); --j) { + escape.append('0'); + } + escape.append(hex); + } else { + escape.append(c); + } + break; + } + } + escape.append(QUOTE); + _writeStringData(escape.toString()); + } + + public void writeBinary(byte[] bin) throws TException { + try { + // TODO(mcslee): Fix this + writeString(new String(bin, "UTF-8")); + } catch (UnsupportedEncodingException uex) { + throw new TException("JVM DOES NOT SUPPORT UTF-8"); + } + } + + /** + * Reading methods. + */ + + public TMessage readMessageBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_MESSAGE; + } + + public void readMessageEnd() {} + + public TStruct readStructBegin() { + // TODO(mcslee): implement + return ANONYMOUS_STRUCT; + } + + public void readStructEnd() {} + + public TField readFieldBegin() throws TException { + // TODO(mcslee): implement + return ANONYMOUS_FIELD; + } + + public void readFieldEnd() {} + + public TMap readMapBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_MAP; + } + + public void readMapEnd() {} + + public TList readListBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_LIST; + } + + public void readListEnd() {} + + public TSet readSetBegin() throws TException { + // TODO(mcslee): implement + return EMPTY_SET; + } + + public void readSetEnd() {} + + public boolean readBool() throws TException { + return (readByte() == 1); + } + + public byte readByte() throws TException { + // TODO(mcslee): implement + return 0; + } + + public short readI16() throws TException { + // TODO(mcslee): implement + return 0; + } + + public int readI32() throws TException { + // TODO(mcslee): implement + return 0; + } + + public long readI64() throws TException { + // TODO(mcslee): implement + return 0; + } + + public double readDouble() throws TException { + // TODO(mcslee): implement + return 0; + } + + public String readString() throws TException { + // TODO(mcslee): implement + return ""; + } + + public String readStringBody(int size) throws TException { + // TODO(mcslee): implement + return ""; + } + + public byte[] readBinary() throws TException { + // TODO(mcslee): implement + return new byte[0]; + } + +} diff --git a/lib/java/src/org/apache/thrift/protocol/TStruct.java b/lib/java/src/org/apache/thrift/protocol/TStruct.java new file mode 100644 index 00000000..a0f79012 --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TStruct.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Helper class that encapsulates struct metadata. + * + */ +public final class TStruct { + public TStruct() { + this(""); + } + + public TStruct(String n) { + name = n; + } + + public final String name; +} diff --git a/lib/java/src/org/apache/thrift/protocol/TType.java b/lib/java/src/org/apache/thrift/protocol/TType.java new file mode 100644 index 00000000..dbdc3caa --- /dev/null +++ b/lib/java/src/org/apache/thrift/protocol/TType.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.protocol; + +/** + * Type constants in the Thrift protocol. + * + */ +public final class TType { + public static final byte STOP = 0; + public static final byte VOID = 1; + public static final byte BOOL = 2; + public static final byte BYTE = 3; + public static final byte DOUBLE = 4; + public static final byte I16 = 6; + public static final byte I32 = 8; + public static final byte I64 = 10; + public static final byte STRING = 11; + public static final byte STRUCT = 12; + public static final byte MAP = 13; + public static final byte SET = 14; + public static final byte LIST = 15; +} diff --git a/lib/java/src/org/apache/thrift/server/THsHaServer.java b/lib/java/src/org/apache/thrift/server/THsHaServer.java new file mode 100644 index 00000000..8bf096ed --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/THsHaServer.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.server; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TNonblockingServerTransport; + +/** + * An extension of the TNonblockingServer to a Half-Sync/Half-Async server. + * Like TNonblockingServer, it relies on the use of TFramedTransport. + */ +public class THsHaServer extends TNonblockingServer { + + // This wraps all the functionality of queueing and thread pool management + // for the passing of Invocations from the Selector to workers. + private ExecutorService invoker; + + protected final int MIN_WORKER_THREADS; + protected final int MAX_WORKER_THREADS; + protected final int STOP_TIMEOUT_VAL; + protected final TimeUnit STOP_TIMEOUT_UNIT; + + /** + * Create server with given processor, and server transport. Default server + * options, TBinaryProtocol for the protocol, and TFramedTransport.Factory on + * both input and output transports. A TProcessorFactory will be created that + * always returns the specified processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport) { + this(processor, serverTransport, new Options()); + } + + /** + * Create server with given processor, server transport, and server options + * using TBinaryProtocol for the protocol, and TFramedTransport.Factory on + * both input and output transports. A TProcessorFactory will be created that + * always returns the specified processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + Options options) { + this(new TProcessorFactory(processor), serverTransport, options); + } + + /** + * Create server with specified processor factory and server transport. Uses + * default options. TBinaryProtocol is assumed. TFramedTransport.Factory is + * used on both input and output transports. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport) { + this(processorFactory, serverTransport, new Options()); + } + + /** + * Create server with specified processor factory, server transport, and server + * options. TBinaryProtocol is assumed. TFramedTransport.Factory is used on + * both input and output transports. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + Options options) { + this(processorFactory, serverTransport, new TFramedTransport.Factory(), + new TBinaryProtocol.Factory(), options); + } + + /** + * Server with specified processor, server transport, and in/out protocol + * factory. Defaults will be used for in/out transport factory and server + * options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, protocolFactory, new Options()); + } + + /** + * Server with specified processor, server transport, and in/out protocol + * factory. Defaults will be used for in/out transport factory and server + * options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory, + Options options) { + this(processor, serverTransport, new TFramedTransport.Factory(), + protocolFactory); + } + + /** + * Create server with specified processor, server transport, in/out + * transport factory, in/out protocol factory, and default server options. A + * processor factory will be created that always returns the specified + * processor. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + transportFactory, protocolFactory); + } + + /** + * Create server with specified processor factory, server transport, in/out + * transport factory, in/out protocol factory, and default server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, new Options()); + } + + /** + * Create server with specified processor factory, server transport, in/out + * transport factory, in/out protocol factory, and server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory, + Options options) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory, + options); + } + + /** + * Create server with everything specified, except use default server options. + */ + public THsHaServer( TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + /** + * Create server with everything specified, except use default server options. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) + { + this(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, new Options()); + } + + /** + * Create server with every option fully specified. + */ + public THsHaServer( TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) + { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + options); + + MIN_WORKER_THREADS = options.minWorkerThreads; + MAX_WORKER_THREADS = options.maxWorkerThreads; + STOP_TIMEOUT_VAL = options.stopTimeoutVal; + STOP_TIMEOUT_UNIT = options.stopTimeoutUnit; + } + + /** @inheritDoc */ + @Override + public void serve() { + if (!startInvokerPool()) { + return; + } + + // start listening, or exit + if (!startListening()) { + return; + } + + // start the selector, or exit + if (!startSelectorThread()) { + return; + } + + // this will block while we serve + joinSelector(); + + gracefullyShutdownInvokerPool(); + + // do a little cleanup + stopListening(); + + // ungracefully shut down the invoker pool? + } + + protected boolean startInvokerPool() { + // start the invoker pool + LinkedBlockingQueue queue = new LinkedBlockingQueue(); + invoker = new ThreadPoolExecutor(MIN_WORKER_THREADS, MAX_WORKER_THREADS, + STOP_TIMEOUT_VAL, STOP_TIMEOUT_UNIT, queue); + + return true; + } + + protected void gracefullyShutdownInvokerPool() { + // try to gracefully shut down the executor service + invoker.shutdown(); + + // Loop until awaitTermination finally does return without a interrupted + // exception. If we don't do this, then we'll shut down prematurely. We want + // to let the executorService clear it's task queue, closing client sockets + // appropriately. + long timeoutMS = 10000; + long now = System.currentTimeMillis(); + while (timeoutMS >= 0) { + try { + invoker.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS); + break; + } catch (InterruptedException ix) { + long newnow = System.currentTimeMillis(); + timeoutMS -= (newnow - now); + now = newnow; + } + } + } + + /** + * We override the standard invoke method here to queue the invocation for + * invoker service instead of immediately invoking. The thread pool takes care of the rest. + */ + @Override + protected void requestInvoke(FrameBuffer frameBuffer) { + invoker.execute(new Invocation(frameBuffer)); + } + + /** + * An Invocation represents a method call that is prepared to execute, given + * an idle worker thread. It contains the input and output protocols the + * thread's processor should use to perform the usual Thrift invocation. + */ + private class Invocation implements Runnable { + + private final FrameBuffer frameBuffer; + + public Invocation(final FrameBuffer frameBuffer) { + this.frameBuffer = frameBuffer; + } + + public void run() { + frameBuffer.invoke(); + } + } + + public static class Options extends TNonblockingServer.Options { + public int minWorkerThreads = 5; + public int maxWorkerThreads = Integer.MAX_VALUE; + public int stopTimeoutVal = 60; + public TimeUnit stopTimeoutUnit = TimeUnit.SECONDS; + } +} diff --git a/lib/java/src/org/apache/thrift/server/TNonblockingServer.java b/lib/java/src/org/apache/thrift/server/TNonblockingServer.java new file mode 100644 index 00000000..95d81e22 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TNonblockingServer.java @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.server; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + +import org.apache.log4j.Logger; + +import org.apache.thrift.TByteArrayOutputStream; +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TNonblockingServerTransport; +import org.apache.thrift.transport.TNonblockingTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; + +/** + * A nonblocking TServer implementation. This allows for fairness amongst all + * connected clients in terms of invocations. + * + * This server is inherently single-threaded. If you want a limited thread pool + * coupled with invocation-fairness, see THsHaServer. + * + * To use this server, you MUST use a TFramedTransport at the outermost + * transport, otherwise this server will be unable to determine when a whole + * method call has been read off the wire. Clients must also use TFramedTransport. + */ +public class TNonblockingServer extends TServer { + private static final Logger LOGGER = + Logger.getLogger(TNonblockingServer.class.getName()); + + // Flag for stopping the server + private volatile boolean stopped_; + + private SelectThread selectThread_; + + /** + * The maximum amount of memory we will allocate to client IO buffers at a + * time. Without this limit, the server will gladly allocate client buffers + * right into an out of memory exception, rather than waiting. + */ + private final long MAX_READ_BUFFER_BYTES; + + protected final Options options_; + + /** + * How many bytes are currently allocated to read buffers. + */ + private long readBufferBytesAllocated = 0; + + /** + * Create server with given processor and server transport, using + * TBinaryProtocol for the protocol, TFramedTransport.Factory on both input + * and output transports. A TProcessorFactory will be created that always + * returns the specified processor. + */ + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport) { + this(new TProcessorFactory(processor), serverTransport); + } + + /** + * Create server with specified processor factory and server transport. + * TBinaryProtocol is assumed. TFramedTransport.Factory is used on both input + * and output transports. + */ + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport) { + this(processorFactory, serverTransport, + new TFramedTransport.Factory(), new TFramedTransport.Factory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + new TFramedTransport.Factory(), new TFramedTransport.Factory(), + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TNonblockingServer(TProcessor processor, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + new Options()); + } + + public TNonblockingServer(TProcessorFactory processorFactory, + TNonblockingServerTransport serverTransport, + TFramedTransport.Factory inputTransportFactory, + TFramedTransport.Factory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + options_ = options; + options_.validate(); + MAX_READ_BUFFER_BYTES = options.maxReadBufferBytes; + } + + /** + * Begin accepting connections and processing invocations. + */ + public void serve() { + // start listening, or exit + if (!startListening()) { + return; + } + + // start the selector, or exit + if (!startSelectorThread()) { + return; + } + + // this will block while we serve + joinSelector(); + + // do a little cleanup + stopListening(); + } + + /** + * Have the server transport start accepting connections. + * + * @return true if we started listening successfully, false if something went + * wrong. + */ + protected boolean startListening() { + try { + serverTransport_.listen(); + return true; + } catch (TTransportException ttx) { + LOGGER.error("Failed to start listening on server socket!", ttx); + return false; + } + } + + /** + * Stop listening for conections. + */ + protected void stopListening() { + serverTransport_.close(); + } + + /** + * Start the selector thread running to deal with clients. + * + * @return true if everything went ok, false if we couldn't start for some + * reason. + */ + protected boolean startSelectorThread() { + // start the selector + try { + selectThread_ = new SelectThread((TNonblockingServerTransport)serverTransport_); + selectThread_.start(); + return true; + } catch (IOException e) { + LOGGER.error("Failed to start selector thread!", e); + return false; + } + } + + /** + * Block until the selector exits. + */ + protected void joinSelector() { + // wait until the selector thread exits + try { + selectThread_.join(); + } catch (InterruptedException e) { + // for now, just silently ignore. technically this means we'll have less of + // a graceful shutdown as a result. + } + } + + /** + * Stop serving and shut everything down. + */ + public void stop() { + stopped_ = true; + selectThread_.wakeupSelector(); + } + + /** + * Perform an invocation. This method could behave several different ways + * - invoke immediately inline, queue for separate execution, etc. + */ + protected void requestInvoke(FrameBuffer frameBuffer) { + frameBuffer.invoke(); + } + + /** + * A FrameBuffer wants to change its selection preferences, but might not be + * in the select thread. + */ + protected void requestSelectInterestChange(FrameBuffer frameBuffer) { + selectThread_.requestSelectInterestChange(frameBuffer); + } + + /** + * The thread that will be doing all the selecting, managing new connections + * and those that still need to be read. + */ + protected class SelectThread extends Thread { + + private final TNonblockingServerTransport serverTransport; + private final Selector selector; + + // List of FrameBuffers that want to change their selection interests. + private final Set selectInterestChanges = + new HashSet(); + + /** + * Set up the SelectorThread. + */ + public SelectThread(final TNonblockingServerTransport serverTransport) + throws IOException { + this.serverTransport = serverTransport; + this.selector = SelectorProvider.provider().openSelector(); + serverTransport.registerSelector(selector); + } + + /** + * The work loop. Handles both selecting (all IO operations) and managing + * the selection preferences of all existing connections. + */ + public void run() { + while (!stopped_) { + select(); + processInterestChanges(); + } + } + + /** + * If the selector is blocked, wake it up. + */ + public void wakeupSelector() { + selector.wakeup(); + } + + /** + * Add FrameBuffer to the list of select interest changes and wake up the + * selector if it's blocked. When the select() call exits, it'll give the + * FrameBuffer a chance to change its interests. + */ + public void requestSelectInterestChange(FrameBuffer frameBuffer) { + synchronized (selectInterestChanges) { + selectInterestChanges.add(frameBuffer); + } + // wakeup the selector, if it's currently blocked. + selector.wakeup(); + } + + /** + * Select and process IO events appropriately: + * If there are connections to be accepted, accept them. + * If there are existing connections with data waiting to be read, read it, + * bufferring until a whole frame has been read. + * If there are any pending responses, buffer them until their target client + * is available, and then send the data. + */ + private void select() { + try { + // wait for io events. + selector.select(); + + // process the io events we received + Iterator selectedKeys = selector.selectedKeys().iterator(); + while (!stopped_ && selectedKeys.hasNext()) { + SelectionKey key = selectedKeys.next(); + selectedKeys.remove(); + + // skip if not valid + if (!key.isValid()) { + cleanupSelectionkey(key); + continue; + } + + // if the key is marked Accept, then it has to be the server + // transport. + if (key.isAcceptable()) { + handleAccept(); + } else if (key.isReadable()) { + // deal with reads + handleRead(key); + } else if (key.isWritable()) { + // deal with writes + handleWrite(key); + } else { + LOGGER.warn("Unexpected state in select! " + key.interestOps()); + } + } + } catch (IOException e) { + LOGGER.warn("Got an IOException while selecting!", e); + } + } + + /** + * Check to see if there are any FrameBuffers that have switched their + * interest type from read to write or vice versa. + */ + private void processInterestChanges() { + synchronized (selectInterestChanges) { + for (FrameBuffer fb : selectInterestChanges) { + fb.changeSelectInterests(); + } + selectInterestChanges.clear(); + } + } + + /** + * Accept a new connection. + */ + private void handleAccept() throws IOException { + SelectionKey clientKey = null; + TNonblockingTransport client = null; + try { + // accept the connection + client = (TNonblockingTransport)serverTransport.accept(); + clientKey = client.registerSelector(selector, SelectionKey.OP_READ); + + // add this key to the map + FrameBuffer frameBuffer = new FrameBuffer(client, clientKey); + clientKey.attach(frameBuffer); + } catch (TTransportException tte) { + // something went wrong accepting. + LOGGER.warn("Exception trying to accept!", tte); + tte.printStackTrace(); + if (clientKey != null) cleanupSelectionkey(clientKey); + if (client != null) client.close(); + } + } + + /** + * Do the work required to read from a readable client. If the frame is + * fully read, then invoke the method call. + */ + private void handleRead(SelectionKey key) { + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (buffer.read()) { + // if the buffer's frame read is complete, invoke the method. + if (buffer.isFrameFullyRead()) { + requestInvoke(buffer); + } + } else { + cleanupSelectionkey(key); + } + } + + /** + * Let a writable client get written, if there's data to be written. + */ + private void handleWrite(SelectionKey key) { + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (!buffer.write()) { + cleanupSelectionkey(key); + } + } + + /** + * Do connection-close cleanup on a given SelectionKey. + */ + private void cleanupSelectionkey(SelectionKey key) { + // remove the records from the two maps + FrameBuffer buffer = (FrameBuffer)key.attachment(); + if (buffer != null) { + // close the buffer + buffer.close(); + } + // cancel the selection key + key.cancel(); + } + } // SelectorThread + + /** + * Class that implements a sort of state machine around the interaction with + * a client and an invoker. It manages reading the frame size and frame data, + * getting it handed off as wrapped transports, and then the writing of + * reponse data back to the client. In the process it manages flipping the + * read and write bits on the selection key for its client. + */ + protected class FrameBuffer { + // + // Possible states for the FrameBuffer state machine. + // + // in the midst of reading the frame size off the wire + private static final int READING_FRAME_SIZE = 1; + // reading the actual frame data now, but not all the way done yet + private static final int READING_FRAME = 2; + // completely read the frame, so an invocation can now happen + private static final int READ_FRAME_COMPLETE = 3; + // waiting to get switched to listening for write events + private static final int AWAITING_REGISTER_WRITE = 4; + // started writing response data, not fully complete yet + private static final int WRITING = 6; + // another thread wants this framebuffer to go back to reading + private static final int AWAITING_REGISTER_READ = 7; + // we want our transport and selection key invalidated in the selector thread + private static final int AWAITING_CLOSE = 8; + + // + // Instance variables + // + + // the actual transport hooked up to the client. + private final TNonblockingTransport trans_; + + // the SelectionKey that corresponds to our transport + private final SelectionKey selectionKey_; + + // where in the process of reading/writing are we? + private int state_ = READING_FRAME_SIZE; + + // the ByteBuffer we'll be using to write and read, depending on the state + private ByteBuffer buffer_; + + private TByteArrayOutputStream response_; + + public FrameBuffer( final TNonblockingTransport trans, + final SelectionKey selectionKey) { + trans_ = trans; + selectionKey_ = selectionKey; + buffer_ = ByteBuffer.allocate(4); + } + + /** + * Give this FrameBuffer a chance to read. The selector loop should have + * received a read event for this FrameBuffer. + * + * @return true if the connection should live on, false if it should be + * closed + */ + public boolean read() { + if (state_ == READING_FRAME_SIZE) { + // try to read the frame size completely + if (!internalRead()) { + return false; + } + + // if the frame size has been read completely, then prepare to read the + // actual frame. + if (buffer_.remaining() == 0) { + // pull out the frame size as an integer. + int frameSize = buffer_.getInt(0); + if (frameSize <= 0) { + LOGGER.error("Read an invalid frame size of " + frameSize + + ". Are you using TFramedTransport on the client side?"); + return false; + } + + // if this frame will always be too large for this server, log the + // error and close the connection. + if (frameSize + 4 > MAX_READ_BUFFER_BYTES) { + LOGGER.error("Read a frame size of " + frameSize + + ", which is bigger than the maximum allowable buffer size for ALL connections."); + return false; + } + + // if this frame will push us over the memory limit, then return. + // with luck, more memory will free up the next time around. + if (readBufferBytesAllocated + frameSize + 4 > MAX_READ_BUFFER_BYTES) { + return true; + } + + // incremement the amount of memory allocated to read buffers + readBufferBytesAllocated += frameSize + 4; + + // reallocate the readbuffer as a frame-sized buffer + buffer_ = ByteBuffer.allocate(frameSize + 4); + // put the frame size at the head of the buffer + buffer_.putInt(frameSize); + + state_ = READING_FRAME; + } else { + // this skips the check of READING_FRAME state below, since we can't + // possibly go on to that state if there's data left to be read at + // this one. + return true; + } + } + + // it is possible to fall through from the READING_FRAME_SIZE section + // to READING_FRAME if there's already some frame data available once + // READING_FRAME_SIZE is complete. + + if (state_ == READING_FRAME) { + if (!internalRead()) { + return false; + } + + // since we're already in the select loop here for sure, we can just + // modify our selection key directly. + if (buffer_.remaining() == 0) { + // get rid of the read select interests + selectionKey_.interestOps(0); + state_ = READ_FRAME_COMPLETE; + } + + return true; + } + + // if we fall through to this point, then the state must be invalid. + LOGGER.error("Read was called but state is invalid (" + state_ + ")"); + return false; + } + + /** + * Give this FrameBuffer a chance to write its output to the final client. + */ + public boolean write() { + if (state_ == WRITING) { + try { + if (trans_.write(buffer_) < 0) { + return false; + } + } catch (IOException e) { + LOGGER.warn("Got an IOException during write!", e); + return false; + } + + // we're done writing. now we need to switch back to reading. + if (buffer_.remaining() == 0) { + prepareRead(); + } + return true; + } + + LOGGER.error("Write was called, but state is invalid (" + state_ + ")"); + return false; + } + + /** + * Give this FrameBuffer a chance to set its interest to write, once data + * has come in. + */ + public void changeSelectInterests() { + if (state_ == AWAITING_REGISTER_WRITE) { + // set the OP_WRITE interest + selectionKey_.interestOps(SelectionKey.OP_WRITE); + state_ = WRITING; + } else if (state_ == AWAITING_REGISTER_READ) { + prepareRead(); + } else if (state_ == AWAITING_CLOSE){ + close(); + selectionKey_.cancel(); + } else { + LOGGER.error( + "changeSelectInterest was called, but state is invalid (" + + state_ + ")"); + } + } + + /** + * Shut the connection down. + */ + public void close() { + // if we're being closed due to an error, we might have allocated a + // buffer that we need to subtract for our memory accounting. + if (state_ == READING_FRAME || state_ == READ_FRAME_COMPLETE) { + readBufferBytesAllocated -= buffer_.array().length; + } + trans_.close(); + } + + /** + * Check if this FrameBuffer has a full frame read. + */ + public boolean isFrameFullyRead() { + return state_ == READ_FRAME_COMPLETE; + } + + /** + * After the processor has processed the invocation, whatever thread is + * managing invocations should call this method on this FrameBuffer so we + * know it's time to start trying to write again. Also, if it turns out + * that there actually isn't any data in the response buffer, we'll skip + * trying to write and instead go back to reading. + */ + public void responseReady() { + // the read buffer is definitely no longer in use, so we will decrement + // our read buffer count. we do this here as well as in close because + // we'd like to free this read memory up as quickly as possible for other + // clients. + readBufferBytesAllocated -= buffer_.array().length; + + if (response_.len() == 0) { + // go straight to reading again. this was probably an oneway method + state_ = AWAITING_REGISTER_READ; + buffer_ = null; + } else { + buffer_ = ByteBuffer.wrap(response_.get(), 0, response_.len()); + + // set state that we're waiting to be switched to write. we do this + // asynchronously through requestSelectInterestChange() because there is a + // possibility that we're not in the main thread, and thus currently + // blocked in select(). (this functionality is in place for the sake of + // the HsHa server.) + state_ = AWAITING_REGISTER_WRITE; + } + requestSelectInterestChange(); + } + + /** + * Actually invoke the method signified by this FrameBuffer. + */ + public void invoke() { + TTransport inTrans = getInputTransport(); + TProtocol inProt = inputProtocolFactory_.getProtocol(inTrans); + TProtocol outProt = outputProtocolFactory_.getProtocol(getOutputTransport()); + + try { + processorFactory_.getProcessor(inTrans).process(inProt, outProt); + responseReady(); + return; + } catch (TException te) { + LOGGER.warn("Exception while invoking!", te); + } catch (Exception e) { + LOGGER.error("Unexpected exception while invoking!", e); + } + // This will only be reached when there is an exception. + state_ = AWAITING_CLOSE; + requestSelectInterestChange(); + } + + /** + * Wrap the read buffer in a memory-based transport so a processor can read + * the data it needs to handle an invocation. + */ + private TTransport getInputTransport() { + return inputTransportFactory_.getTransport(new TIOStreamTransport( + new ByteArrayInputStream(buffer_.array()))); + } + + /** + * Get the transport that should be used by the invoker for responding. + */ + private TTransport getOutputTransport() { + response_ = new TByteArrayOutputStream(); + return outputTransportFactory_.getTransport(new TIOStreamTransport(response_)); + } + + /** + * Perform a read into buffer. + * + * @return true if the read succeeded, false if there was an error or the + * connection closed. + */ + private boolean internalRead() { + try { + if (trans_.read(buffer_) < 0) { + return false; + } + return true; + } catch (IOException e) { + LOGGER.warn("Got an IOException in internalRead!", e); + return false; + } + } + + /** + * We're done writing, so reset our interest ops and change state accordingly. + */ + private void prepareRead() { + // we can set our interest directly without using the queue because + // we're in the select thread. + selectionKey_.interestOps(SelectionKey.OP_READ); + // get ready for another go-around + buffer_ = ByteBuffer.allocate(4); + state_ = READING_FRAME_SIZE; + } + + /** + * When this FrameBuffer needs to change it's select interests and execution + * might not be in the select thread, then this method will make sure the + * interest change gets done when the select thread wakes back up. When the + * current thread is the select thread, then it just does the interest change + * immediately. + */ + private void requestSelectInterestChange() { + if (Thread.currentThread() == selectThread_) { + changeSelectInterests(); + } else { + TNonblockingServer.this.requestSelectInterestChange(this); + } + } + } // FrameBuffer + + + public static class Options { + public long maxReadBufferBytes = Long.MAX_VALUE; + + public Options() {} + + public void validate() { + if (maxReadBufferBytes <= 1024) { + throw new IllegalArgumentException("You must allocate at least 1KB to the read buffer."); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/server/TServer.java b/lib/java/src/org/apache/thrift/server/TServer.java new file mode 100644 index 00000000..eafe0c17 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TServer.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransportFactory; + +/** + * Generic interface for a Thrift server. + * + */ +public abstract class TServer { + + /** + * Core processor + */ + protected TProcessorFactory processorFactory_; + + /** + * Server transport + */ + protected TServerTransport serverTransport_; + + /** + * Input Transport Factory + */ + protected TTransportFactory inputTransportFactory_; + + /** + * Output Transport Factory + */ + protected TTransportFactory outputTransportFactory_; + + /** + * Input Protocol Factory + */ + protected TProtocolFactory inputProtocolFactory_; + + /** + * Output Protocol Factory + */ + protected TProtocolFactory outputProtocolFactory_; + + /** + * Default constructors. + */ + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + this(processorFactory, + serverTransport, + new TTransportFactory(), + new TTransportFactory(), + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory()); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory) { + this(processorFactory, + serverTransport, + transportFactory, + transportFactory, + new TBinaryProtocol.Factory(), + new TBinaryProtocol.Factory()); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, + serverTransport, + transportFactory, + transportFactory, + protocolFactory, + protocolFactory); + } + + protected TServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + processorFactory_ = processorFactory; + serverTransport_ = serverTransport; + inputTransportFactory_ = inputTransportFactory; + outputTransportFactory_ = outputTransportFactory; + inputProtocolFactory_ = inputProtocolFactory; + outputProtocolFactory_ = outputProtocolFactory; + } + + /** + * The run method fires up the server and gets things going. + */ + public abstract void serve(); + + /** + * Stop the server. This is optional on a per-implementation basis. Not + * all servers are required to be cleanly stoppable. + */ + public void stop() {} + +} diff --git a/lib/java/src/org/apache/thrift/server/TSimpleServer.java b/lib/java/src/org/apache/thrift/server/TSimpleServer.java new file mode 100644 index 00000000..b3ee5ad6 --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TSimpleServer.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportFactory; +import org.apache.thrift.transport.TTransportException; +import org.apache.log4j.Logger; + +/** + * Simple singlethreaded server for testing. + * + */ +public class TSimpleServer extends TServer { + + private static final Logger LOGGER = Logger.getLogger(TSimpleServer.class.getName()); + + private boolean stopped_ = false; + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport) { + super(new TProcessorFactory(processor), serverTransport); + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + super(new TProcessorFactory(processor), serverTransport, transportFactory, protocolFactory); + } + + public TSimpleServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + super(processorFactory, serverTransport); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + super(processorFactory, serverTransport, transportFactory, protocolFactory); + } + + public TSimpleServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + + public void serve() { + stopped_ = false; + try { + serverTransport_.listen(); + } catch (TTransportException ttx) { + LOGGER.error("Error occurred during listening.", ttx); + return; + } + + while (!stopped_) { + TTransport client = null; + TProcessor processor = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try { + client = serverTransport_.accept(); + if (client != null) { + processor = processorFactory_.getProcessor(client); + inputTransport = inputTransportFactory_.getTransport(client); + outputTransport = outputTransportFactory_.getTransport(client); + inputProtocol = inputProtocolFactory_.getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_.getProtocol(outputTransport); + while (processor.process(inputProtocol, outputProtocol)) {} + } + } catch (TTransportException ttx) { + // Client died, just move on + } catch (TException tx) { + if (!stopped_) { + LOGGER.error("Thrift error occurred during processing of message.", tx); + } + } catch (Exception x) { + if (!stopped_) { + LOGGER.error("Error occurred during processing of message.", x); + } + } + + if (inputTransport != null) { + inputTransport.close(); + } + + if (outputTransport != null) { + outputTransport.close(); + } + + } + } + + public void stop() { + stopped_ = true; + serverTransport_.interrupt(); + } +} diff --git a/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java new file mode 100644 index 00000000..ebc5a9be --- /dev/null +++ b/lib/java/src/org/apache/thrift/server/TThreadPoolServer.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.server; + +import org.apache.thrift.TException; +import org.apache.thrift.TProcessor; +import org.apache.thrift.TProcessorFactory; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.transport.TTransportFactory; +import org.apache.log4j.Logger; +import org.apache.log4j.Level; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + + +/** + * Server which uses Java's built in ThreadPool management to spawn off + * a worker pool that + * + */ +public class TThreadPoolServer extends TServer { + + private static final Logger LOGGER = Logger.getLogger(TThreadPoolServer.class.getName()); + + // Executor service for handling client connections + private ExecutorService executorService_; + + // Flag for stopping the server + private volatile boolean stopped_; + + // Server options + private Options options_; + + // Customizable server options + public static class Options { + public int minWorkerThreads = 5; + public int maxWorkerThreads = Integer.MAX_VALUE; + public int stopTimeoutVal = 60; + public TimeUnit stopTimeoutUnit = TimeUnit.SECONDS; + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport) { + this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport) { + this(processorFactory, serverTransport, + new TTransportFactory(), new TTransportFactory(), + new TBinaryProtocol.Factory(), new TBinaryProtocol.Factory()); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + new TTransportFactory(), new TTransportFactory(), + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processor, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory transportFactory, + TProtocolFactory protocolFactory) { + this(processorFactory, serverTransport, + transportFactory, transportFactory, + protocolFactory, protocolFactory); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + options_ = new Options(); + executorService_ = Executors.newCachedThreadPool(); + } + + public TThreadPoolServer(TProcessor processor, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + this(new TProcessorFactory(processor), serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory, + options); + } + + public TThreadPoolServer(TProcessorFactory processorFactory, + TServerTransport serverTransport, + TTransportFactory inputTransportFactory, + TTransportFactory outputTransportFactory, + TProtocolFactory inputProtocolFactory, + TProtocolFactory outputProtocolFactory, + Options options) { + super(processorFactory, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory); + + executorService_ = null; + + SynchronousQueue executorQueue = + new SynchronousQueue(); + + executorService_ = new ThreadPoolExecutor(options.minWorkerThreads, + options.maxWorkerThreads, + 60, + TimeUnit.SECONDS, + executorQueue); + + options_ = options; + } + + + public void serve() { + try { + serverTransport_.listen(); + } catch (TTransportException ttx) { + LOGGER.error("Error occurred during listening.", ttx); + return; + } + + stopped_ = false; + while (!stopped_) { + int failureCount = 0; + try { + TTransport client = serverTransport_.accept(); + WorkerProcess wp = new WorkerProcess(client); + executorService_.execute(wp); + } catch (TTransportException ttx) { + if (!stopped_) { + ++failureCount; + LOGGER.warn("Transport error occurred during acceptance of message.", ttx); + } + } + } + + executorService_.shutdown(); + + // Loop until awaitTermination finally does return without a interrupted + // exception. If we don't do this, then we'll shut down prematurely. We want + // to let the executorService clear it's task queue, closing client sockets + // appropriately. + long timeoutMS = options_.stopTimeoutUnit.toMillis(options_.stopTimeoutVal); + long now = System.currentTimeMillis(); + while (timeoutMS >= 0) { + try { + executorService_.awaitTermination(timeoutMS, TimeUnit.MILLISECONDS); + break; + } catch (InterruptedException ix) { + long newnow = System.currentTimeMillis(); + timeoutMS -= (newnow - now); + now = newnow; + } + } + } + + public void stop() { + stopped_ = true; + serverTransport_.interrupt(); + } + + private class WorkerProcess implements Runnable { + + /** + * Client that this services. + */ + private TTransport client_; + + /** + * Default constructor. + * + * @param client Transport to process + */ + private WorkerProcess(TTransport client) { + client_ = client; + } + + /** + * Loops on processing a client forever + */ + public void run() { + TProcessor processor = null; + TTransport inputTransport = null; + TTransport outputTransport = null; + TProtocol inputProtocol = null; + TProtocol outputProtocol = null; + try { + processor = processorFactory_.getProcessor(client_); + inputTransport = inputTransportFactory_.getTransport(client_); + outputTransport = outputTransportFactory_.getTransport(client_); + inputProtocol = inputProtocolFactory_.getProtocol(inputTransport); + outputProtocol = outputProtocolFactory_.getProtocol(outputTransport); + // we check stopped_ first to make sure we're not supposed to be shutting + // down. this is necessary for graceful shutdown. + while (!stopped_ && processor.process(inputProtocol, outputProtocol)) {} + } catch (TTransportException ttx) { + // Assume the client died and continue silently + } catch (TException tx) { + LOGGER.error("Thrift error occurred during processing of message.", tx); + } catch (Exception x) { + LOGGER.error("Error occurred during processing of message.", x); + } + + if (inputTransport != null) { + inputTransport.close(); + } + + if (outputTransport != null) { + outputTransport.close(); + } + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TFramedTransport.java b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java new file mode 100644 index 00000000..c83748ad --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TFramedTransport.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.ByteArrayInputStream; + +import org.apache.thrift.TByteArrayOutputStream; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + * + */ +public class TFramedTransport extends TTransport { + + /** + * Underlying transport + */ + private TTransport transport_ = null; + + /** + * Buffer for output + */ + private final TByteArrayOutputStream writeBuffer_ = + new TByteArrayOutputStream(1024); + + /** + * Buffer for input + */ + private ByteArrayInputStream readBuffer_ = null; + + public static class Factory extends TTransportFactory { + public Factory() { + } + + public TTransport getTransport(TTransport base) { + return new TFramedTransport(base); + } + } + + /** + * Constructor wraps around another tranpsort + */ + public TFramedTransport(TTransport transport) { + transport_ = transport; + } + + public void open() throws TTransportException { + transport_.open(); + } + + public boolean isOpen() { + return transport_.isOpen(); + } + + public void close() { + transport_.close(); + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + if (readBuffer_ != null) { + int got = readBuffer_.read(buf, off, len); + if (got > 0) { + return got; + } + } + + // Read another frame of data + readFrame(); + + return readBuffer_.read(buf, off, len); + } + + private void readFrame() throws TTransportException { + byte[] i32rd = new byte[4]; + transport_.readAll(i32rd, 0, 4); + int size = + ((i32rd[0] & 0xff) << 24) | + ((i32rd[1] & 0xff) << 16) | + ((i32rd[2] & 0xff) << 8) | + ((i32rd[3] & 0xff)); + + byte[] buff = new byte[size]; + transport_.readAll(buff, 0, size); + readBuffer_ = new ByteArrayInputStream(buff); + } + + public void write(byte[] buf, int off, int len) throws TTransportException { + writeBuffer_.write(buf, off, len); + } + + public void flush() throws TTransportException { + byte[] buf = writeBuffer_.get(); + int len = writeBuffer_.len(); + writeBuffer_.reset(); + + byte[] i32out = new byte[4]; + i32out[0] = (byte)(0xff & (len >> 24)); + i32out[1] = (byte)(0xff & (len >> 16)); + i32out[2] = (byte)(0xff & (len >> 8)); + i32out[3] = (byte)(0xff & (len)); + transport_.write(i32out, 0, 4); + transport_.write(buf, 0, len); + transport_.flush(); + } +} diff --git a/lib/java/src/org/apache/thrift/transport/THttpClient.java b/lib/java/src/org/apache/thrift/transport/THttpClient.java new file mode 100644 index 00000000..41923531 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/THttpClient.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.io.IOException; + +import java.net.URL; +import java.net.HttpURLConnection; +import java.util.HashMap; +import java.util.Map; + +/** + * HTTP implementation of the TTransport interface. Used for working with a + * Thrift web services implementation. + * + */ +public class THttpClient extends TTransport { + + private URL url_ = null; + + private final ByteArrayOutputStream requestBuffer_ = + new ByteArrayOutputStream(); + + private InputStream inputStream_ = null; + + private int connectTimeout_ = 0; + + private int readTimeout_ = 0; + + private Map customHeaders_ = null; + + public THttpClient(String url) throws TTransportException { + try { + url_ = new URL(url); + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void setConnectTimeout(int timeout) { + connectTimeout_ = timeout; + } + + public void setReadTimeout(int timeout) { + readTimeout_ = timeout; + } + + public void setCustomHeaders(Map headers) { + customHeaders_ = headers; + } + + public void setCustomHeader(String key, String value) { + if (customHeaders_ == null) { + customHeaders_ = new HashMap(); + } + customHeaders_.put(key, value); + } + + public void open() {} + + public void close() { + if (null != inputStream_) { + try { + inputStream_.close(); + } catch (IOException ioe) { + ; + } + inputStream_ = null; + } + } + + public boolean isOpen() { + return true; + } + + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException("Response buffer is empty, no request."); + } + try { + int ret = inputStream_.read(buf, off, len); + if (ret == -1) { + throw new TTransportException("No more data available."); + } + return ret; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void write(byte[] buf, int off, int len) { + requestBuffer_.write(buf, off, len); + } + + public void flush() throws TTransportException { + // Extract request and reset buffer + byte[] data = requestBuffer_.toByteArray(); + requestBuffer_.reset(); + + try { + // Create connection object + HttpURLConnection connection = (HttpURLConnection)url_.openConnection(); + + // Timeouts, only if explicitly set + if (connectTimeout_ > 0) { + connection.setConnectTimeout(connectTimeout_); + } + if (readTimeout_ > 0) { + connection.setReadTimeout(readTimeout_); + } + + // Make the request + connection.setRequestMethod("POST"); + connection.setRequestProperty("Content-Type", "application/x-thrift"); + connection.setRequestProperty("Accept", "application/x-thrift"); + connection.setRequestProperty("User-Agent", "Java/THttpClient"); + if (customHeaders_ != null) { + for (Map.Entry header : customHeaders_.entrySet()) { + connection.setRequestProperty(header.getKey(), header.getValue()); + } + } + connection.setDoOutput(true); + connection.connect(); + connection.getOutputStream().write(data); + + int responseCode = connection.getResponseCode(); + if (responseCode != HttpURLConnection.HTTP_OK) { + throw new TTransportException("HTTP Response code: " + responseCode); + } + + // Read the responses + inputStream_ = connection.getInputStream(); + + } catch (IOException iox) { + throw new TTransportException(iox); + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java new file mode 100644 index 00000000..89cdb582 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + +/** + * This is the most commonly used base transport. It takes an InputStream + * and an OutputStream and uses those to perform all transport operations. + * This allows for compatibility with all the nice constructs Java already + * has to provide a variety of types of streams. + * + */ +public class TIOStreamTransport extends TTransport { + + private static final Logger LOGGER = Logger.getLogger(TIOStreamTransport.class.getName()); + + /** Underlying inputStream */ + protected InputStream inputStream_ = null; + + /** Underlying outputStream */ + protected OutputStream outputStream_ = null; + + /** + * Subclasses can invoke the default constructor and then assign the input + * streams in the open method. + */ + protected TIOStreamTransport() {} + + /** + * Input stream constructor. + * + * @param is Input stream to read from + */ + public TIOStreamTransport(InputStream is) { + inputStream_ = is; + } + + /** + * Output stream constructor. + * + * @param os Output stream to read from + */ + public TIOStreamTransport(OutputStream os) { + outputStream_ = os; + } + + /** + * Two-way stream constructor. + * + * @param is Input stream to read from + * @param os Output stream to read from + */ + public TIOStreamTransport(InputStream is, OutputStream os) { + inputStream_ = is; + outputStream_ = os; + } + + /** + * The streams must already be open at construction time, so this should + * always return true. + * + * @return true + */ + public boolean isOpen() { + return true; + } + + /** + * The streams must already be open. This method does nothing. + */ + public void open() throws TTransportException {} + + /** + * Closes both the input and output streams. + */ + public void close() { + if (inputStream_ != null) { + try { + inputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing input stream.", iox); + } + inputStream_ = null; + } + if (outputStream_ != null) { + try { + outputStream_.close(); + } catch (IOException iox) { + LOGGER.warn("Error closing output stream.", iox); + } + outputStream_ = null; + } + } + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if (inputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot read from null inputStream"); + } + try { + return inputStream_.read(buf, off, len); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot write to null outputStream"); + } + try { + outputStream_.write(buf, off, len); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Flushes the underlying output stream if not null. + */ + public void flush() throws TTransportException { + if (outputStream_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot flush null outputStream"); + } + try { + outputStream_.flush(); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java new file mode 100644 index 00000000..886fcbf6 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TMemoryBuffer.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TByteArrayOutputStream; +import java.io.UnsupportedEncodingException; + +/** + * Memory buffer-based implementation of the TTransport interface. + * + */ +public class TMemoryBuffer extends TTransport { + + /** + * + */ + public TMemoryBuffer(int size) { + arr_ = new TByteArrayOutputStream(size); + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void open() { + /* Do nothing */ + } + + @Override + public void close() { + /* Do nothing */ + } + + @Override + public int read(byte[] buf, int off, int len) { + byte[] src = arr_.get(); + int amtToRead = (len > arr_.len() - pos_ ? arr_.len() - pos_ : len); + if (amtToRead > 0) { + System.arraycopy(src, pos_, buf, off, amtToRead); + pos_ += amtToRead; + } + return amtToRead; + } + + @Override + public void write(byte[] buf, int off, int len) { + arr_.write(buf, off, len); + } + + /** + * Output the contents of the memory buffer as a String, using the supplied + * encoding + * @param enc the encoding to use + * @return the contents of the memory buffer as a String + */ + public String toString(String enc) throws UnsupportedEncodingException { + return arr_.toString(enc); + } + + public String inspect() { + String buf = ""; + byte[] bytes = arr_.toByteArray(); + for (int i = 0; i < bytes.length; i++) { + buf += (pos_ == i ? "==>" : "" ) + Integer.toHexString(bytes[i] & 0xff) + " "; + } + return buf; + } + + // The contents of the buffer + private TByteArrayOutputStream arr_; + + // Position to read next byte from + private int pos_; + + public int length() { + return arr_.size(); + } +} + diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java new file mode 100644 index 00000000..571adbff --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.SocketException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +/** + * Wrapper around ServerSocketChannel + */ +public class TNonblockingServerSocket extends TNonblockingServerTransport { + + /** + * This channel is where all the nonblocking magic happens. + */ + private ServerSocketChannel serverSocketChannel = null; + + /** + * Underlying serversocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Port to listen on + */ + private int port_ = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + /** + * Creates a server socket from underlying socket object + */ + // public TNonblockingServerSocket(ServerSocket serverSocket) { + // this(serverSocket, 0); + // } + + /** + * Creates a server socket from underlying socket object + */ + // public TNonblockingServerSocket(ServerSocket serverSocket, int clientTimeout) { + // serverSocket_ = serverSocket; + // clientTimeout_ = clientTimeout; + // } + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TNonblockingServerSocket(int port, int clientTimeout) throws TTransportException { + port_ = port; + clientTimeout_ = clientTimeout; + try { + serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + + // Make server socket + serverSocket_ = serverSocketChannel.socket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(new InetSocketAddress(port_)); + } catch (IOException ioe) { + serverSocket_ = null; + throw new TTransportException("Could not create ServerSocket on port " + port + "."); + } + } + + public void listen() throws TTransportException { + // Make sure not to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + } + + protected TNonblockingSocket acceptImpl() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + try { + SocketChannel socketChannel = serverSocketChannel.accept(); + if (socketChannel == null) { + return null; + } + + TNonblockingSocket tsocket = new TNonblockingSocket(socketChannel); + tsocket.setTimeout(clientTimeout_); + return tsocket; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void registerSelector(Selector selector) { + try { + // Register the server socket channel, indicating an interest in + // accepting new connections + serverSocketChannel.register(selector, SelectionKey.OP_ACCEPT); + } catch (ClosedChannelException e) { + // this shouldn't happen, ideally... + // TODO: decide what to do with this. + } + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + System.err.println("WARNING: Could not close server socket: " + + iox.getMessage()); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java new file mode 100644 index 00000000..ba45b09d --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.nio.channels.Selector; + +/** + * Server transport that can be operated in a nonblocking fashion. + */ +public abstract class TNonblockingServerTransport extends TServerTransport { + + public abstract void registerSelector(Selector selector); +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java new file mode 100644 index 00000000..bc2d5396 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + */ +public class TNonblockingSocket extends TNonblockingTransport { + + private SocketChannel socketChannel = null; + + /** + * Wrapped Socket object + */ + private Socket socket_ = null; + + /** + * Remote host + */ + private String host_ = null; + + /** + * Remote port + */ + private int port_ = 0; + + /** + * Socket timeout + */ + private int timeout_ = 0; + + /** + * Constructor that takes an already created socket. + * + * @param socketChannel Already created SocketChannel object + * @throws TTransportException if there is an error setting up the streams + */ + public TNonblockingSocket(SocketChannel socketChannel) throws TTransportException { + try { + // make it a nonblocking channel + socketChannel.configureBlocking(false); + } catch (IOException e) { + throw new TTransportException(e); + } + + this.socketChannel = socketChannel; + this.socket_ = socketChannel.socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Register this socket with the specified selector for both read and write + * operations. + * + * @param selector + * @return the selection key for this socket. + */ + public SelectionKey registerSelector(Selector selector, int interests) throws IOException { + // Register the new SocketChannel with our Selector, indicating + // we'd like to be notified when there's data waiting to be read + return socketChannel.register(selector, interests); + } + + /** + * Initializes the socket object + */ + private void initSocket() { + socket_ = new Socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setSoTimeout(timeout_); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Sets the socket timeout + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + timeout_ = timeout; + try { + socket_.setSoTimeout(timeout); + } catch (SocketException sx) { + sx.printStackTrace(); + } + } + + /** + * Returns a reference to the underlying socket. + */ + public Socket getSocket() { + if (socket_ == null) { + initSocket(); + } + return socket_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + if (socket_ == null) { + return false; + } + return socket_.isConnected(); + } + + /** + * Connects the socket, creating a new socket object if necessary. + */ + public void open() throws TTransportException { + throw new RuntimeException("Not implemented yet"); + } + + /** + * Perform a nonblocking read into buffer. + */ + public int read(ByteBuffer buffer) throws IOException { + return socketChannel.read(buffer); + } + + + /** + * Reads from the underlying input stream if not null. + */ + public int read(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel.validOps() & SelectionKey.OP_READ) != SelectionKey.OP_READ) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot read from write-only socket channel"); + } + try { + return socketChannel.read(ByteBuffer.wrap(buf, off, len)); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Perform a nonblocking write of the data in buffer; + */ + public int write(ByteBuffer buffer) throws IOException { + return socketChannel.write(buffer); + } + + /** + * Writes to the underlying output stream if not null. + */ + public void write(byte[] buf, int off, int len) throws TTransportException { + if ((socketChannel.validOps() & SelectionKey.OP_WRITE) != SelectionKey.OP_WRITE) { + throw new TTransportException(TTransportException.NOT_OPEN, + "Cannot write to write-only socket channel"); + } + try { + socketChannel.write(ByteBuffer.wrap(buf, off, len)); + } catch (IOException iox) { + throw new TTransportException(TTransportException.UNKNOWN, iox); + } + } + + /** + * Flushes the underlying output stream if not null. + */ + public void flush() throws TTransportException { + // Not supported by SocketChannel. + } + + /** + * Closes the socket. + */ + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + // silently ignore. + } + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java new file mode 100644 index 00000000..517eacb7 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TNonblockingTransport.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.nio.channels.Selector; +import java.nio.channels.SelectionKey; +import java.nio.ByteBuffer; + +public abstract class TNonblockingTransport extends TTransport { + public abstract SelectionKey registerSelector(Selector selector, int interests) throws IOException; + public abstract int read(ByteBuffer buffer) throws IOException; + public abstract int write(ByteBuffer buffer) throws IOException; +} diff --git a/lib/java/src/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/org/apache/thrift/transport/TServerSocket.java new file mode 100644 index 00000000..796cd659 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TServerSocket.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; + +/** + * Wrapper around ServerSocket for Thrift. + * + */ +public class TServerSocket extends TServerTransport { + + private static final Logger LOGGER = Logger.getLogger(TServerSocket.class.getName()); + + /** + * Underlying serversocket object + */ + private ServerSocket serverSocket_ = null; + + /** + * Port to listen on + */ + private int port_ = 0; + + /** + * Timeout for client sockets from accept + */ + private int clientTimeout_ = 0; + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket) { + this(serverSocket, 0); + } + + /** + * Creates a server socket from underlying socket object + */ + public TServerSocket(ServerSocket serverSocket, int clientTimeout) { + serverSocket_ = serverSocket; + clientTimeout_ = clientTimeout; + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port) throws TTransportException { + this(port, 0); + } + + /** + * Creates just a port listening server socket + */ + public TServerSocket(int port, int clientTimeout) throws TTransportException { + this(new InetSocketAddress(port), clientTimeout); + port_ = port; + } + + public TServerSocket(InetSocketAddress bindAddr) throws TTransportException { + this(bindAddr, 0); + } + + public TServerSocket(InetSocketAddress bindAddr, int clientTimeout) throws TTransportException { + clientTimeout_ = clientTimeout; + try { + // Make server socket + serverSocket_ = new ServerSocket(); + // Prevent 2MSL delay problem on server restarts + serverSocket_.setReuseAddress(true); + // Bind to listening port + serverSocket_.bind(bindAddr); + } catch (IOException ioe) { + serverSocket_ = null; + throw new TTransportException("Could not create ServerSocket on address " + bindAddr.toString() + "."); + } + } + + public void listen() throws TTransportException { + // Make sure not to block on accept + if (serverSocket_ != null) { + try { + serverSocket_.setSoTimeout(0); + } catch (SocketException sx) { + LOGGER.error("Could not set socket timeout.", sx); + } + } + } + + protected TSocket acceptImpl() throws TTransportException { + if (serverSocket_ == null) { + throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket."); + } + try { + Socket result = serverSocket_.accept(); + TSocket result2 = new TSocket(result); + result2.setTimeout(clientTimeout_); + return result2; + } catch (IOException iox) { + throw new TTransportException(iox); + } + } + + public void close() { + if (serverSocket_ != null) { + try { + serverSocket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close server socket.", iox); + } + serverSocket_ = null; + } + } + + public void interrupt() { + // The thread-safeness of this is dubious, but Java documentation suggests + // that it is safe to do this from a different thread context + close(); + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/org/apache/thrift/transport/TServerTransport.java new file mode 100644 index 00000000..17ff86be --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TServerTransport.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Server transport. Object which provides client transports. + * + */ +public abstract class TServerTransport { + + public abstract void listen() throws TTransportException; + + public final TTransport accept() throws TTransportException { + TTransport transport = acceptImpl(); + if (transport == null) { + throw new TTransportException("accept() may not return NULL"); + } + return transport; + } + + public abstract void close(); + + protected abstract TTransport acceptImpl() throws TTransportException; + + /** + * Optional method implementation. This signals to the server transport + * that it should break out of any accept() or listen() that it is currently + * blocked on. This method, if implemented, MUST be thread safe, as it may + * be called from a different thread context than the other TServerTransport + * methods. + */ + public void interrupt() {} + +} diff --git a/lib/java/src/org/apache/thrift/transport/TSocket.java b/lib/java/src/org/apache/thrift/transport/TSocket.java new file mode 100644 index 00000000..cdf1bcc4 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TSocket.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.log4j.Logger; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketException; + +/** + * Socket implementation of the TTransport interface. To be commented soon! + * + */ +public class TSocket extends TIOStreamTransport { + + private static final Logger LOGGER = Logger.getLogger(TSocket.class.getName()); + + /** + * Wrapped Socket object + */ + private Socket socket_ = null; + + /** + * Remote host + */ + private String host_ = null; + + /** + * Remote port + */ + private int port_ = 0; + + /** + * Socket timeout + */ + private int timeout_ = 0; + + /** + * Constructor that takes an already created socket. + * + * @param socket Already created socket object + * @throws TTransportException if there is an error setting up the streams + */ + public TSocket(Socket socket) throws TTransportException { + socket_ = socket; + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + } catch (SocketException sx) { + LOGGER.warn("Could not configure socket.", sx); + } + + if (isOpen()) { + try { + inputStream_ = new BufferedInputStream(socket_.getInputStream(), 1024); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream(), 1024); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + */ + public TSocket(String host, int port) { + this(host, port, 0); + } + + /** + * Creates a new unconnected socket that will connect to the given host + * on the given port. + * + * @param host Remote host + * @param port Remote port + * @param timeout Socket timeout + */ + public TSocket(String host, int port, int timeout) { + host_ = host; + port_ = port; + timeout_ = timeout; + initSocket(); + } + + /** + * Initializes the socket object + */ + private void initSocket() { + socket_ = new Socket(); + try { + socket_.setSoLinger(false, 0); + socket_.setTcpNoDelay(true); + socket_.setSoTimeout(timeout_); + } catch (SocketException sx) { + LOGGER.error("Could not configure socket.", sx); + } + } + + /** + * Sets the socket timeout + * + * @param timeout Milliseconds timeout + */ + public void setTimeout(int timeout) { + timeout_ = timeout; + try { + socket_.setSoTimeout(timeout); + } catch (SocketException sx) { + LOGGER.warn("Could not set socket timeout.", sx); + } + } + + /** + * Returns a reference to the underlying socket. + */ + public Socket getSocket() { + if (socket_ == null) { + initSocket(); + } + return socket_; + } + + /** + * Checks whether the socket is connected. + */ + public boolean isOpen() { + if (socket_ == null) { + return false; + } + return socket_.isConnected(); + } + + /** + * Connects the socket, creating a new socket object if necessary. + */ + public void open() throws TTransportException { + if (isOpen()) { + throw new TTransportException(TTransportException.ALREADY_OPEN, "Socket already connected."); + } + + if (host_.length() == 0) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot open null host."); + } + if (port_ <= 0) { + throw new TTransportException(TTransportException.NOT_OPEN, "Cannot open without port."); + } + + if (socket_ == null) { + initSocket(); + } + + try { + socket_.connect(new InetSocketAddress(host_, port_)); + inputStream_ = new BufferedInputStream(socket_.getInputStream(), 1024); + outputStream_ = new BufferedOutputStream(socket_.getOutputStream(), 1024); + } catch (IOException iox) { + close(); + throw new TTransportException(TTransportException.NOT_OPEN, iox); + } + } + + /** + * Closes the socket. + */ + public void close() { + // Close the underlying streams + super.close(); + + // Close the socket + if (socket_ != null) { + try { + socket_.close(); + } catch (IOException iox) { + LOGGER.warn("Could not close socket.", iox); + } + socket_ = null; + } + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransport.java b/lib/java/src/org/apache/thrift/transport/TTransport.java new file mode 100644 index 00000000..a6c047bb --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransport.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Generic class that encapsulates the I/O layer. This is basically a thin + * wrapper around the combined functionality of Java input/output streams. + * + */ +public abstract class TTransport { + + /** + * Queries whether the transport is open. + * + * @return True if the transport is open. + */ + public abstract boolean isOpen(); + + /** + * Is there more data to be read? + * + * @return True if the remote side is still alive and feeding us + */ + public boolean peek() { + return isOpen(); + } + + /** + * Opens the transport for reading/writing. + * + * @throws TTransportException if the transport could not be opened + */ + public abstract void open() + throws TTransportException; + + /** + * Closes the transport. + */ + public abstract void close(); + + /** + * Reads up to len bytes into buffer buf, starting att offset off. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read + * @throws TTransportException if there was an error reading data + */ + public abstract int read(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Guarantees that all of len bytes are actually read off the transport. + * + * @param buf Array to read into + * @param off Index to start reading at + * @param len Maximum number of bytes to read + * @return The number of bytes actually read, which must be equal to len + * @throws TTransportException if there was an error reading data + */ + public int readAll(byte[] buf, int off, int len) + throws TTransportException { + int got = 0; + int ret = 0; + while (got < len) { + ret = read(buf, off+got, len-got); + if (ret <= 0) { + throw new TTransportException("Cannot read. Remote side has closed. Tried to read " + len + " bytes, but only got " + got + " bytes."); + } + got += ret; + } + return got; + } + + /** + * Writes the buffer to the output + * + * @param buf The output data buffer + * @throws TTransportException if an error occurs writing data + */ + public void write(byte[] buf) throws TTransportException { + write(buf, 0, buf.length); + } + + /** + * Writes up to len bytes from the buffer. + * + * @param buf The output data buffer + * @param off The offset to start writing from + * @param len The number of bytes to write + * @throws TTransportException if there was an error writing data + */ + public abstract void write(byte[] buf, int off, int len) + throws TTransportException; + + /** + * Flush any pending data out of a transport buffer. + * + * @throws TTransportException if there was an error writing out data. + */ + public void flush() + throws TTransportException {} +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransportException.java b/lib/java/src/org/apache/thrift/transport/TTransportException.java new file mode 100644 index 00000000..d08f3b02 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransportException.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import org.apache.thrift.TException; + +/** + * Transport exceptions. + * + */ +public class TTransportException extends TException { + + private static final long serialVersionUID = 1L; + + public static final int UNKNOWN = 0; + public static final int NOT_OPEN = 1; + public static final int ALREADY_OPEN = 2; + public static final int TIMED_OUT = 3; + public static final int END_OF_FILE = 4; + + protected int type_ = UNKNOWN; + + public TTransportException() { + super(); + } + + public TTransportException(int type) { + super(); + type_ = type; + } + + public TTransportException(int type, String message) { + super(message); + type_ = type; + } + + public TTransportException(String message) { + super(message); + } + + public TTransportException(int type, Throwable cause) { + super(cause); + type_ = type; + } + + public TTransportException(Throwable cause) { + super(cause); + } + + public TTransportException(String message, Throwable cause) { + super(message, cause); + } + + public TTransportException(int type, String message, Throwable cause) { + super(message, cause); + type_ = type; + } + + public int getType() { + return type_; + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TTransportFactory.java b/lib/java/src/org/apache/thrift/transport/TTransportFactory.java new file mode 100644 index 00000000..3e71630a --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TTransportFactory.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +/** + * Factory class used to create wrapped instance of Transports. + * This is used primarily in servers, which get Transports from + * a ServerTransport and then may want to mutate them (i.e. create + * a BufferedTransport from the underlying base transport) + * + */ +public class TTransportFactory { + + /** + * Return a wrapped instance of the base Transport. + * + * @param trans The base transport + * @return Wrapped Transport + */ + public TTransport getTransport(TTransport trans) { + return trans; + } + +} diff --git a/lib/java/test/TestClient b/lib/java/test/TestClient new file mode 100755 index 00000000..bd3c996f --- /dev/null +++ b/lib/java/test/TestClient @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestClient $* diff --git a/lib/java/test/TestNonblockingServer b/lib/java/test/TestNonblockingServer new file mode 100644 index 00000000..070991c7 --- /dev/null +++ b/lib/java/test/TestNonblockingServer @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -server -Xmx256m -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestNonblockingServer $* diff --git a/lib/java/test/TestServer b/lib/java/test/TestServer new file mode 100755 index 00000000..0d36b58e --- /dev/null +++ b/lib/java/test/TestServer @@ -0,0 +1,22 @@ +#!/bin/bash -v + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -server -cp thrifttest.jar:../../lib/java/libthrift.jar org.apache.thrift.test.TestServer $* diff --git a/lib/java/test/org/apache/thrift/test/DeepCopyTest.java b/lib/java/test/org/apache/thrift/test/DeepCopyTest.java new file mode 100644 index 00000000..a171cab3 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/DeepCopyTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; +import thrift.test.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +public class DeepCopyTest { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + public static void main(String[] args) throws Exception { + TSerializer binarySerializer = new TSerializer(new TBinaryProtocol.Factory()); + TDeserializer binaryDeserializer = new TDeserializer(new TBinaryProtocol.Factory()); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte) 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1 << 24; + ooe.integer64 = (long) 6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + ooe.base64 = "string to bytes".getBytes(); + + Nesting n = new Nesting(new Bonk(), new OneOfEach()); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5) + 1) / 2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + + hm.big = new ArrayList(); + hm.big.add(ooe); + hm.big.add(n.my_ooe); + hm.big.get(0).a_bite = (byte) 0x22; + hm.big.get(1).a_bite = (byte) 0x23; + + hm.contain = new HashSet>(); + ArrayList stage1 = new ArrayList(2); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList(0); + hm.contain.add(stage1); + + ArrayList stage2 = new ArrayList(); + hm.bonks = new HashMap>(); + hm.bonks.put("nothing", stage2); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + stage2 = new ArrayList(); + hm.bonks.put("something", stage2); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + hm.bonks.put("poe", stage2); + + + byte[] binaryCopy = binarySerializer.serialize(hm); + HolyMoley hmCopy = new HolyMoley(); + binaryDeserializer.deserialize(hmCopy, binaryCopy); + HolyMoley hmCopy2 = new HolyMoley(hm); + + if (!hm.equals(hmCopy)) + throw new RuntimeException("copy constructor modified the original object!"); + if (!hmCopy.equals(hmCopy2)) + throw new RuntimeException("copy constructor generated incorrect copy"); + + hm.big.get(0).base64[0]++; // change binary value in original object + if (hm.equals(hmCopy2)) // make sure the change didn't propagate to the copied object + throw new RuntimeException("Binary field not copied correctly!"); + hm.big.get(0).base64[0]--; // undo change + + hmCopy2.bonks.get("nothing").get(1).message = "What else?"; + + if (hm.equals(hmCopy2)) + throw new RuntimeException("A deep copy was not done!"); + + } +} diff --git a/lib/java/test/org/apache/thrift/test/EqualityTest.java b/lib/java/test/org/apache/thrift/test/EqualityTest.java new file mode 100644 index 00000000..f01378f7 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/EqualityTest.java @@ -0,0 +1,661 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +This program was generated by the following Python script: + +#!/usr/bin/python2.5 + +# Remove this when Python 2.6 hits the streets. +from __future__ import with_statement + +import sys +import os.path + + +# Quines the easy way. +with open(sys.argv[0], 'r') as handle: + source = handle.read() + +with open(os.path.join(os.path.dirname(sys.argv[0]), 'EqualityTest.java'), 'w') as out: + print >> out, ("/""*" r""" + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + """ "*""/") + print >> out + print >> out, "/""*" + print >> out, "This program was generated by the following Python script:" + print >> out + out.write(source) + print >> out, "*""/" + + print >> out, r''' +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +/'''r'''** + *'''r'''/ +public class EqualityTest { + public static void main(String[] args) throws Exception { + JavaTestHelper lhs, rhs; +''' + + vals = { + 'int': ("1", "2"), + 'obj': ("\"foo\"", "\"bar\""), + 'bin': ("new byte[]{1,2}", "new byte[]{3,4}"), + } + matrix = ( + (False,False), + (False,True ), + (True ,False), + (True ,True ), + ) + + for type in ('int', 'obj', 'bin'): + for option in ('req', 'opt'): + nulls = matrix[0:1] if type == 'int' else matrix[-1::-1] + issets = matrix + for is_null in nulls: + for is_set in issets: + # isset is implied for non-primitives, so only consider the case + # where isset and non-null match. + if type != 'int' and list(is_set) != [ not null for null in is_null ]: + continue + for equal in (True, False): + print >> out + print >> out, " lhs = new JavaTestHelper();" + print >> out, " rhs = new JavaTestHelper();" + print >> out, " lhs." + option + "_" + type, "=", vals[type][0] + ";" + print >> out, " rhs." + option + "_" + type, "=", vals[type][0 if equal else 1] + ";" + isset_setter = "set" + option[0].upper() + option[1:] + "_" + type + "IsSet" + if (type == 'int' and is_set[0]): print >> out, " lhs." + isset_setter + "(true);" + if (type == 'int' and is_set[1]): print >> out, " rhs." + isset_setter + "(true);" + if (is_null[0]): print >> out, " lhs." + option + "_" + type, "= null;" + if (is_null[1]): print >> out, " rhs." + option + "_" + type, "= null;" + this_present = not is_null[0] and (option == 'req' or is_set[0]) + that_present = not is_null[1] and (option == 'req' or is_set[1]) + print >> out, " // this_present = " + repr(this_present) + print >> out, " // that_present = " + repr(that_present) + is_equal = \ + (not this_present and not that_present) or \ + (this_present and that_present and equal) + eq_str = 'true' if is_equal else 'false' + + print >> out, " if (lhs.equals(rhs) != "+eq_str+")" + print >> out, " throw new RuntimeException(\"Failure\");" + if is_equal: + print >> out, " if (lhs.hashCode() != rhs.hashCode())" + print >> out, " throw new RuntimeException(\"Failure\");" + + print >> out, r''' + } +} +''' +*/ + +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +/** + */ +public class EqualityTest { + public static void main(String[] args) throws Exception { + JavaTestHelper lhs, rhs; + + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + lhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + lhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 1; + lhs.setReq_intIsSet(true); + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_int = 1; + rhs.req_int = 2; + lhs.setReq_intIsSet(true); + rhs.setReq_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + rhs.setOpt_intIsSet(true); + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + rhs.setOpt_intIsSet(true); + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + lhs.setOpt_intIsSet(true); + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + lhs.setOpt_intIsSet(true); + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 1; + lhs.setOpt_intIsSet(true); + rhs.setOpt_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_int = 1; + rhs.opt_int = 2; + lhs.setOpt_intIsSet(true); + rhs.setOpt_intIsSet(true); + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + lhs.req_obj = null; + rhs.req_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + lhs.req_obj = null; + rhs.req_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + lhs.req_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + lhs.req_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + rhs.req_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + rhs.req_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "foo"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_obj = "foo"; + rhs.req_obj = "bar"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + lhs.opt_obj = null; + rhs.opt_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + lhs.opt_obj = null; + rhs.opt_obj = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + lhs.opt_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + lhs.opt_obj = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + rhs.opt_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + rhs.opt_obj = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "foo"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_obj = "foo"; + rhs.opt_obj = "bar"; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + lhs.req_bin = null; + rhs.req_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + lhs.req_bin = null; + rhs.req_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + lhs.req_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + lhs.req_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + rhs.req_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + rhs.req_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{1,2}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.req_bin = new byte[]{1,2}; + rhs.req_bin = new byte[]{3,4}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + lhs.opt_bin = null; + rhs.opt_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + lhs.opt_bin = null; + rhs.opt_bin = null; + // this_present = False + // that_present = False + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + lhs.opt_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + lhs.opt_bin = null; + // this_present = False + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + rhs.opt_bin = null; + // this_present = True + // that_present = False + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{1,2}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != true) + throw new RuntimeException("Failure"); + if (lhs.hashCode() != rhs.hashCode()) + throw new RuntimeException("Failure"); + + lhs = new JavaTestHelper(); + rhs = new JavaTestHelper(); + lhs.opt_bin = new byte[]{1,2}; + rhs.opt_bin = new byte[]{3,4}; + // this_present = True + // that_present = True + if (lhs.equals(rhs) != false) + throw new RuntimeException("Failure"); + + } +} + diff --git a/lib/java/test/org/apache/thrift/test/Fixtures.java b/lib/java/test/org/apache/thrift/test/Fixtures.java new file mode 100644 index 00000000..14ac44f7 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/Fixtures.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.*; +import thrift.test.*; + +public class Fixtures { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + + public static final OneOfEach oneOfEach; + public static final Nesting nesting; + public static final HolyMoley holyMoley; + public static final CompactProtoTestStruct compactProtoTestStruct; + + static { + try { + oneOfEach = new OneOfEach(); + oneOfEach.im_true = true; + oneOfEach.im_false = false; + oneOfEach.a_bite = (byte) 0x03; + oneOfEach.integer16 = 27000; + oneOfEach.integer32 = 1 << 24; + oneOfEach.integer64 = (long) 6000 * 1000 * 1000; + oneOfEach.double_precision = Math.PI; + oneOfEach.some_characters = "JSON THIS! \"\1"; + oneOfEach.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + + nesting = new Nesting(new Bonk(), new OneOfEach()); + nesting.my_ooe.integer16 = 16; + nesting.my_ooe.integer32 = 32; + nesting.my_ooe.integer64 = 64; + nesting.my_ooe.double_precision = (Math.sqrt(5) + 1) / 2; + nesting.my_ooe.some_characters = ":R (me going \"rrrr\")"; + nesting.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + nesting.my_bonk.type = 31337; + nesting.my_bonk.message = "I am a bonk... xor!"; + + holyMoley = new HolyMoley(); + + holyMoley.big = new ArrayList(); + holyMoley.big.add(new OneOfEach(oneOfEach)); + holyMoley.big.add(nesting.my_ooe); + holyMoley.big.get(0).a_bite = (byte) 0x22; + holyMoley.big.get(1).a_bite = (byte) 0x23; + + holyMoley.contain = new HashSet>(); + ArrayList stage1 = new ArrayList(2); + stage1.add("and a one"); + stage1.add("and a two"); + holyMoley.contain.add(stage1); + stage1 = new ArrayList(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + holyMoley.contain.add(stage1); + stage1 = new ArrayList(0); + holyMoley.contain.add(stage1); + + ArrayList stage2 = new ArrayList(); + holyMoley.bonks = new HashMap>(); + // one empty + holyMoley.bonks.put("nothing", stage2); + + // one with two + stage2 = new ArrayList(); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + holyMoley.bonks.put("something", stage2); + + // one with three + stage2 = new ArrayList(); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + holyMoley.bonks.put("poe", stage2); + + // superhuge compact proto test struct + compactProtoTestStruct = new CompactProtoTestStruct(thrift.test.Constants.COMPACT_TEST); + compactProtoTestStruct.a_binary = new byte[]{0,1,2,3,4,5,6,7,8}; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + +} diff --git a/lib/java/test/org/apache/thrift/test/IdentityTest.java b/lib/java/test/org/apache/thrift/test/IdentityTest.java new file mode 100644 index 00000000..c6453ce7 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/IdentityTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; + +import thrift.test.Bonk; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; + +/** + * + */ +public class IdentityTest { + public static Object deepCopy(Object oldObj) throws Exception { + ObjectOutputStream oos = null; + ObjectInputStream ois = null; + try { + ByteArrayOutputStream bos = + new ByteArrayOutputStream(); + oos = new ObjectOutputStream(bos); + oos.writeObject(oldObj); + oos.flush(); + ByteArrayInputStream bis = + new ByteArrayInputStream(bos.toByteArray()); + ois = new ObjectInputStream(bis); + return ois.readObject(); + } finally { + oos.close(); + ois.close(); + } + } + + public static void main(String[] args) throws Exception { + TSerializer binarySerializer = new TSerializer(new TBinaryProtocol.Factory()); + TDeserializer binaryDeserializer = new TDeserializer(new TBinaryProtocol.Factory()); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\u0001"; + ooe.base64 = new byte[]{1,2,3,(byte)255}; + + Nesting n = new Nesting(); + n.my_ooe = (OneOfEach)deepCopy(ooe); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\u04c0\u216e\u039d\u0020\u041d\u03bf\u217f"+ + "\u043e\u0261\u0433\u0430\u03c1\u210e\u0020"+ + "\u0391\u0074\u0074\u03b1\u217d\u03ba\u01c3"+ + "\u203c"; + n.my_bonk = new Bonk(); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + hm.big = new ArrayList(); + hm.contain = new HashSet>(); + hm.bonks = new HashMap>(); + + hm.big.add((OneOfEach)deepCopy(ooe)); + hm.big.add((OneOfEach)deepCopy(n.my_ooe)); + hm.big.get(0).a_bite = 0x22; + hm.big.get(1).a_bite = 0x33; + + List stage1 = new ArrayList(); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList(); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList(); + hm.contain.add(stage1); + + List stage2 = new ArrayList(); + hm.bonks.put("nothing", stage2); + stage2.add(new Bonk()); + stage2.get(0).type = 1; + stage2.get(0).message = "Wait."; + stage2.add(new Bonk()); + stage2.get(1).type = 2; + stage2.get(1).message = "What?"; + hm.bonks.put("something", stage2); + stage2 = new ArrayList(); + stage2.add(new Bonk()); + stage2.get(0).type = 3; + stage2.get(0).message = "quoth"; + stage2.add(new Bonk()); + stage2.get(1).type = 4; + stage2.get(1).message = "the raven"; + stage2.add(new Bonk()); + stage2.get(2).type = 5; + stage2.get(2).message = "nevermore"; + hm.bonks.put("poe", stage2); + + OneOfEach ooe2 = new OneOfEach(); + binaryDeserializer.deserialize( + ooe2, + binarySerializer.serialize(ooe)); + + if (!ooe.equals(ooe2)) { + throw new RuntimeException("Failure: ooe (equals)"); + } + if (ooe.hashCode() != ooe2.hashCode()) { + throw new RuntimeException("Failure: ooe (hash)"); + } + + + Nesting n2 = new Nesting(); + binaryDeserializer.deserialize( + n2, + binarySerializer.serialize(n)); + + if (!n.equals(n2)) { + throw new RuntimeException("Failure: n (equals)"); + } + if (n.hashCode() != n2.hashCode()) { + throw new RuntimeException("Failure: n (hash)"); + } + + HolyMoley hm2 = new HolyMoley(); + binaryDeserializer.deserialize( + hm2, + binarySerializer.serialize(hm)); + + if (!hm.equals(hm2)) { + throw new RuntimeException("Failure: hm (equals)"); + } + if (hm.hashCode() != hm2.hashCode()) { + throw new RuntimeException("Failure: hm (hash)"); + } + + } +} diff --git a/lib/java/test/org/apache/thrift/test/JSONProtoTest.java b/lib/java/test/org/apache/thrift/test/JSONProtoTest.java new file mode 100644 index 00000000..59f4ce18 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/JSONProtoTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.apache.thrift.protocol.TJSONProtocol; +import org.apache.thrift.transport.TMemoryBuffer; + +import thrift.test.Base64; +import thrift.test.Bonk; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; + +/** + * Tests for the Java implementation of TJSONProtocol. Mirrors the C++ version + * + */ +public class JSONProtoTest { + + private static final byte[] kUnicodeBytes = { + (byte)0xd3, (byte)0x80, (byte)0xe2, (byte)0x85, (byte)0xae, (byte)0xce, + (byte)0x9d, (byte)0x20, (byte)0xd0, (byte)0x9d, (byte)0xce, (byte)0xbf, + (byte)0xe2, (byte)0x85, (byte)0xbf, (byte)0xd0, (byte)0xbe, (byte)0xc9, + (byte)0xa1, (byte)0xd0, (byte)0xb3, (byte)0xd0, (byte)0xb0, (byte)0xcf, + (byte)0x81, (byte)0xe2, (byte)0x84, (byte)0x8e, (byte)0x20, (byte)0xce, + (byte)0x91, (byte)0x74, (byte)0x74, (byte)0xce, (byte)0xb1, (byte)0xe2, + (byte)0x85, (byte)0xbd, (byte)0xce, (byte)0xba, (byte)0x83, (byte)0xe2, + (byte)0x80, (byte)0xbc + }; + + public static void main(String [] args) throws Exception { + try { + System.out.println("In JSON Proto test"); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + + + Nesting n = new Nesting(new Bonk(), new OneOfEach()); + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (Math.sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = new String(kUnicodeBytes, "UTF-8"); + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + HolyMoley hm = new HolyMoley(); + + hm.big = new ArrayList(); + hm.big.add(ooe); + hm.big.add(n.my_ooe); + hm.big.get(0).a_bite = (byte)0x22; + hm.big.get(1).a_bite = (byte)0x23; + + hm.contain = new HashSet>(); + ArrayList stage1 = new ArrayList(2); + stage1.add("and a one"); + stage1.add("and a two"); + hm.contain.add(stage1); + stage1 = new ArrayList(3); + stage1.add("then a one, two"); + stage1.add("three!"); + stage1.add("FOUR!!"); + hm.contain.add(stage1); + stage1 = new ArrayList(0); + hm.contain.add(stage1); + + ArrayList stage2 = new ArrayList(); + hm.bonks = new HashMap>(); + hm.bonks.put("nothing", stage2); + Bonk b = new Bonk(); + b.type = 1; + b.message = "Wait."; + stage2.add(b); + b = new Bonk(); + b.type = 2; + b.message = "What?"; + stage2.add(b); + stage2 = new ArrayList(); + hm.bonks.put("something", stage2); + b = new Bonk(); + b.type = 3; + b.message = "quoth"; + b = new Bonk(); + b.type = 4; + b.message = "the raven"; + b = new Bonk(); + b.type = 5; + b.message = "nevermore"; + hm.bonks.put("poe", stage2); + + TMemoryBuffer buffer = new TMemoryBuffer(1024); + TJSONProtocol proto = new TJSONProtocol(buffer); + + System.out.println("Writing ooe"); + ooe.write(proto); + System.out.println("Reading ooe"); + OneOfEach ooe2 = new OneOfEach(); + ooe2.read(proto); + + System.out.println("Comparing ooe"); + if (!ooe.equals(ooe2)) { + throw new RuntimeException("ooe != ooe2"); + } + + System.out.println("Writing hm"); + hm.write(proto); + + System.out.println("Reading hm"); + HolyMoley hm2 = new HolyMoley(); + hm2.read(proto); + + System.out.println("Comparing hm"); + if (!hm.equals(hm2)) { + throw new RuntimeException("hm != hm2"); + } + + hm2.big.get(0).a_bite = (byte)0xFF; + if (hm.equals(hm2)) { + throw new RuntimeException("hm should not equal hm2"); + } + + Base64 base = new Base64(); + base.a = 123; + base.b1 = "1".getBytes("UTF-8"); + base.b2 = "12".getBytes("UTF-8"); + base.b3 = "123".getBytes("UTF-8"); + base.b4 = "1234".getBytes("UTF-8"); + base.b5 = "12345".getBytes("UTF-8"); + base.b6 = "123456".getBytes("UTF-8"); + + System.out.println("Writing base"); + base.write(proto); + + System.out.println("Reading base"); + Base64 base2 = new Base64(); + base2.read(proto); + + System.out.println("Comparing base"); + if (!base.equals(base2)) { + throw new RuntimeException("base != base2"); + } + + } catch (Exception ex) { + ex.printStackTrace(); + throw ex; + } + } + +} diff --git a/lib/java/test/org/apache/thrift/test/JavaBeansTest.java b/lib/java/test/org/apache/thrift/test/JavaBeansTest.java new file mode 100644 index 00000000..b72bd388 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/JavaBeansTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.util.LinkedList; +import thrift.test.OneOfEachBeans; + +public class JavaBeansTest { + public static void main(String[] args) throws Exception { + // Test isSet methods + OneOfEachBeans ooe = new OneOfEachBeans(); + + // Nothing should be set + if (ooe.is_set_a_bite()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_base64()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_byte_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_double_precision()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_i16_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_i64_list()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_boolean_field()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer16()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer32()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_integer64()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + if (ooe.is_set_some_characters()) + throw new RuntimeException("isSet method error: unset field returned as set!"); + + for (int i = 1; i < 12; i++){ + if (ooe.isSet(i)) + throw new RuntimeException("isSet method error: unset field " + i + " returned as set!"); + } + + // Everything is set + ooe.set_a_bite((byte) 1); + ooe.set_base64("bytes".getBytes()); + ooe.set_byte_list(new LinkedList()); + ooe.set_double_precision(1); + ooe.set_i16_list(new LinkedList()); + ooe.set_i64_list(new LinkedList()); + ooe.set_boolean_field(true); + ooe.set_integer16((short) 1); + ooe.set_integer32(1); + ooe.set_integer64(1); + ooe.set_some_characters("string"); + + if (!ooe.is_set_a_bite()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_base64()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_byte_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_double_precision()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_i16_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_i64_list()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_boolean_field()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer16()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer32()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_integer64()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + if (!ooe.is_set_some_characters()) + throw new RuntimeException("isSet method error: set field returned as unset!"); + + for (int i = 1; i < 12; i++){ + if (!ooe.isSet(i)) + throw new RuntimeException("isSet method error: set field " + i + " returned as unset!"); + } + + // Should throw exception when field doesn't exist + boolean exceptionThrown = false; + try{ + if (ooe.isSet(100)); + } catch (IllegalArgumentException e){ + exceptionThrown = true; + } + if (!exceptionThrown) + throw new RuntimeException("isSet method error: non-existent field provided as agument but no exception thrown!"); + } +} diff --git a/lib/java/test/org/apache/thrift/test/MetaDataTest.java b/lib/java/test/org/apache/thrift/test/MetaDataTest.java new file mode 100644 index 00000000..a0180348 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/MetaDataTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.Map; +import org.apache.thrift.TFieldRequirementType; +import org.apache.thrift.meta_data.FieldMetaData; +import org.apache.thrift.meta_data.ListMetaData; +import org.apache.thrift.meta_data.MapMetaData; +import org.apache.thrift.meta_data.SetMetaData; +import org.apache.thrift.meta_data.StructMetaData; +import org.apache.thrift.protocol.TType; +import thrift.test.*; + +public class MetaDataTest { + + public static void main(String[] args) throws Exception { + Map mdMap = CrazyNesting.metaDataMap; + + // Check for struct fields existence + if (mdMap.size() != 3) + throw new RuntimeException("metadata map contains wrong number of entries!"); + if (!mdMap.containsKey(CrazyNesting.SET_FIELD) || !mdMap.containsKey(CrazyNesting.LIST_FIELD) || !mdMap.containsKey(CrazyNesting.STRING_FIELD)) + throw new RuntimeException("metadata map doesn't contain entry for a struct field!"); + + // Check for struct fields contents + if (!mdMap.get(CrazyNesting.STRING_FIELD).fieldName.equals("string_field") || + !mdMap.get(CrazyNesting.LIST_FIELD).fieldName.equals("list_field") || + !mdMap.get(CrazyNesting.SET_FIELD).fieldName.equals("set_field")) + throw new RuntimeException("metadata map contains a wrong fieldname"); + if (mdMap.get(CrazyNesting.STRING_FIELD).requirementType != TFieldRequirementType.DEFAULT || + mdMap.get(CrazyNesting.LIST_FIELD).requirementType != TFieldRequirementType.REQUIRED || + mdMap.get(CrazyNesting.SET_FIELD).requirementType != TFieldRequirementType.OPTIONAL) + throw new RuntimeException("metadata map contains the wrong requirement type for a field"); + if (mdMap.get(CrazyNesting.STRING_FIELD).valueMetaData.type != TType.STRING || + mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.type != TType.LIST || + mdMap.get(CrazyNesting.SET_FIELD).valueMetaData.type != TType.SET) + throw new RuntimeException("metadata map contains the wrong requirement type for a field"); + + // Check nested structures + if (!mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isContainer()) + throw new RuntimeException("value metadata for a list is stored as non-container!"); + if (mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData.isStruct()) + throw new RuntimeException("value metadata for a list is stored as a struct!"); + if (((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData.type != TType.STRUCT) + throw new RuntimeException("metadata map contains wrong type for a value in a deeply nested structure"); + if (((StructMetaData)((MapMetaData)((ListMetaData)((SetMetaData)((MapMetaData)((MapMetaData)((ListMetaData)mdMap.get(CrazyNesting.LIST_FIELD).valueMetaData).elemMetaData).valueMetaData).valueMetaData).elemMetaData).elemMetaData).keyMetaData).structClass != Insanity.class) + throw new RuntimeException("metadata map contains wrong class for a struct in a deeply nested structure"); + + // Check that FieldMetaData contains a map with metadata for all generated struct classes + if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) == null || + FieldMetaData.getStructMetaDataMap(Insanity.class) == null || + FieldMetaData.getStructMetaDataMap(Xtruct.class) == null) + throw new RuntimeException("global metadata map doesn't contain an entry for a known struct"); + if (FieldMetaData.getStructMetaDataMap(CrazyNesting.class) != CrazyNesting.metaDataMap || + FieldMetaData.getStructMetaDataMap(Insanity.class) != Insanity.metaDataMap) + throw new RuntimeException("global metadata map contains wrong entry for a loaded struct"); + } +} diff --git a/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java b/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java new file mode 100644 index 00000000..54d78e56 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/OverloadNonblockingServer.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.transport.TSocket; + + +public class OverloadNonblockingServer { + + public static void main(String[] args) throws Exception { + int msg_size_mb = Integer.parseInt(args[0]); + int msg_size = msg_size_mb * 1024 * 1024; + + TSocket socket = new TSocket("localhost", 9090); + TBinaryProtocol binprot = new TBinaryProtocol(socket); + socket.open(); + binprot.writeI32(msg_size); + binprot.writeI32(1); + socket.flush(); + + System.in.read(); + // Thread.sleep(30000); + for (int i = 0; i < msg_size_mb; i++) { + binprot.writeBinary(new byte[1024 * 1024]); + } + + socket.close(); + } +} diff --git a/lib/java/test/org/apache/thrift/test/ReadStruct.java b/lib/java/test/org/apache/thrift/test/ReadStruct.java new file mode 100644 index 00000000..2dc042c5 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/ReadStruct.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.io.BufferedInputStream; +import java.io.FileInputStream; + +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TTransport; + +import thrift.test.CompactProtoTestStruct; + +public class ReadStruct { + public static void main(String[] args) throws Exception { + if (args.length != 2) { + System.out.println("usage: java -cp build/classes org.apache.thrift.test.ReadStruct filename proto_factory_class"); + System.out.println("Read in an instance of CompactProtocolTestStruct from 'file', making sure that it is equivalent to Fixtures.compactProtoTestStruct. Use a protocol from 'proto_factory_class'."); + } + + TTransport trans = new TIOStreamTransport(new BufferedInputStream(new FileInputStream(args[0]))); + + TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance(); + + TProtocol proto = factory.getProtocol(trans); + + CompactProtoTestStruct cpts = new CompactProtoTestStruct(); + + for (Integer fid : CompactProtoTestStruct.metaDataMap.keySet()) { + cpts.setFieldValue(fid, null); + } + + cpts.read(proto); + + if (cpts.equals(Fixtures.compactProtoTestStruct)) { + System.out.println("Object verified successfully!"); + } else { + System.out.println("Object failed verification!"); + System.out.println("Expected: " + Fixtures.compactProtoTestStruct + " but got " + cpts); + } + + } + +} diff --git a/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java new file mode 100644 index 00000000..b83b2f9b --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/SerializationBenchmark.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.io.ByteArrayInputStream; + +import org.apache.thrift.*; +import org.apache.thrift.protocol.*; +import org.apache.thrift.transport.*; + +import thrift.test.*; + +public class SerializationBenchmark { + private final static int HOW_MANY = 10000000; + + public static void main(String[] args) throws Exception { + TProtocolFactory factory = new TBinaryProtocol.Factory(); + + OneOfEach ooe = new OneOfEach(); + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = (byte)0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (long)6000 * 1000 * 1000; + ooe.double_precision = Math.PI; + ooe.some_characters = "JSON THIS! \"\u0001"; + ooe.base64 = new byte[]{1,2,3,(byte)255}; + + testSerialization(factory, ooe); + testDeserialization(factory, ooe, OneOfEach.class); + } + + public static void testSerialization(TProtocolFactory factory, TBase object) throws Exception { + TTransport trans = new TTransport() { + public void write(byte[] bin, int x, int y) throws TTransportException {} + public int read(byte[] bin, int x, int y) throws TTransportException {return 0;} + public void close() {} + public void open() {} + public boolean isOpen() {return true;} + }; + + TProtocol proto = factory.getProtocol(trans); + + long startTime = System.currentTimeMillis(); + for (int i = 0; i < HOW_MANY; i++) { + object.write(proto); + } + long endTime = System.currentTimeMillis(); + + System.out.println("Test time: " + (endTime - startTime) + " ms"); + } + + public static void testDeserialization(TProtocolFactory factory, T object, Class klass) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + object.write(factory.getProtocol(buf)); + byte[] serialized = new byte[100*1024]; + buf.read(serialized, 0, 100*1024); + + long startTime = System.currentTimeMillis(); + for (int i = 0; i < HOW_MANY; i++) { + T o2 = klass.newInstance(); + o2.read(factory.getProtocol(new TIOStreamTransport(new ByteArrayInputStream(serialized)))); + } + long endTime = System.currentTimeMillis(); + + System.out.println("Test time: " + (endTime - startTime) + " ms"); + } +} \ No newline at end of file diff --git a/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java new file mode 100755 index 00000000..5d1c3cba --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TCompactProtocolTest.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.thrift.test; + +import java.util.Arrays; +import java.util.List; + +import org.apache.thrift.TBase; +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TMessage; +import org.apache.thrift.protocol.TMessageType; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.protocol.TStruct; +import org.apache.thrift.protocol.TType; +import org.apache.thrift.transport.TMemoryBuffer; + +import thrift.test.CompactProtoTestStruct; +import thrift.test.HolyMoley; +import thrift.test.Nesting; +import thrift.test.OneOfEach; +import thrift.test.Srv; + +public class TCompactProtocolTest { + + static TProtocolFactory factory = new TCompactProtocol.Factory(); + + public static void main(String[] args) throws Exception { + testNakedByte(); + for (int i = 0; i < 128; i++) { + testByteField((byte)i); + testByteField((byte)-i); + } + + testNakedI16((short)0); + testNakedI16((short)1); + testNakedI16((short)15000); + testNakedI16((short)0x7fff); + testNakedI16((short)-1); + testNakedI16((short)-15000); + testNakedI16((short)-0x7fff); + + testI16Field((short)0); + testI16Field((short)1); + testI16Field((short)7); + testI16Field((short)150); + testI16Field((short)15000); + testI16Field((short)0x7fff); + testI16Field((short)-1); + testI16Field((short)-7); + testI16Field((short)-150); + testI16Field((short)-15000); + testI16Field((short)-0x7fff); + + testNakedI32(0); + testNakedI32(1); + testNakedI32(15000); + testNakedI32(0xffff); + testNakedI32(-1); + testNakedI32(-15000); + testNakedI32(-0xffff); + + testI32Field(0); + testI32Field(1); + testI32Field(7); + testI32Field(150); + testI32Field(15000); + testI32Field(31337); + testI32Field(0xffff); + testI32Field(0xffffff); + testI32Field(-1); + testI32Field(-7); + testI32Field(-150); + testI32Field(-15000); + testI32Field(-0xffff); + testI32Field(-0xffffff); + + testNakedI64(0); + for (int i = 0; i < 62; i++) { + testNakedI64(1L << i); + testNakedI64(-(1L << i)); + } + + testI64Field(0); + for (int i = 0; i < 62; i++) { + testI64Field(1L << i); + testI64Field(-(1L << i)); + } + + testDouble(); + + testNakedString(""); + testNakedString("short"); + testNakedString("borderlinetiny"); + testNakedString("a bit longer than the smallest possible"); + + testStringField(""); + testStringField("short"); + testStringField("borderlinetiny"); + testStringField("a bit longer than the smallest possible"); + + testNakedBinary(new byte[]{}); + testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10}); + testNakedBinary(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}); + testNakedBinary(new byte[128]); + + testBinaryField(new byte[]{}); + testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10}); + testBinaryField(new byte[]{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14}); + testBinaryField(new byte[128]); + + testSerialization(OneOfEach.class, Fixtures.oneOfEach); + testSerialization(Nesting.class, Fixtures.nesting); + testSerialization(HolyMoley.class, Fixtures.holyMoley); + testSerialization(CompactProtoTestStruct.class, Fixtures.compactProtoTestStruct); + + testMessage(); + + testServerRequest(); + } + + public static void testNakedByte() throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeByte((byte)123); + byte out = proto.readByte(); + if (out != 123) { + throw new RuntimeException("Byte was supposed to be " + (byte)123 + " but was " + out); + } + } + + public static void testByteField(final byte b) throws Exception { + testStructField(new StructFieldTestCase(TType.BYTE, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeByte(b); + } + + public void readMethod(TProtocol proto) throws TException { + byte result = proto.readByte(); + if (result != b) { + throw new RuntimeException("Byte was supposed to be " + (byte)b + " but was " + result); + } + } + }); + } + + public static void testNakedI16(short n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI16(n); + // System.out.println(buf.inspect()); + int out = proto.readI16(); + if (out != n) { + throw new RuntimeException("I16 was supposed to be " + n + " but was " + out); + } + } + + public static void testI16Field(final short n) throws Exception { + testStructField(new StructFieldTestCase(TType.I16, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI16(n); + } + + public void readMethod(TProtocol proto) throws TException { + short result = proto.readI16(); + if (result != n) { + throw new RuntimeException("I16 was supposed to be " + n + " but was " + result); + } + } + }); + } + + public static void testNakedI32(int n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI32(n); + // System.out.println(buf.inspect()); + int out = proto.readI32(); + if (out != n) { + throw new RuntimeException("I32 was supposed to be " + n + " but was " + out); + } + } + + public static void testI32Field(final int n) throws Exception { + testStructField(new StructFieldTestCase(TType.I32, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI32(n); + } + + public void readMethod(TProtocol proto) throws TException { + int result = proto.readI32(); + if (result != n) { + throw new RuntimeException("I32 was supposed to be " + n + " but was " + result); + } + } + }); + + } + + public static void testNakedI64(long n) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeI64(n); + // System.out.println(buf.inspect()); + long out = proto.readI64(); + if (out != n) { + throw new RuntimeException("I64 was supposed to be " + n + " but was " + out); + } + } + + public static void testI64Field(final long n) throws Exception { + testStructField(new StructFieldTestCase(TType.I64, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeI64(n); + } + + public void readMethod(TProtocol proto) throws TException { + long result = proto.readI64(); + if (result != n) { + throw new RuntimeException("I64 was supposed to be " + n + " but was " + result); + } + } + }); + } + + public static void testDouble() throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(1000); + TProtocol proto = factory.getProtocol(buf); + proto.writeDouble(123.456); + double out = proto.readDouble(); + if (out != 123.456) { + throw new RuntimeException("Double was supposed to be " + 123.456 + " but was " + out); + } + } + + public static void testNakedString(String str) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeString(str); + // System.out.println(buf.inspect()); + String out = proto.readString(); + if (!str.equals(out)) { + throw new RuntimeException("String was supposed to be '" + str + "' but was '" + out + "'"); + } + } + + public static void testStringField(final String str) throws Exception { + testStructField(new StructFieldTestCase(TType.STRING, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeString(str); + } + + public void readMethod(TProtocol proto) throws TException { + String result = proto.readString(); + if (!result.equals(str)) { + throw new RuntimeException("String was supposed to be " + str + " but was " + result); + } + } + }); + } + + public static void testNakedBinary(byte[] data) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + proto.writeBinary(data); + // System.out.println(buf.inspect()); + byte[] out = proto.readBinary(); + if (!Arrays.equals(data, out)) { + throw new RuntimeException("Binary was supposed to be '" + data + "' but was '" + out + "'"); + } + } + + public static void testBinaryField(final byte[] data) throws Exception { + testStructField(new StructFieldTestCase(TType.STRING, (short)15) { + public void writeMethod(TProtocol proto) throws TException { + proto.writeBinary(data); + } + + public void readMethod(TProtocol proto) throws TException { + byte[] result = proto.readBinary(); + if (!Arrays.equals(data, result)) { + throw new RuntimeException("Binary was supposed to be '" + bytesToString(data) + "' but was '" + bytesToString(result) + "'"); + } + } + }); + + } + + public static void testSerialization(Class klass, T obj) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TBinaryProtocol binproto = new TBinaryProtocol(buf); + + try { + obj.write(binproto); + // System.out.println("Size in binary protocol: " + buf.length()); + + buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + + obj.write(proto); + System.out.println("Size in compact protocol: " + buf.length()); + // System.out.println(buf.inspect()); + + T objRead = klass.newInstance(); + objRead.read(proto); + if (!obj.equals(objRead)) { + System.out.println("Expected: " + obj.toString()); + System.out.println("Actual: " + objRead.toString()); + // System.out.println(buf.inspect()); + throw new RuntimeException("Objects didn't match!"); + } + } catch (Exception e) { + System.out.println(buf.inspect()); + throw e; + } + } + + public static void testMessage() throws Exception { + List msgs = Arrays.asList(new TMessage[]{ + new TMessage("short message name", TMessageType.CALL, 0), + new TMessage("1", TMessageType.REPLY, 12345), + new TMessage("loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16), + new TMessage("Janky", TMessageType.CALL, 0), + }); + + for (TMessage msg : msgs) { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + TMessage output = null; + + proto.writeMessageBegin(msg); + proto.writeMessageEnd(); + + output = proto.readMessageBegin(); + + if (!msg.equals(output)) { + throw new RuntimeException("Message was supposed to be " + msg + " but was " + output); + } + } + } + + public static void testServerRequest() throws Exception { + Srv.Iface handler = new Srv.Iface() { + public int Janky(int i32arg) throws TException { + return i32arg * 2; + } + + public int primitiveMethod() throws TException { + // TODO Auto-generated method stub + return 0; + } + + public CompactProtoTestStruct structMethod() throws TException { + // TODO Auto-generated method stub + return null; + } + + public void voidMethod() throws TException { + // TODO Auto-generated method stub + + } + }; + + Srv.Processor testProcessor = new Srv.Processor(handler); + + TMemoryBuffer clientOutTrans = new TMemoryBuffer(0); + TProtocol clientOutProto = factory.getProtocol(clientOutTrans); + TMemoryBuffer clientInTrans = new TMemoryBuffer(0); + TProtocol clientInProto = factory.getProtocol(clientInTrans); + + Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto); + + testClient.send_Janky(1); + // System.out.println(clientOutTrans.inspect()); + testProcessor.process(clientOutProto, clientInProto); + // System.out.println(clientInTrans.inspect()); + int result = testClient.recv_Janky(); + if (result != 2) { + throw new RuntimeException("Got an unexpected result: " + result); + } + } + + // + // Helper methods + // + + private static String bytesToString(byte[] bytes) { + String s = ""; + for (int i = 0; i < bytes.length; i++) { + s += Integer.toHexString((int)bytes[i]) + " "; + } + return s; + } + + private static void testStructField(StructFieldTestCase testCase) throws Exception { + TMemoryBuffer buf = new TMemoryBuffer(0); + TProtocol proto = factory.getProtocol(buf); + + TField field = new TField("test_field", testCase.type_, testCase.id_); + proto.writeStructBegin(new TStruct("test_struct")); + proto.writeFieldBegin(field); + testCase.writeMethod(proto); + proto.writeFieldEnd(); + proto.writeStructEnd(); + + // System.out.println(buf.inspect()); + + proto.readStructBegin(); + TField readField = proto.readFieldBegin(); + // TODO: verify the field is as expected + if (!field.equals(readField)) { + throw new RuntimeException("Expected " + field + " but got " + readField); + } + testCase.readMethod(proto); + proto.readStructEnd(); + } + + public static abstract class StructFieldTestCase { + byte type_; + short id_; + public StructFieldTestCase(byte type, short id) { + type_ = type; + id_ = id; + } + + public abstract void writeMethod(TProtocol proto) throws TException; + public abstract void readMethod(TProtocol proto) throws TException; + } +} \ No newline at end of file diff --git a/lib/java/test/org/apache/thrift/test/TestClient.java b/lib/java/test/org/apache/thrift/test/TestClient.java new file mode 100644 index 00000000..4a95f7a6 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestClient.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +// Generated code +import thrift.test.*; + +import org.apache.thrift.TApplicationException; +import org.apache.thrift.TSerializer; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.THttpClient; +import org.apache.thrift.transport.TFramedTransport; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TSimpleJSONProtocol; + +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.ArrayList; + +/** + * Test Java client for thrift. Essentially just a copy of the C++ version, + * this makes a variety of requests to enable testing for both performance and + * correctness of the output. + * + */ +public class TestClient { + public static void main(String [] args) { + try { + String host = "localhost"; + int port = 9090; + String url = null; + int numTests = 1; + boolean framed = false; + + int socketTimeout = 1000; + + try { + for (int i = 0; i < args.length; ++i) { + if (args[i].equals("-h")) { + String[] hostport = (args[++i]).split(":"); + host = hostport[0]; + port = Integer.valueOf(hostport[1]); + } else if (args[i].equals("-f") || args[i].equals("-framed")) { + framed = true; + } else if (args[i].equals("-u")) { + url = args[++i]; + } else if (args[i].equals("-n")) { + numTests = Integer.valueOf(args[++i]); + } else if (args[i].equals("-timeout")) { + socketTimeout = Integer.valueOf(args[++i]); + } + } + } catch (Exception x) { + x.printStackTrace(); + } + + TTransport transport; + + if (url != null) { + transport = new THttpClient(url); + } else { + TSocket socket = new TSocket(host, port); + socket.setTimeout(socketTimeout); + transport = socket; + if (framed) { + transport = new TFramedTransport(transport); + } + } + + TBinaryProtocol binaryProtocol = + new TBinaryProtocol(transport); + ThriftTest.Client testClient = + new ThriftTest.Client(binaryProtocol); + Insanity insane = new Insanity(); + + long timeMin = 0; + long timeMax = 0; + long timeTot = 0; + + for (int test = 0; test < numTests; ++test) { + + /** + * CONNECT TEST + */ + System.out.println("Test #" + (test+1) + ", " + "connect " + host + ":" + port); + try { + transport.open(); + } catch (TTransportException ttx) { + System.out.println("Connect failed: " + ttx.getMessage()); + continue; + } + + long start = System.nanoTime(); + + /** + * VOID TEST + */ + try { + System.out.print("testVoid()"); + testClient.testVoid(); + System.out.print(" = void\n"); + } catch (TApplicationException tax) { + tax.printStackTrace(); + } + + /** + * STRING TEST + */ + System.out.print("testString(\"Test\")"); + String s = testClient.testString("Test"); + System.out.print(" = \"" + s + "\"\n"); + + /** + * BYTE TEST + */ + System.out.print("testByte(1)"); + byte i8 = testClient.testByte((byte)1); + System.out.print(" = " + i8 + "\n"); + + /** + * I32 TEST + */ + System.out.print("testI32(-1)"); + int i32 = testClient.testI32(-1); + System.out.print(" = " + i32 + "\n"); + + /** + * I64 TEST + */ + System.out.print("testI64(-34359738368)"); + long i64 = testClient.testI64(-34359738368L); + System.out.print(" = " + i64 + "\n"); + + /** + * DOUBLE TEST + */ + System.out.print("testDouble(5.325098235)"); + double dub = testClient.testDouble(5.325098235); + System.out.print(" = " + dub + "\n"); + + /** + * STRUCT TEST + */ + System.out.print("testStruct({\"Zero\", 1, -3, -5})"); + Xtruct out = new Xtruct(); + out.string_thing = "Zero"; + out.byte_thing = (byte) 1; + out.i32_thing = -3; + out.i64_thing = -5; + Xtruct in = testClient.testStruct(out); + System.out.print(" = {" + "\"" + in.string_thing + "\", " + in.byte_thing + ", " + in.i32_thing + ", " + in.i64_thing + "}\n"); + + /** + * NESTED STRUCT TEST + */ + System.out.print("testNest({1, {\"Zero\", 1, -3, -5}), 5}"); + Xtruct2 out2 = new Xtruct2(); + out2.byte_thing = (short)1; + out2.struct_thing = out; + out2.i32_thing = 5; + Xtruct2 in2 = testClient.testNest(out2); + in = in2.struct_thing; + System.out.print(" = {" + in2.byte_thing + ", {" + "\"" + in.string_thing + "\", " + in.byte_thing + ", " + in.i32_thing + ", " + in.i64_thing + "}, " + in2.i32_thing + "}\n"); + + /** + * MAP TEST + */ + Map mapout = new HashMap(); + for (int i = 0; i < 5; ++i) { + mapout.put(i, i-10); + } + System.out.print("testMap({"); + boolean first = true; + for (int key : mapout.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + mapout.get(key)); + } + System.out.print("})"); + Map mapin = testClient.testMap(mapout); + System.out.print(" = {"); + first = true; + for (int key : mapin.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + mapout.get(key)); + } + System.out.print("}\n"); + + /** + * SET TEST + */ + Set setout = new HashSet(); + for (int i = -2; i < 3; ++i) { + setout.add(i); + } + System.out.print("testSet({"); + first = true; + for (int elem : setout) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})"); + Set setin = testClient.testSet(setout); + System.out.print(" = {"); + first = true; + for (int elem : setin) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("}\n"); + + /** + * LIST TEST + */ + List listout = new ArrayList(); + for (int i = -2; i < 3; ++i) { + listout.add(i); + } + System.out.print("testList({"); + first = true; + for (int elem : listout) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})"); + List listin = testClient.testList(listout); + System.out.print(" = {"); + first = true; + for (int elem : listin) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("}\n"); + + /** + * ENUM TEST + */ + System.out.print("testEnum(ONE)"); + int ret = testClient.testEnum(Numberz.ONE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(TWO)"); + ret = testClient.testEnum(Numberz.TWO); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(THREE)"); + ret = testClient.testEnum(Numberz.THREE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(FIVE)"); + ret = testClient.testEnum(Numberz.FIVE); + System.out.print(" = " + ret + "\n"); + + System.out.print("testEnum(EIGHT)"); + ret = testClient.testEnum(Numberz.EIGHT); + System.out.print(" = " + ret + "\n"); + + /** + * TYPEDEF TEST + */ + System.out.print("testTypedef(309858235082523)"); + long uid = testClient.testTypedef(309858235082523L); + System.out.print(" = " + uid + "\n"); + + /** + * NESTED MAP TEST + */ + System.out.print("testMapMap(1)"); + Map> mm = + testClient.testMapMap(1); + System.out.print(" = {"); + for (int key : mm.keySet()) { + System.out.print(key + " => {"); + Map m2 = mm.get(key); + for (int k2 : m2.keySet()) { + System.out.print(k2 + " => " + m2.get(k2) + ", "); + } + System.out.print("}, "); + } + System.out.print("}\n"); + + /** + * INSANITY TEST + */ + insane = new Insanity(); + insane.userMap = new HashMap(); + insane.userMap.put(Numberz.FIVE, (long)5000); + Xtruct truck = new Xtruct(); + truck.string_thing = "Truck"; + truck.byte_thing = (byte)8; + truck.i32_thing = 8; + truck.i64_thing = 8; + insane.xtructs = new ArrayList(); + insane.xtructs.add(truck); + System.out.print("testInsanity()"); + Map> whoa = + testClient.testInsanity(insane); + System.out.print(" = {"); + for (long key : whoa.keySet()) { + Map val = whoa.get(key); + System.out.print(key + " => {"); + + for (int k2 : val.keySet()) { + Insanity v2 = val.get(k2); + System.out.print(k2 + " => {"); + Map userMap = v2.userMap; + System.out.print("{"); + if (userMap != null) { + for (int k3 : userMap.keySet()) { + System.out.print(k3 + " => " + userMap.get(k3) + ", "); + } + } + System.out.print("}, "); + + List xtructs = v2.xtructs; + System.out.print("{"); + if (xtructs != null) { + for (Xtruct x : xtructs) { + System.out.print("{" + "\"" + x.string_thing + "\", " + x.byte_thing + ", " + x.i32_thing + ", "+ x.i64_thing + "}, "); + } + } + System.out.print("}"); + + System.out.print("}, "); + } + System.out.print("}, "); + } + System.out.print("}\n"); + + // Test oneway + System.out.print("testOneway(3)..."); + long startOneway = System.nanoTime(); + testClient.testOneway(3); + long onewayElapsedMillis = (System.nanoTime() - startOneway) / 1000000; + if (onewayElapsedMillis > 200) { + throw new Exception("Oneway test failed: took " + + Long.toString(onewayElapsedMillis) + + "ms"); + } else { + System.out.println("Success - took " + + Long.toString(onewayElapsedMillis) + + "ms"); + } + + + long stop = System.nanoTime(); + long tot = stop-start; + + System.out.println("Total time: " + tot/1000 + "us"); + + if (timeMin == 0 || tot < timeMin) { + timeMin = tot; + } + if (tot > timeMax) { + timeMax = tot; + } + timeTot += tot; + + transport.close(); + } + + long timeAvg = timeTot / numTests; + + System.out.println("Min time: " + timeMin/1000 + "us"); + System.out.println("Max time: " + timeMax/1000 + "us"); + System.out.println("Avg time: " + timeAvg/1000 + "us"); + + String json = (new TSerializer(new TSimpleJSONProtocol.Factory())).toString(insane); + + System.out.println("\nFor good meausre here is some JSON:\n" + json); + + } catch (Exception x) { + x.printStackTrace(); + } + + } + +} diff --git a/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java new file mode 100644 index 00000000..2e6a1780 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestNonblockingServer.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import org.apache.thrift.server.THsHaServer; +import org.apache.thrift.server.TNonblockingServer; +import org.apache.thrift.server.TServer; +import org.apache.thrift.transport.TNonblockingServerSocket; + +import thrift.test.ThriftTest; + + +public class TestNonblockingServer extends TestServer { + public static void main(String [] args) { + try { + int port = 9090; + boolean hsha = false; + + for (int i = 0; i < args.length; i++) { + if (args[i].equals("-p")) { + port = Integer.valueOf(args[i++]); + } else if (args[i].equals("-hsha")) { + hsha = true; + } + } + + // Processor + TestHandler testHandler = + new TestHandler(); + ThriftTest.Processor testProcessor = + new ThriftTest.Processor(testHandler); + + // Transport + TNonblockingServerSocket tServerSocket = + new TNonblockingServerSocket(port); + + TServer serverEngine; + + if (hsha) { + // HsHa Server + serverEngine = new THsHaServer(testProcessor, tServerSocket); + } else { + // Nonblocking Server + serverEngine = new TNonblockingServer(testProcessor, tServerSocket); + } + + // Run it + System.out.println("Starting the server on port " + port + "..."); + serverEngine.serve(); + + } catch (Exception x) { + x.printStackTrace(); + } + System.out.println("done."); + } +} diff --git a/lib/java/test/org/apache/thrift/test/TestServer.java b/lib/java/test/org/apache/thrift/test/TestServer.java new file mode 100644 index 00000000..986f8890 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/TestServer.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TThreadPoolServer; +import org.apache.thrift.transport.TServerSocket; + +import thrift.test.Insanity; +import thrift.test.Numberz; +import thrift.test.ThriftTest; +import thrift.test.Xception; +import thrift.test.Xception2; +import thrift.test.Xtruct; +import thrift.test.Xtruct2; + +public class TestServer { + + public static class TestHandler implements ThriftTest.Iface { + + public TestHandler() {} + + public void testVoid() { + System.out.print("testVoid()\n"); + } + + public String testString(String thing) { + System.out.print("testString(\"" + thing + "\")\n"); + return thing; + } + + public byte testByte(byte thing) { + System.out.print("testByte(" + thing + ")\n"); + return thing; + } + + public int testI32(int thing) { + System.out.print("testI32(" + thing + ")\n"); + return thing; + } + + public long testI64(long thing) { + System.out.print("testI64(" + thing + ")\n"); + return thing; + } + + public double testDouble(double thing) { + System.out.print("testDouble(" + thing + ")\n"); + return thing; + } + + public Xtruct testStruct(Xtruct thing) { + System.out.print("testStruct({" + + "\"" + thing.string_thing + "\", " + + thing.byte_thing + ", " + + thing.i32_thing + ", " + + thing.i64_thing + "})\n"); + return thing; + } + + public Xtruct2 testNest(Xtruct2 nest) { + Xtruct thing = nest.struct_thing; + System.out.print("testNest({" + + nest.byte_thing + ", {" + + "\"" + thing.string_thing + "\", " + + thing.byte_thing + ", " + + thing.i32_thing + ", " + + thing.i64_thing + "}, " + + nest.i32_thing + "})\n"); + return nest; + } + + public Map testMap(Map thing) { + System.out.print("testMap({"); + boolean first = true; + for (int key : thing.keySet()) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(key + " => " + thing.get(key)); + } + System.out.print("})\n"); + return thing; + } + + public Set testSet(Set thing) { + System.out.print("testSet({"); + boolean first = true; + for (int elem : thing) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})\n"); + return thing; + } + + public List testList(List thing) { + System.out.print("testList({"); + boolean first = true; + for (int elem : thing) { + if (first) { + first = false; + } else { + System.out.print(", "); + } + System.out.print(elem); + } + System.out.print("})\n"); + return thing; + } + + public int testEnum(int thing) { + System.out.print("testEnum(" + thing + ")\n"); + return thing; + } + + public long testTypedef(long thing) { + System.out.print("testTypedef(" + thing + ")\n"); + return thing; + } + + public Map> testMapMap(int hello) { + System.out.print("testMapMap(" + hello + ")\n"); + Map> mapmap = + new HashMap>(); + + HashMap pos = new HashMap(); + HashMap neg = new HashMap(); + for (int i = 1; i < 5; i++) { + pos.put(i, i); + neg.put(-i, -i); + } + + mapmap.put(4, pos); + mapmap.put(-4, neg); + + return mapmap; + } + + public Map> testInsanity(Insanity argument) { + System.out.print("testInsanity()\n"); + + Xtruct hello = new Xtruct(); + hello.string_thing = "Hello2"; + hello.byte_thing = 2; + hello.i32_thing = 2; + hello.i64_thing = 2; + + Xtruct goodbye = new Xtruct(); + goodbye.string_thing = "Goodbye4"; + goodbye.byte_thing = (byte)4; + goodbye.i32_thing = 4; + goodbye.i64_thing = (long)4; + + Insanity crazy = new Insanity(); + crazy.userMap = new HashMap(); + crazy.xtructs = new ArrayList(); + + crazy.userMap.put(Numberz.EIGHT, (long)8); + crazy.xtructs.add(goodbye); + + Insanity looney = new Insanity(); + crazy.userMap.put(Numberz.FIVE, (long)5); + crazy.xtructs.add(hello); + + HashMap first_map = new HashMap(); + HashMap second_map = new HashMap();; + + first_map.put(Numberz.TWO, crazy); + first_map.put(Numberz.THREE, crazy); + + second_map.put(Numberz.SIX, looney); + + Map> insane = + new HashMap>(); + insane.put((long)1, first_map); + insane.put((long)2, second_map); + + return insane; + } + + public Xtruct testMulti(byte arg0, int arg1, long arg2, Map arg3, int arg4, long arg5) { + System.out.print("testMulti()\n"); + + Xtruct hello = new Xtruct();; + hello.string_thing = "Hello2"; + hello.byte_thing = arg0; + hello.i32_thing = arg1; + hello.i64_thing = arg2; + return hello; + } + + public void testException(String arg) throws Xception { + System.out.print("testException("+arg+")\n"); + if (arg.equals("Xception")) { + Xception x = new Xception(); + x.errorCode = 1001; + x.message = "This is an Xception"; + throw x; + } + return; + } + + public Xtruct testMultiException(String arg0, String arg1) throws Xception, Xception2 { + System.out.print("testMultiException(" + arg0 + ", " + arg1 + ")\n"); + if (arg0.equals("Xception")) { + Xception x = new Xception(); + x.errorCode = 1001; + x.message = "This is an Xception"; + throw x; + } else if (arg0.equals("Xception2")) { + Xception2 x = new Xception2(); + x.errorCode = 2002; + x.struct_thing = new Xtruct(); + x.struct_thing.string_thing = "This is an Xception2"; + throw x; + } + + Xtruct result = new Xtruct(); + result.string_thing = arg1; + return result; + } + + public void testOneway(int sleepFor) { + System.out.println("testOneway(" + Integer.toString(sleepFor) + + ") => sleeping..."); + try { + Thread.sleep(sleepFor * 1000); + System.out.println("Done sleeping!"); + } catch (InterruptedException ie) { + throw new RuntimeException(ie); + } + } + + } // class TestHandler + + public static void main(String [] args) { + try { + int port = 9090; + if (args.length > 1) { + port = Integer.valueOf(args[0]); + } + + // Processor + TestHandler testHandler = + new TestHandler(); + ThriftTest.Processor testProcessor = + new ThriftTest.Processor(testHandler); + + // Transport + TServerSocket tServerSocket = + new TServerSocket(port); + + // Protocol factory + TProtocolFactory tProtocolFactory = + new TBinaryProtocol.Factory(); + + TServer serverEngine; + + // Simple Server + // serverEngine = new TSimpleServer(testProcessor, tServerSocket); + + // ThreadPool Server + serverEngine = new TThreadPoolServer(testProcessor, tServerSocket, tProtocolFactory); + + // Run it + System.out.println("Starting the server on port " + port + "..."); + serverEngine.serve(); + + } catch (Exception x) { + x.printStackTrace(); + } + System.out.println("done."); + } +} diff --git a/lib/java/test/org/apache/thrift/test/ToStringTest.java b/lib/java/test/org/apache/thrift/test/ToStringTest.java new file mode 100644 index 00000000..569a61c4 --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/ToStringTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import thrift.test.*; + +/** + */ +public class ToStringTest { + public static void main(String[] args) throws Exception { + JavaTestHelper object = new JavaTestHelper(); + object.req_int = 0; + object.req_obj = ""; + + + object.req_bin = new byte[] { + 0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, + 16, -17, 18, -19, 20, -21, 22, -23, 24, -25, 26, -27, 28, -29, + 30, -31, 32, -33, 34, -35, 36, -37, 38, -39, 40, -41, 42, -43, 44, + -45, 46, -47, 48, -49, 50, -51, 52, -53, 54, -55, 56, -57, 58, -59, + 60, -61, 62, -63, 64, -65, 66, -67, 68, -69, 70, -71, 72, -73, 74, + -75, 76, -77, 78, -79, 80, -81, 82, -83, 84, -85, 86, -87, 88, -89, + 90, -91, 92, -93, 94, -95, 96, -97, 98, -99, 100, -101, 102, -103, + 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + 116, -117, 118, -119, 120, -121, 122, -123, 124, -125, 126, -127, + }; + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:"+ + "00 FF 02 FD 04 FB 06 F9 08 F7 0A F5 0C F3 0E F1 10 EF 12 ED 14 "+ + "EB 16 E9 18 E7 1A E5 1C E3 1E E1 20 DF 22 DD 24 DB 26 D9 28 D7 "+ + "2A D5 2C D3 2E D1 30 CF 32 CD 34 CB 36 C9 38 C7 3A C5 3C C3 3E "+ + "C1 40 BF 42 BD 44 BB 46 B9 48 B7 4A B5 4C B3 4E B1 50 AF 52 AD "+ + "54 AB 56 A9 58 A7 5A A5 5C A3 5E A1 60 9F 62 9D 64 9B 66 99 68 "+ + "97 6A 95 6C 93 6E 91 70 8F 72 8D 74 8B 76 89 78 87 7A 85 7C 83 "+ + "7E 81)")) { + throw new RuntimeException(); + } + + object.req_bin = new byte[] { + 0, -1, 2, -3, 4, -5, 6, -7, 8, -9, 10, -11, 12, -13, 14, -15, + 16, -17, 18, -19, 20, -21, 22, -23, 24, -25, 26, -27, 28, -29, + 30, -31, 32, -33, 34, -35, 36, -37, 38, -39, 40, -41, 42, -43, 44, + -45, 46, -47, 48, -49, 50, -51, 52, -53, 54, -55, 56, -57, 58, -59, + 60, -61, 62, -63, 64, -65, 66, -67, 68, -69, 70, -71, 72, -73, 74, + -75, 76, -77, 78, -79, 80, -81, 82, -83, 84, -85, 86, -87, 88, -89, + 90, -91, 92, -93, 94, -95, 96, -97, 98, -99, 100, -101, 102, -103, + 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + 116, -117, 118, -119, 120, -121, 122, -123, 124, -125, 126, -127, + 0, + }; + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:"+ + "00 FF 02 FD 04 FB 06 F9 08 F7 0A F5 0C F3 0E F1 10 EF 12 ED 14 "+ + "EB 16 E9 18 E7 1A E5 1C E3 1E E1 20 DF 22 DD 24 DB 26 D9 28 D7 "+ + "2A D5 2C D3 2E D1 30 CF 32 CD 34 CB 36 C9 38 C7 3A C5 3C C3 3E "+ + "C1 40 BF 42 BD 44 BB 46 B9 48 B7 4A B5 4C B3 4E B1 50 AF 52 AD "+ + "54 AB 56 A9 58 A7 5A A5 5C A3 5E A1 60 9F 62 9D 64 9B 66 99 68 "+ + "97 6A 95 6C 93 6E 91 70 8F 72 8D 74 8B 76 89 78 87 7A 85 7C 83 "+ + "7E 81 ...)")) { + throw new RuntimeException(); + } + + object.req_bin = new byte[] {}; + object.setOpt_binIsSet(true); + + + if (!object.toString().equals( + "JavaTestHelper(req_int:0, req_obj:, req_bin:)")) { + throw new RuntimeException(); + } + } +} + diff --git a/lib/java/test/org/apache/thrift/test/WriteStruct.java b/lib/java/test/org/apache/thrift/test/WriteStruct.java new file mode 100644 index 00000000..474c808e --- /dev/null +++ b/lib/java/test/org/apache/thrift/test/WriteStruct.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.test; + +import java.io.BufferedOutputStream; +import java.io.FileOutputStream; + +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.transport.TIOStreamTransport; +import org.apache.thrift.transport.TTransport; + +public class WriteStruct { + public static void main(String[] args) throws Exception { + if (args.length != 2) { + System.out.println("usage: java -cp build/classes org.apache.thrift.test.WriteStruct filename proto_factory_class"); + System.out.println("Write out an instance of Fixtures.compactProtocolTestStruct to 'file'. Use a protocol from 'proto_factory_class'."); + } + + TTransport trans = new TIOStreamTransport(new BufferedOutputStream(new FileOutputStream(args[0]))); + + TProtocolFactory factory = (TProtocolFactory)Class.forName(args[1]).newInstance(); + + TProtocol proto = factory.getProtocol(trans); + + Fixtures.compactProtoTestStruct.write(proto); + trans.flush(); + } + +} diff --git a/lib/ocaml/Makefile b/lib/ocaml/Makefile new file mode 100644 index 00000000..6abeee71 --- /dev/null +++ b/lib/ocaml/Makefile @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +all: + cd src; make; cd .. +clean: + cd src; make clean; cd .. diff --git a/lib/ocaml/OCamlMakefile b/lib/ocaml/OCamlMakefile new file mode 100644 index 00000000..b0b9252c --- /dev/null +++ b/lib/ocaml/OCamlMakefile @@ -0,0 +1,1231 @@ +########################################################################### +# OCamlMakefile +# Copyright (C) 1999-2007 Markus Mottl +# +# For updates see: +# http://www.ocaml.info/home/ocaml_sources.html +# +########################################################################### + +# Modified by damien for .glade.ml compilation + +# Set these variables to the names of the sources to be processed and +# the result variable. Order matters during linkage! + +ifndef SOURCES + SOURCES := foo.ml +endif +export SOURCES + +ifndef RES_CLIB_SUF + RES_CLIB_SUF := _stubs +endif +export RES_CLIB_SUF + +ifndef RESULT + RESULT := foo +endif +export RESULT := $(strip $(RESULT)) + +export LIB_PACK_NAME + +ifndef DOC_FILES + DOC_FILES := $(filter %.mli, $(SOURCES)) +endif +export DOC_FILES +FIRST_DOC_FILE := $(firstword $(DOC_FILES)) + +export BCSUFFIX +export NCSUFFIX + +ifndef TOPSUFFIX + TOPSUFFIX := .top +endif +export TOPSUFFIX + +# Eventually set include- and library-paths, libraries to link, +# additional compilation-, link- and ocamlyacc-flags +# Path- and library information needs not be written with "-I" and such... +# Define THREADS if you need it, otherwise leave it unset (same for +# USE_CAMLP4)! + +export THREADS +export VMTHREADS +export ANNOTATE +export USE_CAMLP4 + +export INCDIRS +export LIBDIRS +export EXTLIBDIRS +export RESULTDEPS +export OCAML_DEFAULT_DIRS + +export LIBS +export CLIBS +export CFRAMEWORKS + +export OCAMLFLAGS +export OCAMLNCFLAGS +export OCAMLBCFLAGS + +export OCAMLLDFLAGS +export OCAMLNLDFLAGS +export OCAMLBLDFLAGS + +export OCAMLMKLIB_FLAGS + +ifndef OCAMLCPFLAGS + OCAMLCPFLAGS := a +endif +export OCAMLCPFLAGS + +ifndef DOC_DIR + DOC_DIR := doc +endif +export DOC_DIR + +export PPFLAGS + +export LFLAGS +export YFLAGS +export IDLFLAGS + +export OCAMLDOCFLAGS + +export OCAMLFIND_INSTFLAGS + +export DVIPSFLAGS + +export STATIC + +# Add a list of optional trash files that should be deleted by "make clean" +export TRASH + +ECHO := echo + +ifdef REALLY_QUIET + export REALLY_QUIET + ECHO := true + LFLAGS := $(LFLAGS) -q + YFLAGS := $(YFLAGS) -q +endif + +#################### variables depending on your OCaml-installation + +ifdef MINGW + export MINGW + WIN32 := 1 + CFLAGS_WIN32 := -mno-cygwin +endif +ifdef MSVC + export MSVC + WIN32 := 1 + ifndef STATIC + CPPFLAGS_WIN32 := -DCAML_DLL + endif + CFLAGS_WIN32 += -nologo + EXT_OBJ := obj + EXT_LIB := lib + ifeq ($(CC),gcc) + # work around GNU Make default value + ifdef THREADS + CC := cl -MT + else + CC := cl + endif + endif + ifeq ($(CXX),g++) + # work around GNU Make default value + CXX := $(CC) + endif + CFLAG_O := -Fo +endif +ifdef WIN32 + EXT_CXX := cpp + EXE := .exe +endif + +ifndef EXT_OBJ + EXT_OBJ := o +endif +ifndef EXT_LIB + EXT_LIB := a +endif +ifndef EXT_CXX + EXT_CXX := cc +endif +ifndef EXE + EXE := # empty +endif +ifndef CFLAG_O + CFLAG_O := -o # do not delete this comment (preserves trailing whitespace)! +endif + +export CC +export CXX +export CFLAGS +export CXXFLAGS +export LDFLAGS +export CPPFLAGS + +ifndef RPATH_FLAG + ifdef ELF_RPATH_FLAG + RPATH_FLAG := $(ELF_RPATH_FLAG) + else + RPATH_FLAG := -R + endif +endif +export RPATH_FLAG + +ifndef MSVC +ifndef PIC_CFLAGS + PIC_CFLAGS := -fPIC +endif +ifndef PIC_CPPFLAGS + PIC_CPPFLAGS := -DPIC +endif +endif + +export PIC_CFLAGS +export PIC_CPPFLAGS + +BCRESULT := $(addsuffix $(BCSUFFIX), $(RESULT)) +NCRESULT := $(addsuffix $(NCSUFFIX), $(RESULT)) +TOPRESULT := $(addsuffix $(TOPSUFFIX), $(RESULT)) + +ifndef OCAMLFIND + OCAMLFIND := ocamlfind +endif +export OCAMLFIND + +ifndef OCAMLC + OCAMLC := ocamlc +endif +export OCAMLC + +ifndef OCAMLOPT + OCAMLOPT := ocamlopt +endif +export OCAMLOPT + +ifndef OCAMLMKTOP + OCAMLMKTOP := ocamlmktop +endif +export OCAMLMKTOP + +ifndef OCAMLCP + OCAMLCP := ocamlcp +endif +export OCAMLCP + +ifndef OCAMLDEP + OCAMLDEP := ocamldep +endif +export OCAMLDEP + +ifndef OCAMLLEX + OCAMLLEX := ocamllex +endif +export OCAMLLEX + +ifndef OCAMLYACC + OCAMLYACC := ocamlyacc +endif +export OCAMLYACC + +ifndef OCAMLMKLIB + OCAMLMKLIB := ocamlmklib +endif +export OCAMLMKLIB + +ifndef OCAML_GLADECC + OCAML_GLADECC := lablgladecc2 +endif +export OCAML_GLADECC + +ifndef OCAML_GLADECC_FLAGS + OCAML_GLADECC_FLAGS := +endif +export OCAML_GLADECC_FLAGS + +ifndef CAMELEON_REPORT + CAMELEON_REPORT := report +endif +export CAMELEON_REPORT + +ifndef CAMELEON_REPORT_FLAGS + CAMELEON_REPORT_FLAGS := +endif +export CAMELEON_REPORT_FLAGS + +ifndef CAMELEON_ZOGGY + CAMELEON_ZOGGY := camlp4o pa_zog.cma pr_o.cmo +endif +export CAMELEON_ZOGGY + +ifndef CAMELEON_ZOGGY_FLAGS + CAMELEON_ZOGGY_FLAGS := +endif +export CAMELEON_ZOGGY_FLAGS + +ifndef OXRIDL + OXRIDL := oxridl +endif +export OXRIDL + +ifndef CAMLIDL + CAMLIDL := camlidl +endif +export CAMLIDL + +ifndef CAMLIDLDLL + CAMLIDLDLL := camlidldll +endif +export CAMLIDLDLL + +ifndef NOIDLHEADER + MAYBE_IDL_HEADER := -header +endif +export NOIDLHEADER + +export NO_CUSTOM + +ifndef CAMLP4 + CAMLP4 := camlp4 +endif +export CAMLP4 + +ifndef REAL_OCAMLFIND + ifdef PACKS + ifndef CREATE_LIB + ifdef THREADS + PACKS += threads + endif + endif + empty := + space := $(empty) $(empty) + comma := , + ifdef PREDS + PRE_OCAML_FIND_PREDICATES := $(subst $(space),$(comma),$(PREDS)) + PRE_OCAML_FIND_PACKAGES := $(subst $(space),$(comma),$(PACKS)) + OCAML_FIND_PREDICATES := -predicates $(PRE_OCAML_FIND_PREDICATES) + # OCAML_DEP_PREDICATES := -syntax $(PRE_OCAML_FIND_PREDICATES) + OCAML_FIND_PACKAGES := $(OCAML_FIND_PREDICATES) -package $(PRE_OCAML_FIND_PACKAGES) + OCAML_DEP_PACKAGES := $(OCAML_DEP_PREDICATES) -package $(PRE_OCAML_FIND_PACKAGES) + else + OCAML_FIND_PACKAGES := -package $(subst $(space),$(comma),$(PACKS)) + OCAML_DEP_PACKAGES := + endif + OCAML_FIND_LINKPKG := -linkpkg + REAL_OCAMLFIND := $(OCAMLFIND) + endif +endif + +export OCAML_FIND_PACKAGES +export OCAML_DEP_PACKAGES +export OCAML_FIND_LINKPKG +export REAL_OCAMLFIND + +ifndef OCAMLDOC + OCAMLDOC := ocamldoc +endif +export OCAMLDOC + +ifndef LATEX + LATEX := latex +endif +export LATEX + +ifndef DVIPS + DVIPS := dvips +endif +export DVIPS + +ifndef PS2PDF + PS2PDF := ps2pdf +endif +export PS2PDF + +ifndef OCAMLMAKEFILE + OCAMLMAKEFILE := OCamlMakefile +endif +export OCAMLMAKEFILE + +ifndef OCAMLLIBPATH + OCAMLLIBPATH := \ + $(shell $(OCAMLC) 2>/dev/null -where || echo /usr/local/lib/ocaml) +endif +export OCAMLLIBPATH + +ifndef OCAML_LIB_INSTALL + OCAML_LIB_INSTALL := $(OCAMLLIBPATH)/contrib +endif +export OCAML_LIB_INSTALL + +########################################################################### + +#################### change following sections only if +#################### you know what you are doing! + +# delete target files when a build command fails +.PHONY: .DELETE_ON_ERROR +.DELETE_ON_ERROR: + +# for pedants using "--warn-undefined-variables" +export MAYBE_IDL +export REAL_RESULT +export CAMLIDLFLAGS +export THREAD_FLAG +export RES_CLIB +export MAKEDLL +export ANNOT_FLAG +export C_OXRIDL +export SUBPROJS +export CFLAGS_WIN32 +export CPPFLAGS_WIN32 + +INCFLAGS := + +SHELL := /bin/sh + +MLDEPDIR := ._d +BCDIDIR := ._bcdi +NCDIDIR := ._ncdi + +FILTER_EXTNS := %.mli %.ml %.mll %.mly %.idl %.oxridl %.c %.m %.$(EXT_CXX) %.rep %.zog %.glade + +FILTERED := $(filter $(FILTER_EXTNS), $(SOURCES)) +SOURCE_DIRS := $(filter-out ./, $(sort $(dir $(FILTERED)))) + +FILTERED_REP := $(filter %.rep, $(FILTERED)) +DEP_REP := $(FILTERED_REP:%.rep=$(MLDEPDIR)/%.d) +AUTO_REP := $(FILTERED_REP:.rep=.ml) + +FILTERED_ZOG := $(filter %.zog, $(FILTERED)) +DEP_ZOG := $(FILTERED_ZOG:%.zog=$(MLDEPDIR)/%.d) +AUTO_ZOG := $(FILTERED_ZOG:.zog=.ml) + +FILTERED_GLADE := $(filter %.glade, $(FILTERED)) +DEP_GLADE := $(FILTERED_GLADE:%.glade=$(MLDEPDIR)/%.d) +AUTO_GLADE := $(FILTERED_GLADE:.glade=.ml) + +FILTERED_ML := $(filter %.ml, $(FILTERED)) +DEP_ML := $(FILTERED_ML:%.ml=$(MLDEPDIR)/%.d) + +FILTERED_MLI := $(filter %.mli, $(FILTERED)) +DEP_MLI := $(FILTERED_MLI:.mli=.di) + +FILTERED_MLL := $(filter %.mll, $(FILTERED)) +DEP_MLL := $(FILTERED_MLL:%.mll=$(MLDEPDIR)/%.d) +AUTO_MLL := $(FILTERED_MLL:.mll=.ml) + +FILTERED_MLY := $(filter %.mly, $(FILTERED)) +DEP_MLY := $(FILTERED_MLY:%.mly=$(MLDEPDIR)/%.d) $(FILTERED_MLY:.mly=.di) +AUTO_MLY := $(FILTERED_MLY:.mly=.mli) $(FILTERED_MLY:.mly=.ml) + +FILTERED_IDL := $(filter %.idl, $(FILTERED)) +DEP_IDL := $(FILTERED_IDL:%.idl=$(MLDEPDIR)/%.d) $(FILTERED_IDL:.idl=.di) +C_IDL := $(FILTERED_IDL:%.idl=%_stubs.c) +ifndef NOIDLHEADER + C_IDL += $(FILTERED_IDL:.idl=.h) +endif +OBJ_C_IDL := $(FILTERED_IDL:%.idl=%_stubs.$(EXT_OBJ)) +AUTO_IDL := $(FILTERED_IDL:.idl=.mli) $(FILTERED_IDL:.idl=.ml) $(C_IDL) + +FILTERED_OXRIDL := $(filter %.oxridl, $(FILTERED)) +DEP_OXRIDL := $(FILTERED_OXRIDL:%.oxridl=$(MLDEPDIR)/%.d) $(FILTERED_OXRIDL:.oxridl=.di) +AUTO_OXRIDL := $(FILTERED_OXRIDL:.oxridl=.mli) $(FILTERED_OXRIDL:.oxridl=.ml) $(C_OXRIDL) + +FILTERED_C_CXX := $(filter %.c %.m %.$(EXT_CXX), $(FILTERED)) +OBJ_C_CXX := $(FILTERED_C_CXX:.c=.$(EXT_OBJ)) +OBJ_C_CXX := $(OBJ_C_CXX:.m=.$(EXT_OBJ)) +OBJ_C_CXX := $(OBJ_C_CXX:.$(EXT_CXX)=.$(EXT_OBJ)) + +PRE_TARGETS += $(AUTO_MLL) $(AUTO_MLY) $(AUTO_IDL) $(AUTO_OXRIDL) $(AUTO_ZOG) $(AUTO_REP) $(AUTO_GLADE) + +ALL_DEPS := $(DEP_ML) $(DEP_MLI) $(DEP_MLL) $(DEP_MLY) $(DEP_IDL) $(DEP_OXRIDL) $(DEP_ZOG) $(DEP_REP) $(DEP_GLADE) + +MLDEPS := $(filter %.d, $(ALL_DEPS)) +MLIDEPS := $(filter %.di, $(ALL_DEPS)) +BCDEPIS := $(MLIDEPS:%.di=$(BCDIDIR)/%.di) +NCDEPIS := $(MLIDEPS:%.di=$(NCDIDIR)/%.di) + +ALLML := $(filter %.mli %.ml %.mll %.mly %.idl %.oxridl %.rep %.zog %.glade, $(FILTERED)) + +IMPLO_INTF := $(ALLML:%.mli=%.mli.__) +IMPLO_INTF := $(foreach file, $(IMPLO_INTF), \ + $(basename $(file)).cmi $(basename $(file)).cmo) +IMPLO_INTF := $(filter-out %.mli.cmo, $(IMPLO_INTF)) +IMPLO_INTF := $(IMPLO_INTF:%.mli.cmi=%.cmi) + +IMPLX_INTF := $(IMPLO_INTF:.cmo=.cmx) + +INTF := $(filter %.cmi, $(IMPLO_INTF)) +IMPL_CMO := $(filter %.cmo, $(IMPLO_INTF)) +IMPL_CMX := $(IMPL_CMO:.cmo=.cmx) +IMPL_ASM := $(IMPL_CMO:.cmo=.asm) +IMPL_S := $(IMPL_CMO:.cmo=.s) + +OBJ_LINK := $(OBJ_C_IDL) $(OBJ_C_CXX) +OBJ_FILES := $(IMPL_CMO:.cmo=.$(EXT_OBJ)) $(OBJ_LINK) + +EXECS := $(addsuffix $(EXE), \ + $(sort $(TOPRESULT) $(BCRESULT) $(NCRESULT))) +ifdef WIN32 + EXECS += $(BCRESULT).dll $(NCRESULT).dll +endif + +CLIB_BASE := $(RESULT)$(RES_CLIB_SUF) +ifneq ($(strip $(OBJ_LINK)),) + RES_CLIB := lib$(CLIB_BASE).$(EXT_LIB) +endif + +ifdef WIN32 +DLLSONAME := $(CLIB_BASE).dll +else +DLLSONAME := dll$(CLIB_BASE).so +endif + +NONEXECS := $(INTF) $(IMPL_CMO) $(IMPL_CMX) $(IMPL_ASM) $(IMPL_S) \ + $(OBJ_FILES) $(PRE_TARGETS) $(BCRESULT).cma $(NCRESULT).cmxa \ + $(NCRESULT).$(EXT_LIB) $(BCRESULT).cmi $(BCRESULT).cmo \ + $(NCRESULT).cmi $(NCRESULT).cmx $(NCRESULT).o \ + $(RES_CLIB) $(IMPL_CMO:.cmo=.annot) \ + $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo $(LIB_PACK_NAME).cmx $(LIB_PACK_NAME).o + +ifndef STATIC + NONEXECS += $(DLLSONAME) +endif + +ifndef LIBINSTALL_FILES + LIBINSTALL_FILES := $(RESULT).mli $(RESULT).cmi $(RESULT).cma \ + $(RESULT).cmxa $(RESULT).$(EXT_LIB) $(RES_CLIB) + ifndef STATIC + ifneq ($(strip $(OBJ_LINK)),) + LIBINSTALL_FILES += $(DLLSONAME) + endif + endif +endif + +export LIBINSTALL_FILES + +ifdef WIN32 + # some extra stuff is created while linking DLLs + NONEXECS += $(BCRESULT).$(EXT_LIB) $(BCRESULT).exp $(NCRESULT).exp $(CLIB_BASE).exp $(CLIB_BASE).lib +endif + +TARGETS := $(EXECS) $(NONEXECS) + +# If there are IDL-files +ifneq ($(strip $(FILTERED_IDL)),) + MAYBE_IDL := -cclib -lcamlidl +endif + +ifdef USE_CAMLP4 + CAMLP4PATH := \ + $(shell $(CAMLP4) -where 2>/dev/null || echo /usr/local/lib/camlp4) + INCFLAGS := -I $(CAMLP4PATH) + CINCFLAGS := -I$(CAMLP4PATH) +endif + +DINCFLAGS := $(INCFLAGS) $(SOURCE_DIRS:%=-I %) $(OCAML_DEFAULT_DIRS:%=-I %) +INCFLAGS := $(DINCFLAGS) $(INCDIRS:%=-I %) +CINCFLAGS += $(SOURCE_DIRS:%=-I%) $(INCDIRS:%=-I%) $(OCAML_DEFAULT_DIRS:%=-I%) + +ifndef MSVC + CLIBFLAGS += $(SOURCE_DIRS:%=-L%) $(LIBDIRS:%=-L%) \ + $(EXTLIBDIRS:%=-L%) $(OCAML_DEFAULT_DIRS:%=-L%) + + ifeq ($(ELF_RPATH), yes) + CLIBFLAGS += $(EXTLIBDIRS:%=-Wl,$(RPATH_FLAG)%) + endif +endif + +ifndef PROFILING + INTF_OCAMLC := $(OCAMLC) +else + ifndef THREADS + INTF_OCAMLC := $(OCAMLCP) -p $(OCAMLCPFLAGS) + else + # OCaml does not support profiling byte code + # with threads (yet), therefore we force an error. + ifndef REAL_OCAMLC + $(error Profiling of multithreaded byte code not yet supported by OCaml) + endif + INTF_OCAMLC := $(OCAMLC) + endif +endif + +ifndef MSVC + COMMON_LDFLAGS := $(LDFLAGS:%=-ccopt %) $(SOURCE_DIRS:%=-ccopt -L%) \ + $(LIBDIRS:%=-ccopt -L%) $(EXTLIBDIRS:%=-ccopt -L%) \ + $(EXTLIBDIRS:%=-ccopt -Wl $(OCAML_DEFAULT_DIRS:%=-ccopt -L%)) + + ifeq ($(ELF_RPATH),yes) + COMMON_LDFLAGS += $(EXTLIBDIRS:%=-ccopt -Wl,$(RPATH_FLAG)%) + endif +else + COMMON_LDFLAGS := -ccopt "/link -NODEFAULTLIB:LIBC $(LDFLAGS:%=%) $(SOURCE_DIRS:%=-LIBPATH:%) \ + $(LIBDIRS:%=-LIBPATH:%) $(EXTLIBDIRS:%=-LIBPATH:%) \ + $(OCAML_DEFAULT_DIRS:%=-LIBPATH:%) " +endif + +CLIBS_OPTS := $(CLIBS:%=-cclib -l%) $(CFRAMEWORKS:%=-cclib '-framework %') +ifdef MSVC + ifndef STATIC + # MSVC libraries do not have 'lib' prefix + CLIBS_OPTS := $(CLIBS:%=-cclib %.lib) + endif +endif + +ifneq ($(strip $(OBJ_LINK)),) + ifdef CREATE_LIB + OBJS_LIBS := -cclib -l$(CLIB_BASE) $(CLIBS_OPTS) $(MAYBE_IDL) + else + OBJS_LIBS := $(OBJ_LINK) $(CLIBS_OPTS) $(MAYBE_IDL) + endif +else + OBJS_LIBS := $(CLIBS_OPTS) $(MAYBE_IDL) +endif + +# If we have to make byte-code +ifndef REAL_OCAMLC + BYTE_OCAML := y + + # EXTRADEPS is added dependencies we have to insert for all + # executable files we generate. Ideally it should be all of the + # libraries we use, but it's hard to find the ones that get searched on + # the path since I don't know the paths built into the compiler, so + # just include the ones with slashes in their names. + EXTRADEPS := $(addsuffix .cma,$(foreach i,$(LIBS),$(if $(findstring /,$(i)),$(i)))) + SPECIAL_OCAMLFLAGS := $(OCAMLBCFLAGS) + + REAL_OCAMLC := $(INTF_OCAMLC) + + REAL_IMPL := $(IMPL_CMO) + REAL_IMPL_INTF := $(IMPLO_INTF) + IMPL_SUF := .cmo + + DEPFLAGS := + MAKE_DEPS := $(MLDEPS) $(BCDEPIS) + + ifdef CREATE_LIB + override CFLAGS := $(PIC_CFLAGS) $(CFLAGS) + override CPPFLAGS := $(PIC_CPPFLAGS) $(CPPFLAGS) + ifndef STATIC + ifneq ($(strip $(OBJ_LINK)),) + MAKEDLL := $(DLLSONAME) + ALL_LDFLAGS := -dllib $(DLLSONAME) + endif + endif + endif + + ifndef NO_CUSTOM + ifneq "$(strip $(OBJ_LINK) $(THREADS) $(MAYBE_IDL) $(CLIBS) $(CFRAMEWORKS))" "" + ALL_LDFLAGS += -custom + endif + endif + + ALL_LDFLAGS += $(INCFLAGS) $(OCAMLLDFLAGS) $(OCAMLBLDFLAGS) \ + $(COMMON_LDFLAGS) $(LIBS:%=%.cma) + CAMLIDLDLLFLAGS := + + ifdef THREADS + ifdef VMTHREADS + THREAD_FLAG := -vmthread + else + THREAD_FLAG := -thread + endif + ALL_LDFLAGS := $(THREAD_FLAG) $(ALL_LDFLAGS) + ifndef CREATE_LIB + ifndef REAL_OCAMLFIND + ALL_LDFLAGS := unix.cma threads.cma $(ALL_LDFLAGS) + endif + endif + endif + +# we have to make native-code +else + EXTRADEPS := $(addsuffix .cmxa,$(foreach i,$(LIBS),$(if $(findstring /,$(i)),$(i)))) + ifndef PROFILING + SPECIAL_OCAMLFLAGS := $(OCAMLNCFLAGS) + PLDFLAGS := + else + SPECIAL_OCAMLFLAGS := -p $(OCAMLNCFLAGS) + PLDFLAGS := -p + endif + + REAL_IMPL := $(IMPL_CMX) + REAL_IMPL_INTF := $(IMPLX_INTF) + IMPL_SUF := .cmx + + override CPPFLAGS := -DNATIVE_CODE $(CPPFLAGS) + + DEPFLAGS := -native + MAKE_DEPS := $(MLDEPS) $(NCDEPIS) + + ALL_LDFLAGS := $(PLDFLAGS) $(INCFLAGS) $(OCAMLLDFLAGS) \ + $(OCAMLNLDFLAGS) $(COMMON_LDFLAGS) + CAMLIDLDLLFLAGS := -opt + + ifndef CREATE_LIB + ALL_LDFLAGS += $(LIBS:%=%.cmxa) + else + override CFLAGS := $(PIC_CFLAGS) $(CFLAGS) + override CPPFLAGS := $(PIC_CPPFLAGS) $(CPPFLAGS) + endif + + ifdef THREADS + THREAD_FLAG := -thread + ALL_LDFLAGS := $(THREAD_FLAG) $(ALL_LDFLAGS) + ifndef CREATE_LIB + ifndef REAL_OCAMLFIND + ALL_LDFLAGS := unix.cmxa threads.cmxa $(ALL_LDFLAGS) + endif + endif + endif +endif + +export MAKE_DEPS + +ifdef ANNOTATE + ANNOT_FLAG := -dtypes +else +endif + +ALL_OCAMLCFLAGS := $(THREAD_FLAG) $(ANNOT_FLAG) $(OCAMLFLAGS) \ + $(INCFLAGS) $(SPECIAL_OCAMLFLAGS) + +ifdef make_deps + -include $(MAKE_DEPS) + PRE_TARGETS := +endif + +########################################################################### +# USER RULES + +# Call "OCamlMakefile QUIET=" to get rid of all of the @'s. +QUIET=@ + +# generates byte-code (default) +byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes +bc: byte-code + +byte-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(BCRESULT)" make_deps=yes +bcnl: byte-code-nolink + +top: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(TOPRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes + +# generates native-code + +native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +nc: native-code + +native-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +ncnl: native-code-nolink + +# generates byte-code libraries +byte-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" \ + CREATE_LIB=yes \ + make_deps=yes +bcl: byte-code-library + +# generates native-code libraries +native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cmxa \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + make_deps=yes +ncl: native-code-library + +ifdef WIN32 +# generates byte-code dll +byte-code-dll: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).dll \ + REAL_RESULT="$(BCRESULT)" \ + make_deps=yes +bcd: byte-code-dll + +# generates native-code dll +native-code-dll: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).dll \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + make_deps=yes +ncd: native-code-dll +endif + +# generates byte-code with debugging information +debug-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dc: debug-code + +debug-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dcnl: debug-code-nolink + +# generates byte-code with debugging information (native code) +debug-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dnc: debug-native-code + +debug-native-code-nolink: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) nolink \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dncnl: debug-native-code-nolink + +# generates byte-code libraries with debugging information +debug-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" make_deps=yes \ + CREATE_LIB=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dcl: debug-code-library + +# generates byte-code libraries with debugging information (native code) +debug-native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cma \ + REAL_RESULT="$(NCRESULT)" make_deps=yes \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + OCAMLFLAGS="-g $(OCAMLFLAGS)" \ + OCAMLLDFLAGS="-g $(OCAMLLDFLAGS)" +dncl: debug-native-code-library + +# generates byte-code for profiling +profiling-byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT) \ + REAL_RESULT="$(BCRESULT)" PROFILING="y" \ + make_deps=yes +pbc: profiling-byte-code + +# generates native-code + +profiling-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(NCRESULT) \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + PROFILING="y" \ + make_deps=yes +pnc: profiling-native-code + +# generates byte-code libraries +profiling-byte-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(BCRESULT).cma \ + REAL_RESULT="$(BCRESULT)" PROFILING="y" \ + CREATE_LIB=yes \ + make_deps=yes +pbcl: profiling-byte-code-library + +# generates native-code libraries +profiling-native-code-library: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(RES_CLIB) $(NCRESULT).cmxa \ + REAL_RESULT="$(NCRESULT)" PROFILING="y" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + CREATE_LIB=yes \ + make_deps=yes +pncl: profiling-native-code-library + +# packs byte-code objects +pack-byte-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) $(BCRESULT).cmo \ + REAL_RESULT="$(BCRESULT)" \ + PACK_LIB=yes make_deps=yes +pabc: pack-byte-code + +# packs native-code objects +pack-native-code: $(PRE_TARGETS) + $(QUIET)$(MAKE) -r -f $(OCAMLMAKEFILE) \ + $(NCRESULT).cmx $(NCRESULT).o \ + REAL_RESULT="$(NCRESULT)" \ + REAL_OCAMLC="$(OCAMLOPT)" \ + PACK_LIB=yes make_deps=yes +panc: pack-native-code + +# generates HTML-documentation +htdoc: $(DOC_DIR)/$(RESULT)/html/index.html + +# generates Latex-documentation +ladoc: $(DOC_DIR)/$(RESULT)/latex/doc.tex + +# generates PostScript-documentation +psdoc: $(DOC_DIR)/$(RESULT)/latex/doc.ps + +# generates PDF-documentation +pdfdoc: $(DOC_DIR)/$(RESULT)/latex/doc.pdf + +# generates all supported forms of documentation +doc: htdoc ladoc psdoc pdfdoc + +########################################################################### +# LOW LEVEL RULES + +$(REAL_RESULT): $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) \ + $(OCAML_FIND_PACKAGES) $(OCAML_FIND_LINKPKG) \ + $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@$(EXE) \ + $(REAL_IMPL) + +nolink: $(REAL_IMPL_INTF) $(OBJ_LINK) + +ifdef WIN32 +$(REAL_RESULT).dll: $(REAL_IMPL_INTF) $(OBJ_LINK) + $(CAMLIDLDLL) $(CAMLIDLDLLFLAGS) $(OBJ_LINK) $(CLIBS) \ + -o $@ $(REAL_IMPL) +endif + +%$(TOPSUFFIX): $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) + $(REAL_OCAMLFIND) $(OCAMLMKTOP) \ + $(OCAML_FIND_PACKAGES) $(OCAML_FIND_LINKPKG) \ + $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@$(EXE) \ + $(REAL_IMPL) + +.SUFFIXES: .mli .ml .cmi .cmo .cmx .cma .cmxa .$(EXT_OBJ) \ + .mly .di .d .$(EXT_LIB) .idl %.oxridl .c .m .$(EXT_CXX) .h .so \ + .rep .zog .glade + +ifndef STATIC +ifdef MINGW +$(DLLSONAME): $(OBJ_LINK) + $(CC) $(CFLAGS) $(CFLAGS_WIN32) $(OBJ_LINK) -shared -o $@ \ + -Wl,--whole-archive $(wildcard $(foreach dir,$(LIBDIRS),$(CLIBS:%=$(dir)/lib%.a))) \ + $(OCAMLLIBPATH)/ocamlrun.a \ + -Wl,--export-all-symbols \ + -Wl,--no-whole-archive +else +ifdef MSVC +$(DLLSONAME): $(OBJ_LINK) + link /NOLOGO /DLL /OUT:$@ $(OBJ_LINK) \ + $(wildcard $(foreach dir,$(LIBDIRS),$(CLIBS:%=$(dir)/%.lib))) \ + $(OCAMLLIBPATH)/ocamlrun.lib + +else +$(DLLSONAME): $(OBJ_LINK) + $(OCAMLMKLIB) $(INCFLAGS) $(CLIBFLAGS) \ + -o $(CLIB_BASE) $(OBJ_LINK) $(CLIBS:%=-l%) $(CFRAMEWORKS:%=-framework %) \ + $(OCAMLMKLIB_FLAGS) +endif +endif +endif + +ifndef LIB_PACK_NAME +$(RESULT).cma: $(REAL_IMPL_INTF) $(MAKEDLL) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(REAL_IMPL) + +$(RESULT).cmxa $(RESULT).$(EXT_LIB): $(REAL_IMPL_INTF) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(OCAMLOPT) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(REAL_IMPL) +else +ifdef BYTE_OCAML +$(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo: $(REAL_IMPL_INTF) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -pack -o $(LIB_PACK_NAME).cmo $(OCAMLLDFLAGS) $(REAL_IMPL) +else +$(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmx: $(REAL_IMPL_INTF) + $(REAL_OCAMLFIND) $(OCAMLOPT) -pack -o $(LIB_PACK_NAME).cmx $(OCAMLLDFLAGS) $(REAL_IMPL) +endif + +$(RESULT).cma: $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmo $(MAKEDLL) $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -a $(ALL_LDFLAGS) $(OBJS_LIBS) -o $@ $(LIB_PACK_NAME).cmo + +$(RESULT).cmxa $(RESULT).$(EXT_LIB): $(LIB_PACK_NAME).cmi $(LIB_PACK_NAME).cmx $(EXTRADEPS) $(RESULTDEPS) + $(REAL_OCAMLFIND) $(OCAMLOPT) -a $(filter-out -custom, $(ALL_LDFLAGS)) $(OBJS_LIBS) -o $@ $(LIB_PACK_NAME).cmx +endif + +$(RES_CLIB): $(OBJ_LINK) +ifndef MSVC + ifneq ($(strip $(OBJ_LINK)),) + $(AR) rcs $@ $(OBJ_LINK) + endif +else + ifneq ($(strip $(OBJ_LINK)),) + lib -nologo -debugtype:cv -out:$(RES_CLIB) $(OBJ_LINK) + endif +endif + +.mli.cmi: $(EXTRADEPS) + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp \"$$pp $(PPFLAGS)\" $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(INTF_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp "$$pp $(PPFLAGS)" $(THREAD_FLAG) $(ANNOT_FLAG) \ + $(OCAMLFLAGS) $(INCFLAGS) $<; \ + fi + +.ml.cmi .ml.$(EXT_OBJ) .ml.cmx .ml.cmo: $(EXTRADEPS) + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(ALL_OCAMLCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c $(ALL_OCAMLCFLAGS) $<; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp \"$$pp $(PPFLAGS)\" $(ALL_OCAMLCFLAGS) $<; \ + $(REAL_OCAMLFIND) $(REAL_OCAMLC) $(OCAML_FIND_PACKAGES) \ + -c -pp "$$pp $(PPFLAGS)" $(ALL_OCAMLCFLAGS) $<; \ + fi + +ifdef PACK_LIB +$(REAL_RESULT).cmo $(REAL_RESULT).cmx $(REAL_RESULT).o: $(REAL_IMPL_INTF) $(OBJ_LINK) $(EXTRADEPS) + $(REAL_OCAMLFIND) $(REAL_OCAMLC) -pack $(ALL_LDFLAGS) \ + $(OBJS_LIBS) -o $@ $(REAL_IMPL) +endif + +.PRECIOUS: %.ml +%.ml: %.mll + $(OCAMLLEX) $(LFLAGS) $< + +.PRECIOUS: %.ml %.mli +%.ml %.mli: %.mly + $(OCAMLYACC) $(YFLAGS) $< + $(QUIET)pp=`sed -n -e 's/.*(\*pp \([^*]*\) \*).*/\1/p;q' $<`; \ + if [ ! -z "$$pp" ]; then \ + mv $*.ml $*.ml.temporary; \ + echo "(*pp $$pp $(PPFLAGS)*)" > $*.ml; \ + cat $*.ml.temporary >> $*.ml; \ + rm $*.ml.temporary; \ + mv $*.mli $*.mli.temporary; \ + echo "(*pp $$pp $(PPFLAGS)*)" > $*.mli; \ + cat $*.mli.temporary >> $*.mli; \ + rm $*.mli.temporary; \ + fi + + +.PRECIOUS: %.ml +%.ml: %.rep + $(CAMELEON_REPORT) $(CAMELEON_REPORT_FLAGS) -gen $< + +.PRECIOUS: %.ml +%.ml: %.zog + $(CAMELEON_ZOGGY) $(CAMELEON_ZOGGY_FLAGS) -impl $< > $@ + +.PRECIOUS: %.ml +%.ml: %.glade + $(OCAML_GLADECC) $(OCAML_GLADECC_FLAGS) $< > $@ + +.PRECIOUS: %.ml %.mli +%.ml %.mli: %.oxridl + $(OXRIDL) $< + +.PRECIOUS: %.ml %.mli %_stubs.c %.h +%.ml %.mli %_stubs.c %.h: %.idl + $(CAMLIDL) $(MAYBE_IDL_HEADER) $(IDLFLAGS) \ + $(CAMLIDLFLAGS) $< + $(QUIET)if [ $(NOIDLHEADER) ]; then touch $*.h; fi + +.c.$(EXT_OBJ): + $(OCAMLC) -c -cc "$(CC)" -ccopt "$(CFLAGS) \ + $(CPPFLAGS) $(CPPFLAGS_WIN32) \ + $(CFLAGS_WIN32) $(CINCFLAGS) $(CFLAG_O)$@ " $< + +.m.$(EXT_OBJ): + $(CC) -c $(CFLAGS) $(CINCFLAGS) $(CPPFLAGS) \ + -I'$(OCAMLLIBPATH)' \ + $< $(CFLAG_O)$@ + +.$(EXT_CXX).$(EXT_OBJ): + $(CXX) -c $(CXXFLAGS) $(CINCFLAGS) $(CPPFLAGS) \ + -I'$(OCAMLLIBPATH)' \ + $< $(CFLAG_O)$@ + +$(MLDEPDIR)/%.d: %.ml + $(QUIET)if [ ! -d $(@D) ]; then mkdir -p $(@D); fi + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + $(DINCFLAGS) $< > $@; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + -pp \"$$pp $(PPFLAGS)\" $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(OCAML_DEP_PACKAGES) \ + -pp "$$pp $(PPFLAGS)" $(DINCFLAGS) $< > $@; \ + fi + +$(BCDIDIR)/%.di $(NCDIDIR)/%.di: %.mli + $(QUIET)if [ ! -d $(@D) ]; then mkdir -p $(@D); fi + $(QUIET)pp=`sed -n -e '/^#/d' -e 's/(\*pp \([^*]*\) \*)/\1/p;q' $<`; \ + if [ -z "$$pp" ]; then \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) $(DINCFLAGS) $< > $@; \ + else \ + $(ECHO) $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) \ + -pp \"$$pp $(PPFLAGS)\" $(DINCFLAGS) $< \> $@; \ + $(REAL_OCAMLFIND) $(OCAMLDEP) $(DEPFLAGS) \ + -pp "$$pp $(PPFLAGS)" $(DINCFLAGS) $< > $@; \ + fi + +$(DOC_DIR)/$(RESULT)/html: + mkdir -p $@ + +$(DOC_DIR)/$(RESULT)/html/index.html: $(DOC_DIR)/$(RESULT)/html $(DOC_FILES) + rm -rf $ + +module Numberz = +struct +type t = +| ONE +| TWO +| THREE +| FIVE +| SIX +| EIGHT + +let of_i = ... +let to_i = ... +end + +typedef format +-------------- +Typedef turns into the type declaration: +typedef i64 UserId + +==> + +type userid Int64.t + +exception format +---------------- +The same as structs except that the module also has an exception type +E of t that is raised/caught. + +For example, with an exception Xception, +raise (Xception.E (new Xception.t)) +and +try + ... +with Xception.E e -> ... + +list format +----------- +Lists are turned into OCaml native lists. + +Map/Set formats +--------------- +These are both turned into Hashtbl.t's. Set values are bool. + +Services +-------- +The client is a class "client" parametrized on input and output +protocols. The processor is a class parametrized on a handler. A +handler is a class inheriting the iface abstract class. Unlike other +implementations, client does not implement iface since iface functions +must take option arguments so as to deal with the case where a client +does not send all the arguments. diff --git a/lib/ocaml/README-OCamlMakefile b/lib/ocaml/README-OCamlMakefile new file mode 100644 index 00000000..54787b96 --- /dev/null +++ b/lib/ocaml/README-OCamlMakefile @@ -0,0 +1,640 @@ +--------------------------------------------------------------------------- + + Distribution of "ocaml_make" + Copyright (C) 1999 - 2006 Markus Mottl - free to copy and modify! + USE AT YOUR OWN RISK! + +--------------------------------------------------------------------------- + + PREREQUISITES + + *** YOU WILL NEED GNU-MAKE VERSION >3.80 *** + +--------------------------------------------------------------------------- + + Contents of this distribution + +Changes - guess what? ;-) + +OCamlMakefile - Makefile for easy handling of compilation of not so easy + OCaml-projects. It generates dependencies of OCaml-files + automatically, is able to handle "ocamllex"-, + "ocamlyacc"-, IDL- and C-files, knows how to run + preprocessors and generates native- or byte-code, as + executable or as library - with thread-support if you + want! Profiling and debugging support can be added on + the fly! There is also support for installing libraries. + Ah, yes, and you can also create toplevels from any + sources: this allows you immediate interactive testing. + Automatic generation of documentation is easy due to + integration of support for OCamldoc. + +README - this file + +calc/ - Directory containing a quite fully-featured example + of what "OCamlMakefile" can do for you. This example + makes use of "ocamllex", "ocamlyacc", IDL + C and + threads. + +camlp4/ - This simple example demonstrates how to automatically + preprocess files with the camlp4-preprocessor. + +gtk/ - Demonstration of how to use OCamlMakefile with GTK + and threads. Courtesy of Tim Freeman . + +idl/ - Contains a very small example of how to use + "camlidl" together with "OCamlMakefile". Also intended + to show, how easy it is to interface OCaml and C. + +threads/ - Two examples of how to use threads (originally + posted by Xavier Leroy some time ago). Shows the use of + "OCamlMakefile" in an environment of multiple compilation + targets. + +--------------------------------------------------------------------------- + + Why should you use it? + +For several reasons: + + * It is well-tested (I use it in all of my projects). + + * In contrast to most other approaches it generates dependencies + correctly by ensuring that all automatically generated OCaml-files + exist before dependency calculation. This is the only way to + guarantee that "ocamldep" works correctly. + + * It is extremely convenient (at least I think so ;-). + Even quite complex compilation processes (see example "calc.ml") + need very little information to work correctly - actually just about + the minimum (file names of sources). + +--------------------------------------------------------------------------- + + When you shouldn't use it... + +In projects where every compilation unit needs different flags - but +in such complicated cases you will be on your own anyway. Luckily, +this doesn't happen too frequently... + +--------------------------------------------------------------------------- + + How to use "OCamlMakefile" in your own project + (Take a look at the examples for a quick introduction!) + +Create your project-specific "Makefile" in the appropriate directory. + +Now there are two ways of making use of "OCamlMakefile": + + 1) Have a look at the default settings in "OCamlMakefile" and set + them to the values that are vaild on your system - whether the + path to the standard libraries is ok, what executables shall be + used, etc... + + 2) Copy it into the directory of the project to be compiled. + Add "-include OCamlMakefile" as a last line of your "Makefile". + + 3) Put it somewhere else on the system. In this case you will have to + set a variable "OCAMLMAKEFILE" in your project-specific "Makefile". + This is the way in which the examples are written: so you need + only one version of "OCamlMakefile" to manage all your projects! + See the examples for details. + +You should usually specify two further variables for your project: + + * SOURCES (default: foo.ml) + * RESULT (default: foo) + +Put all the sources necessary for a target into variable "SOURCES". +Then set "RESULT" to the name of the target. If you want to generate +libraries, you should *not* specify the suffix (".cma", ".cmxa", ".a") +- it will be added automatically if you specify that you want to build +a library. + + ** Don't forget to add the ".mli"-files, too! ** + ** Don't forget that order of the source files matters! ** + +The order is important, because it matters during linking anyway +due to potential side effects caused at program startup. This is +why OCamlMakefile does not attempt to partially order dependencies by +itself, which might confuse users even more. It just compiles and links +OCaml-sources in the order specified by the user, even if it could +determine automatically that the order cannot be correct. + +The minimum of your "Makefile" looks like this (assuming that +"OCamlMakefile" is in the search path of "make"): + + -include OCamlMakefile + +This will assume that you want to compile a file "foo.ml" to a binary +"foo". + +Otherwise, your Makefile will probably contain something like this: + + SOURCES = foo.ml + RESULT = foo + -include OCamlMakefile + +Be careful with the names you put into these variables: if they are wrong, +a "make clean" might erase the wrong files - but I know you will not do +that ;-) + +A simple "make" will generate a byte-code executable. If you want to +change this, you may add an "all"-rule that generates something else. + +E.g.: + + SOURCES = foo.ml + RESULT = foo + all: native-code-library + -include OCamlMakefile + +This will build a native-code library "foo.cmxa" (+ "foo.a") from file +"foo.ml". + +You may even build several targets at once. To produce byte- and native-code +executables with one "make", add the following rule: + + all: byte-code native-code + +You will probably want to use a different suffix for each of these targets +so that the result will not be overwritten (see optional variables below +for details). + +You may also tell "make" at the command-line what kind of target to +produce (e.g. "make nc"). Here all the possibilities with shortcuts +between parenthesis: + + * byte-code (bc) + * byte-code-nolink (bcnl) - no linking stage + * byte-code-library (bcl) + * native-code (nc) + * native-code-nolink (ncnl) - no linking stage + * native-code-library (ncl) + * debug-code (dc) + * debug-code-nolink (dcnl) - no linking stage + * debug-code-library (dcl) + * profiling-byte-code (pbc) + * profiling-byte-code-library (pbcl) + * profiling-native-code (pnc) + * profiling-native-code-library (pncl) + * byte-code-dll (bcd) + * native-code-dll (ncd) + * pack-byte-code (pabc) + * pack-native-code (panc) + * toplevel interpreter (top) + * subprojs + +Here a short note concerning building and linking byte code libraries +with C-files: + + OCaml links C-object files only when they are used in an executable. + After compilation they should be placed in some directory that is in + your include path if you link your library against an executable. + + It is sometimes more convenient to link all C-object files into a + single C-library. Then you have to override the automatic link flags + of your library using "-noautolink" and add another linkflag that + links in your C-library explicitly. + +What concerns maintainance: + + "make clean" removes all (all!) automatically generated files - so + again: make sure your variables are ok! + + "make cleanup" is similar to "make clean" but leaves executables. + +Another way to destroy some important files is by having "OCamlMakefile" +automatically generate files with the same name. Read the documentation +about the tools in the OCaml-distribution to see what kind of files are +generated. "OCamlMakefile" additionally generates ('%' is basename of +source file): + + %_idl.c - "camlidl" generates a file "%.c" from "%.idl", but this is + not such a good idea, because when generating native-code, + both the file "%.c" and "%.ml" would generate files "%.o" + which would overwrite each other. Thus, "OCamlMakefile" + renames "%.c" to "%_idl.c" to work around this problem. + +The dependencies are stored in three different subdirectories (dot dirs): + + ._d - contains dependencies for .ml-files + ._bcdi - contains byte code dependencies for .mli-files + ._ncdi - contains native code dependencies for .mli-files + +The endings of the dependency files are: "%.d" for those generated from +"%.ml"-files, "%.di" for ones derived from "%.mli"-files. + +--------------------------------------------------------------------------- + + Debugging + + This is easy: if you discover a bug, just do a "make clean; make dc" + to recompile your project with debugging information. Then you can + immediately apply "ocamldebug" to the executable. + +--------------------------------------------------------------------------- + + Profiling + + For generating code that can be profiled with "ocamlprof" (byte code) + or "gprof" (native code), compile your project with one of the profiling + targets (see targets above). E.g.: + + * "make pbc" will build byte code that can be profiled with + "ocamlprof". + + * "make pnc" will build native code that can be profiled with + "gprof". + + Please note that it is not currently possible to profile byte code with + threads. OCamlMakefile will force an error if you try to do this. + + A short hint for DEC Alpha-users (under Digital Unix): you may also + compile your sources to native code without any further profiling + options/targets. Then call "pixie my_exec", "my_exec" being your + executable. This will produce (among other files) an executable + "my_exec.pixie". Call it and it will produce profiling information which + can be analysed using "prof -pixie my_exec". The resulting information + is extremely detailed and allows analysis up to the clock cycle level... + +--------------------------------------------------------------------------- + + Using Preprocessors + + Because one could employ any kind of program that reads from standard + input and prints to standard output as preprocessor, there cannot be any + default way to handle all of them correctly without further knowledge. + + Therefore you have to cooperate a bit with OCamlMakefile to let + preprocessing happen automatically. Basically, this only requires + that you put a comment into the first line of files that should be + preprocessed, e.g.: + + (*pp cat *) + ... rest of program ... + + OCamlMakefile looks at the first line of your files, and if it finds + a comment that starts with "(*pp", then it will assume that the + rest of the comment tells it how to correctly call the appropriate + preprocessor. In this case the program "cat" will be called, which will, + of course, just output the source text again without changing it. + + If you are, for example, an advocate of the new "revised syntax", + which is supported by the camlp4 preprocessor, you could simply write: + + (*pp camlp4r *) + ... rest of program in revised syntax ... + + Simple, isn't it? + + If you want to write your own syntax extensions, just take a look at the + example in the directory "camlp4": it implements the "repeat ... until" + extension as described in the camlp4-tutorial. + +--------------------------------------------------------------------------- + + Library (Un-)Installation Support + + OCamlMakefile contains two targets using "ocamlfind" for this purpose: + + * libinstall + * libuninstall + + These two targets require the existence of the variable + "LIBINSTALL_FILES", which should be set to all the files that you + want to install in the library directory (usually %.mli, %.cmi, %.cma, + %.cmxa, %.a and possibly further C-libraries). The target "libinstall" + has the dependency "all" to force compilation of the library so make + sure you define target "all" in your Makefile appropriately. + + The targets inform the user about the configured install path and ask + for confirmation to (un)install there. If you want to use them, it + is often a good idea to just alias them in your Makefile to "install" + and "uninstall" respectively. + + Two other targets allow installation of files into a particular + directory (without using ocamlfind): + + * rawinstall + * rawuninstall + +--------------------------------------------------------------------------- + + Building toplevels + + There is just one target for this: + + * top + + The generated file can be used immediately for interactive sessions - + even with scanners, parsers, C-files, etc.! + +--------------------------------------------------------------------------- + + Generating documentation + + The following targets are supported: + + * htdoc - generates HTML-documentation + * ladoc - generates Latex-documentation + * psdoc - generates PostScript-documentation + * pdfdoc - generates PDF-documentation + * doc - generates all supported forms of documentation + * clean-doc - generates all supported forms of documentation + + All of them generate a sub-directory "doc". More precisely, for HTML it + is "doc/$(RESULT)/html" and for Latex, PostScript and PDF the directory + "doc/$(RESULT)/latex". See the OCamldoc-manual for details and the + optional variables below for settings you can control. + +--------------------------------------------------------------------------- + + Handling subprojects + + You can have several targets in the same directory and manage them + from within an single Makefile. + + Give each subproject a name, e.g. "p1", "p2", etc. Then you export + settings specific to each project by using variables of the form + "PROJ_p1", "PROJ_p2", etc. E.g.: + + define PROJ_p1 + SOURCES="foo.ml main.ml" + RESULT="p1" + OCAMLFLAGS="-unsafe" + endef + export PROJ_p1 + + define PROJ_p2 + ... + endef + export PROJ_p2 + + You may also export common settings used by all projects directly, e.g. + "export THREADS = y". + + Now it is a good idea to define, which projects should be affected by + commands by default. E.g.: + + ifndef SUBPROJS + export SUBPROJS = p1 p2 + endif + + This will automatically generate a given target for all those + subprojects if this variable has not been defined in the shell + environment or in the command line of the make-invocation by the user. + E.g., "make dc" will generate debug code for all subprojects. + + Then you need to define a default action for your subprojects if "make" + has been called without arguments: + + all: bc + + This will build byte code by default for all subprojects. + + Finally, you'll have to define a catch-all target that uses the target + provided by the user for all subprojects. Just add (assuming that + OCAMLMAKEFILE has been defined appropriately): + + %: + @make -f $(OCAMLMAKEFILE) subprojs SUBTARGET=$@ + + See the "threads"-directory in the distribution for a short example! + +--------------------------------------------------------------------------- + + Optional variables that may be passed to "OCamlMakefile" + + * LIB_PACK_NAME - packs all modules of a library into a module whose + name is given in variable "LIB_PACK_NAME". + + * RES_CLIB_SUF - when building a library that contains C-stubs, this + variable controls the suffix appended to the name + of the C-library (default: "_stubs"). + + * THREADS - say "THREADS = yes" if you need thread support compiled in, + otherwise leave it away. + + * VMTHREADS - say "VMTHREADS = yes" if you want to force VM-level + scheduling of threads (byte-code only). + + * ANNOTATE - say "ANNOTATE = yes" to generate type annotation files + (.annot) to support displaying of type information + in editors. + + * USE_CAMLP4 - say "USE_CAMLP4 = yes" in your "Makefile" if you + want to include the camlp4 directory during the + build process, otherwise leave it away. + + * INCDIRS - directories that should be searched for ".cmi"- and + ".cmo"-files. You need not write "-I ..." - just the + plain names. + * LIBDIRS - directories that should be searched for libraries + Also just put the plain paths into this variable + * EXTLIBDIRS - Same as "LIBDIRS", but paths in this variable are + also added to the binary via the "-R"-flag so that + dynamic libraries in non-standard places can be found. + * RESULTDEPS - Targets on which results (executables or libraries) + should additionally depend. + + * PACKS - adds packages under control of "findlib". + + * PREDS - specifies "findlib"-predicates. + + * LIBS - OCaml-libraries that should be linked (just plain names). + E.g. if you want to link the Str-library, just write + "str" (without quotes). + The new OCaml-compiler handles libraries in such + a way that they "remember" whether they have to + be linked against a C-library and it gets linked + in automatically. + If there is a slash in the library name (such as + "./str" or "lib/foo") then make is told that the + generated files depend on the library. This + helps to ensure that changes to your libraries are + taken into account, which is important if you are + regenerating your libraries frequently. + * CLIBS - C-libraries that should be linked (just plain names). + + * PRE_TARGETS - set this to a list of target files that you want + to have buildt before dependency calculation actually + takes place. E.g. use this to automatically compile + modules needed by camlp4, which have to be available + before other modules can be parsed at all. + + ** WARNING **: the files mentioned in this variable + will be removed when "make clean" is executed! + + * LIBINSTALL_FILES - the files of a library that should be installed + using "findlib". Default: + + $(RESULT).mli $(RESULT).cmi $(RESULT).cma + $(RESULT).cmxa $(RESULT).a lib$(RESULT).a + + * OCAML_LIB_INSTALL - target directory for "rawinstall/rawuninstall". + (default: $(OCAMLLIBPATH)/contrib) + + * DOC_FILES - names of files from which documentation is generated. + (default: all .mli-files in your $(SOURCES)). + + * DOC_DIR - name of directory where documentation should be stored. + + * OCAMLFLAGS - flags passed to the compilers + * OCAMLBCFLAGS - flags passed to the byte code compiler only + * OCAMLNCFLAGS - flags passed to the native code compiler only + + * OCAMLLDFLAGS - flags passed to the OCaml-linker + * OCAMLBLDFLAGS - flags passed to the OCaml-linker when linking byte code + * OCAMLNLDFLAGS - flags passed to the OCaml-linker when linking + native code + + * OCAMLMKLIB_FLAGS - flags passed to the OCaml library tool + + * OCAMLCPFLAGS - profiling flags passed to "ocamlcp" (default: "a") + + * PPFLAGS - additional flags passed to the preprocessor (default: none) + + * LFLAGS - flags passed to "ocamllex" + * YFLAGS - flags passed to "ocamlyacc" + * IDLFLAGS - flags passed to "camlidl" + + * OCAMLDOCFLAGS - flags passed to "ocamldoc" + + * OCAMLFIND_INSTFLAGS - flags passed to "ocamlfind" during installation + (default: none) + + * DVIPSFLAGS - flags passed to dvips + (when generating documentation in PostScript). + + * STATIC - set this variable if you want to force creation + of static libraries + + * CC - the C-compiler to be used + * CXX - the C++-compiler to be used + + * CFLAGS - additional flags passed to the C-compiler. + The flag "-DNATIVE_CODE" will be passed automatically + if you choose to build native code. This allows you + to compile your C-files conditionally. But please + note: You should do a "make clean" or remove the + object files manually or touch the %.c-files: + otherwise, they may not be correctly recompiled + between different builds. + + * CXXFLAGS - additional flags passed to the C++-compiler. + + * CPPFLAGS - additional flags passed to the C-preprocessor. + + * CFRAMEWORKS - Objective-C framework to pass to linker on MacOS X. + + * LDFLAGS - additional flags passed to the C-linker + + * RPATH_FLAG - flag passed through to the C-linker to set a path for + dynamic libraries. May need to be set by user on + exotic platforms. (default: "-R"). + + * ELF_RPATH_FLAG - this flag is used to set the rpath on ELF-platforms. + (default: "-R") + + * ELF_RPATH - if this flag is "yes", then the RPATH_FLAG will be + passed by "-Wl" to the linker as normal on + ELF-platforms. + + * OCAMLLIBPATH - path to the OCaml-standard-libraries + (first default: `$(OCAMLC) -where`) + (second default: "/usr/local/lib/ocaml") + + * OCAML_DEFAULT_DIRS - additional path in which the user can supply + default directories to his own collection of + libraries. The idea is to pass this as an environment + variable so that the Makefiles do not have to contain + this path all the time. + + * OCAMLFIND - ocamlfind from findlib (default: "ocamlfind") + * OCAMLC - byte-code compiler (default: "ocamlc") + * OCAMLOPT - native-code compiler (default: "ocamlopt") + * OCAMLMKTOP - top-level compiler (default: "ocamlmktop") + * OCAMLCP - profiling byte-code compiler (default: "ocamlcp") + * OCAMLDEP - dependency generator (default: "ocamldep") + * OCAMLLEX - scanner generator (default: "ocamllex") + * OCAMLYACC - parser generator (default: "ocamlyacc") + * OCAMLMKLIB - tool to create libraries (default: "ocamlmklib") + * CAMLIDL - IDL-code generator (default: "camlidl") + * CAMLIDLDLL - IDL-utility (default: "camlidldll") + * CAMLP4 - camlp4 preprocessor (default: "camlp4") + * OCAMLDOC - OCamldoc-command (default: "ocamldoc") + + * LATEX - Latex-processor (default: "latex") + * DVIPS - dvips-command (default: "dvips") + * PS2PDF - PostScript-to-PDF converter (default: "ps2pdf") + + * CAMELEON_REPORT - report tool of Cameleon (default: "report") + * CAMELEON_REPORT_FLAGS - flags for the report tool of Cameleon + + * CAMELEON_ZOGGY - zoggy tool of Cameleon + (default: "camlp4o pa_zog.cma pr_o.cmo") + * CAMELEON_ZOGGY_FLAGS - flags for the zoggy tool of Cameleon + + * OCAML_GLADECC - Glade compiler for OCaml (default: "lablgladecc2") + * OCAML_GLADECC_FLAGS - flags for the Glade compiler + + * OXRIDL - OXRIDL-generator (default: "oxridl") + + * NOIDLHEADER - set to "yes" to prohibit "OCamlMakefile" from using + the default camlidl-flag "-header". + + * NO_CUSTOM - Prevent linking in custom mode. + + * QUIET - unsetting this variable (e.g. "make QUIET=") + will print all executed commands, including + intermediate ones. This allows more comfortable + debugging when things go wrong during a build. + + * REALLY_QUIET - when set this flag turns off output from some commands. + + * OCAMLMAKEFILE - location of (=path to) this "OCamlMakefile". + Because it calles itself recursively, it has to + know where it is. (default: "OCamlMakefile" = + local directory) + + * BCSUFFIX - Suffix for all byte-code files. E.g.: + + RESULT = foo + BCSUFFIX = _bc + + This will produce byte-code executables/libraries + with basename "foo_bc". + + * NCSUFFIX - Similar to "BCSUFFIX", but for native-code files. + * TOPSUFFIX - Suffix added to toplevel interpreters (default: ".top") + + * SUBPROJS - variable containing the names of subprojects to be + compiled. + + * SUBTARGET - target to be built for all projects in variable + SUBPROJS. + +--------------------------------------------------------------------------- + + Optional variables for Windows users + + * MINGW - variable to detect the MINGW-environment + * MSVC - variable to detect the MSVC-compiler + +--------------------------------------------------------------------------- + +Up-to-date information (newest release of distribution) can always be +found at: + + http://www.ocaml.info/home/ocaml_sources.html + +--------------------------------------------------------------------------- + +Enjoy! + +New York, 2007-04-22 +Markus Mottl + +e-mail: markus.mottl@gmail.com +WWW: http://www.ocaml.info diff --git a/lib/ocaml/TODO b/lib/ocaml/TODO new file mode 100644 index 00000000..4d1dc771 --- /dev/null +++ b/lib/ocaml/TODO @@ -0,0 +1,5 @@ +Write interfaces +Clean up the code generator +Avoid capture properly instead of relying on the user not to use _ + + diff --git a/lib/ocaml/src/Makefile b/lib/ocaml/src/Makefile new file mode 100644 index 00000000..42ec8dbd --- /dev/null +++ b/lib/ocaml/src/Makefile @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SOURCES = Thrift.ml TBinaryProtocol.ml TSocket.ml TChannelTransport.ml TServer.ml TSimpleServer.ml TServerSocket.ml TThreadedServer.ml +RESULT = thrift +LIBS = unix threads +THREADS = yes +all: native-code-library debug-code-library top +OCAMLMAKEFILE = ../OCamlMakefile +include $(OCAMLMAKEFILE) diff --git a/lib/ocaml/src/TBinaryProtocol.ml b/lib/ocaml/src/TBinaryProtocol.ml new file mode 100644 index 00000000..a06cc9a9 --- /dev/null +++ b/lib/ocaml/src/TBinaryProtocol.ml @@ -0,0 +1,171 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +module P = Protocol + +let get_byte i b = 255 land (i lsr (8*b)) +let get_byte64 i b = 255 land (Int64.to_int (Int64.shift_right i (8*b))) + + +let tv = P.t_type_to_i +let vt = P.t_type_of_i + + +let comp_int b n = + let s = ref 0l in + let sb = 32 - 8*n in + for i=0 to (n-1) do + s:= Int32.logor !s (Int32.shift_left (Int32.of_int (int_of_char b.[i])) (8*(n-1-i))) + done; + Int32.to_int (Int32.shift_right (Int32.shift_left !s sb) sb) + +let comp_int64 b n = + let s = ref 0L in + for i=0 to (n-1) do + s:=Int64.logor !s (Int64.shift_left (Int64.of_int (int_of_char b.[i])) (8*(n-1-i))) + done; + !s + +let version_mask = 0xffff0000 +let version_1 = 0x80010000 + +class t trans = +object (self) + inherit P.t trans + val ibyte = String.create 8 + method writeBool b = + ibyte.[0] <- char_of_int (if b then 1 else 0); + trans#write ibyte 0 1 + method writeByte i = + ibyte.[0] <- char_of_int (get_byte i 0); + trans#write ibyte 0 1 + method writeI16 i = + let gb = get_byte i in + ibyte.[1] <- char_of_int (gb 0); + ibyte.[0] <- char_of_int (gb 1); + trans#write ibyte 0 2 + method writeI32 i = + let gb = get_byte i in + for i=0 to 3 do + ibyte.[3-i] <- char_of_int (gb i) + done; + trans#write ibyte 0 4 + method writeI64 i= + let gb = get_byte64 i in + for i=0 to 7 do + ibyte.[7-i] <- char_of_int (gb i) + done; + trans#write ibyte 0 8 + method writeDouble d = + self#writeI64 (Int64.bits_of_float d) + method writeString s= + let n = String.length s in + self#writeI32(n); + trans#write s 0 n + method writeBinary a = self#writeString a + method writeMessageBegin (n,t,s) = + self#writeI32 (version_1 lor (P.message_type_to_i t)); + self#writeString n; + self#writeI32 s + method writeMessageEnd = () + method writeStructBegin s = () + method writeStructEnd = () + method writeFieldBegin (n,t,i) = + self#writeByte (tv t); + self#writeI16 i + method writeFieldEnd = () + method writeFieldStop = + self#writeByte (tv (Protocol.T_STOP)) + method writeMapBegin (k,v,s) = + self#writeByte (tv k); + self#writeByte (tv v); + self#writeI32 s + method writeMapEnd = () + method writeListBegin (t,s) = + self#writeByte (tv t); + self#writeI32 s + method writeListEnd = () + method writeSetBegin (t,s) = + self#writeByte (tv t); + self#writeI32 s + method writeSetEnd = () + method readByte = + ignore (trans#readAll ibyte 0 1); + (comp_int ibyte 1) + method readI16 = + ignore (trans#readAll ibyte 0 2); + comp_int ibyte 2 + method readI32 = + ignore (trans#readAll ibyte 0 4); + comp_int ibyte 4 + method readI64 = + ignore (trans#readAll ibyte 0 8); + comp_int64 ibyte 8 + method readDouble = + Int64.float_of_bits (self#readI64) + method readBool = + self#readByte = 1 + method readString = + let sz = self#readI32 in + let buf = String.create sz in + ignore (trans#readAll buf 0 sz); + buf + method readBinary = self#readString + method readMessageBegin = + let ver = self#readI32 in + if (ver land version_mask != version_1) then + (print_int ver; + raise (P.E (P.BAD_VERSION, "Missing version identifier"))) + else + let s = self#readString in + let mt = P.message_type_of_i (ver land 0xFF) in + (s,mt, self#readI32) + method readMessageEnd = () + method readStructBegin = + "" + method readStructEnd = () + method readFieldBegin = + let t = (vt (self#readByte)) + in + if t != P.T_STOP then + ("",t,self#readI16) + else ("",t,0); + method readFieldEnd = () + method readMapBegin = + let kt = vt (self#readByte) in + let vt = vt (self#readByte) in + (kt,vt, self#readI32) + method readMapEnd = () + method readListBegin = + let t = vt (self#readByte) in + (t,self#readI32) + method readListEnd = () + method readSetBegin = + let t = vt (self#readByte) in + (t, self#readI32); + method readSetEnd = () +end + +class factory = +object + inherit P.factory + method getProtocol tr = new t tr +end diff --git a/lib/ocaml/src/TChannelTransport.ml b/lib/ocaml/src/TChannelTransport.ml new file mode 100644 index 00000000..0f7d616f --- /dev/null +++ b/lib/ocaml/src/TChannelTransport.ml @@ -0,0 +1,39 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift +module T = Transport + +class t (i,o) = +object (self) + val mutable opened = true + inherit Transport.t + method isOpen = opened + method opn = () + method close = close_in i; opened <- false + method read buf off len = + if opened then + try + really_input i buf off len; len + with _ -> raise (T.E (T.UNKNOWN, ("TChannelTransport: Could not read "^(string_of_int len)))) + else + raise (T.E (T.NOT_OPEN, "TChannelTransport: Channel was closed")) + method write buf off len = output o buf off len + method flush = flush o +end diff --git a/lib/ocaml/src/TServer.ml b/lib/ocaml/src/TServer.ml new file mode 100644 index 00000000..fc51efa8 --- /dev/null +++ b/lib/ocaml/src/TServer.ml @@ -0,0 +1,42 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class virtual t + (pf : Processor.t) + (st : Transport.server_t) + (tf : Transport.factory) + (ipf : Protocol.factory) + (opf : Protocol.factory)= +object + method virtual serve : unit +end;; + + + +let run_basic_server proc port = + Unix.establish_server (fun inp -> fun out -> + let trans = new TChannelTransport.t (inp,out) in + let proto = new TBinaryProtocol.t (trans :> Transport.t) in + try + while proc#process proto proto do () done; () + with e -> ()) (Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1",port)) + + diff --git a/lib/ocaml/src/TServerSocket.ml b/lib/ocaml/src/TServerSocket.ml new file mode 100644 index 00000000..405ef82c --- /dev/null +++ b/lib/ocaml/src/TServerSocket.ml @@ -0,0 +1,41 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class t port = +object + inherit Transport.server_t + val mutable sock = None + method listen = + let s = Unix.socket Unix.PF_INET Unix.SOCK_STREAM 0 in + sock <- Some s; + Unix.bind s (Unix.ADDR_INET (Unix.inet_addr_any, port)); + Unix.listen s 256 + method close = + match sock with + Some s -> Unix.shutdown s Unix.SHUTDOWN_ALL; Unix.close s; + sock <- None + | _ -> () + method acceptImpl = + match sock with + Some s -> let (fd,_) = Unix.accept s in + new TChannelTransport.t (Unix.in_channel_of_descr fd,Unix.out_channel_of_descr fd) + | _ -> raise (Transport.E (Transport.NOT_OPEN,"TServerSocket: Not listening but tried to accept")) +end diff --git a/lib/ocaml/src/TSimpleServer.ml b/lib/ocaml/src/TSimpleServer.ml new file mode 100644 index 00000000..d19d8c55 --- /dev/null +++ b/lib/ocaml/src/TSimpleServer.ml @@ -0,0 +1,38 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift +module S = TServer + +class t pf st tf ipf opf = +object + inherit S.t pf st tf ipf opf + method serve = + try + st#listen; + let c = st#accept in + let trans = tf#getTransport c in + let inp = ipf#getProtocol trans in + let op = opf#getProtocol trans in + try + while (pf#process inp op) do () done; + trans#close + with e -> trans#close; raise e + with _ -> () +end diff --git a/lib/ocaml/src/TSocket.ml b/lib/ocaml/src/TSocket.ml new file mode 100644 index 00000000..109e11c5 --- /dev/null +++ b/lib/ocaml/src/TSocket.ml @@ -0,0 +1,59 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +module T = Transport + +class t host port= +object (self) + inherit T.t + val mutable chans = None + method isOpen = chans != None + method opn = + try + let addr = (let {Unix.h_addr_list=x} = Unix.gethostbyname host in x.(0)) in + chans <- Some(Unix.open_connection (Unix.ADDR_INET (addr,port))) + with + Unix.Unix_error (e,fn,_) -> raise (T.E (T.NOT_OPEN, ("TSocket: Could not connect to "^host^":"^(string_of_int port)^" because: "^fn^":"^(Unix.error_message e)))) + | _ -> raise (T.E (T.NOT_OPEN, ("TSocket: Could not connect to "^host^":"^(string_of_int port)))) + + method close = + match chans with + None -> () + | Some(inc,out) -> (Unix.shutdown_connection inc; + close_in inc; + chans <- None) + method read buf off len = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> + try + really_input i buf off len; len + with + Unix.Unix_error (e,fn,_) -> raise (T.E (T.UNKNOWN, ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)^" because: "^fn^":"^(Unix.error_message e)))) + | _ -> raise (T.E (T.UNKNOWN, ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)))) + method write buf off len = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> output o buf off len + method flush = match chans with + None -> raise (T.E (T.NOT_OPEN, "TSocket: Socket not open")) + | Some(i,o) -> flush o +end + + diff --git a/lib/ocaml/src/TThreadedServer.ml b/lib/ocaml/src/TThreadedServer.ml new file mode 100644 index 00000000..4462dbd7 --- /dev/null +++ b/lib/ocaml/src/TThreadedServer.ml @@ -0,0 +1,45 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift + +class t + (pf : Processor.t) + (st : Transport.server_t) + (tf : Transport.factory) + (ipf : Protocol.factory) + (opf : Protocol.factory)= +object + inherit TServer.t pf st tf ipf opf + method serve = + st#listen; + while true do + let tr = tf#getTransport (st#accept) in + ignore (Thread.create + (fun _ -> + let ip = ipf#getProtocol tr in + let op = opf#getProtocol tr in + try + while pf#process ip op do + () + done + with _ -> ()) ()) + done +end + diff --git a/lib/ocaml/src/Thrift.ml b/lib/ocaml/src/Thrift.ml new file mode 100644 index 00000000..8dc9afa3 --- /dev/null +++ b/lib/ocaml/src/Thrift.ml @@ -0,0 +1,368 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +exception Break;; +exception Thrift_error;; +exception Field_empty of string;; + +class t_exn = +object + val mutable message = "" + method get_message = message + method set_message s = message <- s +end;; + +module Transport = +struct + type exn_type = + | UNKNOWN + | NOT_OPEN + | ALREADY_OPEN + | TIMED_OUT + | END_OF_FILE;; + + exception E of exn_type * string + + class virtual t = + object (self) + method virtual isOpen : bool + method virtual opn : unit + method virtual close : unit + method virtual read : string -> int -> int -> int + method readAll buf off len = + let got = ref 0 in + let ret = ref 0 in + while !got < len do + ret := self#read buf (off+(!got)) (len - (!got)); + if !ret <= 0 then + raise (E (UNKNOWN, "Cannot read. Remote side has closed.")); + got := !got + !ret + done; + !got + method virtual write : string -> int -> int -> unit + method virtual flush : unit + end + + class factory = + object + method getTransport (t : t) = t + end + + class virtual server_t = + object (self) + method virtual listen : unit + method accept = self#acceptImpl + method virtual close : unit + method virtual acceptImpl : t + end + +end;; + + + +module Protocol = +struct + type t_type = + | T_STOP + | T_VOID + | T_BOOL + | T_BYTE + | T_I08 + | T_I16 + | T_I32 + | T_U64 + | T_I64 + | T_DOUBLE + | T_STRING + | T_UTF7 + | T_STRUCT + | T_MAP + | T_SET + | T_LIST + | T_UTF8 + | T_UTF16 + + let t_type_to_i = function + T_STOP -> 0 + | T_VOID -> 1 + | T_BOOL -> 2 + | T_BYTE -> 3 + | T_I08 -> 3 + | T_I16 -> 6 + | T_I32 -> 8 + | T_U64 -> 9 + | T_I64 -> 10 + | T_DOUBLE -> 4 + | T_STRING -> 11 + | T_UTF7 -> 11 + | T_STRUCT -> 12 + | T_MAP -> 13 + | T_SET -> 14 + | T_LIST -> 15 + | T_UTF8 -> 16 + | T_UTF16 -> 17 + + let t_type_of_i = function + 0 -> T_STOP + | 1 -> T_VOID + | 2 -> T_BOOL + | 3 -> T_BYTE + | 6-> T_I16 + | 8 -> T_I32 + | 9 -> T_U64 + | 10 -> T_I64 + | 4 -> T_DOUBLE + | 11 -> T_STRING + | 12 -> T_STRUCT + | 13 -> T_MAP + | 14 -> T_SET + | 15 -> T_LIST + | 16 -> T_UTF8 + | 17 -> T_UTF16 + | _ -> raise Thrift_error + + type message_type = + | CALL + | REPLY + | EXCEPTION + | ONEWAY + + let message_type_to_i = function + | CALL -> 1 + | REPLY -> 2 + | EXCEPTION -> 3 + | ONEWAY -> 4 + + let message_type_of_i = function + | 1 -> CALL + | 2 -> REPLY + | 3 -> EXCEPTION + | 4 -> ONEWAY + | _ -> raise Thrift_error + + class virtual t (trans: Transport.t) = + object (self) + val mutable trans_ = trans + method getTransport = trans_ + (* writing methods *) + method virtual writeMessageBegin : string * message_type * int -> unit + method virtual writeMessageEnd : unit + method virtual writeStructBegin : string -> unit + method virtual writeStructEnd : unit + method virtual writeFieldBegin : string * t_type * int -> unit + method virtual writeFieldEnd : unit + method virtual writeFieldStop : unit + method virtual writeMapBegin : t_type * t_type * int -> unit + method virtual writeMapEnd : unit + method virtual writeListBegin : t_type * int -> unit + method virtual writeListEnd : unit + method virtual writeSetBegin : t_type * int -> unit + method virtual writeSetEnd : unit + method virtual writeBool : bool -> unit + method virtual writeByte : int -> unit + method virtual writeI16 : int -> unit + method virtual writeI32 : int -> unit + method virtual writeI64 : Int64.t -> unit + method virtual writeDouble : float -> unit + method virtual writeString : string -> unit + method virtual writeBinary : string -> unit + (* reading methods *) + method virtual readMessageBegin : string * message_type * int + method virtual readMessageEnd : unit + method virtual readStructBegin : string + method virtual readStructEnd : unit + method virtual readFieldBegin : string * t_type * int + method virtual readFieldEnd : unit + method virtual readMapBegin : t_type * t_type * int + method virtual readMapEnd : unit + method virtual readListBegin : t_type * int + method virtual readListEnd : unit + method virtual readSetBegin : t_type * int + method virtual readSetEnd : unit + method virtual readBool : bool + method virtual readByte : int + method virtual readI16 : int + method virtual readI32: int + method virtual readI64 : Int64.t + method virtual readDouble : float + method virtual readString : string + method virtual readBinary : string + (* skippage *) + method skip typ = + match typ with + | T_STOP -> () + | T_VOID -> () + | T_BOOL -> ignore self#readBool + | T_BYTE + | T_I08 -> ignore self#readByte + | T_I16 -> ignore self#readI16 + | T_I32 -> ignore self#readI32 + | T_U64 + | T_I64 -> ignore self#readI64 + | T_DOUBLE -> ignore self#readDouble + | T_STRING -> ignore self#readString + | T_UTF7 -> () + | T_STRUCT -> ignore ((ignore self#readStructBegin); + (try + while true do + let (_,t,_) = self#readFieldBegin in + if t = T_STOP then + raise Break + else + (self#skip t; + self#readFieldEnd) + done + with Break -> ()); + self#readStructEnd) + | T_MAP -> ignore (let (k,v,s) = self#readMapBegin in + for i=0 to s do + self#skip k; + self#skip v; + done; + self#readMapEnd) + | T_SET -> ignore (let (t,s) = self#readSetBegin in + for i=0 to s do + self#skip t + done; + self#readSetEnd) + | T_LIST -> ignore (let (t,s) = self#readListBegin in + for i=0 to s do + self#skip t + done; + self#readListEnd) + | T_UTF8 -> () + | T_UTF16 -> () + end + + class virtual factory = + object + method virtual getProtocol : Transport.t -> t + end + + type exn_type = + | UNKNOWN + | INVALID_DATA + | NEGATIVE_SIZE + | SIZE_LIMIT + | BAD_VERSION + + exception E of exn_type * string;; + +end;; + + +module Processor = +struct + class virtual t = + object + method virtual process : Protocol.t -> Protocol.t -> bool + end;; + + class factory (processor : t) = + object + val processor_ = processor + method getProcessor (trans : Transport.t) = processor_ + end;; +end + + +(* Ugly *) +module Application_Exn = +struct + type typ= + | UNKNOWN + | UNKNOWN_METHOD + | INVALID_MESSAGE_TYPE + | WRONG_METHOD_NAME + | BAD_SEQUENCE_ID + | MISSING_RESULT + + let typ_of_i = function + 0 -> UNKNOWN + | 1 -> UNKNOWN_METHOD + | 2 -> INVALID_MESSAGE_TYPE + | 3 -> WRONG_METHOD_NAME + | 4 -> BAD_SEQUENCE_ID + | 5 -> MISSING_RESULT + | _ -> raise Thrift_error;; + let typ_to_i = function + | UNKNOWN -> 0 + | UNKNOWN_METHOD -> 1 + | INVALID_MESSAGE_TYPE -> 2 + | WRONG_METHOD_NAME -> 3 + | BAD_SEQUENCE_ID -> 4 + | MISSING_RESULT -> 5 + + class t = + object (self) + inherit t_exn + val mutable typ = UNKNOWN + method get_type = typ + method set_type t = typ <- t + method write (oprot : Protocol.t) = + oprot#writeStructBegin "TApplicationExeception"; + if self#get_message != "" then + (oprot#writeFieldBegin ("message",Protocol.T_STRING, 1); + oprot#writeString self#get_message; + oprot#writeFieldEnd) + else (); + oprot#writeFieldBegin ("type",Protocol.T_I32,2); + oprot#writeI32 (typ_to_i typ); + oprot#writeFieldEnd; + oprot#writeFieldStop; + oprot#writeStructEnd + end;; + + let create typ msg = + let e = new t in + e#set_type typ; + e#set_message msg; + e + + let read (iprot : Protocol.t) = + let msg = ref "" in + let typ = ref 0 in + ignore iprot#readStructBegin; + (try + while true do + let (name,ft,id) =iprot#readFieldBegin in + if ft = Protocol.T_STOP then + raise Break + else (); + (match id with + | 1 -> (if ft = Protocol.T_STRING then + msg := (iprot#readString) + else + iprot#skip ft) + | 2 -> (if ft = Protocol.T_I32 then + typ := iprot#readI32 + else + iprot#skip ft) + | _ -> iprot#skip ft); + iprot#readFieldEnd + done + with Break -> ()); + iprot#readStructEnd; + let e = new t in + e#set_type (typ_of_i !typ); + e#set_message !msg; + e;; + + exception E of t +end;; diff --git a/lib/perl/Makefile.PL b/lib/perl/Makefile.PL new file mode 100644 index 00000000..94ea37ce --- /dev/null +++ b/lib/perl/Makefile.PL @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use ExtUtils::MakeMaker; +WriteMakefile( 'NAME' => 'Thrift', + 'VERSION_FROM' => 'lib/Thrift.pm', + 'PREREQ_PM' => { + 'Bit::Vector' => 0, + 'Class::Accessor' => 0 + }, + ($] >= 5.005 ? + ( AUTHOR => 'T Jake Luciani ') : ()), + ); diff --git a/lib/perl/Makefile.am b/lib/perl/Makefile.am new file mode 100644 index 00000000..163d0158 --- /dev/null +++ b/lib/perl/Makefile.am @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = test + +Makefile-perl.mk : Makefile.PL + $(PERL) Makefile.PL MAKEFILE=Makefile-perl.mk INSTALLDIRS=$(INSTALLDIRS) + +all-local: Makefile-perl.mk + $(MAKE) -f Makefile-perl.mk + find blib -name 'Makefile*' -exec rm -f {} \; + +check-local: + $(PERL) -Iblib/lib -I@abs_srcdir@ -I@builddir@/test/gen-perl \ + @abs_srcdir@/test.pl @abs_srcdir@/test/*.t + +install-exec-local: Makefile-perl.mk + $(MAKE) -f Makefile-perl.mk install DESTDIR=$(DESTDIR)/ + +clean-local: + if test -f Makefile-perl.mk ; then \ + $(MAKE) -f Makefile-perl.mk clean ; \ + fi + rm -f Makefile-perl.mk.old + +EXTRA_DIST = \ + Makefile.PL \ + test.pl \ + lib/Thrift.pm \ + lib/Thrift.pm \ + lib/Thrift/BinaryProtocol.pm \ + lib/Thrift/BufferedTransport.pm \ + lib/Thrift/FramedTransport.pm \ + lib/Thrift/HttpClient.pm \ + lib/Thrift/MemoryBuffer.pm \ + lib/Thrift/Protocol.pm \ + lib/Thrift/Socket.pm \ + lib/Thrift/Transport.pm diff --git a/lib/perl/README b/lib/perl/README new file mode 100644 index 00000000..691488b4 --- /dev/null +++ b/lib/perl/README @@ -0,0 +1,41 @@ +Thrift Perl Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Perl +===================== + +Thrift requires Perl >= 5.6.0 + +Exceptions are thrown with die so be sure to wrap eval{} statments +around any code that contains exceptions. + +The 64bit Integers work only upto 2^42 on my machine :-? +Math::BigInt is probably needed. + +Please see tutoral and test dirs for examples... + +Dependencies +============ + +Bit::Vector - comes with modern perl installations. +Class::Accessor + diff --git a/lib/perl/lib/Thrift.pm b/lib/perl/lib/Thrift.pm new file mode 100644 index 00000000..fe0f8e72 --- /dev/null +++ b/lib/perl/lib/Thrift.pm @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +our $VERSION = '0.1'; + +require 5.6.0; +use strict; +use warnings; + +# +# Data types that can be sent via Thrift +# +package TType; +use constant STOP => 0; +use constant VOID => 1; +use constant BOOL => 2; +use constant BYTE => 3; +use constant I08 => 3; +use constant DOUBLE => 4; +use constant I16 => 6; +use constant I32 => 8; +use constant I64 => 10; +use constant STRING => 11; +use constant UTF7 => 11; +use constant STRUCT => 12; +use constant MAP => 13; +use constant SET => 14; +use constant LIST => 15; +use constant UTF8 => 16; +use constant UTF16 => 17; +1; + +# +# Message types for RPC +# +package TMessageType; +use constant CALL => 1; +use constant REPLY => 2; +use constant EXCEPTION => 3; +use constant ONEWAY => 4; +1; + +package Thrift::TException; + +sub new { + my $classname = shift; + my $self = {message => shift, code => shift || 0}; + + return bless($self,$classname); +} +1; + +package TApplicationException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant UNKNOWN_METHOD => 1; +use constant INVALID_MESSAGE_TYPE => 2; +use constant WRONG_METHOD_NAME => 3; +use constant BAD_SEQUENCE_ID => 4; +use constant MISSING_RESULT => 5; + +sub new { + my $classname = shift; + + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +sub read { + my $self = shift; + my $input = shift; + + my $xfer = 0; + my $fname = undef; + my $ftype = 0; + my $fid = 0; + + $xfer += $input->readStructBegin($fname); + + while (1) + { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + last; next; + } + + SWITCH: for($fid) + { + /1/ && do{ + + if ($ftype == TType::STRING) { + $xfer += $input->readString($self->{message}); + } else { + $xfer += $input->skip($ftype); + } + + last; + }; + + /2/ && do{ + if ($ftype == TType::I32) { + $xfer += $input->readI32($self->{code}); + } else { + $xfer += $input->skip($ftype); + } + last; + }; + + $xfer += $input->skip($ftype); + } + + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + + return $xfer; +} + +sub write { + my $self = shift; + my $output = shift; + + my $xfer = 0; + + $xfer += $output->writeStructBegin('TApplicationException'); + + if ($self->getMessage()) { + $xfer += $output->writeFieldBegin('message', TType::STRING, 1); + $xfer += $output->writeString($self->getMessage()); + $xfer += $output->writeFieldEnd(); + } + + if ($self->getCode()) { + $xfer += $output->writeFieldBegin('type', TType::I32, 2); + $xfer += $output->writeI32($self->getCode()); + $xfer += $output->writeFieldEnd(); + } + + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + + return $xfer; +} + +sub getMessage +{ + my $self = shift; + + return $self->{message}; +} + +sub getCode +{ + my $self = shift; + + return $self->{code}; +} + +1; diff --git a/lib/perl/lib/Thrift/BinaryProtocol.pm b/lib/perl/lib/Thrift/BinaryProtocol.pm new file mode 100644 index 00000000..0e5d61d3 --- /dev/null +++ b/lib/perl/lib/Thrift/BinaryProtocol.pm @@ -0,0 +1,498 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; + +use strict; +use warnings; + +use utf8; +use Encode; + +use Thrift; +use Thrift::Protocol; + +use Bit::Vector; + +# +# Binary implementation of the Thrift protocol. +# +package Thrift::BinaryProtocol; +use base('Thrift::Protocol'); + +use constant VERSION_MASK => 0xffff0000; +use constant VERSION_1 => 0x80010000; + +sub new +{ + my $classname = shift; + my $trans = shift; + my $self = $classname->SUPER::new($trans); + + return bless($self,$classname); +} + +sub writeMessageBegin +{ + my $self = shift; + my ($name, $type, $seqid) = @_; + + return + $self->writeI32(VERSION_1 | $type) + + $self->writeString($name) + + $self->writeI32($seqid); +} + +sub writeMessageEnd +{ + my $self = shift; + return 0; +} + +sub writeStructBegin{ + my $self = shift; + my $name = shift; + return 0; +} + +sub writeStructEnd +{ + my $self = shift; + return 0; +} + +sub writeFieldBegin +{ + my $self = shift; + my ($fieldName, $fieldType, $fieldId) = @_; + + return + $self->writeByte($fieldType) + + $self->writeI16($fieldId); +} + +sub writeFieldEnd +{ + my $self = shift; + return 0; +} + +sub writeFieldStop +{ + my $self = shift; + return $self->writeByte(TType::STOP); +} + +sub writeMapBegin +{ + my $self = shift; + my ($keyType, $valType, $size) = @_; + + return + $self->writeByte($keyType) + + $self->writeByte($valType) + + $self->writeI32($size); +} + +sub writeMapEnd +{ + my $self = shift; + return 0; +} + +sub writeListBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->writeByte($elemType) + + $self->writeI32($size); +} + +sub writeListEnd +{ + my $self = shift; + return 0; +} + +sub writeSetBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->writeByte($elemType) + + $self->writeI32($size); +} + +sub writeSetEnd +{ + my $self = shift; + return 0; +} + +sub writeBool +{ + my $self = shift; + my $value = shift; + + my $data = pack('c', $value ? 1 : 0); + $self->{trans}->write($data, 1); + return 1; +} + +sub writeByte +{ + my $self = shift; + my $value= shift; + + my $data = pack('c', $value); + $self->{trans}->write($data, 1); + return 1; +} + +sub writeI16 +{ + my $self = shift; + my $value= shift; + + my $data = pack('n', $value); + $self->{trans}->write($data, 2); + return 2; +} + +sub writeI32 +{ + my $self = shift; + my $value= shift; + + my $data = pack('N', $value); + $self->{trans}->write($data, 4); + return 4; +} + +sub writeI64 +{ + my $self = shift; + my $value= shift; + my $data; + + my $vec; + #stop annoying error + $vec = Bit::Vector->new_Dec(64, $value); + $data = pack 'NN', $vec->Chunk_Read(32, 32), $vec->Chunk_Read(32, 0); + + $self->{trans}->write($data, 8); + + return 8; +} + + +sub writeDouble +{ + my $self = shift; + my $value= shift; + + my $data = pack('d', $value); + $self->{trans}->write(scalar reverse($data), 8); + return 8; +} + +sub writeString{ + my $self = shift; + my $value= shift; + + if( utf8::is_utf8($value) ){ + $value = Encode::encode_utf8($value); + } + + my $len = length($value); + + my $result = $self->writeI32($len); + + if ($len) { + $self->{trans}->write($value,$len); + } + return $result + $len; + } + + +# +#All references +# +sub readMessageBegin +{ + my $self = shift; + my ($name, $type, $seqid) = @_; + + my $version = 0; + my $result = $self->readI32(\$version); + if (($version & VERSION_MASK) > 0) { + if (($version & VERSION_MASK) != VERSION_1) { + die new Thrift::TException('Missing version identifier') + } + $$type = $version & 0x000000ff; + return + $result + + $self->readString($name) + + $self->readI32($seqid); + } else { # old client support code + return + $result + + $self->readStringBody($name, $version) + # version here holds the size of the string + $self->readByte($type) + + $self->readI32($seqid); + } +} + +sub readMessageEnd +{ + my $self = shift; + return 0; +} + +sub readStructBegin +{ + my $self = shift; + my $name = shift; + + $$name = ''; + + return 0; +} + +sub readStructEnd +{ + my $self = shift; + return 0; +} + +sub readFieldBegin +{ + my $self = shift; + my ($name, $fieldType, $fieldId) = @_; + + my $result = $self->readByte($fieldType); + + if ($$fieldType == TType::STOP) { + $$fieldId = 0; + return $result; + } + + $result += $self->readI16($fieldId); + + return $result; +} + +sub readFieldEnd() { + my $self = shift; + return 0; +} + +sub readMapBegin +{ + my $self = shift; + my ($keyType, $valType, $size) = @_; + + return + $self->readByte($keyType) + + $self->readByte($valType) + + $self->readI32($size); +} + +sub readMapEnd() +{ + my $self = shift; + return 0; +} + +sub readListBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->readByte($elemType) + + $self->readI32($size); +} + +sub readListEnd +{ + my $self = shift; + return 0; +} + +sub readSetBegin +{ + my $self = shift; + my ($elemType, $size) = @_; + + return + $self->readByte($elemType) + + $self->readI32($size); +} + +sub readSetEnd +{ + my $self = shift; + return 0; +} + +sub readBool +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(1); + my @arr = unpack('c', $data); + $$value = $arr[0] == 1; + return 1; +} + +sub readByte +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(1); + my @arr = unpack('c', $data); + $$value = $arr[0]; + return 1; +} + +sub readI16 +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(2); + + my @arr = unpack('n', $data); + + $$value = $arr[0]; + + if ($$value > 0x7fff) { + $$value = 0 - (($$value - 1) ^ 0xffff); + } + + return 2; +} + +sub readI32 +{ + my $self = shift; + my $value= shift; + + my $data = $self->{trans}->readAll(4); + my @arr = unpack('N', $data); + + $$value = $arr[0]; + if ($$value > 0x7fffffff) { + $$value = 0 - (($$value - 1) ^ 0xffffffff); + } + return 4; +} + +sub readI64 +{ + my $self = shift; + my $value = shift; + + my $data = $self->{trans}->readAll(8); + + my ($hi,$lo)=unpack('NN',$data); + + my $vec = new Bit::Vector(64); + + $vec->Chunk_Store(32,32,$hi); + $vec->Chunk_Store(32,0,$lo); + + $$value = $vec->to_Dec(); + + return 8; +} + +sub readDouble +{ + my $self = shift; + my $value = shift; + + my $data = scalar reverse($self->{trans}->readAll(8)); + my @arr = unpack('d', $data); + + $$value = $arr[0]; + + return 8; +} + +sub readString +{ + my $self = shift; + my $value = shift; + + my $len; + my $result = $self->readI32(\$len); + + if ($len) { + $$value = $self->{trans}->readAll($len); + } else { + $$value = ''; + } + + return $result + $len; +} + +sub readStringBody +{ + my $self = shift; + my $value = shift; + my $len = shift; + + if ($len) { + $$value = $self->{trans}->readAll($len); + } else { + $$value = ''; + } + + return $len; +} + +# +# Binary Protocol Factory +# +package TBinaryProtocolFactory; +use base('TProtocolFactory'); + +sub new +{ + my $classname = shift; + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +sub getProtocol{ + my $self = shift; + my $trans = shift; + + return new TBinaryProtocol($trans); +} + +1; diff --git a/lib/perl/lib/Thrift/BufferedTransport.pm b/lib/perl/lib/Thrift/BufferedTransport.pm new file mode 100644 index 00000000..bef564d6 --- /dev/null +++ b/lib/perl/lib/Thrift/BufferedTransport.pm @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +package Thrift::BufferedTransport; +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $transport = shift; + my $rBufSize = shift || 512; + my $wBufSize = shift || 512; + + my $self = { + transport => $transport, + rBufSize => $rBufSize, + wBufSize => $wBufSize, + wBuf => '', + rBuf => '', + }; + + return bless($self,$classname); +} + +sub isOpen +{ + my $self = shift; + + return $self->{transport}->isOpen(); +} + +sub open +{ + my $self = shift; + $self->{transport}->open(); +} + +sub close() +{ + my $self = shift; + $self->{transport}->close(); +} + +sub readAll +{ + my $self = shift; + my $len = shift; + + return $self->{transport}->readAll($len); +} + +sub read +{ + my $self = shift; + my $len = shift; + my $ret; + + # Methinks Perl is already buffering these for us + return $self->{transport}->read($len); +} + +sub write +{ + my $self = shift; + my $buf = shift; + + $self->{wBuf} .= $buf; + if (length($self->{wBuf}) >= $self->{wBufSize}) { + $self->{transport}->write($self->{wBuf}); + $self->{wBuf} = ''; + } +} + +sub flush +{ + my $self = shift; + + if (length($self->{wBuf}) > 0) { + $self->{transport}->write($self->{wBuf}); + $self->{wBuf} = ''; + } + $self->{transport}->flush(); +} + + +1; diff --git a/lib/perl/lib/Thrift/FramedTransport.pm b/lib/perl/lib/Thrift/FramedTransport.pm new file mode 100644 index 00000000..b78b1989 --- /dev/null +++ b/lib/perl/lib/Thrift/FramedTransport.pm @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +# +# Framed transport. Writes and reads data in chunks that are stamped with +# their length. +# +# @package thrift.transport +# +package Thrift::FramedTransport; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $transport = shift; + my $read = shift || 1; + my $write = shift || 1; + + my $self = { + transport => $transport, + read => $read, + write => $write, + wBuf => '', + rBuf => '', + }; + + return bless($self,$classname); +} + +sub isOpen +{ + my $self = shift; + return $self->{transport}->isOpen(); +} + +sub open +{ + my $self = shift; + + $self->{transport}->open(); +} + +sub close +{ + my $self = shift; + + $self->{transport}->close(); +} + +# +# Reads from the buffer. When more data is required reads another entire +# chunk and serves future reads out of that. +# +# @param int $len How much data +# +sub read +{ + + my $self = shift; + my $len = shift; + + if (!$self->{read}) { + return $self->{transport}->read($len); + } + + if (length($self->{rBuf}) == 0) { + $self->_readFrame(); + } + + + # Just return full buff + if ($len > length($self->{rBuf})) { + my $out = $self->{rBuf}; + $self->{rBuf} = ''; + return $out; + } + + # Return substr + my $out = substr($self->{rBuf}, 0, $len); + $self->{rBuf} = substr($self->{rBuf}, $len); + return $out; +} + +# +# Reads a chunk of data into the internal read buffer. +# (private) +sub _readFrame +{ + my $self = shift; + my $buf = $self->{transport}->readAll(4); + my @val = unpack('N', $buf); + my $sz = $val[0]; + + $self->{rBuf} = $self->{transport}->readAll($sz); +} + +# +# Writes some data to the pending output buffer. +# +# @param string $buf The data +# @param int $len Limit of bytes to write +# +sub write +{ + my $self = shift; + my $buf = shift; + my $len = shift; + + unless($self->{write}) { + return $self->{transport}->write($buf, $len); + } + + if ( defined $len && $len < length($buf)) { + $buf = substr($buf, 0, $len); + } + + $self->{wBuf} .= $buf; + } + +# +# Writes the output buffer to the stream in the format of a 4-byte length +# followed by the actual data. +# +sub flush +{ + my $self = shift; + + unless ($self->{write}) { + return $self->{transport}->flush(); + } + + my $out = pack('N', length($self->{wBuf})); + $out .= $self->{wBuf}; + $self->{transport}->write($out); + $self->{transport}->flush(); + $self->{wBuf} = ''; + +} + +1; diff --git a/lib/perl/lib/Thrift/HttpClient.pm b/lib/perl/lib/Thrift/HttpClient.pm new file mode 100644 index 00000000..d6fc8be3 --- /dev/null +++ b/lib/perl/lib/Thrift/HttpClient.pm @@ -0,0 +1,200 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +use HTTP::Request; +use LWP::UserAgent; +use IO::String; + +package Thrift::HttpClient; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $url = shift || 'http://localhost:9090'; + my $debugHandler = shift; + + my $out = IO::String->new; + binmode($out); + + my $self = { + url => $url, + out => $out, + debugHandler => $debugHandler, + debug => 0, + sendTimeout => 100, + recvTimeout => 750, + handle => undef, + }; + + return bless($self,$classname); +} + +sub setSendTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{sendTimeout} = $timeout; +} + +sub setRecvTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{recvTimeout} = $timeout; +} + + +# +#Sets debugging output on or off +# +# @param bool $debug +# +sub setDebug +{ + my $self = shift; + my $debug = shift; + + $self->{debug} = $debug; +} + +# +# Tests whether this is open +# +# @return bool true if the socket is open +# +sub isOpen +{ + return 1; +} + +sub open {} + +# +# Cleans up the buffer. +# +sub close +{ + my $self = shift; + if (defined($self->{io})) { + close($self->{io}); + $self->{io} = undef; + } +} + +# +# Guarantees that the full amount of data is read. +# +# @return string The data, of exact length +# @throws TTransportException if cannot read data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + my $buf = $self->read($len); + + if (!defined($buf)) { + die new Thrift::TException('TSocket: Could not read '.$len.' bytes from input buffer'); + } + return $buf; +} + +# +# Read and return string +# +sub read +{ + my $self = shift; + my $len = shift; + + my $buf; + + my $in = $self->{in}; + + if (!defined($in)) { + die new Thrift::TException("Response buffer is empty, no request."); + } + eval { + my $ret = sysread($in, $buf, $len); + if (! defined($ret)) { + die new Thrift::TException("No more data available."); + } + }; if($@){ + die new Thrift::TException($@); + } + + return $buf; +} + +# +# Write string +# +sub write +{ + my $self = shift; + my $buf = shift; + $self->{out}->print($buf); +} + +# +# Flush output (do the actual HTTP/HTTPS request) +# +sub flush +{ + my $self = shift; + + my $ua = LWP::UserAgent->new('timeout' => ($self->{sendTimeout} / 1000), + 'agent' => 'Perl/THttpClient' + ); + $ua->default_header('Accept' => 'application/x-thrift'); + $ua->default_header('Content-Type' => 'application/x-thrift'); + $ua->cookie_jar({}); # hash to remember cookies between redirects + + my $out = $self->{out}; + $out->setpos(0); # rewind + my $buf = join('', <$out>); + + my $request = new HTTP::Request(POST => $self->{url}, undef, $buf); + my $response = $ua->request($request); + my $content_ref = $response->content_ref; + + my $in = IO::String->new($content_ref); + binmode($in); + $self->{in} = $in; + $in->setpos(0); # rewind + + # reset write buffer + $out = IO::String->new; + binmode($out); + $self->{out} = $out; +} + +1; diff --git a/lib/perl/lib/Thrift/MemoryBuffer.pm b/lib/perl/lib/Thrift/MemoryBuffer.pm new file mode 100644 index 00000000..32f14424 --- /dev/null +++ b/lib/perl/lib/Thrift/MemoryBuffer.pm @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +package Thrift::MemoryBuffer; +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + + my $bufferSize= shift || 1024; + + my $self = { + buffer => '', + bufferSize=> $bufferSize, + wPos => 0, + rPos => 0, + }; + + return bless($self,$classname); +} + +sub isOpen +{ + return 1; +} + +sub open +{ + +} + +sub close +{ + +} + +sub peek +{ + my $self = shift; + return($self->{rPos} < $self->{wPos}); +} + + +sub getBuffer +{ + my $self = shift; + return $self->{buffer}; +} + +sub resetBuffer +{ + my $self = shift; + + my $new_buffer = shift || ''; + + $self->{buffer} = $new_buffer; + $self->{bufferSize} = length($new_buffer); + $self->{wPos} = length($new_buffer); + $self->{rPos} = 0; +} + +sub available +{ + my $self = shift; + return ($self->{wPos} - $self->{rPos}); +} + +sub read +{ + my $self = shift; + my $len = shift; + my $ret; + + my $avail = ($self->{wPos} - $self->{rPos}); + return '' if $avail == 0; + + #how much to give + my $give = $len; + $give = $avail if $avail < $len; + + $ret = substr($self->{buffer},$self->{rPos},$give); + + $self->{rPos} += $give; + + return $ret; +} + +sub write +{ + my $self = shift; + my $buf = shift; + + $self->{buffer} .= $buf; + $self->{wPos} += length($buf); +} + +sub flush +{ + +} + +1; diff --git a/lib/perl/lib/Thrift/Protocol.pm b/lib/perl/lib/Thrift/Protocol.pm new file mode 100644 index 00000000..034711f3 --- /dev/null +++ b/lib/perl/lib/Thrift/Protocol.pm @@ -0,0 +1,543 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; + +# +# Protocol exceptions +# +package TProtocolException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant INVALID_DATA => 1; +use constant NEGATIVE_SIZE => 2; +use constant SIZE_LIMIT => 3; +use constant BAD_VERSION => 4; + +sub new { + my $classname = shift; + + my $self = $classname->SUPER::new(); + + return bless($self,$classname); +} + +# +# Protocol base class module. +# +package Thrift::Protocol; + +sub new { + my $classname = shift; + my $self = {}; + + my $trans = shift; + $self->{trans}= $trans; + + return bless($self,$classname); +} + +sub getTransport +{ + my $self = shift; + + return $self->{trans}; +} + +# +# Writes the message header +# +# @param string $name Function name +# @param int $type message type TMessageType::CALL or TMessageType::REPLY +# @param int $seqid The sequence id of this message +# +sub writeMessageBegin +{ + my ($name, $type, $seqid); + die "abstract"; +} + +# +# Close the message +# +sub writeMessageEnd { + die "abstract"; +} + +# +# Writes a struct header. +# +# @param string $name Struct name +# @throws TException on write error +# @return int How many bytes written +# +sub writeStructBegin { + my ($name); + + die "abstract"; +} + +# +# Close a struct. +# +# @throws TException on write error +# @return int How many bytes written +# +sub writeStructEnd { + die "abstract"; +} + +# +# Starts a field. +# +# @param string $name Field name +# @param int $type Field type +# @param int $fid Field id +# @throws TException on write error +# @return int How many bytes written +# +sub writeFieldBegin { + my ($fieldName, $fieldType, $fieldId); + + die "abstract"; +} + +sub writeFieldEnd { + die "abstract"; +} + +sub writeFieldStop { + die "abstract"; +} + +sub writeMapBegin { + my ($keyType, $valType, $size); + + die "abstract"; +} + +sub writeMapEnd { + die "abstract"; +} + +sub writeListBegin { + my ($elemType, $size); + die "abstract"; +} + +sub writeListEnd { + die "abstract"; +} + +sub writeSetBegin { + my ($elemType, $size); + die "abstract"; +} + +sub writeSetEnd { + die "abstract"; +} + +sub writeBool { + my ($bool); + die "abstract"; +} + +sub writeByte { + my ($byte); + die "abstract"; +} + +sub writeI16 { + my ($i16); + die "abstract"; +} + +sub writeI32 { + my ($i32); + die "abstract"; +} + +sub writeI64 { + my ($i64); + die "abstract"; +} + +sub writeDouble { + my ($dub); + die "abstract"; +} + +sub writeString +{ + my ($str); + die "abstract"; +} + +# +# Reads the message header +# +# @param string $name Function name +# @param int $type message type TMessageType::CALL or TMessageType::REPLY +# @parem int $seqid The sequence id of this message +# +sub readMessageBegin +{ + my ($name, $type, $seqid); + die "abstract"; +} + +# +# Read the close of message +# +sub readMessageEnd +{ + die "abstract"; +} + +sub readStructBegin +{ + my($name); + + die "abstract"; +} + +sub readStructEnd +{ + die "abstract"; +} + +sub readFieldBegin +{ + my ($name, $fieldType, $fieldId); + die "abstract"; +} + +sub readFieldEnd +{ + die "abstract"; +} + +sub readMapBegin +{ + my ($keyType, $valType, $size); + die "abstract"; +} + +sub readMapEnd +{ + die "abstract"; +} + +sub readListBegin +{ + my ($elemType, $size); + die "abstract"; +} + +sub readListEnd +{ + die "abstract"; +} + +sub readSetBegin +{ + my ($elemType, $size); + die "abstract"; +} + +sub readSetEnd +{ + die "abstract"; +} + +sub readBool +{ + my ($bool); + die "abstract"; +} + +sub readByte +{ + my ($byte); + die "abstract"; +} + +sub readI16 +{ + my ($i16); + die "abstract"; +} + +sub readI32 +{ + my ($i32); + die "abstract"; +} + +sub readI64 +{ + my ($i64); + die "abstract"; +} + +sub readDouble +{ + my ($dub); + die "abstract"; +} + +sub readString +{ + my ($str); + die "abstract"; +} + +# +# The skip function is a utility to parse over unrecognized data without +# causing corruption. +# +# @param TType $type What type is it +# +sub skip +{ + my $self = shift; + my $type = shift; + + my $ref; + my $result; + my $i; + + if($type == TType::BOOL) + { + return $self->readBool(\$ref); + } + elsif($type == TType::BYTE){ + return $self->readByte(\$ref); + } + elsif($type == TType::I16){ + return $self->readI16(\$ref); + } + elsif($type == TType::I32){ + return $self->readI32(\$ref); + } + elsif($type == TType::I64){ + return $self->readI64(\$ref); + } + elsif($type == TType::DOUBLE){ + return $self->readDouble(\$ref); + } + elsif($type == TType::STRING) + { + return $self->readString(\$ref); + } + elsif($type == TType::STRUCT) + { + $result = $self->readStructBegin(\$ref); + while (1) { + my ($ftype,$fid); + $result += $self->readFieldBegin(\$ref, \$ftype, \$fid); + if ($ftype == TType::STOP) { + last; + } + $result += $self->skip($ftype); + $result += $self->readFieldEnd(); + } + $result += $self->readStructEnd(); + return $result; + } + elsif($type == TType::MAP) + { + my($keyType,$valType,$size); + $result = $self->readMapBegin(\$keyType, \$valType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($keyType); + $result += $self->skip($valType); + } + $result += $self->readMapEnd(); + return $result; + } + elsif($type == TType::SET) + { + my ($elemType,$size); + $result = $self->readSetBegin(\$elemType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($elemType); + } + $result += $self->readSetEnd(); + return $result; + } + elsif($type == TType::LIST) + { + my ($elemType,$size); + $result = $self->readListBegin(\$elemType, \$size); + for ($i = 0; $i < $size; $i++) { + $result += $self->skip($elemType); + } + $result += $self->readListEnd(); + return $result; + } + + + return 0; + + } + +# +# Utility for skipping binary data +# +# @param TTransport $itrans TTransport object +# @param int $type Field type +# +sub skipBinary +{ + my $self = shift; + my $itrans = shift; + my $type = shift; + + if($type == TType::BOOL) + { + return $itrans->readAll(1); + } + elsif($type == TType::BYTE) + { + return $itrans->readAll(1); + } + elsif($type == TType::I16) + { + return $itrans->readAll(2); + } + elsif($type == TType::I32) + { + return $itrans->readAll(4); + } + elsif($type == TType::I64) + { + return $itrans->readAll(8); + } + elsif($type == TType::DOUBLE) + { + return $itrans->readAll(8); + } + elsif( $type == TType::STRING ) + { + my @len = unpack('N', $itrans->readAll(4)); + my $len = $len[0]; + if ($len > 0x7fffffff) { + $len = 0 - (($len - 1) ^ 0xffffffff); + } + return 4 + $itrans->readAll($len); + } + elsif( $type == TType::STRUCT ) + { + my $result = 0; + while (1) { + my $ftype = 0; + my $fid = 0; + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + $ftype = $arr[0]; + if ($ftype == TType::STOP) { + last; + } + # I16 field id + $result += $itrans->readAll(2); + $result += $self->skipBinary($itrans, $ftype); + } + return $result; + } + elsif($type == TType::MAP) + { + # Ktype + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + my $ktype = $arr[0]; + # Vtype + $data = $itrans->readAll(1); + @arr = unpack('c', $data); + my $vtype = $arr[0]; + # Size + $data = $itrans->readAll(4); + @arr = unpack('N', $data); + my $size = $arr[0]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + my $result = 6; + for (my $i = 0; $i < $size; $i++) { + $result += $self->skipBinary($itrans, $ktype); + $result += $self->skipBinary($itrans, $vtype); + } + return $result; + } + elsif($type == TType::SET || $type == TType::LIST) + { + # Vtype + my $data = $itrans->readAll(1); + my @arr = unpack('c', $data); + my $vtype = $arr[0]; + # Size + $data = $itrans->readAll(4); + @arr = unpack('N', $data); + my $size = $arr[0]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + my $result = 5; + for (my $i = 0; $i < $size; $i++) { + $result += $self->skipBinary($itrans, $vtype); + } + return $result; + } + + return 0; + +} + +# +# Protocol factory creates protocol objects from transports +# +package TProtocolFactory; + + +sub new { + my $classname = shift; + my $self = {}; + + return bless($self,$classname); +} + +# +# Build a protocol from the base transport +# +# @return TProtcol protocol +# +sub getProtocol +{ + my ($trans); + die "interface"; +} + + +1; diff --git a/lib/perl/lib/Thrift/Socket.pm b/lib/perl/lib/Thrift/Socket.pm new file mode 100644 index 00000000..67faa510 --- /dev/null +++ b/lib/perl/lib/Thrift/Socket.pm @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; +use Thrift::Transport; + +use IO::Socket::INET; +use IO::Select; + +package Thrift::Socket; + +use base('Thrift::Transport'); + +sub new +{ + my $classname = shift; + my $host = shift || "localhost"; + my $port = shift || 9090; + my $debugHandler = shift; + + my $self = { + host => $host, + port => $port, + debugHandler => $debugHandler, + debug => 0, + sendTimeout => 100, + recvTimeout => 750, + handle => undef, + }; + + return bless($self,$classname); +} + + +sub setSendTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{sendTimeout} = $timeout; +} + +sub setRecvTimeout +{ + my $self = shift; + my $timeout = shift; + + $self->{recvTimeout} = $timeout; +} + + +# +#Sets debugging output on or off +# +# @param bool $debug +# +sub setDebug +{ + my $self = shift; + my $debug = shift; + + $self->{debug} = $debug; +} + +# +# Tests whether this is open +# +# @return bool true if the socket is open +# +sub isOpen +{ + my $self = shift; + + if( defined $self->{handle} ){ + return ($self->{handle}->handles())[0]->connected; + } + + return 0; +} + +# +# Connects the socket. +# +sub open +{ + my $self = shift; + + my $sock = IO::Socket::INET->new(PeerAddr => $self->{host}, + PeerPort => $self->{port}, + Proto => 'tcp', + Timeout => $self->{sendTimeout}/1000) + || do { + my $error = 'TSocket: Could not connect to '.$self->{host}.':'.$self->{port}.' ('.$!.')'; + + if ($self->{debug}) { + $self->{debugHandler}->($error); + } + + die new Thrift::TException($error); + + }; + + + $self->{handle} = new IO::Select( $sock ); +} + +# +# Closes the socket. +# +sub close +{ + my $self = shift; + + if( defined $self->{handle} ){ + close( ($self->{handle}->handles())[0] ); + } +} + +# +# Uses stream get contents to do the reading +# +# @param int $len How many bytes +# @return string Binary data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + + return unless defined $self->{handle}; + + my $pre = ""; + while (1) { + + #check for timeout + my @sockets = $self->{handle}->can_read( $self->{recvTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out reading '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my ($buf,$sz); + $sock->recv($buf, $len); + + if (!defined $buf || $buf eq '') { + + die new Thrift::TException('TSocket: Could not read '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + + } elsif (($sz = length($buf)) < $len) { + + $pre .= $buf; + $len -= $sz; + + } else { + return $pre.$buf; + } + } +} + +# +# Read from the socket +# +# @param int $len How many bytes +# @return string Binary data +# +sub read +{ + my $self = shift; + my $len = shift; + + return unless defined $self->{handle}; + + #check for timeout + my @sockets = $self->{handle}->can_read( $self->{sendTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out reading '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my ($buf,$sz); + $sock->recv($buf, $len); + + if (!defined $buf || $buf eq '') { + + die new TException('TSocket: Could not read '.$len.' bytes from '. + $self->{host}.':'.$self->{port}); + + } + + return $buf; +} + + +# +# Write to the socket. +# +# @param string $buf The data to write +# +sub write +{ + my $self = shift; + my $buf = shift; + + + return unless defined $self->{handle}; + + while (length($buf) > 0) { + + + #check for timeout + my @sockets = $self->{handle}->can_write( $self->{recvTimeout} / 1000 ); + + if(@sockets == 0){ + die new Thrift::TException('TSocket: timed out writing to bytes from '. + $self->{host}.':'.$self->{port}); + } + + my $sock = $sockets[0]; + + my $got = $sock->send($buf); + + if (!defined $got || $got == 0 ) { + die new Thrift::TException('TSocket: Could not write '.length($buf).' bytes '. + $self->{host}.':'.$self->{host}); + } + + $buf = substr($buf, $got); + } +} + +# +# Flush output to the socket. +# +sub flush +{ + my $self = shift; + + return unless defined $self->{handle}; + + my $ret = ($self->{handle}->handles())[0]->flush; +} + +1; diff --git a/lib/perl/lib/Thrift/Transport.pm b/lib/perl/lib/Thrift/Transport.pm new file mode 100644 index 00000000..e22592be --- /dev/null +++ b/lib/perl/lib/Thrift/Transport.pm @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; + +use Thrift; + +# +# Transport exceptions +# +package TTransportException; +use base('Thrift::TException'); + +use constant UNKNOWN => 0; +use constant NOT_OPEN => 1; +use constant ALREADY_OPEN => 2; +use constant TIMED_OUT => 3; +use constant END_OF_FILE => 4; + +sub new{ + my $classname = shift; + my $self = $classname->SUPER::new(@_); + + return bless($self,$classname); +} + +package Thrift::Transport; + +# +# Whether this transport is open. +# +# @return boolean true if open +# +sub isOpen +{ + die "abstract"; +} + +# +# Open the transport for reading/writing +# +# @throws TTransportException if cannot open +# +sub open +{ + die "abstract"; +} + +# +# Close the transport. +# +sub close +{ + die "abstract"; +} + +# +# Read some data into the array. +# +# @param int $len How much to read +# @return string The data that has been read +# @throws TTransportException if cannot read any more data +# +sub read +{ + my ($len); + die("abstract"); +} + +# +# Guarantees that the full amount of data is read. +# +# @return string The data, of exact length +# @throws TTransportException if cannot read data +# +sub readAll +{ + my $self = shift; + my $len = shift; + + my $data = ''; + my $got = 0; + + while (($got = length($data)) < $len) { + $data .= $self->read($len - $got); + } + + return $data; +} + +# +# Writes the given data out. +# +# @param string $buf The data to write +# @throws TTransportException if writing fails +# +sub write +{ + my ($buf); + die "abstract"; +} + +# +# Flushes any pending data out of a buffer +# +# @throws TTransportException if a writing error occurs +# +sub flush {} + +1; + diff --git a/lib/perl/test.pl b/lib/perl/test.pl new file mode 100644 index 00000000..7e068402 --- /dev/null +++ b/lib/perl/test.pl @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use strict; +use warnings; + +use Test::Harness; + +runtests(@ARGV); diff --git a/lib/perl/test/Makefile.am b/lib/perl/test/Makefile.am new file mode 100644 index 00000000..ce87c48d --- /dev/null +++ b/lib/perl/test/Makefile.am @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFT = @top_builddir@/compiler/cpp/thrift +THRIFT_IF = @top_srcdir@/test/ThriftTest.thrift + +check-local: gen-perl/ThriftTest/Types.pm + +gen-perl/ThriftTest/Types.pm: $(THRIFT_IF) + $(THRIFT) --gen perl $(THRIFT_IF) + +clean-local: + rm -rf gen-perl + +EXTRA_DIST = memory_buffer.t diff --git a/lib/perl/test/memory_buffer.t b/lib/perl/test/memory_buffer.t new file mode 100644 index 00000000..8fa9fd72 --- /dev/null +++ b/lib/perl/test/memory_buffer.t @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use Test::More tests => 6; + +use strict; +use warnings; + +use Data::Dumper; + +use Thrift::BinaryProtocol; +use Thrift::MemoryBuffer; + +use ThriftTest::Types; + + +my $transport = Thrift::MemoryBuffer->new(); +my $protocol = Thrift::BinaryProtocol->new($transport); + +my $a = ThriftTest::Xtruct->new(); +$a->i32_thing(10); +$a->i64_thing(30); +$a->string_thing('Hello, world!'); +$a->write($protocol); + +my $b = ThriftTest::Xtruct->new(); +$b->read($protocol); +is($b->i32_thing, $a->i32_thing); +is($b->i64_thing, $a->i64_thing); +is($b->string_thing, $a->string_thing); + +$b->write($protocol); +my $c = ThriftTest::Xtruct->new(); +$c->read($protocol); +is($c->i32_thing, $a->i32_thing); +is($c->i64_thing, $a->i64_thing); +is($c->string_thing, $a->string_thing); diff --git a/lib/php/README b/lib/php/README new file mode 100644 index 00000000..bb566f42 --- /dev/null +++ b/lib/php/README @@ -0,0 +1,63 @@ +Thrift PHP Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with PHP +===================== + +Thrift requires PHP 5. Thrift makes as few assumptions about your PHP +environment as possible while trying to make some more advanced PHP +features (i.e. APC cacheing using asbolute path URLs) as simple as possible. + +To use Thrift in your PHP codebase, take the following steps: + +#1) Copy all of thrift/lib/php/src into your PHP codebase +#2) Set $GLOBALS['THRIFT_ROOT'] to the path you installed Thrift +#3) include_once $GLOBALS['THRIFT_ROOT'].'/Thrift.php'; + +Note that #3 must be done before including any other Thrift files. +If you do not do #2, Thrift.php will set this global for you, but it will be +done using dirname(__FILE__), which is less efficient than providing the static +string yourself. + +When you generate a Thrift package using the compiler, it makes an assumption +about where your generated code will live. If your file is "MyPackage.thrift", +the generated files must be installed into: + +$GLOBALS['THRIFT_ROOT'].'/packages/MyPackage/'; + +This allows the code generator to compile your code without any extra flags +for the target directory names while still allowing your include paths to +be absolute (if you have an absolute THRIFT_ROOT). + +Dependencies +============ + +PHP_INT_SIZE + + This built-in signals whether your architecture is 32 or 64 bit and is + used by the TBinaryProtocol to properly use pack() and unpack() to + serialize data. + +apc_fetch(), apc_store() + + APC cache is used by the TSocketPool class. If you do not have APC installed, + Thrift will fill in null stub function definitions. diff --git a/lib/php/README.apache b/lib/php/README.apache new file mode 100644 index 00000000..8c41833d --- /dev/null +++ b/lib/php/README.apache @@ -0,0 +1,62 @@ +Thrift PHP/Apache Integration + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Building PHP Thrift Services with Apache +======================================== + +Thrift can be embedded in the Apache webserver with PHP installed. Sample +code is provided below. Note that to make requests to this type of server +you must use a THttpClient transport. + +Sample Code +=========== + +open(); +$processor->process($protocol, $protocol); +$transport->close(); diff --git a/lib/php/src/Thrift.php b/lib/php/src/Thrift.php new file mode 100644 index 00000000..ef6ab8a4 --- /dev/null +++ b/lib/php/src/Thrift.php @@ -0,0 +1,787 @@ + $fspec) { + $var = $fspec['var']; + if (isset($vals[$var])) { + $this->$var = $vals[$var]; + } + } + } else { + parent::__construct($p1, $p2); + } + } + + static $tmethod = array(TType::BOOL => 'Bool', + TType::BYTE => 'Byte', + TType::I16 => 'I16', + TType::I32 => 'I32', + TType::I64 => 'I64', + TType::DOUBLE => 'Double', + TType::STRING => 'String'); + + private function _readMap(&$var, $spec, $input) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kread = $vread = null; + if (isset(TBase::$tmethod[$ktype])) { + $kread = 'read'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vread = 'read'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $var = array(); + $_ktype = $_vtype = $size = 0; + $xfer += $input->readMapBegin($_ktype, $_vtype, $size); + for ($i = 0; $i < $size; ++$i) { + $key = $val = null; + if ($kread !== null) { + $xfer += $input->$kread($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $class = $kspec['class']; + $key = new $class(); + $xfer += $key->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($key, $kspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($key, $kspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($key, $kspec, $input, true); + break; + } + } + if ($vread !== null) { + $xfer += $input->$vread($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $class = $vspec['class']; + $val = new $class(); + $xfer += $val->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($val, $vspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($val, $vspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($val, $vspec, $input, true); + break; + } + } + $var[$key] = $val; + } + $xfer += $input->readMapEnd(); + return $xfer; + } + + private function _readList(&$var, $spec, $input, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $eread = $vread = null; + if (isset(TBase::$tmethod[$etype])) { + $eread = 'read'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + $var = array(); + $_etype = $size = 0; + if ($set) { + $xfer += $input->readSetBegin($_etype, $size); + } else { + $xfer += $input->readListBegin($_etype, $size); + } + for ($i = 0; $i < $size; ++$i) { + $elem = null; + if ($eread !== null) { + $xfer += $input->$eread($elem); + } else { + $espec = $spec['elem']; + switch ($etype) { + case TType::STRUCT: + $class = $espec['class']; + $elem = new $class(); + $xfer += $elem->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($elem, $espec, $input); + break; + case TType::LST: + $xfer += $this->_readList($elem, $espec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($elem, $espec, $input, true); + break; + } + } + if ($set) { + $var[$elem] = true; + } else { + $var []= $elem; + } + } + if ($set) { + $xfer += $input->readSetEnd(); + } else { + $xfer += $input->readListEnd(); + } + return $xfer; + } + + protected function _read($class, $spec, $input) { + $xfer = 0; + $fname = null; + $ftype = 0; + $fid = 0; + $xfer += $input->readStructBegin($fname); + while (true) { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + if (isset($spec[$fid])) { + $fspec = $spec[$fid]; + $var = $fspec['var']; + if ($ftype == $fspec['type']) { + $xfer = 0; + if (isset(TBase::$tmethod[$ftype])) { + $func = 'read'.TBase::$tmethod[$ftype]; + $xfer += $input->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $class = $fspec['class']; + $this->$var = new $class(); + $xfer += $this->$var->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($this->$var, $fspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($this->$var, $fspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($this->$var, $fspec, $input, true); + break; + } + } + } else { + $xfer += $input->skip($ftype); + } + } else { + $xfer += $input->skip($ftype); + } + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + return $xfer; + } + + private function _writeMap($var, $spec, $output) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kwrite = $vwrite = null; + if (isset(TBase::$tmethod[$ktype])) { + $kwrite = 'write'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vwrite = 'write'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $xfer += $output->writeMapBegin($ktype, $vtype, count($var)); + foreach ($var as $key => $val) { + if (isset($kwrite)) { + $xfer += $output->$kwrite($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $xfer += $key->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($key, $kspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($key, $kspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($key, $kspec, $output, true); + break; + } + } + if (isset($vwrite)) { + $xfer += $output->$vwrite($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $xfer += $val->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($val, $vspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($val, $vspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($val, $vspec, $output, true); + break; + } + } + } + $xfer += $output->writeMapEnd(); + return $xfer; + } + + private function _writeList($var, $spec, $output, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $ewrite = null; + if (isset(TBase::$tmethod[$etype])) { + $ewrite = 'write'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + if ($set) { + $xfer += $output->writeSetBegin($etype, count($var)); + } else { + $xfer += $output->writeListBegin($etype, count($var)); + } + foreach ($var as $key => $val) { + $elem = $set ? $key : $val; + if (isset($ewrite)) { + $xfer += $output->$ewrite($elem); + } else { + switch ($etype) { + case TType::STRUCT: + $xfer += $elem->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($elem, $espec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($elem, $espec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($elem, $espec, $output, true); + break; + } + } + } + if ($set) { + $xfer += $output->writeSetEnd(); + } else { + $xfer += $output->writeListEnd(); + } + return $xfer; + } + + protected function _write($class, $spec, $output) { + $xfer = 0; + $xfer += $output->writeStructBegin($class); + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if ($this->$var !== null) { + $ftype = $fspec['type']; + $xfer += $output->writeFieldBegin($var, $ftype, $fid); + if (isset(TBase::$tmethod[$ftype])) { + $func = 'write'.TBase::$tmethod[$ftype]; + $xfer += $output->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $xfer += $this->$var->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($this->$var, $fspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($this->$var, $fspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($this->$var, $fspec, $output, true); + break; + } + } + $xfer += $output->writeFieldEnd(); + } + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } + +} + +/** + * Base class from which other Thrift structs extend. This is so that we can + * cut back on the size of the generated code which is turning out to have a + * nontrivial cost just to load thanks to the wondrously abysmal implementation + * of PHP. Note that code is intentionally duplicated in here to avoid making + * function calls for every field or member of a container.. + */ +abstract class TBase { + + static $tmethod = array(TType::BOOL => 'Bool', + TType::BYTE => 'Byte', + TType::I16 => 'I16', + TType::I32 => 'I32', + TType::I64 => 'I64', + TType::DOUBLE => 'Double', + TType::STRING => 'String'); + + abstract function read($input); + + abstract function write($output); + + public function __construct($spec=null, $vals=null) { + if (is_array($spec) && is_array($vals)) { + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if (isset($vals[$var])) { + $this->$var = $vals[$var]; + } + } + } + } + + private function _readMap(&$var, $spec, $input) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kread = $vread = null; + if (isset(TBase::$tmethod[$ktype])) { + $kread = 'read'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vread = 'read'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $var = array(); + $_ktype = $_vtype = $size = 0; + $xfer += $input->readMapBegin($_ktype, $_vtype, $size); + for ($i = 0; $i < $size; ++$i) { + $key = $val = null; + if ($kread !== null) { + $xfer += $input->$kread($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $class = $kspec['class']; + $key = new $class(); + $xfer += $key->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($key, $kspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($key, $kspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($key, $kspec, $input, true); + break; + } + } + if ($vread !== null) { + $xfer += $input->$vread($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $class = $vspec['class']; + $val = new $class(); + $xfer += $val->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($val, $vspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($val, $vspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($val, $vspec, $input, true); + break; + } + } + $var[$key] = $val; + } + $xfer += $input->readMapEnd(); + return $xfer; + } + + private function _readList(&$var, $spec, $input, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $eread = $vread = null; + if (isset(TBase::$tmethod[$etype])) { + $eread = 'read'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + $var = array(); + $_etype = $size = 0; + if ($set) { + $xfer += $input->readSetBegin($_etype, $size); + } else { + $xfer += $input->readListBegin($_etype, $size); + } + for ($i = 0; $i < $size; ++$i) { + $elem = null; + if ($eread !== null) { + $xfer += $input->$eread($elem); + } else { + $espec = $spec['elem']; + switch ($etype) { + case TType::STRUCT: + $class = $espec['class']; + $elem = new $class(); + $xfer += $elem->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($elem, $espec, $input); + break; + case TType::LST: + $xfer += $this->_readList($elem, $espec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($elem, $espec, $input, true); + break; + } + } + if ($set) { + $var[$elem] = true; + } else { + $var []= $elem; + } + } + if ($set) { + $xfer += $input->readSetEnd(); + } else { + $xfer += $input->readListEnd(); + } + return $xfer; + } + + protected function _read($class, $spec, $input) { + $xfer = 0; + $fname = null; + $ftype = 0; + $fid = 0; + $xfer += $input->readStructBegin($fname); + while (true) { + $xfer += $input->readFieldBegin($fname, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + if (isset($spec[$fid])) { + $fspec = $spec[$fid]; + $var = $fspec['var']; + if ($ftype == $fspec['type']) { + $xfer = 0; + if (isset(TBase::$tmethod[$ftype])) { + $func = 'read'.TBase::$tmethod[$ftype]; + $xfer += $input->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $class = $fspec['class']; + $this->$var = new $class(); + $xfer += $this->$var->read($input); + break; + case TType::MAP: + $xfer += $this->_readMap($this->$var, $fspec, $input); + break; + case TType::LST: + $xfer += $this->_readList($this->$var, $fspec, $input, false); + break; + case TType::SET: + $xfer += $this->_readList($this->$var, $fspec, $input, true); + break; + } + } + } else { + $xfer += $input->skip($ftype); + } + } else { + $xfer += $input->skip($ftype); + } + $xfer += $input->readFieldEnd(); + } + $xfer += $input->readStructEnd(); + return $xfer; + } + + private function _writeMap($var, $spec, $output) { + $xfer = 0; + $ktype = $spec['ktype']; + $vtype = $spec['vtype']; + $kwrite = $vwrite = null; + if (isset(TBase::$tmethod[$ktype])) { + $kwrite = 'write'.TBase::$tmethod[$ktype]; + } else { + $kspec = $spec['key']; + } + if (isset(TBase::$tmethod[$vtype])) { + $vwrite = 'write'.TBase::$tmethod[$vtype]; + } else { + $vspec = $spec['val']; + } + $xfer += $output->writeMapBegin($ktype, $vtype, count($var)); + foreach ($var as $key => $val) { + if (isset($kwrite)) { + $xfer += $output->$kwrite($key); + } else { + switch ($ktype) { + case TType::STRUCT: + $xfer += $key->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($key, $kspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($key, $kspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($key, $kspec, $output, true); + break; + } + } + if (isset($vwrite)) { + $xfer += $output->$vwrite($val); + } else { + switch ($vtype) { + case TType::STRUCT: + $xfer += $val->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($val, $vspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($val, $vspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($val, $vspec, $output, true); + break; + } + } + } + $xfer += $output->writeMapEnd(); + return $xfer; + } + + private function _writeList($var, $spec, $output, $set=false) { + $xfer = 0; + $etype = $spec['etype']; + $ewrite = null; + if (isset(TBase::$tmethod[$etype])) { + $ewrite = 'write'.TBase::$tmethod[$etype]; + } else { + $espec = $spec['elem']; + } + if ($set) { + $xfer += $output->writeSetBegin($etype, count($var)); + } else { + $xfer += $output->writeListBegin($etype, count($var)); + } + foreach ($var as $key => $val) { + $elem = $set ? $key : $val; + if (isset($ewrite)) { + $xfer += $output->$ewrite($elem); + } else { + switch ($etype) { + case TType::STRUCT: + $xfer += $elem->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($elem, $espec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($elem, $espec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($elem, $espec, $output, true); + break; + } + } + } + if ($set) { + $xfer += $output->writeSetEnd(); + } else { + $xfer += $output->writeListEnd(); + } + return $xfer; + } + + protected function _write($class, $spec, $output) { + $xfer = 0; + $xfer += $output->writeStructBegin($class); + foreach ($spec as $fid => $fspec) { + $var = $fspec['var']; + if ($this->$var !== null) { + $ftype = $fspec['type']; + $xfer += $output->writeFieldBegin($var, $ftype, $fid); + if (isset(TBase::$tmethod[$ftype])) { + $func = 'write'.TBase::$tmethod[$ftype]; + $xfer += $output->$func($this->$var); + } else { + switch ($ftype) { + case TType::STRUCT: + $xfer += $this->$var->write($output); + break; + case TType::MAP: + $xfer += $this->_writeMap($this->$var, $fspec, $output); + break; + case TType::LST: + $xfer += $this->_writeList($this->$var, $fspec, $output, false); + break; + case TType::SET: + $xfer += $this->_writeList($this->$var, $fspec, $output, true); + break; + } + } + $xfer += $output->writeFieldEnd(); + } + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } +} + +class TApplicationException extends TException { + static $_TSPEC = + array(1 => array('var' => 'message', + 'type' => TType::STRING), + 2 => array('var' => 'code', + 'type' => TType::I32)); + + const UNKNOWN = 0; + const UNKNOWN_METHOD = 1; + const INVALID_MESSAGE_TYPE = 2; + const WRONG_METHOD_NAME = 3; + const BAD_SEQUENCE_ID = 4; + const MISSING_RESULT = 5; + + function __construct($message=null, $code=0) { + parent::__construct($message, $code); + } + + public function read($output) { + return $this->_read('TApplicationException', self::$_TSPEC, $output); + } + + public function write($output) { + $xfer = 0; + $xfer += $output->writeStructBegin('TApplicationException'); + if ($message = $this->getMessage()) { + $xfer += $output->writeFieldBegin('message', TType::STRING, 1); + $xfer += $output->writeString($message); + $xfer += $output->writeFieldEnd(); + } + if ($code = $this->getCode()) { + $xfer += $output->writeFieldBegin('type', TType::I32, 2); + $xfer += $output->writeI32($code); + $xfer += $output->writeFieldEnd(); + } + $xfer += $output->writeFieldStop(); + $xfer += $output->writeStructEnd(); + return $xfer; + } +} + +/** + * Set global THRIFT ROOT automatically via inclusion here + */ +if (!isset($GLOBALS['THRIFT_ROOT'])) { + $GLOBALS['THRIFT_ROOT'] = dirname(__FILE__); +} +include_once $GLOBALS['THRIFT_ROOT'].'/protocol/TProtocol.php'; +include_once $GLOBALS['THRIFT_ROOT'].'/transport/TTransport.php'; + +?> diff --git a/lib/php/src/autoload.php b/lib/php/src/autoload.php new file mode 100644 index 00000000..3a35545d --- /dev/null +++ b/lib/php/src/autoload.php @@ -0,0 +1,51 @@ + +#include +#include +#include +#include +#include + +#if __BYTE_ORDER == __LITTLE_ENDIAN +#define htonll(x) bswap_64(x) +#define ntohll(x) bswap_64(x) +#else +#define htonll(x) x +#define ntohll(x) x +#endif + +enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +}; + +const int32_t VERSION_MASK = 0xffff0000; +const int32_t VERSION_1 = 0x80010000; +const int8_t T_CALL = 1; +const int8_t T_REPLY = 2; +const int8_t T_EXCEPTION = 3; +// tprotocolexception +const int INVALID_DATA = 1; +const int BAD_VERSION = 4; + +#include "php.h" +#include "zend_interfaces.h" +#include "zend_exceptions.h" +#include "php_thrift_protocol.h" + +static function_entry thrift_protocol_functions[] = { + PHP_FE(thrift_protocol_write_binary, NULL) + PHP_FE(thrift_protocol_read_binary, NULL) + {NULL, NULL, NULL} +} ; + +zend_module_entry thrift_protocol_module_entry = { + STANDARD_MODULE_HEADER, + "thrift_protocol", + thrift_protocol_functions, + NULL, + NULL, + NULL, + NULL, + NULL, + "1.0", + STANDARD_MODULE_PROPERTIES +}; + +#ifdef COMPILE_DL_THRIFT_PROTOCOL +ZEND_GET_MODULE(thrift_protocol) +#endif + +class PHPExceptionWrapper : public std::exception { +public: + PHPExceptionWrapper(zval* _ex) throw() : ex(_ex) { + snprintf(_what, 40, "PHP exception zval=%p", ex); + } + const char* what() const throw() { return _what; } + ~PHPExceptionWrapper() throw() {} + operator zval*() const throw() { return const_cast(ex); } // Zend API doesn't do 'const'... +protected: + zval* ex; + char _what[40]; +} ; + +class PHPTransport { +public: + zval* protocol() { return p; } + zval* transport() { return t; } +protected: + PHPTransport() {} + + void construct_with_zval(zval* _p, size_t _buffer_size) { + buffer = reinterpret_cast(emalloc(_buffer_size)); + buffer_ptr = buffer; + buffer_used = 0; + buffer_size = _buffer_size; + p = _p; + + // Get the transport for the passed protocol + zval gettransport; + ZVAL_STRING(&gettransport, "getTransport", 0); + MAKE_STD_ZVAL(t); + ZVAL_NULL(t); + TSRMLS_FETCH(); + call_user_function(EG(function_table), &p, &gettransport, t, 0, NULL TSRMLS_CC); + } + ~PHPTransport() { + efree(buffer); + zval_ptr_dtor(&t); + } + + char* buffer; + char* buffer_ptr; + size_t buffer_used; + size_t buffer_size; + + zval* p; + zval* t; +}; + + +class PHPOutputTransport : public PHPTransport { +public: + PHPOutputTransport(zval* _p, size_t _buffer_size = 8192) { + construct_with_zval(_p, _buffer_size); + } + + ~PHPOutputTransport() { + flush(); + directFlush(); + } + + void write(const char* data, size_t len) { + if ((len + buffer_used) > buffer_size) { + flush(); + } + if (len > buffer_size) { + directWrite(data, len); + } else { + memcpy(buffer_ptr, data, len); + buffer_used += len; + buffer_ptr += len; + } + } + + void writeI64(int64_t i) { + i = htonll(i); + write((const char*)&i, 8); + } + + void writeU32(uint32_t i) { + i = htonl(i); + write((const char*)&i, 4); + } + + void writeI32(int32_t i) { + i = htonl(i); + write((const char*)&i, 4); + } + + void writeI16(int16_t i) { + i = htons(i); + write((const char*)&i, 2); + } + + void writeI8(int8_t i) { + write((const char*)&i, 1); + } + + void writeString(const char* str, size_t len) { + writeU32(len); + write(str, len); + } + + void flush() { + if (buffer_used) { + directWrite(buffer, buffer_used); + buffer_ptr = buffer; + buffer_used = 0; + } + } + +protected: + void directFlush() { + zval ret; + ZVAL_NULL(&ret); + zval flushfn; + ZVAL_STRING(&flushfn, "flush", 0); + TSRMLS_FETCH(); + call_user_function(EG(function_table), &t, &flushfn, &ret, 0, NULL TSRMLS_CC); + zval_dtor(&ret); + } + void directWrite(const char* data, size_t len) { + zval writefn; + ZVAL_STRING(&writefn, "write", 0); + char* newbuf = (char*)emalloc(buffer_used + 1); + memcpy(newbuf, buffer, buffer_used); + newbuf[buffer_used] = '\0'; + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_STRINGL(args[0], newbuf, buffer_used, 0); + TSRMLS_FETCH(); + zval ret; + ZVAL_NULL(&ret); + call_user_function(EG(function_table), &t, &writefn, &ret, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + zval_dtor(&ret); + if (EG(exception)) { + zval* ex = EG(exception); + EG(exception) = NULL; + throw PHPExceptionWrapper(ex); + } + } +}; + +class PHPInputTransport : public PHPTransport { +public: + PHPInputTransport(zval* _p, size_t _buffer_size = 8192) { + construct_with_zval(_p, _buffer_size); + } + + ~PHPInputTransport() { + put_back(); + } + + void put_back() { + if (buffer_used) { + zval putbackfn; + ZVAL_STRING(&putbackfn, "putBack", 0); + + char* newbuf = (char*)emalloc(buffer_used + 1); + memcpy(newbuf, buffer_ptr, buffer_used); + newbuf[buffer_used] = '\0'; + + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_STRINGL(args[0], newbuf, buffer_used, 0); + + TSRMLS_FETCH(); + + zval ret; + ZVAL_NULL(&ret); + call_user_function(EG(function_table), &t, &putbackfn, &ret, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + zval_dtor(&ret); + } + buffer_used = 0; + buffer_ptr = buffer; + } + + void skip(size_t len) { + while (len) { + size_t chunk_size = MIN(len, buffer_used); + if (chunk_size) { + buffer_ptr = reinterpret_cast(buffer_ptr) + chunk_size; + buffer_used -= chunk_size; + len -= chunk_size; + } + if (! len) break; + refill(); + } + } + + void readBytes(void* buf, size_t len) { + while (len) { + size_t chunk_size = MIN(len, buffer_used); + if (chunk_size) { + memcpy(buf, buffer_ptr, chunk_size); + buffer_ptr = reinterpret_cast(buffer_ptr) + chunk_size; + buffer_used -= chunk_size; + buf = reinterpret_cast(buf) + chunk_size; + len -= chunk_size; + } + if (! len) break; + refill(); + } + } + + int8_t readI8() { + int8_t c; + readBytes(&c, 1); + return c; + } + + int16_t readI16() { + int16_t c; + readBytes(&c, 2); + return (int16_t)ntohs(c); + } + + uint32_t readU32() { + uint32_t c; + readBytes(&c, 4); + return (uint32_t)ntohl(c); + } + + int32_t readI32() { + int32_t c; + readBytes(&c, 4); + return (int32_t)ntohl(c); + } + +protected: + void refill() { + assert(buffer_used == 0); + zval retval; + ZVAL_NULL(&retval); + + zval *args[1]; + MAKE_STD_ZVAL(args[0]); + ZVAL_LONG(args[0], buffer_size); + + TSRMLS_FETCH(); + + zval funcname; + ZVAL_STRING(&funcname, "read", 0); + + call_user_function(EG(function_table), &t, &funcname, &retval, 1, args TSRMLS_CC); + zval_ptr_dtor(args); + + if (EG(exception)) { + zval_dtor(&retval); + zval* ex = EG(exception); + EG(exception) = NULL; + throw PHPExceptionWrapper(ex); + } + + buffer_used = Z_STRLEN(retval); + memcpy(buffer, Z_STRVAL(retval), buffer_used); + zval_dtor(&retval); + + buffer_ptr = buffer; + } + +}; + +void binary_deserialize_spec(zval* zthis, PHPInputTransport& transport, HashTable* spec); +void binary_serialize_spec(zval* zthis, PHPOutputTransport& transport, HashTable* spec); +void binary_serialize(int8_t thrift_typeID, PHPOutputTransport& transport, zval** value, HashTable* fieldspec); +void skip_element(long thrift_typeID, PHPInputTransport& transport); + +// Create a PHP object given a typename and call the ctor, optionally passing up to 2 arguments +void createObject(char* obj_typename, zval* return_value, int nargs = 0, zval* arg1 = NULL, zval* arg2 = NULL) { + TSRMLS_FETCH(); + size_t obj_typename_len = strlen(obj_typename); + zend_class_entry* ce = zend_fetch_class(obj_typename, obj_typename_len, ZEND_FETCH_CLASS_DEFAULT TSRMLS_CC); + if (! ce) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "Class %s does not exist", obj_typename); + RETURN_NULL(); + } + + object_and_properties_init(return_value, ce, NULL); + zend_function* constructor = zend_std_get_constructor(return_value TSRMLS_CC); + zval* ctor_rv = NULL; + zend_call_method(&return_value, ce, &constructor, NULL, 0, &ctor_rv, nargs, arg1, arg2 TSRMLS_CC); + zval_ptr_dtor(&ctor_rv); +} + +void throw_tprotocolexception(char* what, long errorcode) { + TSRMLS_FETCH(); + + zval *zwhat, *zerrorcode; + MAKE_STD_ZVAL(zwhat); + MAKE_STD_ZVAL(zerrorcode); + + ZVAL_STRING(zwhat, what, 1); + ZVAL_LONG(zerrorcode, errorcode); + + zval* ex; + MAKE_STD_ZVAL(ex); + createObject("TProtocolException", ex, 2, zwhat, zerrorcode); + zval_ptr_dtor(&zwhat); + zval_ptr_dtor(&zerrorcode); + throw PHPExceptionWrapper(ex); +} + +void binary_deserialize(int8_t thrift_typeID, PHPInputTransport& transport, zval* return_value, HashTable* fieldspec) { + zval** val_ptr; + Z_TYPE_P(return_value) = IS_NULL; // just in case + + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + RETURN_NULL(); + return; + case T_STRUCT: { + if (zend_hash_find(fieldspec, "class", 6, (void**)&val_ptr) != SUCCESS) { + throw_tprotocolexception("no class type in spec", INVALID_DATA); + skip_element(T_STRUCT, transport); + RETURN_NULL(); + } + char* structType = Z_STRVAL_PP(val_ptr); + createObject(structType, return_value); + if (Z_TYPE_P(return_value) == IS_NULL) { + // unable to create class entry + skip_element(T_STRUCT, transport); + RETURN_NULL(); + } + TSRMLS_FETCH(); + zval* spec = zend_read_static_property(zend_get_class_entry(return_value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + if (Z_TYPE_P(spec) != IS_ARRAY) { + char errbuf[128]; + snprintf(errbuf, 128, "spec for %s is wrong type: %d\n", structType, Z_TYPE_P(spec)); + throw_tprotocolexception(errbuf, INVALID_DATA); + RETURN_NULL(); + } + binary_deserialize_spec(return_value, transport, Z_ARRVAL_P(spec)); + return; + } break; + case T_BOOL: { + uint8_t c; + transport.readBytes(&c, 1); + RETURN_BOOL(c != 0); + } + //case T_I08: // same numeric value as T_BYTE + case T_BYTE: { + uint8_t c; + transport.readBytes(&c, 1); + RETURN_LONG(c); + } + case T_I16: { + uint16_t c; + transport.readBytes(&c, 2); + RETURN_LONG(ntohs(c)); + } + case T_I32: { + uint32_t c; + transport.readBytes(&c, 4); + RETURN_LONG(ntohl(c)); + } + case T_U64: + case T_I64: { + uint64_t c; + transport.readBytes(&c, 8); + RETURN_LONG(ntohll(c)); + } + case T_DOUBLE: { + union { + uint64_t c; + double d; + } a; + transport.readBytes(&(a.c), 8); + a.c = ntohll(a.c); + RETURN_DOUBLE(a.d); + } + //case T_UTF7: // aliases T_STRING + case T_UTF8: + case T_UTF16: + case T_STRING: { + uint32_t size = transport.readU32(); + if (size) { + char* strbuf = (char*) emalloc(size + 1); + transport.readBytes(strbuf, size); + strbuf[size] = '\0'; + ZVAL_STRINGL(return_value, strbuf, size, 0); + } else { + ZVAL_EMPTY_STRING(return_value); + } + return; + } + case T_MAP: { // array of key -> value + uint8_t types[2]; + transport.readBytes(types, 2); + uint32_t size = transport.readU32(); + array_init(return_value); + + zend_hash_find(fieldspec, "key", 4, (void**)&val_ptr); + HashTable* keyspec = Z_ARRVAL_PP(val_ptr); + zend_hash_find(fieldspec, "val", 4, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + for (uint32_t s = 0; s < size; ++s) { + zval *value; + MAKE_STD_ZVAL(value); + + zval* key; + MAKE_STD_ZVAL(key); + + binary_deserialize(types[0], transport, key, keyspec); + binary_deserialize(types[1], transport, value, valspec); + if (Z_TYPE_P(key) == IS_LONG) { + zend_hash_index_update(return_value->value.ht, Z_LVAL_P(key), &value, sizeof(zval *), NULL); + } + else { + if (Z_TYPE_P(key) != IS_STRING) convert_to_string(key); + zend_hash_update(return_value->value.ht, Z_STRVAL_P(key), Z_STRLEN_P(key) + 1, &value, sizeof(zval *), NULL); + } + zval_ptr_dtor(&key); + } + return; // return_value already populated + } + case T_LIST: { // array with autogenerated numeric keys + int8_t type = transport.readI8(); + uint32_t size = transport.readU32(); + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* elemspec = Z_ARRVAL_PP(val_ptr); + + array_init(return_value); + for (uint32_t s = 0; s < size; ++s) { + zval *value; + MAKE_STD_ZVAL(value); + binary_deserialize(type, transport, value, elemspec); + zend_hash_next_index_insert(return_value->value.ht, &value, sizeof(zval *), NULL); + } + return; + } + case T_SET: { // array of key -> TRUE + uint8_t type; + uint32_t size; + transport.readBytes(&type, 1); + transport.readBytes(&size, 4); + size = ntohl(size); + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* elemspec = Z_ARRVAL_PP(val_ptr); + + array_init(return_value); + + for (uint32_t s = 0; s < size; ++s) { + zval* key; + zval* value; + MAKE_STD_ZVAL(key); + MAKE_STD_ZVAL(value); + ZVAL_TRUE(value); + + binary_deserialize(type, transport, key, elemspec); + + if (Z_TYPE_P(key) == IS_LONG) { + zend_hash_index_update(return_value->value.ht, Z_LVAL_P(key), &value, sizeof(zval *), NULL); + } + else { + if (Z_TYPE_P(key) != IS_STRING) convert_to_string(key); + zend_hash_update(return_value->value.ht, Z_STRVAL_P(key), Z_STRLEN_P(key) + 1, &value, sizeof(zval *), NULL); + } + zval_ptr_dtor(&key); + } + return; + } + }; + + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %d", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + +void skip_element(long thrift_typeID, PHPInputTransport& transport) { + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + return; + case T_STRUCT: + while (true) { + int8_t ttype = transport.readI8(); // get field type + if (ttype == T_STOP) break; + transport.skip(2); // skip field number, I16 + skip_element(ttype, transport); // skip field payload + } + return; + case T_BOOL: + case T_BYTE: + transport.skip(1); + return; + case T_I16: + transport.skip(2); + return; + case T_I32: + transport.skip(4); + return; + case T_U64: + case T_I64: + case T_DOUBLE: + transport.skip(8); + return; + //case T_UTF7: // aliases T_STRING + case T_UTF8: + case T_UTF16: + case T_STRING: { + uint32_t len = transport.readU32(); + transport.skip(len); + } return; + case T_MAP: { + int8_t keytype = transport.readI8(); + int8_t valtype = transport.readI8(); + uint32_t size = transport.readU32(); + for (uint32_t i = 0; i < size; ++i) { + skip_element(keytype, transport); + skip_element(valtype, transport); + } + } return; + case T_LIST: + case T_SET: { + int8_t valtype = transport.readI8(); + uint32_t size = transport.readU32(); + for (uint32_t i = 0; i < size; ++i) { + skip_element(valtype, transport); + } + } return; + }; + + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %ld", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + +void binary_serialize_hashtable_key(int8_t keytype, PHPOutputTransport& transport, HashTable* ht, HashPosition& ht_pos) { + bool keytype_is_numeric = (!((keytype == T_STRING) || (keytype == T_UTF8) || (keytype == T_UTF16))); + + char* key; + uint key_len; + long index = 0; + + zval* z; + MAKE_STD_ZVAL(z); + + int res = zend_hash_get_current_key_ex(ht, &key, &key_len, (ulong*)&index, 0, &ht_pos); + if (keytype_is_numeric) { + if (res == HASH_KEY_IS_STRING) { + index = strtol(key, NULL, 10); + } + ZVAL_LONG(z, index); + } else { + char buf[64]; + if (res == HASH_KEY_IS_STRING) { + key_len -= 1; // skip the null terminator + } else { + sprintf(buf, "%ld", index); + key = buf; key_len = strlen(buf); + } + ZVAL_STRINGL(z, key, key_len, 1); + } + binary_serialize(keytype, transport, &z, NULL); + zval_ptr_dtor(&z); +} + +inline bool ttype_is_int(int8_t t) { + return ((t == T_BYTE) || ((t >= T_I16) && (t <= T_I64))); +} + +inline bool ttypes_are_compatible(int8_t t1, int8_t t2) { + // Integer types of different widths are considered compatible; + // otherwise the typeID must match. + return ((t1 == t2) || (ttype_is_int(t1) && ttype_is_int(t2))); +} + +void binary_deserialize_spec(zval* zthis, PHPInputTransport& transport, HashTable* spec) { + // SET and LIST have 'elem' => array('type', [optional] 'class') + // MAP has 'val' => array('type', [optiona] 'class') + TSRMLS_FETCH(); + zend_class_entry* ce = zend_get_class_entry(zthis TSRMLS_CC); + while (true) { + zval** val_ptr = NULL; + + int8_t ttype = transport.readI8(); + if (ttype == T_STOP) return; + int16_t fieldno = transport.readI16(); + if (zend_hash_index_find(spec, fieldno, (void**)&val_ptr) == SUCCESS) { + HashTable* fieldspec = Z_ARRVAL_PP(val_ptr); + // pull the field name + // zend hash tables use the null at the end in the length... so strlen(hash key) + 1. + zend_hash_find(fieldspec, "var", 4, (void**)&val_ptr); + char* varname = Z_STRVAL_PP(val_ptr); + + // and the type + zend_hash_find(fieldspec, "type", 5, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + int8_t expected_ttype = Z_LVAL_PP(val_ptr); + + if (ttypes_are_compatible(ttype, expected_ttype)) { + zval* rv = NULL; + MAKE_STD_ZVAL(rv); + binary_deserialize(ttype, transport, rv, fieldspec); + zend_update_property(ce, zthis, varname, strlen(varname), rv TSRMLS_CC); + zval_ptr_dtor(&rv); + } else { + skip_element(ttype, transport); + } + } else { + skip_element(ttype, transport); + } + } +} + +void binary_serialize(int8_t thrift_typeID, PHPOutputTransport& transport, zval** value, HashTable* fieldspec) { + // At this point the typeID (and field num, if applicable) should've already been written to the output so all we need to do is write the payload. + switch (thrift_typeID) { + case T_STOP: + case T_VOID: + return; + case T_STRUCT: { + TSRMLS_FETCH(); + if (Z_TYPE_PP(value) != IS_OBJECT) { + throw_tprotocolexception("Attempt to send non-object type as a T_STRUCT", INVALID_DATA); + } + zval* spec = zend_read_static_property(zend_get_class_entry(*value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_serialize_spec(*value, transport, Z_ARRVAL_P(spec)); + } return; + case T_BOOL: + if (Z_TYPE_PP(value) != IS_BOOL) convert_to_boolean(*value); + transport.writeI8(Z_BVAL_PP(value) ? 1 : 0); + return; + case T_BYTE: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI8(Z_LVAL_PP(value)); + return; + case T_I16: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI16(Z_LVAL_PP(value)); + return; + case T_I32: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI32(Z_LVAL_PP(value)); + return; + case T_I64: + case T_U64: + if (Z_TYPE_PP(value) != IS_LONG) convert_to_long(*value); + transport.writeI64(Z_LVAL_PP(value)); + return; + case T_DOUBLE: { + union { + int64_t c; + double d; + } a; + if (Z_TYPE_PP(value) != IS_DOUBLE) convert_to_double(*value); + a.d = Z_DVAL_PP(value); + transport.writeI64(a.c); + } return; + //case T_UTF7: + case T_UTF8: + case T_UTF16: + case T_STRING: + if (Z_TYPE_PP(value) != IS_STRING) convert_to_string(*value); + transport.writeString(Z_STRVAL_PP(value), Z_STRLEN_PP(value)); + return; + case T_MAP: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_MAP)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "ktype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t keytype = Z_LVAL_PP(val_ptr); + transport.writeI8(keytype); + zend_hash_find(fieldspec, "vtype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t valtype = Z_LVAL_PP(val_ptr); + transport.writeI8(valtype); + + zend_hash_find(fieldspec, "val", 4, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize_hashtable_key(keytype, transport, ht, key_ptr); + binary_serialize(valtype, transport, val_ptr, valspec); + } + } return; + case T_LIST: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_LIST)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "etype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t valtype = Z_LVAL_PP(val_ptr); + transport.writeI8(valtype); + + zend_hash_find(fieldspec, "elem", 5, (void**)&val_ptr); + HashTable* valspec = Z_ARRVAL_PP(val_ptr); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize(valtype, transport, val_ptr, valspec); + } + } return; + case T_SET: { + if (Z_TYPE_PP(value) != IS_ARRAY) convert_to_array(*value); + if (Z_TYPE_PP(value) != IS_ARRAY) { + throw_tprotocolexception("Attempt to send an incompatible type as an array (T_SET)", INVALID_DATA); + } + HashTable* ht = Z_ARRVAL_PP(value); + zval** val_ptr; + + zend_hash_find(fieldspec, "etype", 6, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + uint8_t keytype = Z_LVAL_PP(val_ptr); + transport.writeI8(keytype); + + transport.writeI32(zend_hash_num_elements(ht)); + HashPosition key_ptr; + for (zend_hash_internal_pointer_reset_ex(ht, &key_ptr); zend_hash_get_current_data_ex(ht, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(ht, &key_ptr)) { + binary_serialize_hashtable_key(keytype, transport, ht, key_ptr); + } + } return; + }; + char errbuf[128]; + sprintf(errbuf, "Unknown thrift typeID %d", thrift_typeID); + throw_tprotocolexception(errbuf, INVALID_DATA); +} + + +void binary_serialize_spec(zval* zthis, PHPOutputTransport& transport, HashTable* spec) { + HashPosition key_ptr; + zval** val_ptr; + + TSRMLS_FETCH(); + zend_class_entry* ce = zend_get_class_entry(zthis TSRMLS_CC); + + for (zend_hash_internal_pointer_reset_ex(spec, &key_ptr); zend_hash_get_current_data_ex(spec, (void**)&val_ptr, &key_ptr) == SUCCESS; zend_hash_move_forward_ex(spec, &key_ptr)) { + ulong fieldno; + if (zend_hash_get_current_key_ex(spec, NULL, NULL, &fieldno, 0, &key_ptr) != HASH_KEY_IS_LONG) { + throw_tprotocolexception("Bad keytype in TSPEC (expected 'long')", INVALID_DATA); + return; + } + HashTable* fieldspec = Z_ARRVAL_PP(val_ptr); + + // field name + zend_hash_find(fieldspec, "var", 4, (void**)&val_ptr); + char* varname = Z_STRVAL_PP(val_ptr); + + // thrift type + zend_hash_find(fieldspec, "type", 5, (void**)&val_ptr); + if (Z_TYPE_PP(val_ptr) != IS_LONG) convert_to_long(*val_ptr); + int8_t ttype = Z_LVAL_PP(val_ptr); + + zval* prop = zend_read_property(ce, zthis, varname, strlen(varname), false TSRMLS_CC); + if (Z_TYPE_P(prop) != IS_NULL) { + transport.writeI8(ttype); + transport.writeI16(fieldno); + binary_serialize(ttype, transport, &prop, fieldspec); + } + } + transport.writeI8(T_STOP); // struct end +} + +// 6 params: $transport $method_name $ttype $request_struct $seqID $strict_write +PHP_FUNCTION(thrift_protocol_write_binary) { + int argc = ZEND_NUM_ARGS(); + if (argc < 6) { + WRONG_PARAM_COUNT; + } + + zval ***args = (zval***) emalloc(argc * sizeof(zval**)); + zend_get_parameters_array_ex(argc, args); + + if (Z_TYPE_PP(args[0]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "1st parameter is not an object (transport)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[1]) != IS_STRING) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "2nd parameter is not a string (method name)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[3]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "4th parameter is not an object (request struct)"); + efree(args); + RETURN_NULL(); + } + + PHPOutputTransport transport(*args[0]); + const char* method_name = Z_STRVAL_PP(args[1]); + convert_to_long(*args[2]); + int32_t msgtype = Z_LVAL_PP(args[2]); + zval* request_struct = *args[3]; + convert_to_long(*args[4]); + int32_t seqID = Z_LVAL_PP(args[4]); + convert_to_boolean(*args[5]); + bool strictWrite = Z_BVAL_PP(args[5]); + efree(args); + args = NULL; + + try { + if (strictWrite) { + int32_t version = VERSION_1 | msgtype; + transport.writeI32(version); + transport.writeString(method_name, strlen(method_name)); + transport.writeI32(seqID); + } else { + transport.writeString(method_name, strlen(method_name)); + transport.writeI8(msgtype); + transport.writeI32(seqID); + } + + zval* spec = zend_read_static_property(zend_get_class_entry(request_struct TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_serialize_spec(request_struct, transport, Z_ARRVAL_P(spec)); + } catch (const PHPExceptionWrapper& ex) { + zend_throw_exception_object(ex TSRMLS_CC); + RETURN_NULL(); + } +} + +// 3 params: $transport $response_Typename $strict_read +PHP_FUNCTION(thrift_protocol_read_binary) { + int argc = ZEND_NUM_ARGS(); + + if (argc < 3) { + WRONG_PARAM_COUNT; + } + + zval ***args = (zval***) emalloc(argc * sizeof(zval**)); + zend_get_parameters_array_ex(argc, args); + + if (Z_TYPE_PP(args[0]) != IS_OBJECT) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "1st parameter is not an object (transport)"); + efree(args); + RETURN_NULL(); + } + + if (Z_TYPE_PP(args[1]) != IS_STRING) { + php_error_docref(NULL TSRMLS_CC, E_ERROR, "2nd parameter is not a string (typename of expected response struct)"); + efree(args); + RETURN_NULL(); + } + + PHPInputTransport transport(*args[0]); + char* obj_typename = Z_STRVAL_PP(args[1]); + convert_to_boolean(*args[2]); + bool strict_read = Z_BVAL_PP(args[2]); + efree(args); + args = NULL; + + try { + int8_t messageType = 0; + int32_t sz = transport.readI32(); + + if (sz < 0) { + // Check for correct version number + int32_t version = sz & VERSION_MASK; + if (version != VERSION_1) { + throw_tprotocolexception("Bad version identifier", BAD_VERSION); + } + messageType = (sz & 0x000000ff); + int32_t namelen = transport.readI32(); + // skip the name string and the sequence ID, we don't care about those + transport.skip(namelen + 4); + } else { + if (strict_read) { + throw_tprotocolexception("No version identifier... old protocol client in strict mode?", BAD_VERSION); + } else { + // Handle pre-versioned input + transport.skip(sz); // skip string body + messageType = transport.readI8(); + transport.skip(4); // skip sequence number + } + } + + if (messageType == T_EXCEPTION) { + zval* ex; + MAKE_STD_ZVAL(ex); + createObject("TApplicationException", ex); + zval* spec = zend_read_static_property(zend_get_class_entry(ex TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_deserialize_spec(ex, transport, Z_ARRVAL_P(spec)); + throw PHPExceptionWrapper(ex); + } + + createObject(obj_typename, return_value); + zval* spec = zend_read_static_property(zend_get_class_entry(return_value TSRMLS_CC), "_TSPEC", 6, false TSRMLS_CC); + binary_deserialize_spec(return_value, transport, Z_ARRVAL_P(spec)); + } catch (const PHPExceptionWrapper& ex) { + zend_throw_exception_object(ex TSRMLS_CC); + RETURN_NULL(); + } +} + diff --git a/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h new file mode 100644 index 00000000..c9a3e00f --- /dev/null +++ b/lib/php/src/ext/thrift_protocol/php_thrift_protocol.h @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +PHP_FUNCTION(thrift_protocol_write_binary); +PHP_FUNCTION(thrift_protocol_read_binary); + +extern zend_module_entry thrift_protocole_module_entry; + diff --git a/lib/php/src/protocol/TBinaryProtocol.php b/lib/php/src/protocol/TBinaryProtocol.php new file mode 100644 index 00000000..31bbbf9d --- /dev/null +++ b/lib/php/src/protocol/TBinaryProtocol.php @@ -0,0 +1,431 @@ +strictRead_ = $strictRead; + $this->strictWrite_ = $strictWrite; + } + + public function writeMessageBegin($name, $type, $seqid) { + if ($this->strictWrite_) { + $version = self::VERSION_1 | $type; + return + $this->writeI32($version) + + $this->writeString($name) + + $this->writeI32($seqid); + } else { + return + $this->writeString($name) + + $this->writeByte($type) + + $this->writeI32($seqid); + } + } + + public function writeMessageEnd() { + return 0; + } + + public function writeStructBegin($name) { + return 0; + } + + public function writeStructEnd() { + return 0; + } + + public function writeFieldBegin($fieldName, $fieldType, $fieldId) { + return + $this->writeByte($fieldType) + + $this->writeI16($fieldId); + } + + public function writeFieldEnd() { + return 0; + } + + public function writeFieldStop() { + return + $this->writeByte(TType::STOP); + } + + public function writeMapBegin($keyType, $valType, $size) { + return + $this->writeByte($keyType) + + $this->writeByte($valType) + + $this->writeI32($size); + } + + public function writeMapEnd() { + return 0; + } + + public function writeListBegin($elemType, $size) { + return + $this->writeByte($elemType) + + $this->writeI32($size); + } + + public function writeListEnd() { + return 0; + } + + public function writeSetBegin($elemType, $size) { + return + $this->writeByte($elemType) + + $this->writeI32($size); + } + + public function writeSetEnd() { + return 0; + } + + public function writeBool($value) { + $data = pack('c', $value ? 1 : 0); + $this->trans_->write($data, 1); + return 1; + } + + public function writeByte($value) { + $data = pack('c', $value); + $this->trans_->write($data, 1); + return 1; + } + + public function writeI16($value) { + $data = pack('n', $value); + $this->trans_->write($data, 2); + return 2; + } + + public function writeI32($value) { + $data = pack('N', $value); + $this->trans_->write($data, 4); + return 4; + } + + public function writeI64($value) { + // If we are on a 32bit architecture we have to explicitly deal with + // 64-bit twos-complement arithmetic since PHP wants to treat all ints + // as signed and any int over 2^31 - 1 as a float + if (PHP_INT_SIZE == 4) { + $neg = $value < 0; + + if ($neg) { + $value *= -1; + } + + $hi = (int)($value / 4294967296); + $lo = (int)$value; + + if ($neg) { + $hi = ~$hi; + $lo = ~$lo; + if (($lo & (int)0xffffffff) == (int)0xffffffff) { + $lo = 0; + $hi++; + } else { + $lo++; + } + } + $data = pack('N2', $hi, $lo); + + } else { + $hi = $value >> 32; + $lo = $value & 0xFFFFFFFF; + $data = pack('N2', $hi, $lo); + } + + $this->trans_->write($data, 8); + return 8; + } + + public function writeDouble($value) { + $data = pack('d', $value); + $this->trans_->write(strrev($data), 8); + return 8; + } + + public function writeString($value) { + $len = strlen($value); + $result = $this->writeI32($len); + if ($len) { + $this->trans_->write($value, $len); + } + return $result + $len; + } + + public function readMessageBegin(&$name, &$type, &$seqid) { + $result = $this->readI32($sz); + if ($sz < 0) { + $version = (int) ($sz & self::VERSION_MASK); + if ($version != (int) self::VERSION_1) { + throw new TProtocolException('Bad version identifier: '.$sz, TProtocolException::BAD_VERSION); + } + $type = $sz & 0x000000ff; + $result += + $this->readString($name) + + $this->readI32($seqid); + } else { + if ($this->strictRead_) { + throw new TProtocolException('No version identifier, old protocol client?', TProtocolException::BAD_VERSION); + } else { + // Handle pre-versioned input + $name = $this->trans_->readAll($sz); + $result += + $sz + + $this->readByte($type) + + $this->readI32($seqid); + } + } + return $result; + } + + public function readMessageEnd() { + return 0; + } + + public function readStructBegin(&$name) { + $name = ''; + return 0; + } + + public function readStructEnd() { + return 0; + } + + public function readFieldBegin(&$name, &$fieldType, &$fieldId) { + $result = $this->readByte($fieldType); + if ($fieldType == TType::STOP) { + $fieldId = 0; + return $result; + } + $result += $this->readI16($fieldId); + return $result; + } + + public function readFieldEnd() { + return 0; + } + + public function readMapBegin(&$keyType, &$valType, &$size) { + return + $this->readByte($keyType) + + $this->readByte($valType) + + $this->readI32($size); + } + + public function readMapEnd() { + return 0; + } + + public function readListBegin(&$elemType, &$size) { + return + $this->readByte($elemType) + + $this->readI32($size); + } + + public function readListEnd() { + return 0; + } + + public function readSetBegin(&$elemType, &$size) { + return + $this->readByte($elemType) + + $this->readI32($size); + } + + public function readSetEnd() { + return 0; + } + + public function readBool(&$value) { + $data = $this->trans_->readAll(1); + $arr = unpack('c', $data); + $value = $arr[1] == 1; + return 1; + } + + public function readByte(&$value) { + $data = $this->trans_->readAll(1); + $arr = unpack('c', $data); + $value = $arr[1]; + return 1; + } + + public function readI16(&$value) { + $data = $this->trans_->readAll(2); + $arr = unpack('n', $data); + $value = $arr[1]; + if ($value > 0x7fff) { + $value = 0 - (($value - 1) ^ 0xffff); + } + return 2; + } + + public function readI32(&$value) { + $data = $this->trans_->readAll(4); + $arr = unpack('N', $data); + $value = $arr[1]; + if ($value > 0x7fffffff) { + $value = 0 - (($value - 1) ^ 0xffffffff); + } + return 4; + } + + public function readI64(&$value) { + $data = $this->trans_->readAll(8); + + $arr = unpack('N2', $data); + + // If we are on a 32bit architecture we have to explicitly deal with + // 64-bit twos-complement arithmetic since PHP wants to treat all ints + // as signed and any int over 2^31 - 1 as a float + if (PHP_INT_SIZE == 4) { + + $hi = $arr[1]; + $lo = $arr[2]; + $isNeg = $hi < 0; + + // Check for a negative + if ($isNeg) { + $hi = ~$hi & (int)0xffffffff; + $lo = ~$lo & (int)0xffffffff; + + if ($lo == (int)0xffffffff) { + $hi++; + $lo = 0; + } else { + $lo++; + } + } + + // Force 32bit words in excess of 2G to pe positive - we deal wigh sign + // explicitly below + + if ($hi & (int)0x80000000) { + $hi &= (int)0x7fffffff; + $hi += 0x80000000; + } + + if ($lo & (int)0x80000000) { + $lo &= (int)0x7fffffff; + $lo += 0x80000000; + } + + $value = $hi * 4294967296 + $lo; + + if ($isNeg) { + $value = 0 - $value; + } + } else { + + // Upcast negatives in LSB bit + if ($arr[2] & 0x80000000) { + $arr[2] = $arr[2] & 0xffffffff; + } + + // Check for a negative + if ($arr[1] & 0x80000000) { + $arr[1] = $arr[1] & 0xffffffff; + $arr[1] = $arr[1] ^ 0xffffffff; + $arr[2] = $arr[2] ^ 0xffffffff; + $value = 0 - $arr[1]*4294967296 - $arr[2] - 1; + } else { + $value = $arr[1]*4294967296 + $arr[2]; + } + } + + return 8; + } + + public function readDouble(&$value) { + $data = strrev($this->trans_->readAll(8)); + $arr = unpack('d', $data); + $value = $arr[1]; + return 8; + } + + public function readString(&$value) { + $result = $this->readI32($len); + if ($len) { + $value = $this->trans_->readAll($len); + } else { + $value = ''; + } + return $result + $len; + } +} + +/** + * Binary Protocol Factory + */ +class TBinaryProtocolFactory implements TProtocolFactory { + private $strictRead_ = false; + private $strictWrite_ = false; + + public function __construct($strictRead=false, $strictWrite=false) { + $this->strictRead_ = $strictRead; + $this->strictWrite_ = $strictWrite; + } + + public function getProtocol($trans) { + return new TBinaryProtocol($trans, $this->strictRead, $this->strictWrite); + } +} + +/** + * Accelerated binary protocol: used in conjunction with the thrift_protocol + * extension for faster deserialization + */ +class TBinaryProtocolAccelerated extends TBinaryProtocol { + public function __construct($trans, $strictRead=false, $strictWrite=true) { + // If the transport doesn't implement putBack, wrap it in a + // TBufferedTransport (which does) + if (!method_exists($trans, 'putBack')) { + $trans = new TBufferedTransport($trans); + } + parent::__construct($trans, $strictRead, $strictWrite); + } + public function isStrictRead() { + return $this->strictRead_; + } + public function isStrictWrite() { + return $this->strictWrite_; + } +} + +?> diff --git a/lib/php/src/protocol/TProtocol.php b/lib/php/src/protocol/TProtocol.php new file mode 100644 index 00000000..e9ff41a3 --- /dev/null +++ b/lib/php/src/protocol/TProtocol.php @@ -0,0 +1,377 @@ +trans_ = $trans; + } + + /** + * Accessor for transport + * + * @return TTransport + */ + public function getTransport() { + return $this->trans_; + } + + /** + * Writes the message header + * + * @param string $name Function name + * @param int $type message type TMessageType::CALL or TMessageType::REPLY + * @param int $seqid The sequence id of this message + */ + public abstract function writeMessageBegin($name, $type, $seqid); + + /** + * Close the message + */ + public abstract function writeMessageEnd(); + + /** + * Writes a struct header. + * + * @param string $name Struct name + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeStructBegin($name); + + /** + * Close a struct. + * + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeStructEnd(); + + /* + * Starts a field. + * + * @param string $name Field name + * @param int $type Field type + * @param int $fid Field id + * @throws TException on write error + * @return int How many bytes written + */ + public abstract function writeFieldBegin($fieldName, $fieldType, $fieldId); + + public abstract function writeFieldEnd(); + + public abstract function writeFieldStop(); + + public abstract function writeMapBegin($keyType, $valType, $size); + + public abstract function writeMapEnd(); + + public abstract function writeListBegin($elemType, $size); + + public abstract function writeListEnd(); + + public abstract function writeSetBegin($elemType, $size); + + public abstract function writeSetEnd(); + + public abstract function writeBool($bool); + + public abstract function writeByte($byte); + + public abstract function writeI16($i16); + + public abstract function writeI32($i32); + + public abstract function writeI64($i64); + + public abstract function writeDouble($dub); + + public abstract function writeString($str); + + /** + * Reads the message header + * + * @param string $name Function name + * @param int $type message type TMessageType::CALL or TMessageType::REPLY + * @parem int $seqid The sequence id of this message + */ + public abstract function readMessageBegin(&$name, &$type, &$seqid); + + /** + * Read the close of message + */ + public abstract function readMessageEnd(); + + public abstract function readStructBegin(&$name); + + public abstract function readStructEnd(); + + public abstract function readFieldBegin(&$name, &$fieldType, &$fieldId); + + public abstract function readFieldEnd(); + + public abstract function readMapBegin(&$keyType, &$valType, &$size); + + public abstract function readMapEnd(); + + public abstract function readListBegin(&$elemType, &$size); + + public abstract function readListEnd(); + + public abstract function readSetBegin(&$elemType, &$size); + + public abstract function readSetEnd(); + + public abstract function readBool(&$bool); + + public abstract function readByte(&$byte); + + public abstract function readI16(&$i16); + + public abstract function readI32(&$i32); + + public abstract function readI64(&$i64); + + public abstract function readDouble(&$dub); + + public abstract function readString(&$str); + + /** + * The skip function is a utility to parse over unrecognized date without + * causing corruption. + * + * @param TType $type What type is it + */ + public function skip($type) { + switch ($type) { + case TType::BOOL: + return $this->readBool($bool); + case TType::BYTE: + return $this->readByte($byte); + case TType::I16: + return $this->readI16($i16); + case TType::I32: + return $this->readI32($i32); + case TType::I64: + return $this->readI64($i64); + case TType::DOUBLE: + return $this->readDouble($dub); + case TType::STRING: + return $this->readString($str); + case TType::STRUCT: + { + $result = $this->readStructBegin($name); + while (true) { + $result += $this->readFieldBegin($name, $ftype, $fid); + if ($ftype == TType::STOP) { + break; + } + $result += $this->skip($ftype); + $result += $this->readFieldEnd(); + } + $result += $this->readStructEnd(); + return $result; + } + case TType::MAP: + { + $result = $this->readMapBegin($keyType, $valType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($keyType); + $result += $this->skip($valType); + } + $result += $this->readMapEnd(); + return $result; + } + case TType::SET: + { + $result = $this->readSetBegin($elemType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($elemType); + } + $result += $this->readSetEnd(); + return $result; + } + case TType::LST: + { + $result = $this->readListBegin($elemType, $size); + for ($i = 0; $i < $size; $i++) { + $result += $this->skip($elemType); + } + $result += $this->readListEnd(); + return $result; + } + default: + return 0; + } + } + + /** + * Utility for skipping binary data + * + * @param TTransport $itrans TTransport object + * @param int $type Field type + */ + public static function skipBinary($itrans, $type) { + switch ($type) { + case TType::BOOL: + return $itrans->readAll(1); + case TType::BYTE: + return $itrans->readAll(1); + case TType::I16: + return $itrans->readAll(2); + case TType::I32: + return $itrans->readAll(4); + case TType::I64: + return $itrans->readAll(8); + case TType::DOUBLE: + return $itrans->readAll(8); + case TType::STRING: + $len = unpack('N', $itrans->readAll(4)); + $len = $len[1]; + if ($len > 0x7fffffff) { + $len = 0 - (($len - 1) ^ 0xffffffff); + } + return 4 + $itrans->readAll($len); + case TType::STRUCT: + { + $result = 0; + while (true) { + $ftype = 0; + $fid = 0; + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $ftype = $arr[1]; + if ($ftype == TType::STOP) { + break; + } + // I16 field id + $result += $itrans->readAll(2); + $result += self::skipBinary($itrans, $ftype); + } + return $result; + } + case TType::MAP: + { + // Ktype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $ktype = $arr[1]; + // Vtype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $vtype = $arr[1]; + // Size + $data = $itrans->readAll(4); + $arr = unpack('N', $data); + $size = $arr[1]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + $result = 6; + for ($i = 0; $i < $size; $i++) { + $result += self::skipBinary($itrans, $ktype); + $result += self::skipBinary($itrans, $vtype); + } + return $result; + } + case TType::SET: + case TType::LST: + { + // Vtype + $data = $itrans->readAll(1); + $arr = unpack('c', $data); + $vtype = $arr[1]; + // Size + $data = $itrans->readAll(4); + $arr = unpack('N', $data); + $size = $arr[1]; + if ($size > 0x7fffffff) { + $size = 0 - (($size - 1) ^ 0xffffffff); + } + $result = 5; + for ($i = 0; $i < $size; $i++) { + $result += self::skipBinary($itrans, $vtype); + } + return $result; + } + default: + return 0; + } + } +} + +/** + * Protocol factory creates protocol objects from transports + */ +interface TProtocolFactory { + /** + * Build a protocol from the base transport + * + * @return TProtcol protocol + */ + public function getProtocol($trans); +} + + +?> diff --git a/lib/php/src/transport/TBufferedTransport.php b/lib/php/src/transport/TBufferedTransport.php new file mode 100644 index 00000000..cfae767e --- /dev/null +++ b/lib/php/src/transport/TBufferedTransport.php @@ -0,0 +1,163 @@ +transport_ = $transport; + $this->rBufSize_ = $rBufSize; + $this->wBufSize_ = $wBufSize; + } + + /** + * The underlying transport + * + * @var TTransport + */ + protected $transport_ = null; + + /** + * The receive buffer size + * + * @var int + */ + protected $rBufSize_ = 512; + + /** + * The write buffer size + * + * @var int + */ + protected $wBufSize_ = 512; + + /** + * The write buffer. + * + * @var string + */ + protected $wBuf_ = ''; + + /** + * The read buffer. + * + * @var string + */ + protected $rBuf_ = ''; + + public function isOpen() { + return $this->transport_->isOpen(); + } + + public function open() { + $this->transport_->open(); + } + + public function close() { + $this->transport_->close(); + } + + public function putBack($data) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $data; + } else { + $this->rBuf_ = ($data . $this->rBuf_); + } + } + + /** + * The reason that we customize readAll here is that the majority of PHP + * streams are already internally buffered by PHP. The socket stream, for + * example, buffers internally and blocks if you call read with $len greater + * than the amount of data available, unlike recv() in C. + * + * Therefore, use the readAll method of the wrapped transport inside + * the buffered readAll. + */ + public function readAll($len) { + $have = strlen($this->rBuf_); + if ($have == 0) { + $data = $this->transport_->readAll($len); + } else if ($have < $len) { + $data = $this->rBuf_; + $this->rBuf_ = ''; + $data .= $this->transport_->readAll($len - $have); + } else if ($have == $len) { + $data = $this->rBuf_; + $this->rBuf_ = ''; + } else if ($have > $len) { + $data = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + } + return $data; + } + + public function read($len) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $this->transport_->read($this->rBufSize_); + } + + if (strlen($this->rBuf_) <= $len) { + $ret = $this->rBuf_; + $this->rBuf_ = ''; + return $ret; + } + + $ret = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + return $ret; + } + + public function write($buf) { + $this->wBuf_ .= $buf; + if (strlen($this->wBuf_) >= $this->wBufSize_) { + $out = $this->wBuf_; + + // Note that we clear the internal wBuf_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + $this->wBuf_ = ''; + $this->transport_->write($out); + } + } + + public function flush() { + if (strlen($this->wBuf_) > 0) { + $this->transport_->write($this->wBuf_); + $this->wBuf_ = ''; + } + $this->transport_->flush(); + } + +} + +?> diff --git a/lib/php/src/transport/TFramedTransport.php b/lib/php/src/transport/TFramedTransport.php new file mode 100644 index 00000000..dc57392f --- /dev/null +++ b/lib/php/src/transport/TFramedTransport.php @@ -0,0 +1,179 @@ +transport_ = $transport; + $this->read_ = $read; + $this->write_ = $write; + } + + public function isOpen() { + return $this->transport_->isOpen(); + } + + public function open() { + $this->transport_->open(); + } + + public function close() { + $this->transport_->close(); + } + + /** + * Reads from the buffer. When more data is required reads another entire + * chunk and serves future reads out of that. + * + * @param int $len How much data + */ + public function read($len) { + if (!$this->read_) { + return $this->transport_->read($len); + } + + if (strlen($this->rBuf_) === 0) { + $this->readFrame(); + } + + // Just return full buff + if ($len >= strlen($this->rBuf_)) { + $out = $this->rBuf_; + $this->rBuf_ = null; + return $out; + } + + // Return substr + $out = substr($this->rBuf_, 0, $len); + $this->rBuf_ = substr($this->rBuf_, $len); + return $out; + } + + /** + * Put previously read data back into the buffer + * + * @param string $data data to return + */ + public function putBack($data) { + if (strlen($this->rBuf_) === 0) { + $this->rBuf_ = $data; + } else { + $this->rBuf_ = ($data . $this->rBuf_); + } + } + + /** + * Reads a chunk of data into the internal read buffer. + */ + private function readFrame() { + $buf = $this->transport_->readAll(4); + $val = unpack('N', $buf); + $sz = $val[1]; + + $this->rBuf_ = $this->transport_->readAll($sz); + } + + /** + * Writes some data to the pending output buffer. + * + * @param string $buf The data + * @param int $len Limit of bytes to write + */ + public function write($buf, $len=null) { + if (!$this->write_) { + return $this->transport_->write($buf, $len); + } + + if ($len !== null && $len < strlen($buf)) { + $buf = substr($buf, 0, $len); + } + $this->wBuf_ .= $buf; + } + + /** + * Writes the output buffer to the stream in the format of a 4-byte length + * followed by the actual data. + */ + public function flush() { + if (!$this->write_) { + return $this->transport_->flush(); + } + + $out = pack('N', strlen($this->wBuf_)); + $out .= $this->wBuf_; + + // Note that we clear the internal wBuf_ prior to the underlying write + // to ensure we're in a sane state (i.e. internal buffer cleaned) + // if the underlying write throws up an exception + $this->wBuf_ = ''; + $this->transport_->write($out); + $this->transport_->flush(); + } + +} diff --git a/lib/php/src/transport/THttpClient.php b/lib/php/src/transport/THttpClient.php new file mode 100644 index 00000000..224d403b --- /dev/null +++ b/lib/php/src/transport/THttpClient.php @@ -0,0 +1,202 @@ + 0) && ($uri{0} != '/')) { + $uri = '/'.$uri; + } + $this->scheme_ = $scheme; + $this->host_ = $host; + $this->port_ = $port; + $this->uri_ = $uri; + $this->buf_ = ''; + $this->handle_ = null; + $this->timeout_ = null; + } + + /** + * Set read timeout + * + * @param float $timeout + */ + public function setTimeoutSecs($timeout) { + $this->timeout_ = $timeout; + } + + /** + * Whether this transport is open. + * + * @return boolean true if open + */ + public function isOpen() { + return true; + } + + /** + * Open the transport for reading/writing + * + * @throws TTransportException if cannot open + */ + public function open() {} + + /** + * Close the transport. + */ + public function close() { + if ($this->handle_) { + @fclose($this->handle_); + $this->handle_ = null; + } + } + + /** + * Read some data into the array. + * + * @param int $len How much to read + * @return string The data that has been read + * @throws TTransportException if cannot read any more data + */ + public function read($len) { + $data = @fread($this->handle_, $len); + if ($data === FALSE || $data === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TTransportException('THttpClient: timed out reading '.$len.' bytes from '.$this->host_.':'.$this->port_.'/'.$this->uri_, TTransportException::TIMED_OUT); + } else { + throw new TTransportException('THttpClient: Could not read '.$len.' bytes from '.$this->host_.':'.$this->port_.'/'.$this->uri_, TTransportException::UNKNOWN); + } + } + return $data; + } + + /** + * Writes some data into the pending buffer + * + * @param string $buf The data to write + * @throws TTransportException if writing fails + */ + public function write($buf) { + $this->buf_ .= $buf; + } + + /** + * Opens and sends the actual request over the HTTP connection + * + * @throws TTransportException if a writing error occurs + */ + public function flush() { + // God, PHP really has some esoteric ways of doing simple things. + $host = $this->host_.($this->port_ != 80 ? ':'.$this->port_ : ''); + + $headers = array('Host: '.$host, + 'Accept: application/x-thrift', + 'User-Agent: PHP/THttpClient', + 'Content-Type: application/x-thrift', + 'Content-Length: '.strlen($this->buf_)); + + $options = array('method' => 'POST', + 'header' => implode("\r\n", $headers), + 'max_redirects' => 1, + 'content' => $this->buf_); + if ($this->timeout_ > 0) { + $options['timeout'] = $this->timeout_; + } + $this->buf_ = ''; + + $contextid = stream_context_create(array('http' => $options)); + $this->handle_ = @fopen($this->scheme_.'://'.$host.$this->uri_, 'r', false, $contextid); + + // Connect failed? + if ($this->handle_ === FALSE) { + $this->handle_ = null; + $error = 'THttpClient: Could not connect to '.$host.$this->uri_; + throw new TTransportException($error, TTransportException::NOT_OPEN); + } + } + +} + +?> diff --git a/lib/php/src/transport/TMemoryBuffer.php b/lib/php/src/transport/TMemoryBuffer.php new file mode 100644 index 00000000..01eb0f5a --- /dev/null +++ b/lib/php/src/transport/TMemoryBuffer.php @@ -0,0 +1,84 @@ +buf_ = $buf; + } + + protected $buf_ = ''; + + public function isOpen() { + return true; + } + + public function open() {} + + public function close() {} + + public function write($buf) { + $this->buf_ .= $buf; + } + + public function read($len) { + if (strlen($this->buf_) === 0) { + throw new TTransportException('TMemoryBuffer: Could not read ' . + $len . ' bytes from buffer.', + TTransportException::UNKNOWN); + } + + if (strlen($this->buf_) <= $len) { + $ret = $this->buf_; + $this->buf_ = ''; + return $ret; + } + + $ret = substr($this->buf_, 0, $len); + $this->buf_ = substr($this->buf_, $len); + + return $ret; + } + + function getBuffer() { + return $this->buf_; + } + + public function available() { + return strlen($this->buf_); + } +} + +?> diff --git a/lib/php/src/transport/TNullTransport.php b/lib/php/src/transport/TNullTransport.php new file mode 100644 index 00000000..bada5dfb --- /dev/null +++ b/lib/php/src/transport/TNullTransport.php @@ -0,0 +1,48 @@ + diff --git a/lib/php/src/transport/TPhpStream.php b/lib/php/src/transport/TPhpStream.php new file mode 100644 index 00000000..3a1c80b8 --- /dev/null +++ b/lib/php/src/transport/TPhpStream.php @@ -0,0 +1,111 @@ +read_ = $mode & self::MODE_R; + $this->write_ = $mode & self::MODE_W; + } + + public function open() { + if ($this->read_) { + $this->inStream_ = @fopen(self::inStreamName(), 'r'); + if (!is_resource($this->inStream_)) { + throw new TException('TPhpStream: Could not open php://input'); + } + } + if ($this->write_) { + $this->outStream_ = @fopen('php://output', 'w'); + if (!is_resource($this->outStream_)) { + throw new TException('TPhpStream: Could not open php://output'); + } + } + } + + public function close() { + if ($this->read_) { + @fclose($this->inStream_); + $this->inStream_ = null; + } + if ($this->write_) { + @fclose($this->outStream_); + $this->outStream_ = null; + } + } + + public function isOpen() { + return + (!$this->read_ || is_resource($this->inStream_)) && + (!$this->write_ || is_resource($this->outStream_)); + } + + public function read($len) { + $data = @fread($this->inStream_, $len); + if ($data === FALSE || $data === '') { + throw new TException('TPhpStream: Could not read '.$len.' bytes'); + } + return $data; + } + + public function write($buf) { + while (strlen($buf) > 0) { + $got = @fwrite($this->outStream_, $buf); + if ($got === 0 || $got === FALSE) { + throw new TException('TPhpStream: Could not write '.strlen($buf).' bytes'); + } + $buf = substr($buf, $got); + } + } + + public function flush() { + @fflush($this->outStream_); + } + + private static function inStreamName() { + if (php_sapi_name() == 'cli') { + return 'php://stdin'; + } + return 'php://input'; + } + +} + +?> diff --git a/lib/php/src/transport/TSocket.php b/lib/php/src/transport/TSocket.php new file mode 100644 index 00000000..ba3a6318 --- /dev/null +++ b/lib/php/src/transport/TSocket.php @@ -0,0 +1,312 @@ +host_ = $host; + $this->port_ = $port; + $this->persist_ = $persist; + $this->debugHandler_ = $debugHandler ? $debugHandler : 'error_log'; + } + + /** + * Sets the send timeout. + * + * @param int $timeout Timeout in milliseconds. + */ + public function setSendTimeout($timeout) { + $this->sendTimeout_ = $timeout; + } + + /** + * Sets the receive timeout. + * + * @param int $timeout Timeout in milliseconds. + */ + public function setRecvTimeout($timeout) { + $this->recvTimeout_ = $timeout; + } + + /** + * Sets debugging output on or off + * + * @param bool $debug + */ + public function setDebug($debug) { + $this->debug_ = $debug; + } + + /** + * Get the host that this socket is connected to + * + * @return string host + */ + public function getHost() { + return $this->host_; + } + + /** + * Get the remote port that this socket is connected to + * + * @return int port + */ + public function getPort() { + return $this->port_; + } + + /** + * Tests whether this is open + * + * @return bool true if the socket is open + */ + public function isOpen() { + return is_resource($this->handle_); + } + + /** + * Connects the socket. + */ + public function open() { + + if ($this->persist_) { + $this->handle_ = @pfsockopen($this->host_, + $this->port_, + $errno, + $errstr, + $this->sendTimeout_/1000.0); + } else { + $this->handle_ = @fsockopen($this->host_, + $this->port_, + $errno, + $errstr, + $this->sendTimeout_/1000.0); + } + + // Connect failed? + if ($this->handle_ === FALSE) { + $error = 'TSocket: Could not connect to '.$this->host_.':'.$this->port_.' ('.$errstr.' ['.$errno.'])'; + if ($this->debug_) { + call_user_func($this->debugHandler_, $error); + } + throw new TException($error); + } + + stream_set_timeout($this->handle_, 0, $this->sendTimeout_*1000); + $this->sendTimeoutSet_ = TRUE; + } + + /** + * Closes the socket. + */ + public function close() { + if (!$this->persist_) { + @fclose($this->handle_); + $this->handle_ = null; + } + } + + /** + * Uses stream get contents to do the reading + * + * @param int $len How many bytes + * @return string Binary data + */ + public function readAll($len) { + if ($this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->recvTimeout_*1000); + $this->sendTimeoutSet_ = FALSE; + } + // This call does not obey stream_set_timeout values! + // $buf = @stream_get_contents($this->handle_, $len); + + $pre = null; + while (TRUE) { + $buf = @fread($this->handle_, $len); + if ($buf === FALSE || $buf === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not read '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } + } else if (($sz = strlen($buf)) < $len) { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + $pre .= $buf; + $len -= $sz; + } + } else { + return $pre.$buf; + } + } + } + + /** + * Read from the socket + * + * @param int $len How many bytes + * @return string Binary data + */ + public function read($len) { + if ($this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->recvTimeout_*1000); + $this->sendTimeoutSet_ = FALSE; + } + $data = @fread($this->handle_, $len); + if ($data === FALSE || $data === '') { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out reading '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not read '.$len.' bytes from '. + $this->host_.':'.$this->port_); + } + } + return $data; + } + + /** + * Write to the socket. + * + * @param string $buf The data to write + */ + public function write($buf) { + if (!$this->sendTimeoutSet_) { + stream_set_timeout($this->handle_, 0, $this->sendTimeout_*1000); + $this->sendTimeoutSet_ = TRUE; + } + while (strlen($buf) > 0) { + $got = @fwrite($this->handle_, $buf); + if ($got === 0 || $got === FALSE) { + $md = stream_get_meta_data($this->handle_); + if ($md['timed_out']) { + throw new TException('TSocket: timed out writing '.strlen($buf).' bytes from '. + $this->host_.':'.$this->port_); + } else { + throw new TException('TSocket: Could not write '.strlen($buf).' bytes '. + $this->host_.':'.$this->port_); + } + } + $buf = substr($buf, $got); + } + } + + /** + * Flush output to the socket. + */ + public function flush() { + $ret = fflush($this->handle_); + if ($ret === FALSE) { + throw new TException('TSocket: Could not flush: '. + $this->host_.':'.$this->port_); + } + } +} + +?> diff --git a/lib/php/src/transport/TSocketPool.php b/lib/php/src/transport/TSocketPool.php new file mode 100644 index 00000000..7f1157cb --- /dev/null +++ b/lib/php/src/transport/TSocketPool.php @@ -0,0 +1,296 @@ + $val) { + $ports[$key] = $port; + } + } + + foreach ($hosts as $key => $host) { + $this->servers_ []= array('host' => $host, + 'port' => $ports[$key]); + } + } + + /** + * Add a server to the pool + * + * This function does not prevent you from adding a duplicate server entry. + * + * @param string $host hostname or IP + * @param int $port port + */ + public function addServer($host, $port) { + $this->servers_[] = array('host' => $host, 'port' => $port); + } + + /** + * Sets how many time to keep retrying a host in the connect function. + * + * @param int $numRetries + */ + public function setNumRetries($numRetries) { + $this->numRetries_ = $numRetries; + } + + /** + * Sets how long to wait until retrying a host if it was marked down + * + * @param int $numRetries + */ + public function setRetryInterval($retryInterval) { + $this->retryInterval_ = $retryInterval; + } + + /** + * Sets how many time to keep retrying a host before marking it as down. + * + * @param int $numRetries + */ + public function setMaxConsecutiveFailures($maxConsecutiveFailures) { + $this->maxConsecutiveFailures_ = $maxConsecutiveFailures; + } + + /** + * Turns randomization in connect order on or off. + * + * @param bool $randomize + */ + public function setRandomize($randomize) { + $this->randomize_ = $randomize; + } + + /** + * Whether to always try the last server. + * + * @param bool $alwaysTryLast + */ + public function setAlwaysTryLast($alwaysTryLast) { + $this->alwaysTryLast_ = $alwaysTryLast; + } + + + /** + * Connects the socket by iterating through all the servers in the pool + * and trying to find one that works. + */ + public function open() { + // Check if we want order randomization + if ($this->randomize_) { + shuffle($this->servers_); + } + + // Count servers to identify the "last" one + $numServers = count($this->servers_); + + for ($i = 0; $i < $numServers; ++$i) { + + // This extracts the $host and $port variables + extract($this->servers_[$i]); + + // Check APC cache for a record of this server being down + $failtimeKey = 'thrift_failtime:'.$host.':'.$port.'~'; + + // Cache miss? Assume it's OK + $lastFailtime = apc_fetch($failtimeKey); + if ($lastFailtime === FALSE) { + $lastFailtime = 0; + } + + $retryIntervalPassed = FALSE; + + // Cache hit...make sure enough the retry interval has elapsed + if ($lastFailtime > 0) { + $elapsed = time() - $lastFailtime; + if ($elapsed > $this->retryInterval_) { + $retryIntervalPassed = TRUE; + if ($this->debug_) { + call_user_func($this->debugHandler_, + 'TSocketPool: retryInterval '. + '('.$this->retryInterval_.') '. + 'has passed for host '.$host.':'.$port); + } + } + } + + // Only connect if not in the middle of a fail interval, OR if this + // is the LAST server we are trying, just hammer away on it + $isLastServer = FALSE; + if ($this->alwaysTryLast_) { + $isLastServer = ($i == ($numServers - 1)); + } + + if (($lastFailtime === 0) || + ($isLastServer) || + ($lastFailtime > 0 && $retryIntervalPassed)) { + + // Set underlying TSocket params to this one + $this->host_ = $host; + $this->port_ = $port; + + // Try up to numRetries_ connections per server + for ($attempt = 0; $attempt < $this->numRetries_; $attempt++) { + try { + // Use the underlying TSocket open function + parent::open(); + + // Only clear the failure counts if required to do so + if ($lastFailtime > 0) { + apc_store($failtimeKey, 0); + } + + // Successful connection, return now + return; + + } catch (TException $tx) { + // Connection failed + } + } + + // Mark failure of this host in the cache + $consecfailsKey = 'thrift_consecfails:'.$host.':'.$port.'~'; + + // Ignore cache misses + $consecfails = apc_fetch($consecfailsKey); + if ($consecfails === FALSE) { + $consecfails = 0; + } + + // Increment by one + $consecfails++; + + // Log and cache this failure + if ($consecfails >= $this->maxConsecutiveFailures_) { + if ($this->debug_) { + call_user_func($this->debugHandler_, + 'TSocketPool: marking '.$host.':'.$port. + ' as down for '.$this->retryInterval_.' secs '. + 'after '.$consecfails.' failed attempts.'); + } + // Store the failure time + apc_store($failtimeKey, time()); + + // Clear the count of consecutive failures + apc_store($consecfailsKey, 0); + } else { + apc_store($consecfailsKey, $consecfails); + } + } + } + + // Holy shit we failed them all. The system is totally ill! + $error = 'TSocketPool: All hosts in pool are down. '; + $hosts = array(); + foreach ($this->servers_ as $server) { + $hosts []= $server['host'].':'.$server['port']; + } + $hostlist = implode(',', $hosts); + $error .= '('.$hostlist.')'; + if ($this->debug_) { + call_user_func($this->debugHandler_, $error); + } + throw new TException($error); + } +} + +?> diff --git a/lib/php/src/transport/TTransport.php b/lib/php/src/transport/TTransport.php new file mode 100644 index 00000000..e2445259 --- /dev/null +++ b/lib/php/src/transport/TTransport.php @@ -0,0 +1,108 @@ +read($len); + + $data = ''; + $got = 0; + while (($got = strlen($data)) < $len) { + $data .= $this->read($len - $got); + } + return $data; + } + + /** + * Writes the given data out. + * + * @param string $buf The data to write + * @throws TTransportException if writing fails + */ + public abstract function write($buf); + + /** + * Flushes any pending data out of a buffer + * + * @throws TTransportException if a writing error occurs + */ + public function flush() {} +} + +?> diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am new file mode 100644 index 00000000..6d3bbeaa --- /dev/null +++ b/lib/py/Makefile.am @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +DESTDIR ?= / +EXTRA_DIST = setup.py src + +all-local: + $(PYTHON) setup.py build + +# We're ignoring prefix here because site-packages seems to be +# the equivalent of /usr/local/lib in Python land. +# Old version (can't put inline because it's not portable). +#$(PYTHON) setup.py install --prefix=$(prefix) --root=$(DESTDIR) $(PYTHON_SETUPUTIL_ARGS) +install-exec-hook: + $(PYTHON) setup.py install --root=$(DESTDIR) --prefix=$(PY_PREFIX) $(PYTHON_SETUPUTIL_ARGS) + +clean-local: + $(RM) -r build + +check-local: all diff --git a/lib/py/README b/lib/py/README new file mode 100644 index 00000000..29b8c73c --- /dev/null +++ b/lib/py/README @@ -0,0 +1,35 @@ +Thrift Python Software Library + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Using Thrift with Python +======================== + +Thrift is provided as a set of Python packages. The top level package is +thrift, and there are subpackages for the protocol, transport, and server +code. Each package contains modules using standard Thrift naming conventions +(i.e. TProtocol, TTransport) and implementations in corresponding modules +(i.e. TSocket). There is also a subpackage reflection, which contains +the generated code for the reflection structures. + +The Python libraries can be installed manually using the provided setup.py +file, or automatically using the install hook provided via autoconf/automake. +To use the latter, become superuser and do make install. diff --git a/lib/py/setup.py b/lib/py/setup.py new file mode 100644 index 00000000..74837115 --- /dev/null +++ b/lib/py/setup.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from distutils.core import setup, Extension + +fastbinarymod = Extension('thrift.protocol.fastbinary', + sources = ['src/protocol/fastbinary.c'], + ) + +setup(name = 'Thrift', + version = '1.0', + description = 'Thrift Python Libraries', + author = ['Mark Slee'], + author_email = ['mcslee@facebook.com'], + url = 'http://code.facebook.com/thrift', + packages = [ + 'thrift', + 'thrift.protocol', + 'thrift.transport', + 'thrift.server', + ], + package_dir = {'thrift' : 'src'}, + ext_modules = [fastbinarymod], + ) + diff --git a/lib/py/src/TSCons.py b/lib/py/src/TSCons.py new file mode 100644 index 00000000..24046256 --- /dev/null +++ b/lib/py/src/TSCons.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from os import path +from SCons.Builder import Builder + +def scons_env(env, add=''): + opath = path.dirname(path.abspath('$TARGET')) + lstr = 'thrift --gen cpp -o ' + opath + ' ' + add + ' $SOURCE' + cppbuild = Builder(action = lstr) + env.Append(BUILDERS = {'ThriftCpp' : cppbuild}) + +def gen_cpp(env, dir, file): + scons_env(env) + suffixes = ['_types.h', '_types.cpp'] + targets = map(lambda s: 'gen-cpp/' + file + s, suffixes) + return env.ThriftCpp(targets, dir+file+'.thrift') diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py new file mode 100644 index 00000000..21d7aa4e --- /dev/null +++ b/lib/py/src/Thrift.py @@ -0,0 +1,123 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +class TType: + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + I08 = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + UTF7 = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + UTF8 = 16 + UTF16 = 17 + +class TMessageType: + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + +class TProcessor: + + """Base class for procsessor, which works on two streams.""" + + def process(iprot, oprot): + pass + +class TException(Exception): + + """Base class for all thrift exceptions.""" + + def __init__(self, message=None): + Exception.__init__(self, message) + self.message = message + +class TApplicationException(TException): + + """Application level thrift exceptions.""" + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + + def __str__(self): + if self.message: + return self.message + elif self.type == UNKNOWN_METHOD: + return 'Unknown method' + elif self.type == INVALID_MESSAGE_TYPE: + return 'Invalid message type' + elif self.type == WRONG_METHOD_NAME: + return 'Wrong method name' + elif self.type == BAD_SEQUENCE_ID: + return 'Bad sequence ID' + elif self.type == MISSING_RESULT: + return 'Missing result' + else: + return 'Default (unknown) TApplicationException' + + def read(self, iprot): + iprot.readStructBegin() + while True: + (fname, ftype, fid) = iprot.readFieldBegin() + if ftype == TType.STOP: + break + if fid == 1: + if ftype == TType.STRING: + self.message = iprot.readString(); + else: + iprot.skip(ftype) + elif fid == 2: + if ftype == TType.I32: + self.type = iprot.readI32(); + else: + iprot.skip(ftype) + else: + iprot.skip(ftype) + iprot.readFieldEnd() + iprot.readStructEnd() + + def write(self, oprot): + oprot.writeStructBegin('TApplicationException') + if self.message != None: + oprot.writeFieldBegin('message', TType.STRING, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + if self.type != None: + oprot.writeFieldBegin('type', TType.I32, 2) + oprot.writeI32(self.type) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() diff --git a/lib/py/src/__init__.py b/lib/py/src/__init__.py new file mode 100644 index 00000000..48d659c4 --- /dev/null +++ b/lib/py/src/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['Thrift', 'TSCons'] diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py new file mode 100644 index 00000000..db1a7a40 --- /dev/null +++ b/lib/py/src/protocol/TBinaryProtocol.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TProtocol import * +from struct import pack, unpack + +class TBinaryProtocol(TProtocolBase): + + """Binary implementation of the Thrift protocol driver.""" + + # NastyHaxx. Python 2.4+ on 32-bit machines forces hex constants to be + # positive, converting this into a long. If we hardcode the int value + # instead it'll stay in 32 bit-land. + + # VERSION_MASK = 0xffff0000 + VERSION_MASK = -65536 + + # VERSION_1 = 0x80010000 + VERSION_1 = -2147418112 + + TYPE_MASK = 0x000000ff + + def __init__(self, trans, strictRead=False, strictWrite=True): + TProtocolBase.__init__(self, trans) + self.strictRead = strictRead + self.strictWrite = strictWrite + + def writeMessageBegin(self, name, type, seqid): + if self.strictWrite: + self.writeI32(TBinaryProtocol.VERSION_1 | type) + self.writeString(name) + self.writeI32(seqid) + else: + self.writeString(name) + self.writeByte(type) + self.writeI32(seqid) + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + self.writeByte(type) + self.writeI16(id) + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + self.writeByte(TType.STOP); + + def writeMapBegin(self, ktype, vtype, size): + self.writeByte(ktype) + self.writeByte(vtype) + self.writeI32(size) + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + self.writeByte(etype) + self.writeI32(size) + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + if bool: + self.writeByte(1) + else: + self.writeByte(0) + + def writeByte(self, byte): + buff = pack("!b", byte) + self.trans.write(buff) + + def writeI16(self, i16): + buff = pack("!h", i16) + self.trans.write(buff) + + def writeI32(self, i32): + buff = pack("!i", i32) + self.trans.write(buff) + + def writeI64(self, i64): + buff = pack("!q", i64) + self.trans.write(buff) + + def writeDouble(self, dub): + buff = pack("!d", dub) + self.trans.write(buff) + + def writeString(self, str): + self.writeI32(len(str)) + self.trans.write(str) + + def readMessageBegin(self): + sz = self.readI32() + if sz < 0: + version = sz & TBinaryProtocol.VERSION_MASK + if version != TBinaryProtocol.VERSION_1: + raise TProtocolException(TProtocolException.BAD_VERSION, 'Bad version in readMessageBegin: %d' % (sz)) + type = sz & TBinaryProtocol.TYPE_MASK + name = self.readString() + seqid = self.readI32() + else: + if self.strictRead: + raise TProtocolException(TProtocolException.BAD_VERSION, 'No protocol version header') + name = self.trans.readAll(sz) + type = self.readByte() + seqid = self.readI32() + return (name, type, seqid) + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + type = self.readByte() + if type == TType.STOP: + return (None, type, 0) + id = self.readI16() + return (None, type, id) + + def readFieldEnd(self): + pass + + def readMapBegin(self): + ktype = self.readByte() + vtype = self.readByte() + size = self.readI32() + return (ktype, vtype, size) + + def readMapEnd(self): + pass + + def readListBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readListEnd(self): + pass + + def readSetBegin(self): + etype = self.readByte() + size = self.readI32() + return (etype, size) + + def readSetEnd(self): + pass + + def readBool(self): + byte = self.readByte() + if byte == 0: + return False + return True + + def readByte(self): + buff = self.trans.readAll(1) + val, = unpack('!b', buff) + return val + + def readI16(self): + buff = self.trans.readAll(2) + val, = unpack('!h', buff) + return val + + def readI32(self): + buff = self.trans.readAll(4) + val, = unpack('!i', buff) + return val + + def readI64(self): + buff = self.trans.readAll(8) + val, = unpack('!q', buff) + return val + + def readDouble(self): + buff = self.trans.readAll(8) + val, = unpack('!d', buff) + return val + + def readString(self): + len = self.readI32() + str = self.trans.readAll(len) + return str + + +class TBinaryProtocolFactory: + def __init__(self, strictRead=False, strictWrite=True): + self.strictRead = strictRead + self.strictWrite = strictWrite + + def getProtocol(self, trans): + prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite) + return prot + + +class TBinaryProtocolAccelerated(TBinaryProtocol): + + """C-Accelerated version of TBinaryProtocol. + + This class does not override any of TBinaryProtocol's methods, + but the generated code recognizes it directly and will call into + our C module to do the encoding, bypassing this object entirely. + We inherit from TBinaryProtocol so that the normal TBinaryProtocol + encoding can happen if the fastbinary module doesn't work for some + reason. (TODO(dreiss): Make this happen sanely in more cases.) + + In order to take advantage of the C module, just use + TBinaryProtocolAccelerated instead of TBinaryProtocol. + + NOTE: This code was contributed by an external developer. + The internal Thrift team has reviewed and tested it, + but we cannot guarantee that it is production-ready. + Please feel free to report bugs and/or success stories + to the public mailing list. + """ + + pass + + +class TBinaryProtocolAcceleratedFactory: + def getProtocol(self, trans): + return TBinaryProtocolAccelerated(trans) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py new file mode 100644 index 00000000..be3cb140 --- /dev/null +++ b/lib/py/src/protocol/TProtocol.py @@ -0,0 +1,205 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from thrift.Thrift import * + +class TProtocolException(TException): + + """Custom Protocol Exception class""" + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + +class TProtocolBase: + + """Base class for Thrift protocol driver.""" + + def __init__(self, trans): + self.trans = trans + + def writeMessageBegin(self, name, type, seqid): + pass + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + pass + + def writeStructEnd(self): + pass + + def writeFieldBegin(self, name, type, id): + pass + + def writeFieldEnd(self): + pass + + def writeFieldStop(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + pass + + def writeMapEnd(self): + pass + + def writeListBegin(self, etype, size): + pass + + def writeListEnd(self): + pass + + def writeSetBegin(self, etype, size): + pass + + def writeSetEnd(self): + pass + + def writeBool(self, bool): + pass + + def writeByte(self, byte): + pass + + def writeI16(self, i16): + pass + + def writeI32(self, i32): + pass + + def writeI64(self, i64): + pass + + def writeDouble(self, dub): + pass + + def writeString(self, str): + pass + + def readMessageBegin(self): + pass + + def readMessageEnd(self): + pass + + def readStructBegin(self): + pass + + def readStructEnd(self): + pass + + def readFieldBegin(self): + pass + + def readFieldEnd(self): + pass + + def readMapBegin(self): + pass + + def readMapEnd(self): + pass + + def readListBegin(self): + pass + + def readListEnd(self): + pass + + def readSetBegin(self): + pass + + def readSetEnd(self): + pass + + def readBool(self): + pass + + def readByte(self): + pass + + def readI16(self): + pass + + def readI32(self): + pass + + def readI64(self): + pass + + def readDouble(self): + pass + + def readString(self): + pass + + def skip(self, type): + if type == TType.STOP: + return + elif type == TType.BOOL: + self.readBool() + elif type == TType.BYTE: + self.readByte() + elif type == TType.I16: + self.readI16() + elif type == TType.I32: + self.readI32() + elif type == TType.I64: + self.readI64() + elif type == TType.DOUBLE: + self.readDouble() + elif type == TType.STRING: + self.readString() + elif type == TType.STRUCT: + name = self.readStructBegin() + while True: + (name, type, id) = self.readFieldBegin() + if type == TType.STOP: + break + self.skip(type) + self.readFieldEnd() + self.readStructEnd() + elif type == TType.MAP: + (ktype, vtype, size) = self.readMapBegin() + for i in range(size): + self.skip(ktype) + self.skip(vtype) + self.readMapEnd() + elif type == TType.SET: + (etype, size) = self.readSetBegin() + for i in range(size): + self.skip(etype) + self.readSetEnd() + elif type == TType.LIST: + (etype, size) = self.readListBegin() + for i in range(size): + self.skip(etype) + self.readListEnd() + +class TProtocolFactory: + def getProtocol(self, trans): + pass diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py new file mode 100644 index 00000000..01bfe18e --- /dev/null +++ b/lib/py/src/protocol/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary'] diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c new file mode 100644 index 00000000..67b215a8 --- /dev/null +++ b/lib/py/src/protocol/fastbinary.c @@ -0,0 +1,1203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include "cStringIO.h" +#include +#include +#include + +/* Fix endianness issues on Solaris */ +#if defined (__SVR4) && defined (__sun) + #if defined(__i386) && !defined(__i386__) + #define __i386__ + #endif + + #ifndef BIG_ENDIAN + #define BIG_ENDIAN (4321) + #endif + #ifndef LITTLE_ENDIAN + #define LITTLE_ENDIAN (1234) + #endif + + /* I386 is LE, even on Solaris */ + #if !defined(BYTE_ORDER) && defined(__i386__) + #define BYTE_ORDER LITTLE_ENDIAN + #endif +#endif + +// TODO(dreiss): defval appears to be unused. Look into removing it. +// TODO(dreiss): Make parse_spec_args recursive, and cache the output +// permanently in the object. (Malloc and orphan.) +// TODO(dreiss): Why do we need cStringIO for reading, why not just char*? +// Can cStringIO let us work with a BufferedTransport? +// TODO(dreiss): Don't ignore the rv from cwrite (maybe). + +/* ====== BEGIN UTILITIES ====== */ + +#define INIT_OUTBUF_SIZE 128 + +// Stolen out of TProtocol.h. +// It would be a huge pain to have both get this from one place. +typedef enum TType { + T_STOP = 0, + T_VOID = 1, + T_BOOL = 2, + T_BYTE = 3, + T_I08 = 3, + T_I16 = 6, + T_I32 = 8, + T_U64 = 9, + T_I64 = 10, + T_DOUBLE = 4, + T_STRING = 11, + T_UTF7 = 11, + T_STRUCT = 12, + T_MAP = 13, + T_SET = 14, + T_LIST = 15, + T_UTF8 = 16, + T_UTF16 = 17 +} TType; + +#ifndef __BYTE_ORDER +# if defined(BYTE_ORDER) && defined(LITTLE_ENDIAN) && defined(BIG_ENDIAN) +# define __BYTE_ORDER BYTE_ORDER +# define __LITTLE_ENDIAN LITTLE_ENDIAN +# define __BIG_ENDIAN BIG_ENDIAN +# else +# error "Cannot determine endianness" +# endif +#endif + +// Same comment as the enum. Sorry. +#if __BYTE_ORDER == __BIG_ENDIAN +# define ntohll(n) (n) +# define htonll(n) (n) +#elif __BYTE_ORDER == __LITTLE_ENDIAN +# if defined(__GNUC__) && defined(__GLIBC__) +# include +# define ntohll(n) bswap_64(n) +# define htonll(n) bswap_64(n) +# else /* GNUC & GLIBC */ +# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) ) +# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) ) +# endif /* GNUC & GLIBC */ +#else /* __BYTE_ORDER */ +# error "Can't define htonll or ntohll!" +#endif + +// Doing a benchmark shows that interning actually makes a difference, amazingly. +#define INTERN_STRING(value) _intern_ ## value + +#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() ) +#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) ) + +// Py_ssize_t was not defined before Python 2.5 +#if (PY_VERSION_HEX < 0x02050000) +typedef int Py_ssize_t; +#endif + +/** + * A cache of the spec_args for a set or list, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType element_type; + PyObject* typeargs; +} SetListTypeArgs; + +/** + * A cache of the spec_args for a map, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + TType ktag; + TType vtag; + PyObject* ktypeargs; + PyObject* vtypeargs; +} MapTypeArgs; + +/** + * A cache of the spec_args for a struct, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + PyObject* klass; + PyObject* spec; +} StructTypeArgs; + +/** + * A cache of the item spec from a struct specification, + * so we don't have to keep calling PyTuple_GET_ITEM. + */ +typedef struct { + int tag; + TType type; + PyObject* attrname; + PyObject* typeargs; + PyObject* defval; +} StructItemSpec; + +/** + * A cache of the two key attributes of a CReadableTransport, + * so we don't have to keep calling PyObject_GetAttr. + */ +typedef struct { + PyObject* stringiobuf; + PyObject* refill_callable; +} DecodeBuffer; + +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_buf); +/** Pointer to interned string to speed up attribute lookup. */ +static PyObject* INTERN_STRING(cstringio_refill); + +static inline bool +check_ssize_t_32(Py_ssize_t len) { + // error from getting the int + if (INT_CONV_ERROR_OCCURRED(len)) { + return false; + } + if (!CHECK_RANGE(len, 0, INT32_MAX)) { + PyErr_SetString(PyExc_OverflowError, "string size out of range"); + return false; + } + return true; +} + +static inline bool +parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) { + long val = PyInt_AsLong(o); + + if (INT_CONV_ERROR_OCCURRED(val)) { + return false; + } + if (!CHECK_RANGE(val, min, max)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + *ret = (int32_t) val; + return true; +} + + +/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */ + +static bool +parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args"); + return false; + } + + dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->element_type)) { + return false; + } + + dest->typeargs = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static bool +parse_map_args(MapTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 4) { + PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map"); + return false; + } + + dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->ktag)) { + return false; + } + + dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2)); + if (INT_CONV_ERROR_OCCURRED(dest->vtag)) { + return false; + } + + dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1); + dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3); + + return true; +} + +static bool +parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) { + if (PyTuple_Size(typeargs) != 2) { + PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args"); + return false; + } + + dest->klass = PyTuple_GET_ITEM(typeargs, 0); + dest->spec = PyTuple_GET_ITEM(typeargs, 1); + + return true; +} + +static int +parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) { + + // i'd like to use ParseArgs here, but it seems to be a bottleneck. + if (PyTuple_Size(spec_tuple) != 5) { + PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple"); + return false; + } + + dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0)); + if (INT_CONV_ERROR_OCCURRED(dest->tag)) { + return false; + } + + dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1)); + if (INT_CONV_ERROR_OCCURRED(dest->type)) { + return false; + } + + dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2); + dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3); + dest->defval = PyTuple_GET_ITEM(spec_tuple, 4); + return true; +} + +/* ====== END UTILITIES ====== */ + + +/* ====== BEGIN WRITING FUNCTIONS ====== */ + +/* --- LOW-LEVEL WRITING FUNCTIONS --- */ + +static void writeByte(PyObject* outbuf, int8_t val) { + int8_t net = val; + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t)); +} + +static void writeI16(PyObject* outbuf, int16_t val) { + int16_t net = (int16_t)htons(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t)); +} + +static void writeI32(PyObject* outbuf, int32_t val) { + int32_t net = (int32_t)htonl(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t)); +} + +static void writeI64(PyObject* outbuf, int64_t val) { + int64_t net = (int64_t)htonll(val); + PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t)); +} + +static void writeDouble(PyObject* outbuf, double dub) { + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = dub; + writeI64(outbuf, transfer.t); +} + + +/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */ + +static int +output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) { + /* + * Refcounting Strategy: + * + * We assume that elements of the thrift_spec tuple are not going to be + * mutated, so we don't ref count those at all. Other than that, we try to + * keep a reference to all the user-created objects while we work with them. + * output_val assumes that a reference is already held. The *caller* is + * responsible for handling references + */ + + switch (type) { + + case T_BOOL: { + int v = PyObject_IsTrue(value); + if (v == -1) { + return false; + } + + writeByte(output, (int8_t) v); + break; + } + case T_I08: { + int32_t val; + + if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) { + return false; + } + + writeByte(output, (int8_t) val); + break; + } + case T_I16: { + int32_t val; + + if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) { + return false; + } + + writeI16(output, (int16_t) val); + break; + } + case T_I32: { + int32_t val; + + if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) { + return false; + } + + writeI32(output, val); + break; + } + case T_I64: { + int64_t nval = PyLong_AsLongLong(value); + + if (INT_CONV_ERROR_OCCURRED(nval)) { + return false; + } + + if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) { + PyErr_SetString(PyExc_OverflowError, "int out of range"); + return false; + } + + writeI64(output, nval); + break; + } + + case T_DOUBLE: { + double nval = PyFloat_AsDouble(value); + if (nval == -1.0 && PyErr_Occurred()) { + return false; + } + + writeDouble(output, nval); + break; + } + + case T_STRING: { + Py_ssize_t len = PyString_Size(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeI32(output, (int32_t) len); + PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len); + break; + } + + case T_LIST: + case T_SET: { + Py_ssize_t len; + SetListTypeArgs parsedargs; + PyObject *item; + PyObject *iterator; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return false; + } + + len = PyObject_Length(value); + + if (!check_ssize_t_32(len)) { + return false; + } + + writeByte(output, parsedargs.element_type); + writeI32(output, (int32_t) len); + + iterator = PyObject_GetIter(value); + if (iterator == NULL) { + return false; + } + + while ((item = PyIter_Next(iterator))) { + if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) { + Py_DECREF(item); + Py_DECREF(iterator); + return false; + } + Py_DECREF(item); + } + + Py_DECREF(iterator); + + if (PyErr_Occurred()) { + return false; + } + + break; + } + + case T_MAP: { + PyObject *k, *v; + Py_ssize_t pos = 0; + Py_ssize_t len; + + MapTypeArgs parsedargs; + + len = PyDict_Size(value); + if (!check_ssize_t_32(len)) { + return false; + } + + if (!parse_map_args(&parsedargs, typeargs)) { + return false; + } + + writeByte(output, parsedargs.ktag); + writeByte(output, parsedargs.vtag); + writeI32(output, len); + + // TODO(bmaurer): should support any mapping, not just dicts + while (PyDict_Next(value, &pos, &k, &v)) { + // TODO(dreiss): Think hard about whether these INCREFs actually + // turn any unsafe scenarios into safe scenarios. + Py_INCREF(k); + Py_INCREF(v); + + if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs) + || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) { + Py_DECREF(k); + Py_DECREF(v); + return false; + } + Py_DECREF(k); + Py_DECREF(v); + } + break; + } + + // TODO(dreiss): Consider breaking this out as a function + // the way we did for decode_struct. + case T_STRUCT: { + StructTypeArgs parsedargs; + Py_ssize_t nspec; + Py_ssize_t i; + + if (!parse_struct_args(&parsedargs, typeargs)) { + return false; + } + + nspec = PyTuple_Size(parsedargs.spec); + + if (nspec == -1) { + return false; + } + + for (i = 0; i < nspec; i++) { + StructItemSpec parsedspec; + PyObject* spec_tuple; + PyObject* instval = NULL; + + spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i); + if (spec_tuple == Py_None) { + continue; + } + + if (!parse_struct_item_spec (&parsedspec, spec_tuple)) { + return false; + } + + instval = PyObject_GetAttr(value, parsedspec.attrname); + + if (!instval) { + return false; + } + + if (instval == Py_None) { + Py_DECREF(instval); + continue; + } + + writeByte(output, (int8_t) parsedspec.type); + writeI16(output, parsedspec.tag); + + if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) { + Py_DECREF(instval); + return false; + } + + Py_DECREF(instval); + } + + writeByte(output, (int8_t)T_STOP); + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; +} + + +/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */ + +static PyObject * +encode_binary(PyObject *self, PyObject *args) { + PyObject* enc_obj; + PyObject* type_args; + PyObject* buf; + PyObject* ret = NULL; + + if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) { + return NULL; + } + + buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE); + if (output_val(buf, enc_obj, T_STRUCT, type_args)) { + ret = PycStringIO->cgetvalue(buf); + } + + Py_DECREF(buf); + return ret; +} + +/* ====== END WRITING FUNCTIONS ====== */ + + +/* ====== BEGIN READING FUNCTIONS ====== */ + +/* --- LOW-LEVEL READING FUNCTIONS --- */ + +static void +free_decodebuf(DecodeBuffer* d) { + Py_XDECREF(d->stringiobuf); + Py_XDECREF(d->refill_callable); +} + +static bool +decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) { + dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf)); + if (!dest->stringiobuf) { + return false; + } + + if (!PycStringIO_InputCheck(dest->stringiobuf)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting stringio input"); + return false; + } + + dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill)); + + if(!dest->refill_callable) { + free_decodebuf(dest); + return false; + } + + if (!PyCallable_Check(dest->refill_callable)) { + free_decodebuf(dest); + PyErr_SetString(PyExc_TypeError, "expecting callable"); + return false; + } + + return true; +} + +static bool readBytes(DecodeBuffer* input, char** output, int len) { + int read; + + // TODO(dreiss): Don't fear the malloc. Think about taking a copy of + // the partial read instead of forcing the transport + // to prepend it to its buffer. + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + PyObject* newiobuf; + + // using building functions as this is a rare codepath + newiobuf = PyObject_CallFunction( + input->refill_callable, "s#i", *output, read, len, NULL); + if (newiobuf == NULL) { + return false; + } + + // must do this *AFTER* the call so that we don't deref the io buffer + Py_CLEAR(input->stringiobuf); + input->stringiobuf = newiobuf; + + read = PycStringIO->cread(input->stringiobuf, output, len); + + if (read == len) { + return true; + } else if (read == -1) { + return false; + } else { + // TODO(dreiss): This could be a valid code path for big binary blobs. + PyErr_SetString(PyExc_TypeError, + "refill claimed to have refilled the buffer, but didn't!!"); + return false; + } + } +} + +static int8_t readByte(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int8_t))) { + return -1; + } + + return *(int8_t*) buf; +} + +static int16_t readI16(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int16_t))) { + return -1; + } + + return (int16_t) ntohs(*(int16_t*) buf); +} + +static int32_t readI32(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int32_t))) { + return -1; + } + return (int32_t) ntohl(*(int32_t*) buf); +} + + +static int64_t readI64(DecodeBuffer* input) { + char* buf; + if (!readBytes(input, &buf, sizeof(int64_t))) { + return -1; + } + + return (int64_t) ntohll(*(int64_t*) buf); +} + +static double readDouble(DecodeBuffer* input) { + union { + int64_t f; + double t; + } transfer; + + transfer.f = readI64(input); + if (transfer.f == -1) { + return -1; + } + return transfer.t; +} + +static bool +checkTypeByte(DecodeBuffer* input, TType expected) { + TType got = readByte(input); + if (INT_CONV_ERROR_OCCURRED(got)) { + return false; + } + + if (expected != got) { + PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field"); + return false; + } + return true; +} + +static bool +skip(DecodeBuffer* input, TType type) { +#define SKIPBYTES(n) \ + do { \ + if (!readBytes(input, &dummy_buf, (n))) { \ + return false; \ + } \ + } while(0) + + char* dummy_buf; + + switch (type) { + + case T_BOOL: + case T_I08: SKIPBYTES(1); break; + case T_I16: SKIPBYTES(2); break; + case T_I32: SKIPBYTES(4); break; + case T_I64: + case T_DOUBLE: SKIPBYTES(8); break; + + case T_STRING: { + // TODO(dreiss): Find out if these check_ssize_t32s are really necessary. + int len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + SKIPBYTES(len); + break; + } + + case T_LIST: + case T_SET: { + TType etype; + int len, i; + + etype = readByte(input); + if (etype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!skip(input, etype)) { + return false; + } + } + break; + } + + case T_MAP: { + TType ktype, vtype; + int len, i; + + ktype = readByte(input); + if (ktype == -1) { + return false; + } + + vtype = readByte(input); + if (vtype == -1) { + return false; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + for (i = 0; i < len; i++) { + if (!(skip(input, ktype) && skip(input, vtype))) { + return false; + } + } + break; + } + + case T_STRUCT: { + while (true) { + TType type; + + type = readByte(input); + if (type == -1) { + return false; + } + + if (type == T_STOP) + break; + + SKIPBYTES(2); // tag + if (!skip(input, type)) { + return false; + } + } + break; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return false; + + } + + return true; + +#undef SKIPBYTES +} + + +/* --- HELPER FUNCTION FOR DECODE_VAL --- */ + +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs); + +static bool +decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) { + int spec_seq_len = PyTuple_Size(spec_seq); + if (spec_seq_len == -1) { + return false; + } + + while (true) { + TType type; + int16_t tag; + PyObject* item_spec; + PyObject* fieldval = NULL; + StructItemSpec parsedspec; + + type = readByte(input); + if (type == -1) { + return false; + } + if (type == T_STOP) { + break; + } + tag = readI16(input); + if (INT_CONV_ERROR_OCCURRED(tag)) { + return false; + } + if (tag >= 0 && tag < spec_seq_len) { + item_spec = PyTuple_GET_ITEM(spec_seq, tag); + } else { + item_spec = Py_None; + } + + if (item_spec == Py_None) { + if (!skip(input, type)) { + return false; + } else { + continue; + } + } + + if (!parse_struct_item_spec(&parsedspec, item_spec)) { + return false; + } + if (parsedspec.type != type) { + if (!skip(input, type)) { + PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading and can't be skipped"); + return false; + } else { + continue; + } + } + + fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs); + if (fieldval == NULL) { + return false; + } + + if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) { + Py_DECREF(fieldval); + return false; + } + Py_DECREF(fieldval); + } + return true; +} + + +/* --- MAIN RECURSIVE INPUT FUCNTION --- */ + +// Returns a new reference. +static PyObject* +decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) { + switch (type) { + + case T_BOOL: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + switch (v) { + case 0: Py_RETURN_FALSE; + case 1: Py_RETURN_TRUE; + // Don't laugh. This is a potentially serious issue. + default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL; + } + break; + } + case T_I08: { + int8_t v = readByte(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + + return PyInt_FromLong(v); + } + case T_I16: { + int16_t v = readI16(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + case T_I32: { + int32_t v = readI32(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + return PyInt_FromLong(v); + } + + case T_I64: { + int64_t v = readI64(input); + if (INT_CONV_ERROR_OCCURRED(v)) { + return NULL; + } + // TODO(dreiss): Find out if we can take this fastpath always when + // sizeof(long) == sizeof(long long). + if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) { + return PyInt_FromLong((long) v); + } + + return PyLong_FromLongLong(v); + } + + case T_DOUBLE: { + double v = readDouble(input); + if (v == -1.0 && PyErr_Occurred()) { + return false; + } + return PyFloat_FromDouble(v); + } + + case T_STRING: { + Py_ssize_t len = readI32(input); + char* buf; + if (!readBytes(input, &buf, len)) { + return NULL; + } + + return PyString_FromStringAndSize(buf, len); + } + + case T_LIST: + case T_SET: { + SetListTypeArgs parsedargs; + int32_t len; + PyObject* ret = NULL; + int i; + + if (!parse_set_list_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.element_type)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return NULL; + } + + ret = PyList_New(len); + if (!ret) { + return NULL; + } + + for (i = 0; i < len; i++) { + PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs); + if (!item) { + Py_DECREF(ret); + return NULL; + } + PyList_SET_ITEM(ret, i, item); + } + + // TODO(dreiss): Consider biting the bullet and making two separate cases + // for list and set, avoiding this post facto conversion. + if (type == T_SET) { + PyObject* setret; +#if (PY_VERSION_HEX < 0x02050000) + // hack needed for older versions + setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL); +#else + // official version + setret = PySet_New(ret); +#endif + Py_DECREF(ret); + return setret; + } + return ret; + } + + case T_MAP: { + int32_t len; + int i; + MapTypeArgs parsedargs; + PyObject* ret = NULL; + + if (!parse_map_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!checkTypeByte(input, parsedargs.ktag)) { + return NULL; + } + if (!checkTypeByte(input, parsedargs.vtag)) { + return NULL; + } + + len = readI32(input); + if (!check_ssize_t_32(len)) { + return false; + } + + ret = PyDict_New(); + if (!ret) { + goto error; + } + + for (i = 0; i < len; i++) { + PyObject* k = NULL; + PyObject* v = NULL; + k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs); + if (k == NULL) { + goto loop_error; + } + v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs); + if (v == NULL) { + goto loop_error; + } + if (PyDict_SetItem(ret, k, v) == -1) { + goto loop_error; + } + + Py_DECREF(k); + Py_DECREF(v); + continue; + + // Yuck! Destructors, anyone? + loop_error: + Py_XDECREF(k); + Py_XDECREF(v); + goto error; + } + + return ret; + + error: + Py_XDECREF(ret); + return NULL; + } + + case T_STRUCT: { + StructTypeArgs parsedargs; + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + PyObject* ret = PyObject_CallObject(parsedargs.klass, NULL); + if (!ret) { + return NULL; + } + + if (!decode_struct(input, ret, parsedargs.spec)) { + Py_DECREF(ret); + return NULL; + } + + return ret; + } + + case T_STOP: + case T_VOID: + case T_UTF16: + case T_UTF8: + case T_U64: + default: + PyErr_SetString(PyExc_TypeError, "Unexpected TType"); + return NULL; + } +} + + +/* --- TOP-LEVEL WRAPPER FOR INPUT -- */ + +static PyObject* +decode_binary(PyObject *self, PyObject *args) { + PyObject* output_obj = NULL; + PyObject* transport = NULL; + PyObject* typeargs = NULL; + StructTypeArgs parsedargs; + DecodeBuffer input = {}; + + if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) { + return NULL; + } + + if (!parse_struct_args(&parsedargs, typeargs)) { + return NULL; + } + + if (!decode_buffer_from_obj(&input, transport)) { + return NULL; + } + + if (!decode_struct(&input, output_obj, parsedargs.spec)) { + free_decodebuf(&input); + return NULL; + } + + free_decodebuf(&input); + + Py_RETURN_NONE; +} + +/* ====== END READING FUNCTIONS ====== */ + + +/* -- PYTHON MODULE SETUP STUFF --- */ + +static PyMethodDef ThriftFastBinaryMethods[] = { + + {"encode_binary", encode_binary, METH_VARARGS, ""}, + {"decode_binary", decode_binary, METH_VARARGS, ""}, + + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initfastbinary(void) { +#define INIT_INTERN_STRING(value) \ + do { \ + INTERN_STRING(value) = PyString_InternFromString(#value); \ + if(!INTERN_STRING(value)) return; \ + } while(0) + + INIT_INTERN_STRING(cstringio_buf); + INIT_INTERN_STRING(cstringio_refill); +#undef INIT_INTERN_STRING + + PycString_IMPORT; + if (PycStringIO == NULL) return; + + (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods); +} diff --git a/lib/py/src/server/THttpServer.py b/lib/py/src/server/THttpServer.py new file mode 100644 index 00000000..21fc3141 --- /dev/null +++ b/lib/py/src/server/THttpServer.py @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import BaseHTTPServer + +from thrift.server import TServer +from thrift.transport import TTransport + +class THttpServer(TServer.TServer): + """A simple HTTP-based Thrift server + + This class is not very performant, but it is useful (for example) for + acting as a mock version of an Apache-based PHP Thrift endpoint.""" + + def __init__(self, processor, server_address, + inputProtocolFactory, outputProtocolFactory = None): + """Set up protocol factories and HTTP server. + + See BaseHTTPServer for server_address. + See TServer for protocol factories.""" + + if outputProtocolFactory is None: + outputProtocolFactory = inputProtocolFactory + + TServer.TServer.__init__(self, processor, None, None, None, + inputProtocolFactory, outputProtocolFactory) + + thttpserver = self + + class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): + def do_POST(self): + # Don't care about the request path. + self.send_response(200) + self.send_header("content-type", "application/x-thrift") + self.end_headers() + + itrans = TTransport.TFileObjectTransport(self.rfile) + otrans = TTransport.TFileObjectTransport(self.wfile) + iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) + oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) + thttpserver.processor.process(iprot, oprot) + otrans.flush() + + self.httpd = BaseHTTPServer.HTTPServer(server_address, RequestHander) + + def serve(self): + self.httpd.serve_forever() diff --git a/lib/py/src/server/TNonblockingServer.py b/lib/py/src/server/TNonblockingServer.py new file mode 100644 index 00000000..deec708a --- /dev/null +++ b/lib/py/src/server/TNonblockingServer.py @@ -0,0 +1,309 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +"""Implementation of non-blocking server. + +The main idea of the server is reciving and sending requests +only from main thread. + +It also makes thread pool server in tasks terms, not connections. +""" +import threading +import socket +import Queue +import select +import struct +import logging + +from thrift.transport import TTransport +from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory + +__all__ = ['TNonblockingServer'] + +class Worker(threading.Thread): + """Worker is a small helper to process incoming connection.""" + def __init__(self, queue): + threading.Thread.__init__(self) + self.queue = queue + + def run(self): + """Process queries from task queue, stop if processor is None.""" + while True: + try: + processor, iprot, oprot, otrans, callback = self.queue.get() + if processor is None: + break + processor.process(iprot, oprot) + callback(True, otrans.getvalue()) + except Exception: + logging.exception("Exception while processing request") + callback(False, '') + +WAIT_LEN = 0 +WAIT_MESSAGE = 1 +WAIT_PROCESS = 2 +SEND_ANSWER = 3 +CLOSED = 4 + +def locked(func): + "Decorator which locks self.lock." + def nested(self, *args, **kwargs): + self.lock.acquire() + try: + return func(self, *args, **kwargs) + finally: + self.lock.release() + return nested + +def socket_exception(func): + "Decorator close object on socket.error." + def read(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except socket.error: + self.close() + return read + +class Connection: + """Basic class is represented connection. + + It can be in state: + WAIT_LEN --- connection is reading request len. + WAIT_MESSAGE --- connection is reading request. + WAIT_PROCESS --- connection has just read whole request and + waits for call ready routine. + SEND_ANSWER --- connection is sending answer string (including length + of answer). + CLOSED --- socket was closed and connection should be deleted. + """ + def __init__(self, new_socket, wake_up): + self.socket = new_socket + self.socket.setblocking(False) + self.status = WAIT_LEN + self.len = 0 + self.message = '' + self.lock = threading.Lock() + self.wake_up = wake_up + + def _read_len(self): + """Reads length of request. + + It's really paranoic routine and it may be replaced by + self.socket.recv(4).""" + read = self.socket.recv(4 - len(self.message)) + if len(read) == 0: + # if we read 0 bytes and self.message is empty, it means client close + # connection + if len(self.message) != 0: + logging.error("can't read frame size from socket") + self.close() + return + self.message += read + if len(self.message) == 4: + self.len, = struct.unpack('!i', self.message) + if self.len < 0: + logging.error("negative frame size, it seems client"\ + " doesn't use FramedTransport") + self.close() + elif self.len == 0: + logging.error("empty frame, it's really strange") + self.close() + else: + self.message = '' + self.status = WAIT_MESSAGE + + @socket_exception + def read(self): + """Reads data from stream and switch state.""" + assert self.status in (WAIT_LEN, WAIT_MESSAGE) + if self.status == WAIT_LEN: + self._read_len() + # go back to the main loop here for simplicity instead of + # falling through, even though there is a good chance that + # the message is already available + elif self.status == WAIT_MESSAGE: + read = self.socket.recv(self.len - len(self.message)) + if len(read) == 0: + logging.error("can't read frame from socket (get %d of %d bytes)" % + (len(self.message), self.len)) + self.close() + return + self.message += read + if len(self.message) == self.len: + self.status = WAIT_PROCESS + + @socket_exception + def write(self): + """Writes data from socket and switch state.""" + assert self.status == SEND_ANSWER + sent = self.socket.send(self.message) + if sent == len(self.message): + self.status = WAIT_LEN + self.message = '' + self.len = 0 + else: + self.message = self.message[sent:] + + @locked + def ready(self, all_ok, message): + """Callback function for switching state and waking up main thread. + + This function is the only function witch can be called asynchronous. + + The ready can switch Connection to three states: + WAIT_LEN if request was oneway. + SEND_ANSWER if request was processed in normal way. + CLOSED if request throws unexpected exception. + + The one wakes up main thread. + """ + assert self.status == WAIT_PROCESS + if not all_ok: + self.close() + self.wake_up() + return + self.len = '' + self.message = struct.pack('!i', len(message)) + message + if len(message) == 0: + # it was a oneway request, do not write answer + self.status = WAIT_LEN + else: + self.status = SEND_ANSWER + self.wake_up() + + @locked + def is_writeable(self): + "Returns True if connection should be added to write list of select." + return self.status == SEND_ANSWER + + # it's not necessary, but... + @locked + def is_readable(self): + "Returns True if connection should be added to read list of select." + return self.status in (WAIT_LEN, WAIT_MESSAGE) + + @locked + def is_closed(self): + "Returns True if connection is closed." + return self.status == CLOSED + + def fileno(self): + "Returns the file descriptor of the associated socket." + return self.socket.fileno() + + def close(self): + "Closes connection" + self.status = CLOSED + self.socket.close() + +class TNonblockingServer: + """Non-blocking server.""" + def __init__(self, processor, lsocket, inputProtocolFactory=None, + outputProtocolFactory=None, threads=10): + self.processor = processor + self.socket = lsocket + self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory() + self.out_protocol = outputProtocolFactory or self.in_protocol + self.threads = int(threads) + self.clients = {} + self.tasks = Queue.Queue() + self._read, self._write = socket.socketpair() + self.prepared = False + + def setNumThreads(self, num): + """Set the number of worker threads that should be created.""" + # implement ThreadPool interface + assert not self.prepared, "You can't change number of threads for working server" + self.threads = num + + def prepare(self): + """Prepares server for serve requests.""" + self.socket.listen() + for _ in xrange(self.threads): + thread = Worker(self.tasks) + thread.setDaemon(True) + thread.start() + self.prepared = True + + def wake_up(self): + """Wake up main thread. + + The server usualy waits in select call in we should terminate one. + The simplest way is using socketpair. + + Select always wait to read from the first socket of socketpair. + + In this case, we can just write anything to the second socket from + socketpair.""" + self._write.send('1') + + def _select(self): + """Does select on open connections.""" + readable = [self.socket.handle.fileno(), self._read.fileno()] + writable = [] + for i, connection in self.clients.items(): + if connection.is_readable(): + readable.append(connection.fileno()) + if connection.is_writeable(): + writable.append(connection.fileno()) + if connection.is_closed(): + del self.clients[i] + return select.select(readable, writable, readable) + + def handle(self): + """Handle requests. + + WARNING! You must call prepare BEFORE calling handle. + """ + assert self.prepared, "You have to call prepare before handle" + rset, wset, xset = self._select() + for readable in rset: + if readable == self._read.fileno(): + # don't care i just need to clean readable flag + self._read.recv(1024) + elif readable == self.socket.handle.fileno(): + client = self.socket.accept().handle + self.clients[client.fileno()] = Connection(client, self.wake_up) + else: + connection = self.clients[readable] + connection.read() + if connection.status == WAIT_PROCESS: + itransport = TTransport.TMemoryBuffer(connection.message) + otransport = TTransport.TMemoryBuffer() + iprot = self.in_protocol.getProtocol(itransport) + oprot = self.out_protocol.getProtocol(otransport) + self.tasks.put([self.processor, iprot, oprot, + otransport, connection.ready]) + for writeable in wset: + self.clients[writeable].write() + for oob in xset: + self.clients[oob].close() + del self.clients[oob] + + def close(self): + """Closes the server.""" + for _ in xrange(self.threads): + self.tasks.put([None, None, None, None, None]) + self.socket.close() + self.prepared = False + + def serve(self): + """Serve forever.""" + self.prepare() + while True: + self.handle() diff --git a/lib/py/src/server/TServer.py b/lib/py/src/server/TServer.py new file mode 100644 index 00000000..61529111 --- /dev/null +++ b/lib/py/src/server/TServer.py @@ -0,0 +1,270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import logging +import sys +import os +import traceback +import threading +import Queue + +from thrift.Thrift import TProcessor +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +class TServer: + + """Base interface for a server, which must have a serve method.""" + + """ 3 constructors for all servers: + 1) (processor, serverTransport) + 2) (processor, serverTransport, transportFactory, protocolFactory) + 3) (processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory)""" + def __init__(self, *args): + if (len(args) == 2): + self.__initArgs__(args[0], args[1], + TTransport.TTransportFactoryBase(), + TTransport.TTransportFactoryBase(), + TBinaryProtocol.TBinaryProtocolFactory(), + TBinaryProtocol.TBinaryProtocolFactory()) + elif (len(args) == 4): + self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) + elif (len(args) == 6): + self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) + + def __initArgs__(self, processor, serverTransport, + inputTransportFactory, outputTransportFactory, + inputProtocolFactory, outputProtocolFactory): + self.processor = processor + self.serverTransport = serverTransport + self.inputTransportFactory = inputTransportFactory + self.outputTransportFactory = outputTransportFactory + self.inputProtocolFactory = inputProtocolFactory + self.outputProtocolFactory = outputProtocolFactory + + def serve(self): + pass + +class TSimpleServer(TServer): + + """Simple single-threaded server that just pumps around one transport.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + +class TThreadedServer(TServer): + + """Threaded server that spawns a new thread per each connection.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + + def serve(self): + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + t = threading.Thread(target = self.handle, args=(client,)) + t.start() + except KeyboardInterrupt: + raise + except Exception, x: + logging.exception(x) + + def handle(self, client): + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + +class TThreadPoolServer(TServer): + + """Server with a fixed size pool of threads which service requests.""" + + def __init__(self, *args): + TServer.__init__(self, *args) + self.clients = Queue.Queue() + self.threads = 10 + + def setNumThreads(self, num): + """Set the number of worker threads that should be created""" + self.threads = num + + def serveThread(self): + """Loop around getting clients from the shared queue and process them.""" + while True: + try: + client = self.clients.get() + self.serveClient(client) + except Exception, x: + logging.exception(x) + + def serveClient(self, client): + """Process input/output from a client for as long as possible""" + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + itrans.close() + otrans.close() + + def serve(self): + """Start a fixed number of worker threads and put client into a queue""" + for i in range(self.threads): + try: + t = threading.Thread(target = self.serveThread) + t.start() + except Exception, x: + logging.exception(x) + + # Pump the socket for clients + self.serverTransport.listen() + while True: + try: + client = self.serverTransport.accept() + self.clients.put(client) + except Exception, x: + logging.exception(x) + + +class TForkingServer(TServer): + + """A Thrift server that forks a new process for each request""" + """ + This is more scalable than the threaded server as it does not cause + GIL contention. + + Note that this has different semantics from the threading server. + Specifically, updates to shared variables will no longer be shared. + It will also not work on windows. + + This code is heavily inspired by SocketServer.ForkingMixIn in the + Python stdlib. + """ + + def __init__(self, *args): + TServer.__init__(self, *args) + self.children = [] + + def serve(self): + def try_close(file): + try: + file.close() + except IOError, e: + logging.warning(e, exc_info=True) + + + self.serverTransport.listen() + while True: + client = self.serverTransport.accept() + try: + pid = os.fork() + + if pid: # parent + # add before collect, otherwise you race w/ waitpid + self.children.append(pid) + self.collect_children() + + # Parent must close socket or the connection may not get + # closed promptly + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + try_close(itrans) + try_close(otrans) + else: + itrans = self.inputTransportFactory.getTransport(client) + otrans = self.outputTransportFactory.getTransport(client) + + iprot = self.inputProtocolFactory.getProtocol(itrans) + oprot = self.outputProtocolFactory.getProtocol(otrans) + + ecode = 0 + try: + try: + while True: + self.processor.process(iprot, oprot) + except TTransport.TTransportException, tx: + pass + except Exception, e: + logging.exception(e) + ecode = 1 + finally: + try_close(itrans) + try_close(otrans) + + os._exit(ecode) + + except TTransport.TTransportException, tx: + pass + except Exception, x: + logging.exception(x) + + + def collect_children(self): + while self.children: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except os.error: + pid = None + + if pid: + self.children.remove(pid) + else: + break + + diff --git a/lib/py/src/server/__init__.py b/lib/py/src/server/__init__.py new file mode 100644 index 00000000..1bf6e254 --- /dev/null +++ b/lib/py/src/server/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TServer', 'TNonblockingServer'] diff --git a/lib/py/src/transport/THttpClient.py b/lib/py/src/transport/THttpClient.py new file mode 100644 index 00000000..5086032b --- /dev/null +++ b/lib/py/src/transport/THttpClient.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +from cStringIO import StringIO + +import urlparse +import httplib +import warnings + +class THttpClient(TTransportBase): + + """Http implementation of TTransport base.""" + + def __init__(self, uri_or_host, port=None, path=None): + """THttpClient supports two different types constructor parameters. + + THttpClient(host, port, path) - deprecated + THttpClient(uri) + + Only the second supports https.""" + + if port is not None: + warnings.warn("Please use the THttpClient('http://host:port/path') syntax", DeprecationWarning, stacklevel=2) + self.host = uri_or_host + self.port = port + assert path + self.path = path + self.scheme = 'http' + else: + parsed = urlparse.urlparse(uri_or_host) + self.scheme = parsed.scheme + assert self.scheme in ('http', 'https') + if self.scheme == 'http': + self.port = parsed.port or httplib.HTTP_PORT + elif self.scheme == 'https': + self.port = parsed.port or httplib.HTTPS_PORT + self.host = parsed.hostname + self.path = parsed.path + self.__wbuf = StringIO() + self.__http = None + + def open(self): + if self.scheme == 'http': + self.__http = httplib.HTTP(self.host, self.port) + else: + self.__http = httplib.HTTPS(self.host, self.port) + + def close(self): + self.__http.close() + self.__http = None + + def isOpen(self): + return self.__http != None + + def read(self, sz): + return self.__http.file.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + if self.isOpen(): + self.close() + self.open(); + + # Pull data out of buffer + data = self.__wbuf.getvalue() + self.__wbuf = StringIO() + + # HTTP request + self.__http.putrequest('POST', self.path) + + # Write headers + self.__http.putheader('Host', self.host) + self.__http.putheader('Content-Type', 'application/x-thrift') + self.__http.putheader('Content-Length', str(len(data))) + self.__http.endheaders() + + # Write payload + self.__http.send(data) + + # Get reply to flush the request + self.code, self.message, self.headers = self.__http.getreply() diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py new file mode 100644 index 00000000..4645a023 --- /dev/null +++ b/lib/py/src/transport/TSocket.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from TTransport import * +import os +import errno +import socket + +class TSocketBase(TTransportBase): + def _resolveAddr(self): + if self._unix_socket is not None: + return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, self._unix_socket)] + else: + return socket.getaddrinfo(self.host, self.port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0, socket.AI_PASSIVE | socket.AI_ADDRCONFIG) + + def close(self): + if self.handle: + self.handle.close() + self.handle = None + +class TSocket(TSocketBase): + """Socket implementation of TTransport base.""" + + def __init__(self, host='localhost', port=9090, unix_socket=None): + """Initialize a TSocket + + @param host(str) The host to connect to. + @param port(int) The (TCP) port to connect to. + @param unix_socket(str) The filename of a unix socket to connect to. + (host and port will be ignored.) + """ + + self.host = host + self.port = port + self.handle = None + self._unix_socket = unix_socket + self._timeout = None + + def setHandle(self, h): + self.handle = h + + def isOpen(self): + return self.handle != None + + def setTimeout(self, ms): + if ms is None: + self._timeout = None + else: + self._timeout = ms/1000.0 + + if (self.handle != None): + self.handle.settimeout(self._timeout) + + def open(self): + try: + res0 = self._resolveAddr() + for res in res0: + self.handle = socket.socket(res[0], res[1]) + self.handle.settimeout(self._timeout) + try: + self.handle.connect(res[4]) + except socket.error, e: + if res is not res0[-1]: + continue + else: + raise e + break + except socket.error, e: + if self._unix_socket: + message = 'Could not connect to socket %s' % self._unix_socket + else: + message = 'Could not connect to %s:%d' % (self.host, self.port) + raise TTransportException(TTransportException.NOT_OPEN, message) + + def read(self, sz): + buff = self.handle.recv(sz) + if len(buff) == 0: + raise TTransportException('TSocket read 0 bytes') + return buff + + def write(self, buff): + sent = 0 + have = len(buff) + while sent < have: + plus = self.handle.send(buff) + if plus == 0: + raise TTransportException('TSocket sent 0 bytes') + sent += plus + buff = buff[plus:] + + def flush(self): + pass + +class TServerSocket(TSocketBase, TServerTransportBase): + """Socket implementation of TServerTransport base.""" + + def __init__(self, port=9090, unix_socket=None): + self.host = None + self.port = port + self._unix_socket = unix_socket + self.handle = None + + def listen(self): + res0 = self._resolveAddr() + for res in res0: + if res[0] is socket.AF_INET6 or res is res0[-1]: + break + + # We need remove the old unix socket if the file exists and + # nobody is listening on it. + if self._unix_socket: + tmp = socket.socket(res[0], res[1]) + try: + tmp.connect(res[4]) + except socket.error, err: + eno, message = err.args + if eno == errno.ECONNREFUSED: + os.unlink(res[4]) + + self.handle = socket.socket(res[0], res[1]) + self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if hasattr(self.handle, 'set_timeout'): + self.handle.set_timeout(None) + self.handle.bind(res[4]) + self.handle.listen(128) + + def accept(self): + client, addr = self.handle.accept() + result = TSocket() + result.setHandle(client) + return result diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py new file mode 100644 index 00000000..32553bdd --- /dev/null +++ b/lib/py/src/transport/TTransport.py @@ -0,0 +1,326 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from cStringIO import StringIO +from struct import pack,unpack +from thrift.Thrift import TException + +class TTransportException(TException): + + """Custom Transport Exception class""" + + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + + def __init__(self, type=UNKNOWN, message=None): + TException.__init__(self, message) + self.type = type + +class TTransportBase: + + """Base class for Thrift transport layer.""" + + def isOpen(self): + pass + + def open(self): + pass + + def close(self): + pass + + def read(self, sz): + pass + + def readAll(self, sz): + buff = '' + have = 0 + while (have < sz): + chunk = self.read(sz-have) + have += len(chunk) + buff += chunk + + if len(chunk) == 0: + raise EOFError() + + return buff + + def write(self, buf): + pass + + def flush(self): + pass + +# This class should be thought of as an interface. +class CReadableTransport: + """base class for transports that are readable from C""" + + # TODO(dreiss): Think about changing this interface to allow us to use + # a (Python, not c) StringIO instead, because it allows + # you to write after reading. + + # NOTE: This is a classic class, so properties will NOT work + # correctly for setting. + @property + def cstringio_buf(self): + """A cStringIO buffer that contains the current chunk we are reading.""" + pass + + def cstringio_refill(self, partialread, reqlen): + """Refills cstringio_buf. + + Returns the currently used buffer (which can but need not be the same as + the old cstringio_buf). partialread is what the C code has read from the + buffer, and should be inserted into the buffer before any more reads. The + return value must be a new, not borrowed reference. Something along the + lines of self._buf should be fine. + + If reqlen bytes can't be read, throw EOFError. + """ + pass + +class TServerTransportBase: + + """Base class for Thrift server transports.""" + + def listen(self): + pass + + def accept(self): + pass + + def close(self): + pass + +class TTransportFactoryBase: + + """Base class for a Transport Factory""" + + def getTransport(self, trans): + return trans + +class TBufferedTransportFactory: + + """Factory transport that builds buffered transports""" + + def getTransport(self, trans): + buffered = TBufferedTransport(trans) + return buffered + + +class TBufferedTransport(TTransportBase,CReadableTransport): + + """Class that wraps another transport and buffers its I/O.""" + + DEFAULT_BUFFER = 4096 + + def __init__(self, trans): + self.__trans = trans + self.__wbuf = StringIO() + self.__rbuf = StringIO("") + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.__rbuf = StringIO(self.__trans.read(max(sz, self.DEFAULT_BUFFER))) + return self.__rbuf.read(sz) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + out = self.__wbuf.getvalue() + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + self.__trans.write(out) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, partialread, reqlen): + retstring = partialread + if reqlen < self.DEFAULT_BUFFER: + # try to make a read of as much as we can. + retstring += self.__trans.read(self.DEFAULT_BUFFER) + + # but make sure we do read reqlen bytes. + if len(retstring) < reqlen: + retstring += self.__trans.readAll(reqlen - len(retstring)) + + self.__rbuf = StringIO(retstring) + return self.__rbuf + +class TMemoryBuffer(TTransportBase, CReadableTransport): + """Wraps a cStringIO object as a TTransport. + + NOTE: Unlike the C++ version of this class, you cannot write to it + then immediately read from it. If you want to read from a + TMemoryBuffer, you must either pass a string to the constructor. + TODO(dreiss): Make this work like the C++ version. + """ + + def __init__(self, value=None): + """value -- a value to read from for stringio + + If value is set, this will be a transport for reading, + otherwise, it is for writing""" + if value is not None: + self._buffer = StringIO(value) + else: + self._buffer = StringIO() + + def isOpen(self): + return not self._buffer.closed + + def open(self): + pass + + def close(self): + self._buffer.close() + + def read(self, sz): + return self._buffer.read(sz) + + def write(self, buf): + self._buffer.write(buf) + + def flush(self): + pass + + def getvalue(self): + return self._buffer.getvalue() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self._buffer + + def cstringio_refill(self, partialread, reqlen): + # only one shot at reading... + raise EOFError() + +class TFramedTransportFactory: + + """Factory transport that builds framed transports""" + + def getTransport(self, trans): + framed = TFramedTransport(trans) + return framed + + +class TFramedTransport(TTransportBase, CReadableTransport): + + """Class that wraps another transport and frames its I/O when writing.""" + + def __init__(self, trans,): + self.__trans = trans + self.__rbuf = StringIO() + self.__wbuf = StringIO() + + def isOpen(self): + return self.__trans.isOpen() + + def open(self): + return self.__trans.open() + + def close(self): + return self.__trans.close() + + def read(self, sz): + ret = self.__rbuf.read(sz) + if len(ret) != 0: + return ret + + self.readFrame() + return self.__rbuf.read(sz) + + def readFrame(self): + buff = self.__trans.readAll(4) + sz, = unpack('!i', buff) + self.__rbuf = StringIO(self.__trans.readAll(sz)) + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + wout = self.__wbuf.getvalue() + wsz = len(wout) + # reset wbuf before write/flush to preserve state on underlying failure + self.__wbuf = StringIO() + # N.B.: Doing this string concatenation is WAY cheaper than making + # two separate calls to the underlying socket object. Socket writes in + # Python turn out to be REALLY expensive, but it seems to do a pretty + # good job of managing string buffer operations without excessive copies + buf = pack("!i", wsz) + wout + self.__trans.write(buf) + self.__trans.flush() + + # Implement the CReadableTransport interface. + @property + def cstringio_buf(self): + return self.__rbuf + + def cstringio_refill(self, prefix, reqlen): + # self.__rbuf will already be empty here because fastbinary doesn't + # ask for a refill until the previous buffer is empty. Therefore, + # we can start reading new frames immediately. + while len(prefix) < reqlen: + readFrame() + prefix += self.__rbuf.getvalue() + self.__rbuf = StringIO(prefix) + return self.__rbuf + + +class TFileObjectTransport(TTransportBase): + """Wraps a file-like object to make it work as a Thrift transport.""" + + def __init__(self, fileobj): + self.fileobj = fileobj + + def isOpen(self): + return True + + def close(self): + self.fileobj.close() + + def read(self, sz): + return self.fileobj.read(sz) + + def write(self, buf): + self.fileobj.write(buf) + + def flush(self): + self.fileobj.flush() diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py new file mode 100644 index 00000000..b5c2147b --- /dev/null +++ b/lib/py/src/transport/TTwisted.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from zope.interface import implements, Interface, Attribute +from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \ + connectionDone +from twisted.internet import defer +from twisted.protocols import basic +from twisted.python import log + + +from thrift.transport import TTransport +from cStringIO import StringIO + + +class TMessageSenderTransport(TTransport.TTransportBase): + + def __init__(self): + self.__wbuf = StringIO() + + def write(self, buf): + self.__wbuf.write(buf) + + def flush(self): + msg = self.__wbuf.getvalue() + self.__wbuf = StringIO() + self.sendMessage(msg) + + def sendMessage(self, message): + raise NotImplementedError + + +class TCallbackTransport(TMessageSenderTransport): + + def __init__(self, func): + TMessageSenderTransport.__init__(self) + self.func = func + + def sendMessage(self, message): + self.func(message) + + +class ThriftClientProtocol(basic.Int32StringReceiver): + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self._client_class = client_class + self._iprot_factory = iprot_factory + if oprot_factory is None: + self._oprot_factory = iprot_factory + else: + self._oprot_factory = oprot_factory + + self.recv_map = {} + self.started = defer.Deferred() + + def dispatch(self, msg): + self.sendString(msg) + + def connectionMade(self): + tmo = TCallbackTransport(self.dispatch) + self.client = self._client_class(tmo, self._oprot_factory) + self.started.callback(self.client) + + def connectionLost(self, reason=connectionDone): + for k,v in self.client._reqs.iteritems(): + tex = TTransport.TTransportException( + type=TTransport.TTransportException.END_OF_FILE, + message='Connection closed') + v.errback(tex) + + def stringReceived(self, frame): + tr = TTransport.TMemoryBuffer(frame) + iprot = self._iprot_factory.getProtocol(tr) + (fname, mtype, rseqid) = iprot.readMessageBegin() + + try: + method = self.recv_map[fname] + except KeyError: + method = getattr(self.client, 'recv_' + fname) + self.recv_map[fname] = method + + method(iprot, mtype, rseqid) + + +class ThriftServerProtocol(basic.Int32StringReceiver): + + def dispatch(self, msg): + self.sendString(msg) + + def processError(self, error): + self.transport.loseConnection() + + def processOk(self, _, tmo): + msg = tmo.getvalue() + + if len(msg) > 0: + self.dispatch(msg) + + def stringReceived(self, frame): + tmi = TTransport.TMemoryBuffer(frame) + tmo = TTransport.TMemoryBuffer() + + iprot = self.factory.iprot_factory.getProtocol(tmi) + oprot = self.factory.oprot_factory.getProtocol(tmo) + + d = self.factory.processor.process(iprot, oprot) + d.addCallbacks(self.processOk, self.processError, + callbackArgs=(tmo,)) + + +class IThriftServerFactory(Interface): + + processor = Attribute("Thrift processor") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class IThriftClientFactory(Interface): + + client_class = Attribute("Thrift client class") + + iprot_factory = Attribute("Input protocol factory") + + oprot_factory = Attribute("Output protocol factory") + + +class ThriftServerFactory(ServerFactory): + + implements(IThriftServerFactory) + + protocol = ThriftServerProtocol + + def __init__(self, processor, iprot_factory, oprot_factory=None): + self.processor = processor + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + +class ThriftClientFactory(ClientFactory): + + implements(IThriftClientFactory) + + protocol = ThriftClientProtocol + + def __init__(self, client_class, iprot_factory, oprot_factory=None): + self.client_class = client_class + self.iprot_factory = iprot_factory + if oprot_factory is None: + self.oprot_factory = iprot_factory + else: + self.oprot_factory = oprot_factory + + def buildProtocol(self, addr): + p = self.protocol(self.client_class, self.iprot_factory, + self.oprot_factory) + p.factory = self + return p diff --git a/lib/py/src/transport/__init__.py b/lib/py/src/transport/__init__.py new file mode 100644 index 00000000..02c6048a --- /dev/null +++ b/lib/py/src/transport/__init__.py @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +__all__ = ['TTransport', 'TSocket', 'THttpClient'] diff --git a/lib/rb/CHANGELOG b/lib/rb/CHANGELOG new file mode 100644 index 00000000..b5dce2ae --- /dev/null +++ b/lib/rb/CHANGELOG @@ -0,0 +1 @@ +v0.0.1. Initial release diff --git a/lib/rb/Makefile.am b/lib/rb/Makefile.am new file mode 100644 index 00000000..9cfffc71 --- /dev/null +++ b/lib/rb/Makefile.am @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +EXTRA_DIST = \ + CHANGELOG \ + Rakefile \ + Manifest \ + setup.rb \ + lib \ + ext \ + benchmark \ + script \ + spec + +all-local: + $(RUBY) setup.rb config + $(RUBY) setup.rb setup + +install-exec-hook: + $(RUBY) setup.rb install + +# Make sure this doesn't fail if Ruby is not configured. +clean-local: + RUBY=$(RUBY) ; if test -z "$$RUBY" ; then RUBY=: ; fi ; \ + $$RUBY setup.rb clean + +check-local: all +if HAVE_RSPEC + rake spec +endif + diff --git a/lib/rb/Manifest b/lib/rb/Manifest new file mode 100644 index 00000000..7b4503fb --- /dev/null +++ b/lib/rb/Manifest @@ -0,0 +1,81 @@ +CHANGELOG +Manifest +Rakefile +README +setup.rb +benchmark/benchmark.rb +benchmark/Benchmark.thrift +benchmark/client.rb +benchmark/server.rb +benchmark/thin_server.rb +ext/binary_protocol_accelerated.c +ext/binary_protocol_accelerated.h +ext/compact_protocol.c +ext/compact_protocol.h +ext/constants.h +ext/extconf.rb +ext/macros.h +ext/memory_buffer.c +ext/memory_buffer.h +ext/protocol.c +ext/protocol.h +ext/struct.c +ext/struct.h +ext/thrift_native.c +lib/thrift.rb +lib/thrift/client.rb +lib/thrift/core_ext.rb +lib/thrift/exceptions.rb +lib/thrift/processor.rb +lib/thrift/struct.rb +lib/thrift/thrift_native.rb +lib/thrift/types.rb +lib/thrift/core_ext/fixnum.rb +lib/thrift/protocol/base_protocol.rb +lib/thrift/protocol/binary_protocol.rb +lib/thrift/protocol/binary_protocol_accelerated.rb +lib/thrift/protocol/compact_protocol.rb +lib/thrift/serializer/deserializer.rb +lib/thrift/serializer/serializer.rb +lib/thrift/server/base_server.rb +lib/thrift/server/mongrel_http_server.rb +lib/thrift/server/nonblocking_server.rb +lib/thrift/server/simple_server.rb +lib/thrift/server/thread_pool_server.rb +lib/thrift/server/threaded_server.rb +lib/thrift/transport/base_server_transport.rb +lib/thrift/transport/base_transport.rb +lib/thrift/transport/buffered_transport.rb +lib/thrift/transport/framed_transport.rb +lib/thrift/transport/http_client_transport.rb +lib/thrift/transport/io_stream_transport.rb +lib/thrift/transport/memory_buffer_transport.rb +lib/thrift/transport/server_socket.rb +lib/thrift/transport/socket.rb +lib/thrift/transport/unix_server_socket.rb +lib/thrift/transport/unix_socket.rb +script/proto_benchmark.rb +script/read_struct.rb +script/write_struct.rb +spec/base_protocol_spec.rb +spec/base_transport_spec.rb +spec/binary_protocol_accelerated_spec.rb +spec/binary_protocol_spec.rb +spec/binary_protocol_spec_shared.rb +spec/client_spec.rb +spec/compact_protocol_spec.rb +spec/exception_spec.rb +spec/http_client_spec.rb +spec/mongrel_http_server_spec.rb +spec/nonblocking_server_spec.rb +spec/processor_spec.rb +spec/serializer_spec.rb +spec/server_socket_spec.rb +spec/server_spec.rb +spec/socket_spec.rb +spec/socket_spec_shared.rb +spec/spec_helper.rb +spec/struct_spec.rb +spec/ThriftSpec.thrift +spec/types_spec.rb +spec/unix_socket_spec.rb diff --git a/lib/rb/README b/lib/rb/README new file mode 100644 index 00000000..d78e3527 --- /dev/null +++ b/lib/rb/README @@ -0,0 +1,43 @@ +Thrift Ruby Software Library + http://incubator.apache.org/thrift/ + +== LICENSE: + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +== DESCRIPTION: + +Thrift is a strongly-typed language-agnostic RPC system. +This library is the ruby implementation for both clients and servers. + +== INSTALL: + + $ gem install thrift + +== CAVEATS: + +This library provides the client and server implementations of thrift. +It does not provide the compiler for the .thrift files. To compile +.thrift files into language-specific implementations, please download the full +thrift software package. + +== USAGE: + +This section should get written by someone with the time and inclination. +In the meantime, look at existing code, such as the benchmark or the tutorial +in the full thrift distribution. diff --git a/lib/rb/Rakefile b/lib/rb/Rakefile new file mode 100644 index 00000000..1a9467a5 --- /dev/null +++ b/lib/rb/Rakefile @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +require 'rake' +require 'spec/rake/spectask' + +THRIFT = '../../compiler/cpp/thrift' + +task :default => [:spec] + +task :spec => [:'gen-rb', :realspec] + +Spec::Rake::SpecTask.new(:realspec) do |t| + t.spec_files = FileList['spec/**/*_spec.rb'] + t.spec_opts = ['--color'] +end + +Spec::Rake::SpecTask.new(:'spec:rcov') do |t| + t.spec_files = FileList['spec/**/*_spec.rb'] + t.spec_opts = ['--color'] + t.rcov = true + t.rcov_opts = ['--exclude', '^spec,/gems/'] +end + +desc 'Run the compiler tests (requires full thrift checkout)' +task :test do + # ensure this is a full thrift checkout and not a tarball of the ruby libs + cmd = 'head -1 ../../README 2>/dev/null | grep Thrift >/dev/null 2>/dev/null' + system(cmd) or fail "rake test requires a full thrift checkout" + sh 'make', '-C', File.dirname(__FILE__) + "/../../test/rb", "check" +end + +desc 'Compile the .thrift files for the specs' +task :'gen-rb' => [:'gen-rb:spec', :'gen-rb:benchmark', :'gen-rb:debug_proto'] + +namespace :'gen-rb' do + task :'spec' do + dir = File.dirname(__FILE__) + '/spec' + sh THRIFT, '--gen', 'rb', '-o', dir, "#{dir}/ThriftSpec.thrift" + end + + task :'benchmark' do + dir = File.dirname(__FILE__) + '/benchmark' + sh THRIFT, '--gen', 'rb', '-o', dir, "#{dir}/Benchmark.thrift" + end + + task :'debug_proto' do + sh "mkdir", "-p", "debug_proto_test" + sh THRIFT, '--gen', 'rb', "-o", "debug_proto_test", "../../test/DebugProtoTest.thrift" + end +end + +desc 'Run benchmarking of NonblockingServer' +task :benchmark do + ruby 'benchmark/benchmark.rb' +end + + +begin + require 'echoe' + + Echoe.new('thrift') do |p| + p.author = ['Kevin Ballard', 'Kevin Clark', 'Mark Slee'] + p.email = ['kevin@sb.org', 'kevin.clark@gmail.com', 'mcslee@facebook.com'] + p.summary = "Ruby libraries for Thrift (a language-agnostic RPC system)" + p.url = "http://incubator.apache.org/thrift/" + p.include_rakefile = true + p.version = "0.1.0" + end + + task :install => [:check_site_lib] + + require 'rbconfig' + task :check_site_lib do + if File.exist?(File.join(Config::CONFIG['sitelibdir'], 'thrift.rb')) + fail "thrift is already installed in site_ruby" + end + end +rescue LoadError + [:install, :package].each do |t| + desc "Stub for #{t}" + task t do + fail "The Echoe gem is required for this task" + end + end +end diff --git a/lib/rb/benchmark/Benchmark.thrift b/lib/rb/benchmark/Benchmark.thrift new file mode 100644 index 00000000..eb5ae38e --- /dev/null +++ b/lib/rb/benchmark/Benchmark.thrift @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +namespace rb ThriftBenchmark + +service BenchmarkService { + i32 fibonacci(1:byte n) +} diff --git a/lib/rb/benchmark/benchmark.rb b/lib/rb/benchmark/benchmark.rb new file mode 100644 index 00000000..3dc67dd8 --- /dev/null +++ b/lib/rb/benchmark/benchmark.rb @@ -0,0 +1,271 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +require 'stringio' + +HOST = '127.0.0.1' +PORT = 42587 + +############### +## Server +############### + +class Server + attr_accessor :serverclass + attr_accessor :interpreter + attr_accessor :host + attr_accessor :port + + def initialize(opts) + @serverclass = opts.fetch(:class, Thrift::NonblockingServer) + @interpreter = opts.fetch(:interpreter, "ruby") + @host = opts.fetch(:host, ::HOST) + @port = opts.fetch(:port, ::PORT) + end + + def start + return if @serverclass == Object + args = (File.basename(@interpreter) == "jruby" ? "-J-server" : "") + @pipe = IO.popen("#{@interpreter} #{args} #{File.dirname(__FILE__)}/server.rb #{@host} #{@port} #{@serverclass.name}", "r+") + Marshal.load(@pipe) # wait until the server has started + sleep 0.4 # give the server time to actually start spawning sockets + end + + def shutdown + return unless @pipe + Marshal.dump(:shutdown, @pipe) + begin + @pipe.read(10) # block until the server shuts down + rescue EOFError + end + @pipe.close + @pipe = nil + end +end + +class BenchmarkManager + def initialize(opts, server) + @socket = opts.fetch(:socket) do + @host = opts.fetch(:host, 'localhost') + @port = opts.fetch(:port) + nil + end + @num_processes = opts.fetch(:num_processes, 40) + @clients_per_process = opts.fetch(:clients_per_process, 10) + @calls_per_client = opts.fetch(:calls_per_client, 50) + @interpreter = opts.fetch(:interpreter, "ruby") + @server = server + @log_exceptions = opts.fetch(:log_exceptions, false) + end + + def run + @pool = [] + @benchmark_start = Time.now + puts "Spawning benchmark processes..." + @num_processes.times do + spawn + sleep 0.02 # space out spawns + end + collect_output + @benchmark_end = Time.now # we know the procs are done here + translate_output + analyze_output + report_output + end + + def spawn + pipe = IO.popen("#{@interpreter} #{File.dirname(__FILE__)}/client.rb #{"-log-exceptions" if @log_exceptions} #{@host} #{@port} #{@clients_per_process} #{@calls_per_client}") + @pool << pipe + end + + def socket_class + if @socket + Thrift::UNIXSocket + else + Thrift::Socket + end + end + + def collect_output + puts "Collecting output..." + # read from @pool until all sockets are closed + @buffers = Hash.new { |h,k| h[k] = '' } + until @pool.empty? + rd, = select(@pool) + next if rd.nil? + rd.each do |fd| + begin + @buffers[fd] << fd.readpartial(4096) + rescue EOFError + @pool.delete fd + end + end + end + end + + def translate_output + puts "Translating output..." + @output = [] + @buffers.each do |fd, buffer| + strio = StringIO.new(buffer) + logs = [] + begin + loop do + logs << Marshal.load(strio) + end + rescue EOFError + @output << logs + end + end + end + + def analyze_output + puts "Analyzing output..." + call_times = [] + client_times = [] + connection_failures = [] + connection_errors = [] + shortest_call = 0 + shortest_client = 0 + longest_call = 0 + longest_client = 0 + @output.each do |logs| + cur_call, cur_client = nil + logs.each do |tok, time| + case tok + when :start + cur_client = time + when :call_start + cur_call = time + when :call_end + delta = time - cur_call + call_times << delta + longest_call = delta unless longest_call > delta + shortest_call = delta if shortest_call == 0 or delta < shortest_call + cur_call = nil + when :end + delta = time - cur_client + client_times << delta + longest_client = delta unless longest_client > delta + shortest_client = delta if shortest_client == 0 or delta < shortest_client + cur_client = nil + when :connection_failure + connection_failures << time + when :connection_error + connection_errors << time + end + end + end + @report = {} + @report[:total_calls] = call_times.inject(0.0) { |a,t| a += t } + @report[:avg_calls] = @report[:total_calls] / call_times.size + @report[:total_clients] = client_times.inject(0.0) { |a,t| a += t } + @report[:avg_clients] = @report[:total_clients] / client_times.size + @report[:connection_failures] = connection_failures.size + @report[:connection_errors] = connection_errors.size + @report[:shortest_call] = shortest_call + @report[:shortest_client] = shortest_client + @report[:longest_call] = longest_call + @report[:longest_client] = longest_client + @report[:total_benchmark_time] = @benchmark_end - @benchmark_start + @report[:fastthread] = $".include?('fastthread.bundle') + end + + def report_output + fmt = "%.4f seconds" + puts + tabulate "%d", + [["Server class", "%s"], @server.serverclass == Object ? "" : @server.serverclass], + [["Server interpreter", "%s"], @server.interpreter], + [["Client interpreter", "%s"], @interpreter], + [["Socket class", "%s"], socket_class], + ["Number of processes", @num_processes], + ["Clients per process", @clients_per_process], + ["Calls per client", @calls_per_client], + [["Using fastthread", "%s"], @report[:fastthread] ? "yes" : "no"] + puts + failures = (@report[:connection_failures] > 0) + tabulate fmt, + [["Connection failures", "%d", [:red, :bold]], @report[:connection_failures]], + [["Connection errors", "%d", [:red, :bold]], @report[:connection_errors]], + ["Average time per call", @report[:avg_calls]], + ["Average time per client (%d calls)" % @calls_per_client, @report[:avg_clients]], + ["Total time for all calls", @report[:total_calls]], + ["Real time for benchmarking", @report[:total_benchmark_time]], + ["Shortest call time", @report[:shortest_call]], + ["Longest call time", @report[:longest_call]], + ["Shortest client time (%d calls)" % @calls_per_client, @report[:shortest_client]], + ["Longest client time (%d calls)" % @calls_per_client, @report[:longest_client]] + end + + ANSI = { + :reset => 0, + :bold => 1, + :black => 30, + :red => 31, + :green => 32, + :yellow => 33, + :blue => 34, + :magenta => 35, + :cyan => 36, + :white => 37 + } + + def tabulate(fmt, *labels_and_values) + labels = labels_and_values.map { |l| Array === l ? l.first : l } + label_width = labels.inject(0) { |w,l| l.size > w ? l.size : w } + labels_and_values.each do |(l,v)| + f = fmt + l, f, c = l if Array === l + fmtstr = "%-#{label_width+1}s #{f}" + if STDOUT.tty? and c and v.to_i > 0 + fmtstr = "\e[#{[*c].map { |x| ANSI[x] } * ";"}m" + fmtstr + "\e[#{ANSI[:reset]}m" + end + puts fmtstr % [l+":", v] + end + end +end + +def resolve_const(const) + const and const.split('::').inject(Object) { |k,c| k.const_get(c) } +end + +puts "Starting server..." +args = {} +args[:interpreter] = ENV['THRIFT_SERVER_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby" +args[:class] = resolve_const(ENV['THRIFT_SERVER']) || Thrift::NonblockingServer +args[:host] = ENV['THRIFT_HOST'] || HOST +args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i +server = Server.new(args) +server.start + +args = {} +args[:host] = ENV['THRIFT_HOST'] || HOST +args[:port] = (ENV['THRIFT_PORT'] || PORT).to_i +args[:num_processes] = (ENV['THRIFT_NUM_PROCESSES'] || 40).to_i +args[:clients_per_process] = (ENV['THRIFT_NUM_CLIENTS'] || 5).to_i +args[:calls_per_client] = (ENV['THRIFT_NUM_CALLS'] || 50).to_i +args[:interpreter] = ENV['THRIFT_CLIENT_INTERPRETER'] || ENV['THRIFT_INTERPRETER'] || "ruby" +args[:log_exceptions] = !!ENV['THRIFT_LOG_EXCEPTIONS'] +BenchmarkManager.new(args, server).run + +server.shutdown diff --git a/lib/rb/benchmark/client.rb b/lib/rb/benchmark/client.rb new file mode 100644 index 00000000..703dc8f5 --- /dev/null +++ b/lib/rb/benchmark/client.rb @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' + +class Client + def initialize(host, port, clients_per_process, calls_per_client, log_exceptions) + @host = host + @port = port + @clients_per_process = clients_per_process + @calls_per_client = calls_per_client + @log_exceptions = log_exceptions + end + + def run + @clients_per_process.times do + socket = Thrift::Socket.new(@host, @port) + transport = Thrift::FramedTransport.new(socket) + protocol = Thrift::BinaryProtocol.new(transport) + client = ThriftBenchmark::BenchmarkService::Client.new(protocol) + begin + start = Time.now + transport.open + Marshal.dump [:start, start], STDOUT + rescue => e + Marshal.dump [:connection_failure, Time.now], STDOUT + print_exception e if @log_exceptions + else + begin + @calls_per_client.times do + Marshal.dump [:call_start, Time.now], STDOUT + client.fibonacci(15) + Marshal.dump [:call_end, Time.now], STDOUT + end + transport.close + Marshal.dump [:end, Time.now], STDOUT + rescue Thrift::TransportException => e + Marshal.dump [:connection_error, Time.now], STDOUT + print_exception e if @log_exceptions + end + end + end + end + + def print_exception(e) + STDERR.puts "ERROR: #{e.message}" + STDERR.puts "\t#{e.backtrace * "\n\t"}" + end +end + +log_exceptions = true if ARGV[0] == '-log-exceptions' and ARGV.shift + +host, port, clients_per_process, calls_per_client = ARGV + +Client.new(host, port.to_i, clients_per_process.to_i, calls_per_client.to_i, log_exceptions).run diff --git a/lib/rb/benchmark/server.rb b/lib/rb/benchmark/server.rb new file mode 100644 index 00000000..74e13f41 --- /dev/null +++ b/lib/rb/benchmark/server.rb @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' + +module Server + include Thrift + + class BenchmarkHandler + # 1-based index into the fibonacci sequence + def fibonacci(n) + seq = [1, 1] + 3.upto(n) do + seq << seq[-1] + seq[-2] + end + seq[n-1] # n is 1-based + end + end + + def self.start_server(host, port, serverClass) + handler = BenchmarkHandler.new + processor = ThriftBenchmark::BenchmarkService::Processor.new(handler) + transport = ServerSocket.new(host, port) + transport_factory = FramedTransportFactory.new + args = [processor, transport, transport_factory, nil, 20] + if serverClass == NonblockingServer + logger = Logger.new(STDERR) + logger.level = Logger::WARN + args << logger + end + server = serverClass.new(*args) + @server_thread = Thread.new do + server.serve + end + @server = server + end + + def self.shutdown + return if @server.nil? + if @server.respond_to? :shutdown + @server.shutdown + else + @server_thread.kill + end + end +end + +def resolve_const(const) + const and const.split('::').inject(Object) { |k,c| k.const_get(c) } +end + +host, port, serverklass = ARGV + +Server.start_server(host, port.to_i, resolve_const(serverklass)) + +# let our host know that the interpreter has started +# ideally we'd wait until the server was serving, but we don't have a hook for that +Marshal.dump(:started, STDOUT) +STDOUT.flush + +Marshal.load(STDIN) # wait until we're instructed to shut down + +Server.shutdown diff --git a/lib/rb/benchmark/thin_server.rb b/lib/rb/benchmark/thin_server.rb new file mode 100644 index 00000000..4de2eef3 --- /dev/null +++ b/lib/rb/benchmark/thin_server.rb @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/../lib' +require 'thrift' +$:.unshift File.dirname(__FILE__) + "/gen-rb" +require 'benchmark_service' +HOST = 'localhost' +PORT = 42587 + +class BenchmarkHandler + # 1-based index into the fibonacci sequence + def fibonacci(n) + seq = [1, 1] + 3.upto(n) do + seq << seq[-1] + seq[-2] + end + seq[n-1] # n is 1-based + end +end + +handler = BenchmarkHandler.new +processor = ThriftBenchmark::BenchmarkService::Processor.new(handler) +transport = Thrift::ServerSocket.new(HOST, PORT) +transport_factory = Thrift::FramedTransportFactory.new +logger = Logger.new(STDERR) +logger.level = Logger::WARN +Thrift::NonblockingServer.new(processor, transport, transport_factory, nil, 20, logger).serve diff --git a/lib/rb/ext/binary_protocol_accelerated.c b/lib/rb/ext/binary_protocol_accelerated.c new file mode 100644 index 00000000..728a0572 --- /dev/null +++ b/lib/rb/ext/binary_protocol_accelerated.c @@ -0,0 +1,474 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include "macros.h" + +VALUE rb_thrift_binary_proto_native_qmark(VALUE self) { + return Qtrue; +} + + + +static int VERSION_1; +static int VERSION_MASK; +static int TYPE_MASK; +static int BAD_VERSION; + +static void write_byte_direct(VALUE trans, int8_t b) { + WRITE(trans, (char*)&b, 1); +} + +static void write_i16_direct(VALUE trans, int16_t value) { + char data[2]; + + data[1] = value; + data[0] = (value >> 8); + + WRITE(trans, data, 2); +} + +static void write_i32_direct(VALUE trans, int32_t value) { + char data[4]; + + data[3] = value; + data[2] = (value >> 8); + data[1] = (value >> 16); + data[0] = (value >> 24); + + WRITE(trans, data, 4); +} + + +static void write_i64_direct(VALUE trans, int64_t value) { + char data[8]; + + data[7] = value; + data[6] = (value >> 8); + data[5] = (value >> 16); + data[4] = (value >> 24); + data[3] = (value >> 32); + data[2] = (value >> 40); + data[1] = (value >> 48); + data[0] = (value >> 56); + + WRITE(trans, data, 8); +} + +static void write_string_direct(VALUE trans, VALUE str) { + write_i32_direct(trans, RSTRING_LEN(str)); + rb_funcall(trans, write_method_id, 1, str); +} + +//-------------------------------- +// interface writing methods +//-------------------------------- + +VALUE rb_thrift_binary_proto_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_struct_begin(VALUE self, VALUE name) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) { + VALUE trans = GET_TRANSPORT(self); + VALUE strict_write = GET_STRICT_WRITE(self); + + if (strict_write == Qtrue) { + write_i32_direct(trans, VERSION_1 | FIX2INT(type)); + write_string_direct(trans, name); + write_i32_direct(trans, FIX2INT(seqid)); + } else { + write_string_direct(trans, name); + write_byte_direct(trans, FIX2INT(type)); + write_i32_direct(trans, FIX2INT(seqid)); + } + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_begin(VALUE self, VALUE name, VALUE type, VALUE id) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(type)); + write_i16_direct(trans, FIX2INT(id)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_field_stop(VALUE self) { + write_byte_direct(GET_TRANSPORT(self), TTYPE_STOP); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_map_begin(VALUE self, VALUE ktype, VALUE vtype, VALUE size) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(ktype)); + write_byte_direct(trans, FIX2INT(vtype)); + write_i32_direct(trans, FIX2INT(size)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_list_begin(VALUE self, VALUE etype, VALUE size) { + VALUE trans = GET_TRANSPORT(self); + write_byte_direct(trans, FIX2INT(etype)); + write_i32_direct(trans, FIX2INT(size)); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_set_begin(VALUE self, VALUE etype, VALUE size) { + rb_thrift_binary_proto_write_list_begin(self, etype, size); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_bool(VALUE self, VALUE b) { + write_byte_direct(GET_TRANSPORT(self), RTEST(b) ? 1 : 0); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_byte(VALUE self, VALUE byte) { + CHECK_NIL(byte); + write_byte_direct(GET_TRANSPORT(self), NUM2INT(byte)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i16(VALUE self, VALUE i16) { + CHECK_NIL(i16); + write_i16_direct(GET_TRANSPORT(self), FIX2INT(i16)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i32(VALUE self, VALUE i32) { + CHECK_NIL(i32); + write_i32_direct(GET_TRANSPORT(self), NUM2INT(i32)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_i64(VALUE self, VALUE i64) { + CHECK_NIL(i64); + write_i64_direct(GET_TRANSPORT(self), NUM2LL(i64)); + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_double(VALUE self, VALUE dub) { + CHECK_NIL(dub); + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t t; + } transfer; + transfer.f = RFLOAT_VALUE(rb_Float(dub)); + write_i64_direct(GET_TRANSPORT(self), transfer.t); + + return Qnil; +} + +VALUE rb_thrift_binary_proto_write_string(VALUE self, VALUE str) { + CHECK_NIL(str); + VALUE trans = GET_TRANSPORT(self); + write_string_direct(trans, str); + return Qnil; +} + +//--------------------------------------- +// interface reading methods +//--------------------------------------- + +VALUE rb_thrift_binary_proto_read_string(VALUE self); +VALUE rb_thrift_binary_proto_read_byte(VALUE self); +VALUE rb_thrift_binary_proto_read_i32(VALUE self); +VALUE rb_thrift_binary_proto_read_i16(VALUE self); + +static char read_byte_direct(VALUE self) { + VALUE buf = READ(self, 1); + return RSTRING_PTR(buf)[0]; +} + +static int16_t read_i16_direct(VALUE self) { + VALUE buf = READ(self, 2); + return (int16_t)(((uint8_t)(RSTRING_PTR(buf)[1])) | ((uint16_t)((RSTRING_PTR(buf)[0]) << 8))); +} + +static int32_t read_i32_direct(VALUE self) { + VALUE buf = READ(self, 4); + return ((uint8_t)(RSTRING_PTR(buf)[3])) | + (((uint8_t)(RSTRING_PTR(buf)[2])) << 8) | + (((uint8_t)(RSTRING_PTR(buf)[1])) << 16) | + (((uint8_t)(RSTRING_PTR(buf)[0])) << 24); +} + +static int64_t read_i64_direct(VALUE self) { + uint64_t hi = read_i32_direct(self); + uint32_t lo = read_i32_direct(self); + return (hi << 32) | lo; +} + +static VALUE get_protocol_exception(VALUE code, VALUE message) { + VALUE args[2]; + args[0] = code; + args[1] = message; + return rb_class_new_instance(2, (VALUE*)&args, protocol_exception_class); +} + +VALUE rb_thrift_binary_proto_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_struct_begin(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_binary_proto_read_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_binary_proto_read_message_begin(VALUE self) { + VALUE strict_read = GET_STRICT_READ(self); + VALUE name, seqid; + int type; + + int version = read_i32_direct(self); + + if (version < 0) { + if ((version & VERSION_MASK) != VERSION_1) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier"))); + } + type = version & TYPE_MASK; + name = rb_thrift_binary_proto_read_string(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } else { + if (strict_read == Qtrue) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("No version identifier, old protocol client?"))); + } + name = READ(self, version); + type = read_byte_direct(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } + + return rb_ary_new3(3, name, INT2FIX(type), seqid); +} + +VALUE rb_thrift_binary_proto_read_field_begin(VALUE self) { + int type = read_byte_direct(self); + if (type == TTYPE_STOP) { + return rb_ary_new3(3, Qnil, INT2FIX(type), INT2FIX(0)); + } else { + VALUE id = rb_thrift_binary_proto_read_i16(self); + return rb_ary_new3(3, Qnil, INT2FIX(type), id); + } +} + +VALUE rb_thrift_binary_proto_read_map_begin(VALUE self) { + VALUE ktype = rb_thrift_binary_proto_read_byte(self); + VALUE vtype = rb_thrift_binary_proto_read_byte(self); + VALUE size = rb_thrift_binary_proto_read_i32(self); + return rb_ary_new3(3, ktype, vtype, size); +} + +VALUE rb_thrift_binary_proto_read_list_begin(VALUE self) { + VALUE etype = rb_thrift_binary_proto_read_byte(self); + VALUE size = rb_thrift_binary_proto_read_i32(self); + return rb_ary_new3(2, etype, size); +} + +VALUE rb_thrift_binary_proto_read_set_begin(VALUE self) { + return rb_thrift_binary_proto_read_list_begin(self); +} + +VALUE rb_thrift_binary_proto_read_bool(VALUE self) { + char byte = read_byte_direct(self); + return byte != 0 ? Qtrue : Qfalse; +} + +VALUE rb_thrift_binary_proto_read_byte(VALUE self) { + return INT2FIX(read_byte_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i16(VALUE self) { + return INT2FIX(read_i16_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i32(VALUE self) { + return INT2NUM(read_i32_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_i64(VALUE self) { + return LL2NUM(read_i64_direct(self)); +} + +VALUE rb_thrift_binary_proto_read_double(VALUE self) { + union { + double f; + int64_t t; + } transfer; + transfer.t = read_i64_direct(self); + return rb_float_new(transfer.f); +} + +VALUE rb_thrift_binary_proto_read_string(VALUE self) { + int size = read_i32_direct(self); + return READ(self, size); +} + +void Init_binary_protocol_accelerated() { + VALUE thrift_binary_protocol_class = rb_const_get(thrift_module, rb_intern("BinaryProtocol")); + + VERSION_1 = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_1"))); + VERSION_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_MASK"))); + TYPE_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("TYPE_MASK"))); + + VALUE bpa_class = rb_define_class_under(thrift_module, "BinaryProtocolAccelerated", thrift_binary_protocol_class); + + rb_define_method(bpa_class, "native?", rb_thrift_binary_proto_native_qmark, 0); + + rb_define_method(bpa_class, "write_message_begin", rb_thrift_binary_proto_write_message_begin, 3); + rb_define_method(bpa_class, "write_field_begin", rb_thrift_binary_proto_write_field_begin, 3); + rb_define_method(bpa_class, "write_field_stop", rb_thrift_binary_proto_write_field_stop, 0); + rb_define_method(bpa_class, "write_map_begin", rb_thrift_binary_proto_write_map_begin, 3); + rb_define_method(bpa_class, "write_list_begin", rb_thrift_binary_proto_write_list_begin, 2); + rb_define_method(bpa_class, "write_set_begin", rb_thrift_binary_proto_write_set_begin, 2); + rb_define_method(bpa_class, "write_byte", rb_thrift_binary_proto_write_byte, 1); + rb_define_method(bpa_class, "write_bool", rb_thrift_binary_proto_write_bool, 1); + rb_define_method(bpa_class, "write_i16", rb_thrift_binary_proto_write_i16, 1); + rb_define_method(bpa_class, "write_i32", rb_thrift_binary_proto_write_i32, 1); + rb_define_method(bpa_class, "write_i64", rb_thrift_binary_proto_write_i64, 1); + rb_define_method(bpa_class, "write_double", rb_thrift_binary_proto_write_double, 1); + rb_define_method(bpa_class, "write_string", rb_thrift_binary_proto_write_string, 1); + // unused methods + rb_define_method(bpa_class, "write_message_end", rb_thrift_binary_proto_write_message_end, 0); + rb_define_method(bpa_class, "write_struct_begin", rb_thrift_binary_proto_write_struct_begin, 1); + rb_define_method(bpa_class, "write_struct_end", rb_thrift_binary_proto_write_struct_end, 0); + rb_define_method(bpa_class, "write_field_end", rb_thrift_binary_proto_write_field_end, 0); + rb_define_method(bpa_class, "write_map_end", rb_thrift_binary_proto_write_map_end, 0); + rb_define_method(bpa_class, "write_list_end", rb_thrift_binary_proto_write_list_end, 0); + rb_define_method(bpa_class, "write_set_end", rb_thrift_binary_proto_write_set_end, 0); + + + + rb_define_method(bpa_class, "read_message_begin", rb_thrift_binary_proto_read_message_begin, 0); + rb_define_method(bpa_class, "read_field_begin", rb_thrift_binary_proto_read_field_begin, 0); + rb_define_method(bpa_class, "read_map_begin", rb_thrift_binary_proto_read_map_begin, 0); + rb_define_method(bpa_class, "read_list_begin", rb_thrift_binary_proto_read_list_begin, 0); + rb_define_method(bpa_class, "read_set_begin", rb_thrift_binary_proto_read_set_begin, 0); + rb_define_method(bpa_class, "read_byte", rb_thrift_binary_proto_read_byte, 0); + rb_define_method(bpa_class, "read_bool", rb_thrift_binary_proto_read_bool, 0); + rb_define_method(bpa_class, "read_i16", rb_thrift_binary_proto_read_i16, 0); + rb_define_method(bpa_class, "read_i32", rb_thrift_binary_proto_read_i32, 0); + rb_define_method(bpa_class, "read_i64", rb_thrift_binary_proto_read_i64, 0); + rb_define_method(bpa_class, "read_double", rb_thrift_binary_proto_read_double, 0); + rb_define_method(bpa_class, "read_string", rb_thrift_binary_proto_read_string, 0); + // unused methods + rb_define_method(bpa_class, "read_message_end", rb_thrift_binary_proto_read_message_end, 0); + rb_define_method(bpa_class, "read_struct_begin", rb_thift_binary_proto_read_struct_begin, 0); + rb_define_method(bpa_class, "read_struct_end", rb_thift_binary_proto_read_struct_end, 0); + rb_define_method(bpa_class, "read_field_end", rb_thift_binary_proto_read_field_end, 0); + rb_define_method(bpa_class, "read_map_end", rb_thift_binary_proto_read_map_end, 0); + rb_define_method(bpa_class, "read_list_end", rb_thift_binary_proto_read_list_end, 0); + rb_define_method(bpa_class, "read_set_end", rb_thift_binary_proto_read_set_end, 0); + + // set up native method table + native_proto_method_table *npmt; + npmt = ALLOC(native_proto_method_table); + + npmt->write_field_begin = rb_thrift_binary_proto_write_field_begin; + npmt->write_field_stop = rb_thrift_binary_proto_write_field_stop; + npmt->write_map_begin = rb_thrift_binary_proto_write_map_begin; + npmt->write_list_begin = rb_thrift_binary_proto_write_list_begin; + npmt->write_set_begin = rb_thrift_binary_proto_write_set_begin; + npmt->write_byte = rb_thrift_binary_proto_write_byte; + npmt->write_bool = rb_thrift_binary_proto_write_bool; + npmt->write_i16 = rb_thrift_binary_proto_write_i16; + npmt->write_i32 = rb_thrift_binary_proto_write_i32; + npmt->write_i64 = rb_thrift_binary_proto_write_i64; + npmt->write_double = rb_thrift_binary_proto_write_double; + npmt->write_string = rb_thrift_binary_proto_write_string; + npmt->write_message_end = rb_thrift_binary_proto_write_message_end; + npmt->write_struct_begin = rb_thrift_binary_proto_write_struct_begin; + npmt->write_struct_end = rb_thrift_binary_proto_write_struct_end; + npmt->write_field_end = rb_thrift_binary_proto_write_field_end; + npmt->write_map_end = rb_thrift_binary_proto_write_map_end; + npmt->write_list_end = rb_thrift_binary_proto_write_list_end; + npmt->write_set_end = rb_thrift_binary_proto_write_set_end; + + npmt->read_message_begin = rb_thrift_binary_proto_read_message_begin; + npmt->read_field_begin = rb_thrift_binary_proto_read_field_begin; + npmt->read_map_begin = rb_thrift_binary_proto_read_map_begin; + npmt->read_list_begin = rb_thrift_binary_proto_read_list_begin; + npmt->read_set_begin = rb_thrift_binary_proto_read_set_begin; + npmt->read_byte = rb_thrift_binary_proto_read_byte; + npmt->read_bool = rb_thrift_binary_proto_read_bool; + npmt->read_i16 = rb_thrift_binary_proto_read_i16; + npmt->read_i32 = rb_thrift_binary_proto_read_i32; + npmt->read_i64 = rb_thrift_binary_proto_read_i64; + npmt->read_double = rb_thrift_binary_proto_read_double; + npmt->read_string = rb_thrift_binary_proto_read_string; + npmt->read_message_end = rb_thrift_binary_proto_read_message_end; + npmt->read_struct_begin = rb_thift_binary_proto_read_struct_begin; + npmt->read_struct_end = rb_thift_binary_proto_read_struct_end; + npmt->read_field_end = rb_thift_binary_proto_read_field_end; + npmt->read_map_end = rb_thift_binary_proto_read_map_end; + npmt->read_list_end = rb_thift_binary_proto_read_list_end; + npmt->read_set_end = rb_thift_binary_proto_read_set_end; + + VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + rb_const_set(bpa_class, rb_intern("@native_method_table"), method_table_object); +} diff --git a/lib/rb/ext/binary_protocol_accelerated.h b/lib/rb/ext/binary_protocol_accelerated.h new file mode 100644 index 00000000..37baf414 --- /dev/null +++ b/lib/rb/ext/binary_protocol_accelerated.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_binary_protocol_accelerated(); diff --git a/lib/rb/ext/compact_protocol.c b/lib/rb/ext/compact_protocol.c new file mode 100644 index 00000000..7966d3e3 --- /dev/null +++ b/lib/rb/ext/compact_protocol.c @@ -0,0 +1,665 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include "macros.h" + +#define LAST_ID(obj) FIX2INT(rb_ary_pop(rb_ivar_get(obj, last_field_id))) +#define SET_LAST_ID(obj, val) rb_ary_push(rb_ivar_get(obj, last_field_id), val) + +VALUE rb_thrift_compact_proto_native_qmark(VALUE self) { + return Qtrue; +} + +static ID last_field_id; +static ID boolean_field_id; +static ID bool_value_id; + +static int VERSION; +static int VERSION_MASK; +static int TYPE_MASK; +static int TYPE_SHIFT_AMOUNT; +static int PROTOCOL_ID; + +static VALUE thrift_compact_protocol_class; + +static int CTYPE_BOOLEAN_TRUE = 0x01; +static int CTYPE_BOOLEAN_FALSE = 0x02; +static int CTYPE_BYTE = 0x03; +static int CTYPE_I16 = 0x04; +static int CTYPE_I32 = 0x05; +static int CTYPE_I64 = 0x06; +static int CTYPE_DOUBLE = 0x07; +static int CTYPE_BINARY = 0x08; +static int CTYPE_LIST = 0x09; +static int CTYPE_SET = 0x0A; +static int CTYPE_MAP = 0x0B; +static int CTYPE_STRUCT = 0x0C; + +VALUE rb_thrift_compact_proto_write_i16(VALUE self, VALUE i16); + +// TODO: implement this +static int get_compact_type(VALUE type_value) { + int type = FIX2INT(type_value); + if (type == TTYPE_BOOL) { + return CTYPE_BOOLEAN_TRUE; + } else if (type == TTYPE_BYTE) { + return CTYPE_BYTE; + } else if (type == TTYPE_I16) { + return CTYPE_I16; + } else if (type == TTYPE_I32) { + return CTYPE_I32; + } else if (type == TTYPE_I64) { + return CTYPE_I64; + } else if (type == TTYPE_DOUBLE) { + return CTYPE_DOUBLE; + } else if (type == TTYPE_STRING) { + return CTYPE_BINARY; + } else if (type == TTYPE_LIST) { + return CTYPE_LIST; + } else if (type == TTYPE_SET) { + return CTYPE_SET; + } else if (type == TTYPE_MAP) { + return CTYPE_MAP; + } else if (type == TTYPE_STRUCT) { + return CTYPE_STRUCT; + } else { + char str[50]; + sprintf(str, "don't know what type: %d", type); + rb_raise(rb_eStandardError, str); + return 0; + } +} + +static void write_byte_direct(VALUE transport, int8_t b) { + WRITE(transport, (char*)&b, 1); +} + +static void write_field_begin_internal(VALUE self, VALUE type, VALUE id_value, VALUE type_override) { + int id = FIX2INT(id_value); + int last_id = LAST_ID(self); + VALUE transport = GET_TRANSPORT(self); + + // if there's a type override, use that. + int8_t type_to_write = RTEST(type_override) ? FIX2INT(type_override) : get_compact_type(type); + // check if we can use delta encoding for the field id + int diff = id - last_id; + if (diff > 0 && diff <= 15) { + // write them together + write_byte_direct(transport, diff << 4 | (type_to_write & 0x0f)); + } else { + // write them separate + write_byte_direct(transport, type_to_write & 0x0f); + rb_thrift_compact_proto_write_i16(self, id_value); + } + + SET_LAST_ID(self, id_value); +} + +static int32_t int_to_zig_zag(int32_t n) { + return (n << 1) ^ (n >> 31); +} + +static uint64_t ll_to_zig_zag(int64_t n) { + return (n << 1) ^ (n >> 63); +} + +static void write_varint32(VALUE transport, uint32_t n) { + while (true) { + if ((n & ~0x7F) == 0) { + write_byte_direct(transport, n & 0x7f); + break; + } else { + write_byte_direct(transport, (n & 0x7F) | 0x80); + n = n >> 7; + } + } +} + +static void write_varint64(VALUE transport, uint64_t n) { + while (true) { + if ((n & ~0x7F) == 0) { + write_byte_direct(transport, n & 0x7f); + break; + } else { + write_byte_direct(transport, (n & 0x7F) | 0x80); + n = n >> 7; + } + } +} + +static void write_collection_begin(VALUE transport, VALUE elem_type, VALUE size_value) { + int size = FIX2INT(size_value); + if (size <= 14) { + write_byte_direct(transport, size << 4 | get_compact_type(elem_type)); + } else { + write_byte_direct(transport, 0xf0 | get_compact_type(elem_type)); + write_varint32(transport, size); + } +} + + +//-------------------------------- +// interface writing methods +//-------------------------------- + +VALUE rb_thrift_compact_proto_write_i32(VALUE self, VALUE i32); +VALUE rb_thrift_compact_proto_write_string(VALUE self, VALUE str); + +VALUE rb_thrift_compact_proto_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_struct_begin(VALUE self, VALUE name) { + rb_ary_push(rb_ivar_get(self, last_field_id), INT2FIX(0)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_struct_end(VALUE self) { + rb_ary_pop(rb_ivar_get(self, last_field_id)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) { + VALUE transport = GET_TRANSPORT(self); + write_byte_direct(transport, PROTOCOL_ID); + write_byte_direct(transport, (VERSION & VERSION_MASK) | ((FIX2INT(type) << TYPE_SHIFT_AMOUNT) & TYPE_MASK)); + write_varint32(transport, FIX2INT(seqid)); + rb_thrift_compact_proto_write_string(self, name); + + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_begin(VALUE self, VALUE name, VALUE type, VALUE id) { + if (FIX2INT(type) == TTYPE_BOOL) { + // we want to possibly include the value, so we'll wait. + rb_ivar_set(self, boolean_field_id, rb_ary_new3(2, type, id)); + } else { + write_field_begin_internal(self, type, id, Qnil); + } + + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_field_stop(VALUE self) { + write_byte_direct(GET_TRANSPORT(self), TTYPE_STOP); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_map_begin(VALUE self, VALUE ktype, VALUE vtype, VALUE size_value) { + int size = FIX2INT(size_value); + VALUE transport = GET_TRANSPORT(self); + if (size == 0) { + write_byte_direct(transport, 0); + } else { + write_varint32(transport, size); + write_byte_direct(transport, get_compact_type(ktype) << 4 | get_compact_type(vtype)); + } + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_list_begin(VALUE self, VALUE etype, VALUE size) { + write_collection_begin(GET_TRANSPORT(self), etype, size); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_set_begin(VALUE self, VALUE etype, VALUE size) { + write_collection_begin(GET_TRANSPORT(self), etype, size); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_bool(VALUE self, VALUE b) { + int8_t type = b == Qtrue ? CTYPE_BOOLEAN_TRUE : CTYPE_BOOLEAN_FALSE; + VALUE boolean_field = rb_ivar_get(self, boolean_field_id); + if (NIL_P(boolean_field)) { + // we're not part of a field, so just write the value. + write_byte_direct(GET_TRANSPORT(self), type); + } else { + // we haven't written the field header yet + write_field_begin_internal(self, rb_ary_entry(boolean_field, 0), rb_ary_entry(boolean_field, 1), INT2FIX(type)); + rb_ivar_set(self, boolean_field_id, Qnil); + } + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_byte(VALUE self, VALUE byte) { + CHECK_NIL(byte); + write_byte_direct(GET_TRANSPORT(self), FIX2INT(byte)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i16(VALUE self, VALUE i16) { + rb_thrift_compact_proto_write_i32(self, i16); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i32(VALUE self, VALUE i32) { + CHECK_NIL(i32); + write_varint32(GET_TRANSPORT(self), int_to_zig_zag(NUM2INT(i32))); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_i64(VALUE self, VALUE i64) { + CHECK_NIL(i64); + write_varint64(GET_TRANSPORT(self), ll_to_zig_zag(NUM2LL(i64))); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_double(VALUE self, VALUE dub) { + CHECK_NIL(dub); + // Unfortunately, bitwise_cast doesn't work in C. Bad C! + union { + double f; + int64_t l; + } transfer; + transfer.f = RFLOAT_VALUE(rb_Float(dub)); + char buf[8]; + buf[0] = transfer.l & 0xff; + buf[1] = (transfer.l >> 8) & 0xff; + buf[2] = (transfer.l >> 16) & 0xff; + buf[3] = (transfer.l >> 24) & 0xff; + buf[4] = (transfer.l >> 32) & 0xff; + buf[5] = (transfer.l >> 40) & 0xff; + buf[6] = (transfer.l >> 48) & 0xff; + buf[7] = (transfer.l >> 56) & 0xff; + WRITE(GET_TRANSPORT(self), buf, 8); + return Qnil; +} + +VALUE rb_thrift_compact_proto_write_string(VALUE self, VALUE str) { + VALUE transport = GET_TRANSPORT(self); + write_varint32(transport, RSTRING_LEN(str)); + WRITE(transport, RSTRING_PTR(str), RSTRING_LEN(str)); + return Qnil; +} + +//--------------------------------------- +// interface reading methods +//--------------------------------------- + +#define is_bool_type(ctype) (((ctype) & 0x0F) == CTYPE_BOOLEAN_TRUE || ((ctype) & 0x0F) == CTYPE_BOOLEAN_FALSE) + +VALUE rb_thrift_compact_proto_read_string(VALUE self); +VALUE rb_thrift_compact_proto_read_byte(VALUE self); +VALUE rb_thrift_compact_proto_read_i32(VALUE self); +VALUE rb_thrift_compact_proto_read_i16(VALUE self); + +static int8_t get_ttype(int8_t ctype) { + if (ctype == TTYPE_STOP) { + return TTYPE_STOP; + } else if (ctype == CTYPE_BOOLEAN_TRUE || ctype == CTYPE_BOOLEAN_FALSE) { + return TTYPE_BOOL; + } else if (ctype == CTYPE_BYTE) { + return TTYPE_BYTE; + } else if (ctype == CTYPE_I16) { + return TTYPE_I16; + } else if (ctype == CTYPE_I32) { + return TTYPE_I32; + } else if (ctype == CTYPE_I64) { + return TTYPE_I64; + } else if (ctype == CTYPE_DOUBLE) { + return TTYPE_DOUBLE; + } else if (ctype == CTYPE_BINARY) { + return TTYPE_STRING; + } else if (ctype == CTYPE_LIST) { + return TTYPE_LIST; + } else if (ctype == CTYPE_SET) { + return TTYPE_SET; + } else if (ctype == CTYPE_MAP) { + return TTYPE_MAP; + } else if (ctype == CTYPE_STRUCT) { + return TTYPE_STRUCT; + } else { + char str[50]; + sprintf(str, "don't know what type: %d", ctype); + rb_raise(rb_eStandardError, str); + return 0; + } +} + +static char read_byte_direct(VALUE self) { + VALUE buf = READ(self, 1); + return RSTRING_PTR(buf)[0]; +} + +static int64_t zig_zag_to_ll(int64_t n) { + return (((uint64_t)n) >> 1) ^ -(n & 1); +} + +static int32_t zig_zag_to_int(int32_t n) { + return (((uint32_t)n) >> 1) ^ -(n & 1); +} + +static int64_t read_varint64(VALUE self) { + int shift = 0; + int64_t result = 0; + while (true) { + int8_t b = read_byte_direct(self); + result = result | ((uint64_t)(b & 0x7f) << shift); + if ((b & 0x80) != 0x80) { + break; + } + shift += 7; + } + return result; +} + +static int16_t read_i16(VALUE self) { + return zig_zag_to_int((int32_t)read_varint64(self)); +} + +static VALUE get_protocol_exception(VALUE code, VALUE message) { + VALUE args[2]; + args[0] = code; + args[1] = message; + return rb_class_new_instance(2, (VALUE*)&args, protocol_exception_class); +} + +VALUE rb_thrift_compact_proto_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_struct_begin(VALUE self) { + rb_ary_push(rb_ivar_get(self, last_field_id), INT2FIX(0)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_struct_end(VALUE self) { + rb_ary_pop(rb_ivar_get(self, last_field_id)); + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_compact_proto_read_message_begin(VALUE self) { + int8_t protocol_id = read_byte_direct(self); + if (protocol_id != PROTOCOL_ID) { + char buf[100]; + int len = sprintf(buf, "Expected protocol id %d but got %d", PROTOCOL_ID, protocol_id); + buf[len] = 0; + rb_exc_raise(get_protocol_exception(INT2FIX(-1), rb_str_new2(buf))); + } + + int8_t version_and_type = read_byte_direct(self); + int8_t version = version_and_type & VERSION_MASK; + if (version != VERSION) { + char buf[100]; + int len = sprintf(buf, "Expected version id %d but got %d", version, VERSION); + buf[len] = 0; + rb_exc_raise(get_protocol_exception(INT2FIX(-1), rb_str_new2(buf))); + } + + int8_t type = (version_and_type >> TYPE_SHIFT_AMOUNT) & 0x03; + int32_t seqid = read_varint64(self); + VALUE messageName = rb_thrift_compact_proto_read_string(self); + return rb_ary_new3(3, messageName, INT2FIX(type), INT2NUM(seqid)); +} + +VALUE rb_thrift_compact_proto_read_field_begin(VALUE self) { + int8_t type = read_byte_direct(self); + // if it's a stop, then we can return immediately, as the struct is over. + if ((type & 0x0f) == TTYPE_STOP) { + return rb_ary_new3(3, Qnil, INT2FIX(0), INT2FIX(0)); + } else { + int field_id = 0; + + // mask off the 4 MSB of the type header. it could contain a field id delta. + uint8_t modifier = ((type & 0xf0) >> 4); + + if (modifier == 0) { + // not a delta. look ahead for the zigzag varint field id. + field_id = read_i16(self); + } else { + // has a delta. add the delta to the last read field id. + field_id = LAST_ID(self) + modifier; + } + + // if this happens to be a boolean field, the value is encoded in the type + if (is_bool_type(type)) { + // save the boolean value in a special instance variable. + rb_ivar_set(self, bool_value_id, (type & 0x0f) == CTYPE_BOOLEAN_TRUE ? Qtrue : Qfalse); + } + + // push the new field onto the field stack so we can keep the deltas going. + SET_LAST_ID(self, INT2FIX(field_id)); + return rb_ary_new3(3, Qnil, INT2FIX(get_ttype(type & 0x0f)), INT2FIX(field_id)); + } +} + +VALUE rb_thrift_compact_proto_read_map_begin(VALUE self) { + int32_t size = read_varint64(self); + uint8_t key_and_value_type = size == 0 ? 0 : read_byte_direct(self); + return rb_ary_new3(3, INT2FIX(get_ttype(key_and_value_type >> 4)), INT2FIX(get_ttype(key_and_value_type & 0xf)), INT2FIX(size)); +} + +VALUE rb_thrift_compact_proto_read_list_begin(VALUE self) { + uint8_t size_and_type = read_byte_direct(self); + int32_t size = (size_and_type >> 4) & 0x0f; + if (size == 15) { + size = read_varint64(self); + } + uint8_t type = get_ttype(size_and_type & 0x0f); + return rb_ary_new3(2, INT2FIX(type), INT2FIX(size)); +} + +VALUE rb_thrift_compact_proto_read_set_begin(VALUE self) { + return rb_thrift_compact_proto_read_list_begin(self); +} + +VALUE rb_thrift_compact_proto_read_bool(VALUE self) { + VALUE bool_value = rb_ivar_get(self, bool_value_id); + if (NIL_P(bool_value)) { + return read_byte_direct(self) == CTYPE_BOOLEAN_TRUE ? Qtrue : Qfalse; + } else { + rb_ivar_set(self, bool_value_id, Qnil); + return bool_value; + } +} + +VALUE rb_thrift_compact_proto_read_byte(VALUE self) { + return INT2FIX(read_byte_direct(self)); +} + +VALUE rb_thrift_compact_proto_read_i16(VALUE self) { + return INT2FIX(read_i16(self)); +} + +VALUE rb_thrift_compact_proto_read_i32(VALUE self) { + return INT2NUM(zig_zag_to_int(read_varint64(self))); +} + +VALUE rb_thrift_compact_proto_read_i64(VALUE self) { + return LL2NUM(zig_zag_to_ll(read_varint64(self))); +} + +VALUE rb_thrift_compact_proto_read_double(VALUE self) { + union { + double f; + int64_t l; + } transfer; + VALUE bytes = READ(self, 8); + uint32_t lo = ((uint8_t)(RSTRING_PTR(bytes)[0])) + | (((uint8_t)(RSTRING_PTR(bytes)[1])) << 8) + | (((uint8_t)(RSTRING_PTR(bytes)[2])) << 16) + | (((uint8_t)(RSTRING_PTR(bytes)[3])) << 24); + uint64_t hi = (((uint8_t)(RSTRING_PTR(bytes)[4]))) + | (((uint8_t)(RSTRING_PTR(bytes)[5])) << 8) + | (((uint8_t)(RSTRING_PTR(bytes)[6])) << 16) + | (((uint8_t)(RSTRING_PTR(bytes)[7])) << 24); + transfer.l = (hi << 32) | lo; + + return rb_float_new(transfer.f); +} + +VALUE rb_thrift_compact_proto_read_string(VALUE self) { + int64_t size = read_varint64(self); + return READ(self, size); +} + +static void Init_constants() { + thrift_compact_protocol_class = rb_const_get(thrift_module, rb_intern("CompactProtocol")); + + VERSION = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("VERSION"))); + VERSION_MASK = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("VERSION_MASK"))); + TYPE_MASK = rb_num2ll(rb_const_get(thrift_compact_protocol_class, rb_intern("TYPE_MASK"))); + TYPE_SHIFT_AMOUNT = FIX2INT(rb_const_get(thrift_compact_protocol_class, rb_intern("TYPE_SHIFT_AMOUNT"))); + PROTOCOL_ID = FIX2INT(rb_const_get(thrift_compact_protocol_class, rb_intern("PROTOCOL_ID"))); + + last_field_id = rb_intern("@last_field"); + boolean_field_id = rb_intern("@boolean_field"); + bool_value_id = rb_intern("@bool_value"); +} + +static void Init_rb_methods() { + rb_define_method(thrift_compact_protocol_class, "native?", rb_thrift_compact_proto_native_qmark, 0); + + rb_define_method(thrift_compact_protocol_class, "write_message_begin", rb_thrift_compact_proto_write_message_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_field_begin", rb_thrift_compact_proto_write_field_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_field_stop", rb_thrift_compact_proto_write_field_stop, 0); + rb_define_method(thrift_compact_protocol_class, "write_map_begin", rb_thrift_compact_proto_write_map_begin, 3); + rb_define_method(thrift_compact_protocol_class, "write_list_begin", rb_thrift_compact_proto_write_list_begin, 2); + rb_define_method(thrift_compact_protocol_class, "write_set_begin", rb_thrift_compact_proto_write_set_begin, 2); + rb_define_method(thrift_compact_protocol_class, "write_byte", rb_thrift_compact_proto_write_byte, 1); + rb_define_method(thrift_compact_protocol_class, "write_bool", rb_thrift_compact_proto_write_bool, 1); + rb_define_method(thrift_compact_protocol_class, "write_i16", rb_thrift_compact_proto_write_i16, 1); + rb_define_method(thrift_compact_protocol_class, "write_i32", rb_thrift_compact_proto_write_i32, 1); + rb_define_method(thrift_compact_protocol_class, "write_i64", rb_thrift_compact_proto_write_i64, 1); + rb_define_method(thrift_compact_protocol_class, "write_double", rb_thrift_compact_proto_write_double, 1); + rb_define_method(thrift_compact_protocol_class, "write_string", rb_thrift_compact_proto_write_string, 1); + + rb_define_method(thrift_compact_protocol_class, "write_message_end", rb_thrift_compact_proto_write_message_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_struct_begin", rb_thrift_compact_proto_write_struct_begin, 1); + rb_define_method(thrift_compact_protocol_class, "write_struct_end", rb_thrift_compact_proto_write_struct_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_field_end", rb_thrift_compact_proto_write_field_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_map_end", rb_thrift_compact_proto_write_map_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_list_end", rb_thrift_compact_proto_write_list_end, 0); + rb_define_method(thrift_compact_protocol_class, "write_set_end", rb_thrift_compact_proto_write_set_end, 0); + + + rb_define_method(thrift_compact_protocol_class, "read_message_begin", rb_thrift_compact_proto_read_message_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_field_begin", rb_thrift_compact_proto_read_field_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_map_begin", rb_thrift_compact_proto_read_map_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_list_begin", rb_thrift_compact_proto_read_list_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_set_begin", rb_thrift_compact_proto_read_set_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_byte", rb_thrift_compact_proto_read_byte, 0); + rb_define_method(thrift_compact_protocol_class, "read_bool", rb_thrift_compact_proto_read_bool, 0); + rb_define_method(thrift_compact_protocol_class, "read_i16", rb_thrift_compact_proto_read_i16, 0); + rb_define_method(thrift_compact_protocol_class, "read_i32", rb_thrift_compact_proto_read_i32, 0); + rb_define_method(thrift_compact_protocol_class, "read_i64", rb_thrift_compact_proto_read_i64, 0); + rb_define_method(thrift_compact_protocol_class, "read_double", rb_thrift_compact_proto_read_double, 0); + rb_define_method(thrift_compact_protocol_class, "read_string", rb_thrift_compact_proto_read_string, 0); + + rb_define_method(thrift_compact_protocol_class, "read_message_end", rb_thrift_compact_proto_read_message_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_struct_begin", rb_thrift_compact_proto_read_struct_begin, 0); + rb_define_method(thrift_compact_protocol_class, "read_struct_end", rb_thrift_compact_proto_read_struct_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_field_end", rb_thrift_compact_proto_read_field_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_map_end", rb_thrift_compact_proto_read_map_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_list_end", rb_thrift_compact_proto_read_list_end, 0); + rb_define_method(thrift_compact_protocol_class, "read_set_end", rb_thrift_compact_proto_read_set_end, 0); +} + +static void Init_npmt() { + native_proto_method_table *npmt; + npmt = ALLOC(native_proto_method_table); + + npmt->write_field_begin = rb_thrift_compact_proto_write_field_begin; + npmt->write_field_stop = rb_thrift_compact_proto_write_field_stop; + npmt->write_map_begin = rb_thrift_compact_proto_write_map_begin; + npmt->write_list_begin = rb_thrift_compact_proto_write_list_begin; + npmt->write_set_begin = rb_thrift_compact_proto_write_set_begin; + npmt->write_byte = rb_thrift_compact_proto_write_byte; + npmt->write_bool = rb_thrift_compact_proto_write_bool; + npmt->write_i16 = rb_thrift_compact_proto_write_i16; + npmt->write_i32 = rb_thrift_compact_proto_write_i32; + npmt->write_i64 = rb_thrift_compact_proto_write_i64; + npmt->write_double = rb_thrift_compact_proto_write_double; + npmt->write_string = rb_thrift_compact_proto_write_string; + npmt->write_message_end = rb_thrift_compact_proto_write_message_end; + npmt->write_struct_begin = rb_thrift_compact_proto_write_struct_begin; + npmt->write_struct_end = rb_thrift_compact_proto_write_struct_end; + npmt->write_field_end = rb_thrift_compact_proto_write_field_end; + npmt->write_map_end = rb_thrift_compact_proto_write_map_end; + npmt->write_list_end = rb_thrift_compact_proto_write_list_end; + npmt->write_set_end = rb_thrift_compact_proto_write_set_end; + + npmt->read_message_begin = rb_thrift_compact_proto_read_message_begin; + npmt->read_field_begin = rb_thrift_compact_proto_read_field_begin; + npmt->read_map_begin = rb_thrift_compact_proto_read_map_begin; + npmt->read_list_begin = rb_thrift_compact_proto_read_list_begin; + npmt->read_set_begin = rb_thrift_compact_proto_read_set_begin; + npmt->read_byte = rb_thrift_compact_proto_read_byte; + npmt->read_bool = rb_thrift_compact_proto_read_bool; + npmt->read_i16 = rb_thrift_compact_proto_read_i16; + npmt->read_i32 = rb_thrift_compact_proto_read_i32; + npmt->read_i64 = rb_thrift_compact_proto_read_i64; + npmt->read_double = rb_thrift_compact_proto_read_double; + npmt->read_string = rb_thrift_compact_proto_read_string; + npmt->read_message_end = rb_thrift_compact_proto_read_message_end; + npmt->read_struct_begin = rb_thrift_compact_proto_read_struct_begin; + npmt->read_struct_end = rb_thrift_compact_proto_read_struct_end; + npmt->read_field_end = rb_thrift_compact_proto_read_field_end; + npmt->read_map_end = rb_thrift_compact_proto_read_map_end; + npmt->read_list_end = rb_thrift_compact_proto_read_list_end; + npmt->read_set_end = rb_thrift_compact_proto_read_set_end; + + VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + rb_const_set(thrift_compact_protocol_class, rb_intern("@native_method_table"), method_table_object); +} + + + +void Init_compact_protocol() { + Init_constants(); + Init_rb_methods(); + Init_npmt(); +} diff --git a/lib/rb/ext/compact_protocol.h b/lib/rb/ext/compact_protocol.h new file mode 100644 index 00000000..163915e9 --- /dev/null +++ b/lib/rb/ext/compact_protocol.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_compact_protocol(); diff --git a/lib/rb/ext/constants.h b/lib/rb/ext/constants.h new file mode 100644 index 00000000..57df544b --- /dev/null +++ b/lib/rb/ext/constants.h @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern int TTYPE_STOP; +extern int TTYPE_BOOL; +extern int TTYPE_BYTE; +extern int TTYPE_I16; +extern int TTYPE_I32; +extern int TTYPE_I64; +extern int TTYPE_DOUBLE; +extern int TTYPE_STRING; +extern int TTYPE_MAP; +extern int TTYPE_SET; +extern int TTYPE_LIST; +extern int TTYPE_STRUCT; + +extern ID validate_method_id; +extern ID write_struct_begin_method_id; +extern ID write_struct_end_method_id; +extern ID write_field_begin_method_id; +extern ID write_field_end_method_id; +extern ID write_boolean_method_id; +extern ID write_byte_method_id; +extern ID write_i16_method_id; +extern ID write_i32_method_id; +extern ID write_i64_method_id; +extern ID write_double_method_id; +extern ID write_string_method_id; +extern ID write_map_begin_method_id; +extern ID write_map_end_method_id; +extern ID write_list_begin_method_id; +extern ID write_list_end_method_id; +extern ID write_set_begin_method_id; +extern ID write_set_end_method_id; +extern ID size_method_id; +extern ID read_bool_method_id; +extern ID read_byte_method_id; +extern ID read_i16_method_id; +extern ID read_i32_method_id; +extern ID read_i64_method_id; +extern ID read_string_method_id; +extern ID read_double_method_id; +extern ID read_map_begin_method_id; +extern ID read_map_end_method_id; +extern ID read_list_begin_method_id; +extern ID read_list_end_method_id; +extern ID read_set_begin_method_id; +extern ID read_set_end_method_id; +extern ID read_struct_begin_method_id; +extern ID read_struct_end_method_id; +extern ID read_field_begin_method_id; +extern ID read_field_end_method_id; +extern ID keys_method_id; +extern ID entries_method_id; +extern ID name_method_id; +extern ID sort_method_id; +extern ID write_field_stop_method_id; +extern ID skip_method_id; +extern ID write_method_id; +extern ID read_all_method_id; +extern ID native_qmark_method_id; + +extern ID fields_const_id; +extern ID transport_ivar_id; +extern ID strict_read_ivar_id; +extern ID strict_write_ivar_id; + +extern VALUE type_sym; +extern VALUE name_sym; +extern VALUE key_sym; +extern VALUE value_sym; +extern VALUE element_sym; +extern VALUE class_sym; + +extern VALUE rb_cSet; +extern VALUE thrift_module; +extern VALUE thrift_types_module; +extern VALUE class_thrift_protocol; +extern VALUE protocol_exception_class; diff --git a/lib/rb/ext/extconf.rb b/lib/rb/ext/extconf.rb new file mode 100644 index 00000000..07558b8a --- /dev/null +++ b/lib/rb/ext/extconf.rb @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'mkmf' + +$CFLAGS = "-g -O2 -Wall -Werror" + +have_func("strlcpy", "string.h") + +create_makefile 'thrift_native' diff --git a/lib/rb/ext/macros.h b/lib/rb/ext/macros.h new file mode 100644 index 00000000..265f6930 --- /dev/null +++ b/lib/rb/ext/macros.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define GET_TRANSPORT(obj) rb_ivar_get(obj, transport_ivar_id) +#define GET_STRICT_READ(obj) rb_ivar_get(obj, strict_read_ivar_id) +#define GET_STRICT_WRITE(obj) rb_ivar_get(obj, strict_write_ivar_id) +#define WRITE(obj, data, length) rb_funcall(obj, write_method_id, 1, rb_str_new(data, length)) +#define CHECK_NIL(obj) if (NIL_P(obj)) { rb_raise(rb_eStandardError, "nil argument not allowed!");} +#define READ(obj, length) rb_funcall(GET_TRANSPORT(obj), read_all_method_id, 1, INT2FIX(length)) + +#ifndef RFLOAT_VALUE +# define RFLOAT_VALUE(v) RFLOAT(rb_Float(v))->value +#endif + +#ifndef RSTRING_LEN +# define RSTRING_LEN(v) RSTRING(rb_String(v))->len +#endif + +#ifndef RSTRING_PTR +# define RSTRING_PTR(v) RSTRING(rb_String(v))->ptr +#endif + +#ifndef RARRAY_LEN +# define RARRAY_LEN(v) RARRAY(rb_Array(v))->len +#endif diff --git a/lib/rb/ext/memory_buffer.c b/lib/rb/ext/memory_buffer.c new file mode 100644 index 00000000..624012d4 --- /dev/null +++ b/lib/rb/ext/memory_buffer.c @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include "macros.h" + +ID buf_ivar_id; +ID index_ivar_id; + +ID slice_method_id; + +int GARBAGE_BUFFER_SIZE; + +#define GET_BUF(self) rb_ivar_get(self, buf_ivar_id) + +VALUE rb_thrift_memory_buffer_write(VALUE self, VALUE str) { + VALUE buf = GET_BUF(self); + rb_str_buf_cat(buf, RSTRING_PTR(str), RSTRING_LEN(str)); + return Qnil; +} + +VALUE rb_thrift_memory_buffer_read(VALUE self, VALUE length_value) { + int length = FIX2INT(length_value); + + VALUE index_value = rb_ivar_get(self, index_ivar_id); + int index = FIX2INT(index_value); + + VALUE buf = GET_BUF(self); + VALUE data = rb_funcall(buf, slice_method_id, 2, index_value, length_value); + + index += length; + if (index > RSTRING_LEN(buf)) { + index = RSTRING_LEN(buf); + } + if (index >= GARBAGE_BUFFER_SIZE) { + rb_ivar_set(self, buf_ivar_id, rb_funcall(buf, slice_method_id, 2, INT2FIX(index), INT2FIX(RSTRING_LEN(buf) - 1))); + index = 0; + } + + rb_ivar_set(self, index_ivar_id, INT2FIX(index)); + return data; +} + +void Init_memory_buffer() { + VALUE thrift_memory_buffer_class = rb_const_get(thrift_module, rb_intern("MemoryBufferTransport")); + rb_define_method(thrift_memory_buffer_class, "write", rb_thrift_memory_buffer_write, 1); + rb_define_method(thrift_memory_buffer_class, "read", rb_thrift_memory_buffer_read, 1); + + buf_ivar_id = rb_intern("@buf"); + index_ivar_id = rb_intern("@index"); + + slice_method_id = rb_intern("slice"); + + GARBAGE_BUFFER_SIZE = FIX2INT(rb_const_get(thrift_memory_buffer_class, rb_intern("GARBAGE_BUFFER_SIZE"))); +} diff --git a/lib/rb/ext/memory_buffer.h b/lib/rb/ext/memory_buffer.h new file mode 100644 index 00000000..b277fa6f --- /dev/null +++ b/lib/rb/ext/memory_buffer.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_memory_buffer(); diff --git a/lib/rb/ext/protocol.c b/lib/rb/ext/protocol.c new file mode 100644 index 00000000..c1876541 --- /dev/null +++ b/lib/rb/ext/protocol.c @@ -0,0 +1,185 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +static VALUE skip(VALUE self, int ttype) { + if (ttype == TTYPE_STOP) { + return Qnil; + } else if (ttype == TTYPE_BOOL) { + rb_funcall(self, read_bool_method_id, 0); + } else if (ttype == TTYPE_BYTE) { + rb_funcall(self, read_byte_method_id, 0); + } else if (ttype == TTYPE_I16) { + rb_funcall(self, read_i16_method_id, 0); + } else if (ttype == TTYPE_I32) { + rb_funcall(self, read_i32_method_id, 0); + } else if (ttype == TTYPE_I64) { + rb_funcall(self, read_i64_method_id, 0); + } else if (ttype == TTYPE_DOUBLE) { + rb_funcall(self, read_double_method_id, 0); + } else if (ttype == TTYPE_STRING) { + rb_funcall(self, read_string_method_id, 0); + } else if (ttype == TTYPE_STRUCT) { + rb_funcall(self, read_struct_begin_method_id, 0); + while (true) { + VALUE field_header = rb_funcall(self, read_field_begin_method_id, 0); + if (NIL_P(field_header) || FIX2INT(rb_ary_entry(field_header, 1)) == TTYPE_STOP ) { + break; + } + skip(self, FIX2INT(rb_ary_entry(field_header, 1))); + rb_funcall(self, read_field_end_method_id, 0); + } + rb_funcall(self, read_struct_end_method_id, 0); + } else if (ttype == TTYPE_MAP) { + int i; + VALUE map_header = rb_funcall(self, read_map_begin_method_id, 0); + int ktype = FIX2INT(rb_ary_entry(map_header, 0)); + int vtype = FIX2INT(rb_ary_entry(map_header, 1)); + int size = FIX2INT(rb_ary_entry(map_header, 2)); + + for (i = 0; i < size; i++) { + skip(self, ktype); + skip(self, vtype); + } + rb_funcall(self, read_map_end_method_id, 0); + } else if (ttype == TTYPE_LIST || ttype == TTYPE_SET) { + int i; + VALUE collection_header = rb_funcall(self, ttype == TTYPE_LIST ? read_list_begin_method_id : read_set_begin_method_id, 0); + int etype = FIX2INT(rb_ary_entry(collection_header, 0)); + int size = FIX2INT(rb_ary_entry(collection_header, 1)); + for (i = 0; i < size; i++) { + skip(self, etype); + } + rb_funcall(self, ttype == TTYPE_LIST ? read_list_end_method_id : read_set_end_method_id, 0); + } else { + rb_raise(rb_eNotImpError, "don't know how to skip type %d", ttype); + } + + return Qnil; +} + +VALUE rb_thrift_protocol_native_qmark(VALUE self) { + return Qfalse; +} + +VALUE rb_thrift_protocol_skip(VALUE protocol, VALUE ttype) { + return skip(protocol, FIX2INT(ttype)); +} + +VALUE rb_thrift_write_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_struct_begin(VALUE self, VALUE name) { + return Qnil; +} + +VALUE rb_thrift_write_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_write_set_end(VALUE self) { + return Qnil; +} + +VALUE rb_thrift_read_message_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_struct_begin(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_struct_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_field_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_map_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_list_end(VALUE self) { + return Qnil; +} + +VALUE rb_thift_read_set_end(VALUE self) { + return Qnil; +} + +void Init_protocol() { + VALUE c_protocol = rb_const_get(thrift_module, rb_intern("BaseProtocol")); + + rb_define_method(c_protocol, "skip", rb_thrift_protocol_skip, 1); + rb_define_method(c_protocol, "write_message_end", rb_thrift_write_message_end, 0); + rb_define_method(c_protocol, "write_struct_begin", rb_thrift_write_struct_begin, 1); + rb_define_method(c_protocol, "write_struct_end", rb_thrift_write_struct_end, 0); + rb_define_method(c_protocol, "write_field_end", rb_thrift_write_field_end, 0); + rb_define_method(c_protocol, "write_map_end", rb_thrift_write_map_end, 0); + rb_define_method(c_protocol, "write_list_end", rb_thrift_write_list_end, 0); + rb_define_method(c_protocol, "write_set_end", rb_thrift_write_set_end, 0); + rb_define_method(c_protocol, "read_message_end", rb_thrift_read_message_end, 0); + rb_define_method(c_protocol, "read_struct_begin", rb_thift_read_struct_begin, 0); + rb_define_method(c_protocol, "read_struct_end", rb_thift_read_struct_end, 0); + rb_define_method(c_protocol, "read_field_end", rb_thift_read_field_end, 0); + rb_define_method(c_protocol, "read_map_end", rb_thift_read_map_end, 0); + rb_define_method(c_protocol, "read_list_end", rb_thift_read_list_end, 0); + rb_define_method(c_protocol, "read_set_end", rb_thift_read_set_end, 0); + rb_define_method(c_protocol, "native?", rb_thrift_protocol_native_qmark, 0); + + // native_proto_method_table *npmt; + // npmt = ALLOC(native_proto_method_table); + // npmt->write_message_end = rb_thrift_write_message_end; + // npmt->write_struct_begin = rb_thrift_write_struct_begin; + // npmt->write_struct_end = rb_thrift_write_struct_end; + // npmt->write_field_end = rb_thrift_write_field_end; + // npmt->write_map_end = rb_thrift_write_map_end; + // npmt->write_list_end = rb_thrift_write_list_end; + // npmt->write_set_end = rb_thrift_write_set_end; + // npmt->read_message_end = rb_thrift_read_message_end; + // npmt->read_struct_begin = rb_thift_read_struct_begin; + // npmt->read_struct_end = rb_thift_read_struct_end; + // npmt->read_field_end = rb_thift_read_field_end; + // npmt->read_map_end = rb_thift_read_map_end; + // npmt->read_list_end = rb_thift_read_list_end; + // npmt->read_set_end = rb_thift_read_set_end; + // + // VALUE method_table_object = Data_Wrap_Struct(rb_cObject, 0, free, npmt); + // rb_const_set(c_protocol, rb_intern("@native_method_table"), method_table_object); +} diff --git a/lib/rb/ext/protocol.h b/lib/rb/ext/protocol.h new file mode 100644 index 00000000..53695303 --- /dev/null +++ b/lib/rb/ext/protocol.h @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +void Init_protocol(); diff --git a/lib/rb/ext/struct.c b/lib/rb/ext/struct.c new file mode 100644 index 00000000..fee285e9 --- /dev/null +++ b/lib/rb/ext/struct.c @@ -0,0 +1,605 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include "macros.h" + +#ifndef HAVE_STRLCPY + +static +size_t +strlcpy (char *dst, const char *src, size_t dst_sz) +{ + size_t n; + + for (n = 0; n < dst_sz; n++) { + if ((*dst++ = *src++) == '\0') + break; + } + + if (n < dst_sz) + return n; + if (n > 0) + *(dst - 1) = '\0'; + return n + strlen (src); +} + +#endif + +static native_proto_method_table *mt; +static native_proto_method_table *default_mt; +static VALUE last_proto_class = Qnil; + +#define IS_CONTAINER(ttype) ((ttype) == TTYPE_MAP || (ttype) == TTYPE_LIST || (ttype) == TTYPE_SET) +#define STRUCT_FIELDS(obj) rb_const_get(CLASS_OF(obj), fields_const_id) + +static void set_native_proto_function_pointers(VALUE protocol) { + VALUE method_table_object = rb_const_get(CLASS_OF(protocol), rb_intern("@native_method_table")); + // TODO: check nil? + Data_Get_Struct(method_table_object, native_proto_method_table, mt); +} + +static void check_native_proto_method_table(VALUE protocol) { + VALUE protoclass = CLASS_OF(protocol); + if (protoclass != last_proto_class) { + last_proto_class = protoclass; + if (rb_funcall(protocol, native_qmark_method_id, 0) == Qtrue) { + set_native_proto_function_pointers(protocol); + } else { + mt = default_mt; + } + } +} + +//------------------------------------------- +// Writing section +//------------------------------------------- + +// default fn pointers for protocol stuff here + +VALUE default_write_bool(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_boolean_method_id, 1, value); + return Qnil; +} + +VALUE default_write_byte(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_byte_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i16(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i16_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i32(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i32_method_id, 1, value); + return Qnil; +} + +VALUE default_write_i64(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_i64_method_id, 1, value); + return Qnil; +} + +VALUE default_write_double(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_double_method_id, 1, value); + return Qnil; +} + +VALUE default_write_string(VALUE protocol, VALUE value) { + rb_funcall(protocol, write_string_method_id, 1, value); + return Qnil; +} + +VALUE default_write_list_begin(VALUE protocol, VALUE etype, VALUE length) { + rb_funcall(protocol, write_list_begin_method_id, 2, etype, length); + return Qnil; +} + +VALUE default_write_list_end(VALUE protocol) { + rb_funcall(protocol, write_list_end_method_id, 0); + return Qnil; +} + +VALUE default_write_set_begin(VALUE protocol, VALUE etype, VALUE length) { + rb_funcall(protocol, write_set_begin_method_id, 2, etype, length); + return Qnil; +} + +VALUE default_write_set_end(VALUE protocol) { + rb_funcall(protocol, write_set_end_method_id, 0); + return Qnil; +} + +VALUE default_write_map_begin(VALUE protocol, VALUE ktype, VALUE vtype, VALUE length) { + rb_funcall(protocol, write_map_begin_method_id, 3, ktype, vtype, length); + return Qnil; +} + +VALUE default_write_map_end(VALUE protocol) { + rb_funcall(protocol, write_map_end_method_id, 0); + return Qnil; +} + +VALUE default_write_struct_begin(VALUE protocol, VALUE struct_name) { + rb_funcall(protocol, write_struct_begin_method_id, 1, struct_name); + return Qnil; +} + +VALUE default_write_struct_end(VALUE protocol) { + rb_funcall(protocol, write_struct_end_method_id, 0); + return Qnil; +} + +VALUE default_write_field_begin(VALUE protocol, VALUE name, VALUE type, VALUE id) { + rb_funcall(protocol, write_field_begin_method_id, 3, name, type, id); + return Qnil; +} + +VALUE default_write_field_end(VALUE protocol) { + rb_funcall(protocol, write_field_end_method_id, 0); + return Qnil; +} + +VALUE default_write_field_stop(VALUE protocol) { + rb_funcall(protocol, write_field_stop_method_id, 0); + return Qnil; +} + +VALUE default_read_field_begin(VALUE protocol) { + return rb_funcall(protocol, read_field_begin_method_id, 0); +} + +VALUE default_read_field_end(VALUE protocol) { + return rb_funcall(protocol, read_field_end_method_id, 0); +} + +VALUE default_read_map_begin(VALUE protocol) { + return rb_funcall(protocol, read_map_begin_method_id, 0); +} + +VALUE default_read_map_end(VALUE protocol) { + return rb_funcall(protocol, read_map_end_method_id, 0); +} + +VALUE default_read_list_begin(VALUE protocol) { + return rb_funcall(protocol, read_list_begin_method_id, 0); +} + +VALUE default_read_list_end(VALUE protocol) { + return rb_funcall(protocol, read_list_end_method_id, 0); +} + +VALUE default_read_set_begin(VALUE protocol) { + return rb_funcall(protocol, read_set_begin_method_id, 0); +} + +VALUE default_read_set_end(VALUE protocol) { + return rb_funcall(protocol, read_set_end_method_id, 0); +} + +VALUE default_read_byte(VALUE protocol) { + return rb_funcall(protocol, read_byte_method_id, 0); +} + +VALUE default_read_bool(VALUE protocol) { + return rb_funcall(protocol, read_bool_method_id, 0); +} + +VALUE default_read_i16(VALUE protocol) { + return rb_funcall(protocol, read_i16_method_id, 0); +} + +VALUE default_read_i32(VALUE protocol) { + return rb_funcall(protocol, read_i32_method_id, 0); +} + +VALUE default_read_i64(VALUE protocol) { + return rb_funcall(protocol, read_i64_method_id, 0); +} + +VALUE default_read_double(VALUE protocol) { + return rb_funcall(protocol, read_double_method_id, 0); +} + +VALUE default_read_string(VALUE protocol) { + return rb_funcall(protocol, read_string_method_id, 0); +} + +VALUE default_read_struct_begin(VALUE protocol) { + return rb_funcall(protocol, read_struct_begin_method_id, 0); +} + +VALUE default_read_struct_end(VALUE protocol) { + return rb_funcall(protocol, read_struct_end_method_id, 0); +} + +static void set_default_proto_function_pointers() { + default_mt = ALLOC(native_proto_method_table); + + default_mt->write_field_begin = default_write_field_begin; + default_mt->write_field_stop = default_write_field_stop; + default_mt->write_map_begin = default_write_map_begin; + default_mt->write_map_end = default_write_map_end; + default_mt->write_list_begin = default_write_list_begin; + default_mt->write_list_end = default_write_list_end; + default_mt->write_set_begin = default_write_set_begin; + default_mt->write_set_end = default_write_set_end; + default_mt->write_byte = default_write_byte; + default_mt->write_bool = default_write_bool; + default_mt->write_i16 = default_write_i16; + default_mt->write_i32 = default_write_i32; + default_mt->write_i64 = default_write_i64; + default_mt->write_double = default_write_double; + default_mt->write_string = default_write_string; + default_mt->write_struct_begin = default_write_struct_begin; + default_mt->write_struct_end = default_write_struct_end; + default_mt->write_field_end = default_write_field_end; + + default_mt->read_struct_begin = default_read_struct_begin; + default_mt->read_struct_end = default_read_struct_end; + default_mt->read_field_begin = default_read_field_begin; + default_mt->read_field_end = default_read_field_end; + default_mt->read_map_begin = default_read_map_begin; + default_mt->read_map_end = default_read_map_end; + default_mt->read_list_begin = default_read_list_begin; + default_mt->read_list_end = default_read_list_end; + default_mt->read_set_begin = default_read_set_begin; + default_mt->read_set_end = default_read_set_end; + default_mt->read_byte = default_read_byte; + default_mt->read_bool = default_read_bool; + default_mt->read_i16 = default_read_i16; + default_mt->read_i32 = default_read_i32; + default_mt->read_i64 = default_read_i64; + default_mt->read_double = default_read_double; + default_mt->read_string = default_read_string; +} + +// end default protocol methods + + +static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol); +static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info); + +VALUE get_field_value(VALUE obj, VALUE field_name) { + char name_buf[RSTRING_LEN(field_name) + 1]; + + name_buf[0] = '@'; + strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf)); + + VALUE value = rb_ivar_get(obj, rb_intern(name_buf)); + + return value; +} + +static void write_container(int ttype, VALUE field_info, VALUE value, VALUE protocol) { + int sz, i; + + if (ttype == TTYPE_MAP) { + VALUE keys; + VALUE key; + VALUE val; + + Check_Type(value, T_HASH); + + VALUE key_info = rb_hash_aref(field_info, key_sym); + VALUE keytype_value = rb_hash_aref(key_info, type_sym); + int keytype = FIX2INT(keytype_value); + + VALUE value_info = rb_hash_aref(field_info, value_sym); + VALUE valuetype_value = rb_hash_aref(value_info, type_sym); + int valuetype = FIX2INT(valuetype_value); + + keys = rb_funcall(value, keys_method_id, 0); + + sz = RARRAY_LEN(keys); + + mt->write_map_begin(protocol, keytype_value, valuetype_value, INT2FIX(sz)); + + for (i = 0; i < sz; i++) { + key = rb_ary_entry(keys, i); + val = rb_hash_aref(value, key); + + if (IS_CONTAINER(keytype)) { + write_container(keytype, key_info, key, protocol); + } else { + write_anything(keytype, key, protocol, key_info); + } + + if (IS_CONTAINER(valuetype)) { + write_container(valuetype, value_info, val, protocol); + } else { + write_anything(valuetype, val, protocol, value_info); + } + } + + mt->write_map_end(protocol); + } else if (ttype == TTYPE_LIST) { + Check_Type(value, T_ARRAY); + + sz = RARRAY_LEN(value); + + VALUE element_type_info = rb_hash_aref(field_info, element_sym); + VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); + int element_type = FIX2INT(element_type_value); + + mt->write_list_begin(protocol, element_type_value, INT2FIX(sz)); + for (i = 0; i < sz; ++i) { + VALUE val = rb_ary_entry(value, i); + if (IS_CONTAINER(element_type)) { + write_container(element_type, element_type_info, val, protocol); + } else { + write_anything(element_type, val, protocol, element_type_info); + } + } + mt->write_list_end(protocol); + } else if (ttype == TTYPE_SET) { + VALUE items; + + if (TYPE(value) == T_ARRAY) { + items = value; + } else { + if (rb_cSet == CLASS_OF(value)) { + items = rb_funcall(value, entries_method_id, 0); + } else { + Check_Type(value, T_HASH); + items = rb_funcall(value, keys_method_id, 0); + } + } + + sz = RARRAY_LEN(items); + + VALUE element_type_info = rb_hash_aref(field_info, element_sym); + VALUE element_type_value = rb_hash_aref(element_type_info, type_sym); + int element_type = FIX2INT(element_type_value); + + mt->write_set_begin(protocol, element_type_value, INT2FIX(sz)); + + for (i = 0; i < sz; i++) { + VALUE val = rb_ary_entry(items, i); + if (IS_CONTAINER(element_type)) { + write_container(element_type, element_type_info, val, protocol); + } else { + write_anything(element_type, val, protocol, element_type_info); + } + } + + mt->write_set_end(protocol); + } else { + rb_raise(rb_eNotImpError, "can't write container of type: %d", ttype); + } +} + +static void write_anything(int ttype, VALUE value, VALUE protocol, VALUE field_info) { + if (ttype == TTYPE_BOOL) { + mt->write_bool(protocol, value); + } else if (ttype == TTYPE_BYTE) { + mt->write_byte(protocol, value); + } else if (ttype == TTYPE_I16) { + mt->write_i16(protocol, value); + } else if (ttype == TTYPE_I32) { + mt->write_i32(protocol, value); + } else if (ttype == TTYPE_I64) { + mt->write_i64(protocol, value); + } else if (ttype == TTYPE_DOUBLE) { + mt->write_double(protocol, value); + } else if (ttype == TTYPE_STRING) { + mt->write_string(protocol, value); + } else if (IS_CONTAINER(ttype)) { + write_container(ttype, field_info, value, protocol); + } else if (ttype == TTYPE_STRUCT) { + rb_thrift_struct_write(value, protocol); + } else { + rb_raise(rb_eNotImpError, "Unknown type for binary_encoding: %d", ttype); + } +} + +static VALUE rb_thrift_struct_write(VALUE self, VALUE protocol) { + // call validate + rb_funcall(self, validate_method_id, 0); + + check_native_proto_method_table(protocol); + + // write struct begin + mt->write_struct_begin(protocol, rb_class_name(CLASS_OF(self))); + + // iterate through all the fields here + VALUE struct_fields = STRUCT_FIELDS(self); + VALUE struct_field_ids_unordered = rb_funcall(struct_fields, keys_method_id, 0); + VALUE struct_field_ids_ordered = rb_funcall(struct_field_ids_unordered, sort_method_id, 0); + + int i = 0; + for (i=0; i < RARRAY_LEN(struct_field_ids_ordered); i++) { + VALUE field_id = rb_ary_entry(struct_field_ids_ordered, i); + VALUE field_info = rb_hash_aref(struct_fields, field_id); + + VALUE ttype_value = rb_hash_aref(field_info, type_sym); + int ttype = FIX2INT(ttype_value); + VALUE field_name = rb_hash_aref(field_info, name_sym); + VALUE field_value = get_field_value(self, field_name); + + if (!NIL_P(field_value)) { + mt->write_field_begin(protocol, field_name, ttype_value, field_id); + + write_anything(ttype, field_value, protocol, field_info); + + mt->write_field_end(protocol); + } + } + + mt->write_field_stop(protocol); + + // write struct end + mt->write_struct_end(protocol); + + return Qnil; +} + +//------------------------------------------- +// Reading section +//------------------------------------------- + +static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol); + +static void set_field_value(VALUE obj, VALUE field_name, VALUE value) { + char name_buf[RSTRING_LEN(field_name) + 1]; + + name_buf[0] = '@'; + strlcpy(&name_buf[1], RSTRING_PTR(field_name), sizeof(name_buf)); + + rb_ivar_set(obj, rb_intern(name_buf), value); +} + +static VALUE read_anything(VALUE protocol, int ttype, VALUE field_info) { + VALUE result = Qnil; + + if (ttype == TTYPE_BOOL) { + result = mt->read_bool(protocol); + } else if (ttype == TTYPE_BYTE) { + result = mt->read_byte(protocol); + } else if (ttype == TTYPE_I16) { + result = mt->read_i16(protocol); + } else if (ttype == TTYPE_I32) { + result = mt->read_i32(protocol); + } else if (ttype == TTYPE_I64) { + result = mt->read_i64(protocol); + } else if (ttype == TTYPE_STRING) { + result = mt->read_string(protocol); + } else if (ttype == TTYPE_DOUBLE) { + result = mt->read_double(protocol); + } else if (ttype == TTYPE_STRUCT) { + VALUE klass = rb_hash_aref(field_info, class_sym); + result = rb_class_new_instance(0, NULL, klass); + rb_thrift_struct_read(result, protocol); + } else if (ttype == TTYPE_MAP) { + int i; + + VALUE map_header = mt->read_map_begin(protocol); + int key_ttype = FIX2INT(rb_ary_entry(map_header, 0)); + int value_ttype = FIX2INT(rb_ary_entry(map_header, 1)); + int num_entries = FIX2INT(rb_ary_entry(map_header, 2)); + + VALUE key_info = rb_hash_aref(field_info, key_sym); + VALUE value_info = rb_hash_aref(field_info, value_sym); + + result = rb_hash_new(); + + for (i = 0; i < num_entries; ++i) { + VALUE key, val; + + key = read_anything(protocol, key_ttype, key_info); + val = read_anything(protocol, value_ttype, value_info); + + rb_hash_aset(result, key, val); + } + + mt->read_map_end(protocol); + } else if (ttype == TTYPE_LIST) { + int i; + + VALUE list_header = mt->read_list_begin(protocol); + int element_ttype = FIX2INT(rb_ary_entry(list_header, 0)); + int num_elements = FIX2INT(rb_ary_entry(list_header, 1)); + result = rb_ary_new2(num_elements); + + for (i = 0; i < num_elements; ++i) { + rb_ary_push(result, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); + } + + + mt->read_list_end(protocol); + } else if (ttype == TTYPE_SET) { + VALUE items; + int i; + + VALUE set_header = mt->read_set_begin(protocol); + int element_ttype = FIX2INT(rb_ary_entry(set_header, 0)); + int num_elements = FIX2INT(rb_ary_entry(set_header, 1)); + items = rb_ary_new2(num_elements); + + for (i = 0; i < num_elements; ++i) { + rb_ary_push(items, read_anything(protocol, element_ttype, rb_hash_aref(field_info, element_sym))); + } + + + mt->read_set_end(protocol); + + result = rb_class_new_instance(1, &items, rb_cSet); + } else { + rb_raise(rb_eNotImpError, "read_anything not implemented for type %d!", ttype); + } + + return result; +} + +static VALUE rb_thrift_struct_read(VALUE self, VALUE protocol) { + check_native_proto_method_table(protocol); + + // read struct begin + mt->read_struct_begin(protocol); + + VALUE struct_fields = STRUCT_FIELDS(self); + + // read each field + while (true) { + VALUE field_header = mt->read_field_begin(protocol); + VALUE field_type_value = rb_ary_entry(field_header, 1); + int field_type = FIX2INT(field_type_value); + + if (field_type == TTYPE_STOP) { + break; + } + + // make sure we got a type we expected + VALUE field_info = rb_hash_aref(struct_fields, rb_ary_entry(field_header, 2)); + + if (!NIL_P(field_info)) { + int specified_type = FIX2INT(rb_hash_aref(field_info, type_sym)); + if (field_type == specified_type) { + // read the value + VALUE name = rb_hash_aref(field_info, name_sym); + set_field_value(self, name, read_anything(protocol, field_type, field_info)); + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + } else { + rb_funcall(protocol, skip_method_id, 1, field_type_value); + } + + // read field end + mt->read_field_end(protocol); + } + + // read struct end + mt->read_struct_end(protocol); + + return Qnil; +} + +void Init_struct() { + VALUE struct_module = rb_const_get(thrift_module, rb_intern("Struct")); + + rb_define_method(struct_module, "write", rb_thrift_struct_write, 1); + rb_define_method(struct_module, "read", rb_thrift_struct_read, 1); + + set_default_proto_function_pointers(); +} + diff --git a/lib/rb/ext/struct.h b/lib/rb/ext/struct.h new file mode 100644 index 00000000..37b1b35b --- /dev/null +++ b/lib/rb/ext/struct.h @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +typedef struct native_proto_method_table { + VALUE (*write_bool)(VALUE, VALUE); + VALUE (*write_byte)(VALUE, VALUE); + VALUE (*write_i16)(VALUE, VALUE); + VALUE (*write_i32)(VALUE, VALUE); + VALUE (*write_i64)(VALUE, VALUE); + VALUE (*write_double)(VALUE, VALUE); + VALUE (*write_string)(VALUE, VALUE); + VALUE (*write_list_begin)(VALUE, VALUE, VALUE); + VALUE (*write_list_end)(VALUE); + VALUE (*write_set_begin)(VALUE, VALUE, VALUE); + VALUE (*write_set_end)(VALUE); + VALUE (*write_map_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_map_end)(VALUE); + VALUE (*write_struct_begin)(VALUE, VALUE); + VALUE (*write_struct_end)(VALUE); + VALUE (*write_field_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_field_end)(VALUE); + VALUE (*write_field_stop)(VALUE); + VALUE (*write_message_begin)(VALUE, VALUE, VALUE, VALUE); + VALUE (*write_message_end)(VALUE); + + VALUE (*read_message_begin)(VALUE); + VALUE (*read_message_end)(VALUE); + VALUE (*read_field_begin)(VALUE); + VALUE (*read_field_end)(VALUE); + VALUE (*read_map_begin)(VALUE); + VALUE (*read_map_end)(VALUE); + VALUE (*read_list_begin)(VALUE); + VALUE (*read_list_end)(VALUE); + VALUE (*read_set_begin)(VALUE); + VALUE (*read_set_end)(VALUE); + VALUE (*read_byte)(VALUE); + VALUE (*read_bool)(VALUE); + VALUE (*read_i16)(VALUE); + VALUE (*read_i32)(VALUE); + VALUE (*read_i64)(VALUE); + VALUE (*read_double)(VALUE); + VALUE (*read_string)(VALUE); + VALUE (*read_struct_begin)(VALUE); + VALUE (*read_struct_end)(VALUE); + +} native_proto_method_table; + +void Init_struct(); diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c new file mode 100644 index 00000000..effa202c --- /dev/null +++ b/lib/rb/ext/thrift_native.c @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +// cached classes/modules +VALUE rb_cSet; +VALUE thrift_module; +VALUE thrift_types_module; + +// TType constants +int TTYPE_STOP; +int TTYPE_BOOL; +int TTYPE_BYTE; +int TTYPE_I16; +int TTYPE_I32; +int TTYPE_I64; +int TTYPE_DOUBLE; +int TTYPE_STRING; +int TTYPE_MAP; +int TTYPE_SET; +int TTYPE_LIST; +int TTYPE_STRUCT; + +// method ids +ID validate_method_id; +ID write_struct_begin_method_id; +ID write_struct_end_method_id; +ID write_field_begin_method_id; +ID write_field_end_method_id; +ID write_boolean_method_id; +ID write_byte_method_id; +ID write_i16_method_id; +ID write_i32_method_id; +ID write_i64_method_id; +ID write_double_method_id; +ID write_string_method_id; +ID write_map_begin_method_id; +ID write_map_end_method_id; +ID write_list_begin_method_id; +ID write_list_end_method_id; +ID write_set_begin_method_id; +ID write_set_end_method_id; +ID size_method_id; +ID read_bool_method_id; +ID read_byte_method_id; +ID read_i16_method_id; +ID read_i32_method_id; +ID read_i64_method_id; +ID read_string_method_id; +ID read_double_method_id; +ID read_map_begin_method_id; +ID read_map_end_method_id; +ID read_list_begin_method_id; +ID read_list_end_method_id; +ID read_set_begin_method_id; +ID read_set_end_method_id; +ID read_struct_begin_method_id; +ID read_struct_end_method_id; +ID read_field_begin_method_id; +ID read_field_end_method_id; +ID keys_method_id; +ID entries_method_id; +ID name_method_id; +ID sort_method_id; +ID write_field_stop_method_id; +ID skip_method_id; +ID write_method_id; +ID read_all_method_id; +ID native_qmark_method_id; + +// constant ids +ID fields_const_id; +ID transport_ivar_id; +ID strict_read_ivar_id; +ID strict_write_ivar_id; + +// cached symbols +VALUE type_sym; +VALUE name_sym; +VALUE key_sym; +VALUE value_sym; +VALUE element_sym; +VALUE class_sym; +VALUE protocol_exception_class; + +void Init_thrift_native() { + // cached classes + thrift_module = rb_const_get(rb_cObject, rb_intern("Thrift")); + thrift_types_module = rb_const_get(thrift_module, rb_intern("Types")); + rb_cSet = rb_const_get(rb_cObject, rb_intern("Set")); + protocol_exception_class = rb_const_get(thrift_module, rb_intern("ProtocolException")); + + // Init ttype constants + TTYPE_BOOL = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BOOL"))); + TTYPE_BYTE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("BYTE"))); + TTYPE_I16 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I16"))); + TTYPE_I32 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I32"))); + TTYPE_I64 = FIX2INT(rb_const_get(thrift_types_module, rb_intern("I64"))); + TTYPE_DOUBLE = FIX2INT(rb_const_get(thrift_types_module, rb_intern("DOUBLE"))); + TTYPE_STRING = FIX2INT(rb_const_get(thrift_types_module, rb_intern("STRING"))); + TTYPE_MAP = FIX2INT(rb_const_get(thrift_types_module, rb_intern("MAP"))); + TTYPE_SET = FIX2INT(rb_const_get(thrift_types_module, rb_intern("SET"))); + TTYPE_LIST = FIX2INT(rb_const_get(thrift_types_module, rb_intern("LIST"))); + TTYPE_STRUCT = FIX2INT(rb_const_get(thrift_types_module, rb_intern("STRUCT"))); + + // method ids + validate_method_id = rb_intern("validate"); + write_struct_begin_method_id = rb_intern("write_struct_begin"); + write_struct_end_method_id = rb_intern("write_struct_end"); + write_field_begin_method_id = rb_intern("write_field_begin"); + write_field_end_method_id = rb_intern("write_field_end"); + write_boolean_method_id = rb_intern("write_bool"); + write_byte_method_id = rb_intern("write_byte"); + write_i16_method_id = rb_intern("write_i16"); + write_i32_method_id = rb_intern("write_i32"); + write_i64_method_id = rb_intern("write_i64"); + write_double_method_id = rb_intern("write_double"); + write_string_method_id = rb_intern("write_string"); + write_map_begin_method_id = rb_intern("write_map_begin"); + write_map_end_method_id = rb_intern("write_map_end"); + write_list_begin_method_id = rb_intern("write_list_begin"); + write_list_end_method_id = rb_intern("write_list_end"); + write_set_begin_method_id = rb_intern("write_set_begin"); + write_set_end_method_id = rb_intern("write_set_end"); + size_method_id = rb_intern("size"); + read_bool_method_id = rb_intern("read_bool"); + read_byte_method_id = rb_intern("read_byte"); + read_i16_method_id = rb_intern("read_i16"); + read_i32_method_id = rb_intern("read_i32"); + read_i64_method_id = rb_intern("read_i64"); + read_string_method_id = rb_intern("read_string"); + read_double_method_id = rb_intern("read_double"); + read_map_begin_method_id = rb_intern("read_map_begin"); + read_map_end_method_id = rb_intern("read_map_end"); + read_list_begin_method_id = rb_intern("read_list_begin"); + read_list_end_method_id = rb_intern("read_list_end"); + read_set_begin_method_id = rb_intern("read_set_begin"); + read_set_end_method_id = rb_intern("read_set_end"); + read_struct_begin_method_id = rb_intern("read_struct_begin"); + read_struct_end_method_id = rb_intern("read_struct_end"); + read_field_begin_method_id = rb_intern("read_field_begin"); + read_field_end_method_id = rb_intern("read_field_end"); + keys_method_id = rb_intern("keys"); + entries_method_id = rb_intern("entries"); + name_method_id = rb_intern("name"); + sort_method_id = rb_intern("sort"); + write_field_stop_method_id = rb_intern("write_field_stop"); + skip_method_id = rb_intern("skip"); + write_method_id = rb_intern("write"); + read_all_method_id = rb_intern("read_all"); + native_qmark_method_id = rb_intern("native?"); + + // constant ids + fields_const_id = rb_intern("FIELDS"); + transport_ivar_id = rb_intern("@trans"); + strict_read_ivar_id = rb_intern("@strict_read"); + strict_write_ivar_id = rb_intern("@strict_write"); + + // cached symbols + type_sym = ID2SYM(rb_intern("type")); + name_sym = ID2SYM(rb_intern("name")); + key_sym = ID2SYM(rb_intern("key")); + value_sym = ID2SYM(rb_intern("value")); + element_sym = ID2SYM(rb_intern("element")); + class_sym = ID2SYM(rb_intern("class")); + + Init_protocol(); + Init_struct(); + Init_binary_protocol_accelerated(); + Init_compact_protocol(); + Init_memory_buffer(); +} diff --git a/lib/rb/lib/thrift.rb b/lib/rb/lib/thrift.rb new file mode 100644 index 00000000..88562e13 --- /dev/null +++ b/lib/rb/lib/thrift.rb @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + +require 'thrift/core_ext' +require 'thrift/exceptions' +require 'thrift/types' +require 'thrift/processor' +require 'thrift/client' +require 'thrift/struct' + +# serializer +require 'thrift/serializer/serializer' +require 'thrift/serializer/deserializer' + +# protocol +require 'thrift/protocol/base_protocol' +require 'thrift/protocol/binary_protocol' +require 'thrift/protocol/binary_protocol_accelerated' +require 'thrift/protocol/compact_protocol' + +# transport +require 'thrift/transport/base_transport' +require 'thrift/transport/base_server_transport' +require 'thrift/transport/socket' +require 'thrift/transport/server_socket' +require 'thrift/transport/unix_socket' +require 'thrift/transport/unix_server_socket' +require 'thrift/transport/buffered_transport' +require 'thrift/transport/framed_transport' +require 'thrift/transport/http_client_transport' +require 'thrift/transport/io_stream_transport' +require 'thrift/transport/memory_buffer_transport' + +# server +require 'thrift/server/base_server' +require 'thrift/server/nonblocking_server' +require 'thrift/server/simple_server' +require 'thrift/server/threaded_server' +require 'thrift/server/thread_pool_server' + +require 'thrift/thrift_native' \ No newline at end of file diff --git a/lib/rb/lib/thrift/client.rb b/lib/rb/lib/thrift/client.rb new file mode 100644 index 00000000..5b30f015 --- /dev/null +++ b/lib/rb/lib/thrift/client.rb @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + module Client + def initialize(iprot, oprot=nil) + @iprot = iprot + @oprot = oprot || iprot + @seqid = 0 + end + + def send_message(name, args_class, args = {}) + @oprot.write_message_begin(name, MessageTypes::CALL, @seqid) + data = args_class.new + args.each do |k, v| + data.send("#{k.to_s}=", v) + end + begin + data.write(@oprot) + rescue StandardError => e + @oprot.trans.close + raise e + end + @oprot.write_message_end + @oprot.trans.flush + end + + def receive_message(result_klass) + fname, mtype, rseqid = @iprot.read_message_begin + handle_exception(mtype) + result = result_klass.new + result.read(@iprot) + @iprot.read_message_end + result + end + + def handle_exception(mtype) + if mtype == MessageTypes::EXCEPTION + x = ApplicationException.new + x.read(@iprot) + @iprot.read_message_end + raise x + end + end + end +end diff --git a/lib/rb/lib/thrift/core_ext.rb b/lib/rb/lib/thrift/core_ext.rb new file mode 100644 index 00000000..f763cd53 --- /dev/null +++ b/lib/rb/lib/thrift/core_ext.rb @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +Dir[File.dirname(__FILE__) + "/core_ext/*.rb"].each do |file| + name = File.basename(file, '.rb') + require "thrift/core_ext/#{name}" +end diff --git a/lib/rb/lib/thrift/core_ext/fixnum.rb b/lib/rb/lib/thrift/core_ext/fixnum.rb new file mode 100644 index 00000000..b4fc90dd --- /dev/null +++ b/lib/rb/lib/thrift/core_ext/fixnum.rb @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Versions of ruby pre 1.8.7 do not have an .ord method available in the Fixnum +# class. +# +if RUBY_VERSION < "1.8.7" + class Fixnum + def ord + self + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/exceptions.rb b/lib/rb/lib/thrift/exceptions.rb new file mode 100644 index 00000000..dda70894 --- /dev/null +++ b/lib/rb/lib/thrift/exceptions.rb @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Exception < StandardError + def initialize(message) + super + @message = message + end + + attr_reader :message + end + + class ApplicationException < Exception + + UNKNOWN = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + + def read(iprot) + iprot.read_struct_begin + while true + fname, ftype, fid = iprot.read_field_begin + if ftype == Types::STOP + break + end + if fid == 1 and ftype == Types::STRING + @message = iprot.read_string + elsif fid == 2 and ftype == Types::I32 + @type = iprot.read_i32 + else + iprot.skip(ftype) + end + iprot.read_field_end + end + iprot.read_struct_end + end + + def write(oprot) + oprot.write_struct_begin('Thrift::ApplicationException') + unless @message.nil? + oprot.write_field_begin('message', Types::STRING, 1) + oprot.write_string(@message) + oprot.write_field_end + end + unless @type.nil? + oprot.write_field_begin('type', Types::I32, 2) + oprot.write_i32(@type) + oprot.write_field_end + end + oprot.write_field_stop + oprot.write_struct_end + end + + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/processor.rb b/lib/rb/lib/thrift/processor.rb new file mode 100644 index 00000000..5d9e0a11 --- /dev/null +++ b/lib/rb/lib/thrift/processor.rb @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + module Processor + def initialize(handler) + @handler = handler + end + + def process(iprot, oprot) + name, type, seqid = iprot.read_message_begin + if respond_to?("process_#{name}") + send("process_#{name}", seqid, iprot, oprot) + true + else + iprot.skip(Types::STRUCT) + iprot.read_message_end + x = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, 'Unknown function '+name) + oprot.write_message_begin(name, MessageTypes::EXCEPTION, seqid) + x.write(oprot) + oprot.write_message_end + oprot.trans.flush + false + end + end + + def read_args(iprot, args_class) + args = args_class.new + args.read(iprot) + iprot.read_message_end + args + end + + def write_result(result, oprot, name, seqid) + oprot.write_message_begin(name, MessageTypes::REPLY, seqid) + result.write(oprot) + oprot.write_message_end + oprot.trans.flush + end + end +end diff --git a/lib/rb/lib/thrift/protocol/base_protocol.rb b/lib/rb/lib/thrift/protocol/base_protocol.rb new file mode 100644 index 00000000..b19909d5 --- /dev/null +++ b/lib/rb/lib/thrift/protocol/base_protocol.rb @@ -0,0 +1,290 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# this require is to make generated struct definitions happy +require 'set' + +module Thrift + class ProtocolException < Exception + + UNKNOWN = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + end + + class BaseProtocol + + attr_reader :trans + + def initialize(trans) + @trans = trans + end + + def native? + puts "wrong method is being called!" + false + end + + def write_message_begin(name, type, seqid) + raise NotImplementedError + end + + def write_message_end; nil; end + + def write_struct_begin(name) + raise NotImplementedError + end + + def write_struct_end; nil; end + + def write_field_begin(name, type, id) + raise NotImplementedError + end + + def write_field_end; nil; end + + def write_field_stop + raise NotImplementedError + end + + def write_map_begin(ktype, vtype, size) + raise NotImplementedError + end + + def write_map_end; nil; end + + def write_list_begin(etype, size) + raise NotImplementedError + end + + def write_list_end; nil; end + + def write_set_begin(etype, size) + raise NotImplementedError + end + + def write_set_end; nil; end + + def write_bool(bool) + raise NotImplementedError + end + + def write_byte(byte) + raise NotImplementedError + end + + def write_i16(i16) + raise NotImplementedError + end + + def write_i32(i32) + raise NotImplementedError + end + + def write_i64(i64) + raise NotImplementedError + end + + def write_double(dub) + raise NotImplementedError + end + + def write_string(str) + raise NotImplementedError + end + + def read_message_begin + raise NotImplementedError + end + + def read_message_end; nil; end + + def read_struct_begin + raise NotImplementedError + end + + def read_struct_end; nil; end + + def read_field_begin + raise NotImplementedError + end + + def read_field_end; nil; end + + def read_map_begin + raise NotImplementedError + end + + def read_map_end; nil; end + + def read_list_begin + raise NotImplementedError + end + + def read_list_end; nil; end + + def read_set_begin + raise NotImplementedError + end + + def read_set_end; nil; end + + def read_bool + raise NotImplementedError + end + + def read_byte + raise NotImplementedError + end + + def read_i16 + raise NotImplementedError + end + + def read_i32 + raise NotImplementedError + end + + def read_i64 + raise NotImplementedError + end + + def read_double + raise NotImplementedError + end + + def read_string + raise NotImplementedError + end + + def write_field(name, type, fid, value) + write_field_begin(name, type, fid) + write_type(type, value) + write_field_end + end + + def write_type(type, value) + case type + when Types::BOOL + write_bool(value) + when Types::BYTE + write_byte(value) + when Types::DOUBLE + write_double(value) + when Types::I16 + write_i16(value) + when Types::I32 + write_i32(value) + when Types::I64 + write_i64(value) + when Types::STRING + write_string(value) + when Types::STRUCT + value.write(self) + else + raise NotImplementedError + end + end + + def read_type(type) + case type + when Types::BOOL + read_bool + when Types::BYTE + read_byte + when Types::DOUBLE + read_double + when Types::I16 + read_i16 + when Types::I32 + read_i32 + when Types::I64 + read_i64 + when Types::STRING + read_string + else + raise NotImplementedError + end + end + + def skip(type) + case type + when Types::STOP + nil + when Types::BOOL + read_bool + when Types::BYTE + read_byte + when Types::I16 + read_i16 + when Types::I32 + read_i32 + when Types::I64 + read_i64 + when Types::DOUBLE + read_double + when Types::STRING + read_string + when Types::STRUCT + read_struct_begin + while true + name, type, id = read_field_begin + break if type == Types::STOP + skip(type) + read_field_end + end + read_struct_end + when Types::MAP + ktype, vtype, size = read_map_begin + size.times do + skip(ktype) + skip(vtype) + end + read_map_end + when Types::SET + etype, size = read_set_begin + size.times do + skip(etype) + end + read_set_end + when Types::LIST + etype, size = read_list_begin + size.times do + skip(etype) + end + read_list_end + end + end + end + + class BaseProtocolFactory + def get_protocol(trans) + raise NotImplementedError + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/protocol/binary_protocol.rb b/lib/rb/lib/thrift/protocol/binary_protocol.rb new file mode 100644 index 00000000..04d149ac --- /dev/null +++ b/lib/rb/lib/thrift/protocol/binary_protocol.rb @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BinaryProtocol < BaseProtocol + VERSION_MASK = 0xffff0000 + VERSION_1 = 0x80010000 + TYPE_MASK = 0x000000ff + + attr_reader :strict_read, :strict_write + + def initialize(trans, strict_read=true, strict_write=true) + super(trans) + @strict_read = strict_read + @strict_write = strict_write + end + + def write_message_begin(name, type, seqid) + # this is necessary because we added (needed) bounds checking to + # write_i32, and 0x80010000 is too big for that. + if strict_write + write_i16(VERSION_1 >> 16) + write_i16(type) + write_string(name) + write_i32(seqid) + else + write_string(name) + write_byte(type) + write_i32(seqid) + end + end + + def write_struct_begin(name); nil; end + + def write_field_begin(name, type, id) + write_byte(type) + write_i16(id) + end + + def write_field_stop + write_byte(Thrift::Types::STOP) + end + + def write_map_begin(ktype, vtype, size) + write_byte(ktype) + write_byte(vtype) + write_i32(size) + end + + def write_list_begin(etype, size) + write_byte(etype) + write_i32(size) + end + + def write_set_begin(etype, size) + write_byte(etype) + write_i32(size) + end + + def write_bool(bool) + write_byte(bool ? 1 : 0) + end + + def write_byte(byte) + raise RangeError if byte < -2**31 || byte >= 2**32 + trans.write([byte].pack('c')) + end + + def write_i16(i16) + trans.write([i16].pack('n')) + end + + def write_i32(i32) + raise RangeError if i32 < -2**31 || i32 >= 2**31 + trans.write([i32].pack('N')) + end + + def write_i64(i64) + raise RangeError if i64 < -2**63 || i64 >= 2**64 + hi = i64 >> 32 + lo = i64 & 0xffffffff + trans.write([hi, lo].pack('N2')) + end + + def write_double(dub) + trans.write([dub].pack('G')) + end + + def write_string(str) + write_i32(str.length) + trans.write(str) + end + + def read_message_begin + version = read_i32 + if version < 0 + if (version & VERSION_MASK != VERSION_1) + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier') + end + type = version & TYPE_MASK + name = read_string + seqid = read_i32 + [name, type, seqid] + else + if strict_read + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'No version identifier, old protocol client?') + end + name = trans.read_all(version) + type = read_byte + seqid = read_i32 + [name, type, seqid] + end + end + + def read_struct_begin; nil; end + + def read_field_begin + type = read_byte + if (type == Types::STOP) + [nil, type, 0] + else + id = read_i16 + [nil, type, id] + end + end + + def read_map_begin + ktype = read_byte + vtype = read_byte + size = read_i32 + [ktype, vtype, size] + end + + def read_list_begin + etype = read_byte + size = read_i32 + [etype, size] + end + + def read_set_begin + etype = read_byte + size = read_i32 + [etype, size] + end + + def read_bool + byte = read_byte + byte != 0 + end + + def read_byte + dat = trans.read_all(1) + val = dat[0].ord + if (val > 0x7f) + val = 0 - ((val - 1) ^ 0xff) + end + val + end + + def read_i16 + dat = trans.read_all(2) + val, = dat.unpack('n') + if (val > 0x7fff) + val = 0 - ((val - 1) ^ 0xffff) + end + val + end + + def read_i32 + dat = trans.read_all(4) + val, = dat.unpack('N') + if (val > 0x7fffffff) + val = 0 - ((val - 1) ^ 0xffffffff) + end + val + end + + def read_i64 + dat = trans.read_all(8) + hi, lo = dat.unpack('N2') + if (hi > 0x7fffffff) + hi ^= 0xffffffff + lo ^= 0xffffffff + 0 - (hi << 32) - lo - 1 + else + (hi << 32) + lo + end + end + + def read_double + dat = trans.read_all(8) + val = dat.unpack('G').first + val + end + + def read_string + sz = read_i32 + dat = trans.read_all(sz) + dat + end + + end + + class BinaryProtocolFactory < BaseProtocolFactory + def get_protocol(trans) + return Thrift::BinaryProtocol.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb new file mode 100644 index 00000000..eaf64f6b --- /dev/null +++ b/lib/rb/lib/thrift/protocol/binary_protocol_accelerated.rb @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +=begin +The only change required for a transport to support BinaryProtocolAccelerated is to implement 2 methods: + * borrow(size), which takes an optional argument and returns atleast _size_ bytes from the transport, + or the default buffer size if no argument is given + * consume!(size), which removes size bytes from the front of the buffer + +See MemoryBuffer and BufferedTransport for examples. +=end + +module Thrift + class BinaryProtocolAcceleratedFactory < BaseProtocolFactory + def get_protocol(trans) + BinaryProtocolAccelerated.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/protocol/compact_protocol.rb b/lib/rb/lib/thrift/protocol/compact_protocol.rb new file mode 100644 index 00000000..c8f43655 --- /dev/null +++ b/lib/rb/lib/thrift/protocol/compact_protocol.rb @@ -0,0 +1,422 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class CompactProtocol < BaseProtocol + + PROTOCOL_ID = [0x82].pack('c').unpack('c').first + VERSION = 1 + VERSION_MASK = 0x1f + TYPE_MASK = 0xE0 + TYPE_SHIFT_AMOUNT = 5 + + TSTOP = ["", Types::STOP, 0] + + # + # All of the on-wire type codes. + # + class CompactTypes + BOOLEAN_TRUE = 0x01 + BOOLEAN_FALSE = 0x02 + BYTE = 0x03 + I16 = 0x04 + I32 = 0x05 + I64 = 0x06 + DOUBLE = 0x07 + BINARY = 0x08 + LIST = 0x09 + SET = 0x0A + MAP = 0x0B + STRUCT = 0x0C + + def self.is_bool_type?(b) + (b & 0x0f) == BOOLEAN_TRUE || (b & 0x0f) == BOOLEAN_FALSE + end + + COMPACT_TO_TTYPE = { + Types::STOP => Types::STOP, + BOOLEAN_FALSE => Types::BOOL, + BOOLEAN_TRUE => Types::BOOL, + BYTE => Types::BYTE, + I16 => Types::I16, + I32 => Types::I32, + I64 => Types::I64, + DOUBLE => Types::DOUBLE, + BINARY => Types::STRING, + LIST => Types::LIST, + SET => Types::SET, + MAP => Types::MAP, + STRUCT => Types::STRUCT + } + + TTYPE_TO_COMPACT = { + Types::STOP => Types::STOP, + Types::BOOL => BOOLEAN_TRUE, + Types::BYTE => BYTE, + Types::I16 => I16, + Types::I32 => I32, + Types::I64 => I64, + Types::DOUBLE => DOUBLE, + Types::STRING => BINARY, + Types::LIST => LIST, + Types::SET => SET, + Types::MAP => MAP, + Types::STRUCT => STRUCT + } + + def self.get_ttype(compact_type) + val = COMPACT_TO_TTYPE[compact_type & 0x0f] + raise "don't know what type: #{compact_type & 0x0f}" unless val + val + end + + def self.get_compact_type(ttype) + val = TTYPE_TO_COMPACT[ttype] + raise "don't know what type: #{ttype & 0x0f}" unless val + val + end + end + + def initialize(transport) + super(transport) + + @last_field = [0] + @boolean_value = nil + end + + def write_message_begin(name, type, seqid) + write_byte(PROTOCOL_ID) + write_byte((VERSION & VERSION_MASK) | ((type << TYPE_SHIFT_AMOUNT) & TYPE_MASK)) + write_varint32(seqid) + write_string(name) + nil + end + + def write_struct_begin(name) + @last_field.push(0) + nil + end + + def write_struct_end + @last_field.pop + nil + end + + def write_field_begin(name, type, id) + if type == Types::BOOL + # we want to possibly include the value, so we'll wait. + @boolean_field = [type, id] + else + write_field_begin_internal(type, id) + end + nil + end + + # + # The workhorse of writeFieldBegin. It has the option of doing a + # 'type override' of the type header. This is used specifically in the + # boolean field case. + # + def write_field_begin_internal(type, id, type_override=nil) + last_id = @last_field.pop + + # if there's a type override, use that. + typeToWrite = type_override || CompactTypes.get_compact_type(type) + + # check if we can use delta encoding for the field id + if id > last_id && id - last_id <= 15 + # write them together + write_byte((id - last_id) << 4 | typeToWrite) + else + # write them separate + write_byte(typeToWrite) + write_i16(id) + end + + @last_field.push(id) + nil + end + + def write_field_stop + write_byte(Types::STOP) + end + + def write_map_begin(ktype, vtype, size) + if (size == 0) + write_byte(0) + else + write_varint32(size) + write_byte(CompactTypes.get_compact_type(ktype) << 4 | CompactTypes.get_compact_type(vtype)) + end + end + + def write_list_begin(etype, size) + write_collection_begin(etype, size) + end + + def write_set_begin(etype, size) + write_collection_begin(etype, size); + end + + def write_bool(bool) + type = bool ? CompactTypes::BOOLEAN_TRUE : CompactTypes::BOOLEAN_FALSE + unless @boolean_field.nil? + # we haven't written the field header yet + write_field_begin_internal(@boolean_field.first, @boolean_field.last, type) + @boolean_field = nil + else + # we're not part of a field, so just write the value. + write_byte(type) + end + end + + def write_byte(byte) + @trans.write([byte].pack('c')) + end + + def write_i16(i16) + write_varint32(int_to_zig_zag(i16)) + end + + def write_i32(i32) + write_varint32(int_to_zig_zag(i32)) + end + + def write_i64(i64) + write_varint64(long_to_zig_zag(i64)) + end + + def write_double(dub) + @trans.write([dub].pack("G").reverse) + end + + def write_string(str) + write_varint32(str.length) + @trans.write(str) + end + + def read_message_begin + protocol_id = read_byte() + if protocol_id != PROTOCOL_ID + raise ProtocolException.new("Expected protocol id #{PROTOCOL_ID} but got #{protocol_id}") + end + + version_and_type = read_byte() + version = version_and_type & VERSION_MASK + if (version != VERSION) + raise ProtocolException.new("Expected version #{VERSION} but got #{version}"); + end + + type = (version_and_type >> TYPE_SHIFT_AMOUNT) & 0x03 + seqid = read_varint32() + messageName = read_string() + [messageName, type, seqid] + end + + def read_struct_begin + @last_field.push(0) + "" + end + + def read_struct_end + @last_field.pop() + nil + end + + def read_field_begin + type = read_byte() + + # if it's a stop, then we can return immediately, as the struct is over. + if (type & 0x0f) == Types::STOP + TSTOP + else + field_id = nil + + # mask off the 4 MSB of the type header. it could contain a field id delta. + modifier = (type & 0xf0) >> 4 + if modifier == 0 + # not a delta. look ahead for the zigzag varint field id. + field_id = read_i16() + else + # has a delta. add the delta to the last read field id. + field_id = @last_field.pop + modifier + end + + # if this happens to be a boolean field, the value is encoded in the type + if CompactTypes.is_bool_type?(type) + # save the boolean value in a special instance variable. + @bool_value = (type & 0x0f) == CompactTypes::BOOLEAN_TRUE + end + + # push the new field onto the field stack so we can keep the deltas going. + @last_field.push(field_id) + ["", CompactTypes.get_ttype(type & 0x0f), field_id] + end + end + + def read_map_begin + size = read_varint32() + key_and_value_type = size == 0 ? 0 : read_byte() + [CompactTypes.get_ttype(key_and_value_type >> 4), CompactTypes.get_ttype(key_and_value_type & 0xf), size] + end + + def read_list_begin + size_and_type = read_byte() + size = (size_and_type >> 4) & 0x0f + if size == 15 + size = read_varint32() + end + type = CompactTypes.get_ttype(size_and_type) + [type, size] + end + + def read_set_begin + read_list_begin + end + + def read_bool + unless @bool_value.nil? + bv = @bool_value + @bool_value = nil + bv + else + read_byte() == CompactTypes::BOOLEAN_TRUE + end + end + + def read_byte + dat = trans.read_all(1) + val = dat[0] + if (val > 0x7f) + val = 0 - ((val - 1) ^ 0xff) + end + val + end + + def read_i16 + zig_zag_to_int(read_varint32()) + end + + def read_i32 + zig_zag_to_int(read_varint32()) + end + + def read_i64 + zig_zag_to_long(read_varint64()) + end + + def read_double + dat = trans.read_all(8) + val = dat.reverse.unpack('G').first + val + end + + def read_string + size = read_varint32() + trans.read_all(size) + end + + + private + + # + # Abstract method for writing the start of lists and sets. List and sets on + # the wire differ only by the type indicator. + # + def write_collection_begin(elem_type, size) + if size <= 14 + write_byte(size << 4 | CompactTypes.get_compact_type(elem_type)) + else + write_byte(0xf0 | CompactTypes.get_compact_type(elem_type)) + write_varint32(size) + end + end + + def write_varint32(n) + # int idx = 0; + while true + if (n & ~0x7F) == 0 + # i32buf[idx++] = (byte)n; + write_byte(n) + break + # return; + else + # i32buf[idx++] = (byte)((n & 0x7F) | 0x80); + write_byte((n & 0x7F) | 0x80) + n = n >> 7 + end + end + # trans_.write(i32buf, 0, idx); + end + + SEVEN_BIT_MASK = 0x7F + EVERYTHING_ELSE_MASK = ~SEVEN_BIT_MASK + + def write_varint64(n) + while true + if (n & EVERYTHING_ELSE_MASK) == 0 #TODO need to find a way to make this into a long... + write_byte(n) + break + else + write_byte((n & SEVEN_BIT_MASK) | 0x80) + n >>= 7 + end + end + end + + def read_varint32() + read_varint64() + end + + def read_varint64() + shift = 0 + result = 0 + while true + b = read_byte() + result |= (b & 0x7f) << shift + break if (b & 0x80) != 0x80 + shift += 7 + end + result + end + + def int_to_zig_zag(n) + (n << 1) ^ (n >> 31) + end + + def long_to_zig_zag(l) + # puts "zz encoded #{l} to #{(l << 1) ^ (l >> 63)}" + (l << 1) ^ (l >> 63) + end + + def zig_zag_to_int(n) + (n >> 1) ^ -(n & 1) + end + + def zig_zag_to_long(n) + (n >> 1) ^ -(n & 1) + end + end + + class CompactProtocolFactory < BaseProtocolFactory + def get_protocol(trans) + CompactProtocol.new(trans) + end + end +end diff --git a/lib/rb/lib/thrift/serializer/deserializer.rb b/lib/rb/lib/thrift/serializer/deserializer.rb new file mode 100644 index 00000000..d2ee325a --- /dev/null +++ b/lib/rb/lib/thrift/serializer/deserializer.rb @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Deserializer + def initialize(protocol_factory = BinaryProtocolFactory.new) + @transport = MemoryBufferTransport.new + @protocol = protocol_factory.get_protocol(@transport) + end + + def deserialize(base, buffer) + @transport.reset_buffer(buffer) + base.read(@protocol) + base + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/serializer/serializer.rb b/lib/rb/lib/thrift/serializer/serializer.rb new file mode 100644 index 00000000..22316395 --- /dev/null +++ b/lib/rb/lib/thrift/serializer/serializer.rb @@ -0,0 +1,34 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class Serializer + def initialize(protocol_factory = BinaryProtocolFactory.new) + @transport = MemoryBufferTransport.new + @protocol = protocol_factory.get_protocol(@transport) + end + + def serialize(base) + @transport.reset_buffer + base.write(@protocol) + @transport.read(@transport.available) + end + end +end + diff --git a/lib/rb/lib/thrift/server/base_server.rb b/lib/rb/lib/thrift/server/base_server.rb new file mode 100644 index 00000000..1ee12133 --- /dev/null +++ b/lib/rb/lib/thrift/server/base_server.rb @@ -0,0 +1,31 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil) + @processor = processor + @server_transport = server_transport + @transport_factory = transport_factory ? transport_factory : Thrift::BaseTransportFactory.new + @protocol_factory = protocol_factory ? protocol_factory : Thrift::BinaryProtocolFactory.new + end + + def serve; nil; end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/server/mongrel_http_server.rb b/lib/rb/lib/thrift/server/mongrel_http_server.rb new file mode 100644 index 00000000..84eacf0d --- /dev/null +++ b/lib/rb/lib/thrift/server/mongrel_http_server.rb @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'mongrel' + +## Sticks a service on a URL, using mongrel to do the HTTP work +module Thrift + class MongrelHTTPServer < BaseServer + class Handler < Mongrel::HttpHandler + def initialize(processor, protocol_factory) + @processor = processor + @protocol_factory = protocol_factory + end + + def process(request, response) + if request.params["REQUEST_METHOD"] == "POST" + response.start(200) do |head, out| + head["Content-Type"] = "application/x-thrift" + transport = IOStreamTransport.new request.body, out + protocol = @protocol_factory.get_protocol transport + @processor.process protocol, protocol + end + else + response.start(404) { } + end + end + end + + def initialize(processor, opts={}) + port = opts[:port] || 80 + ip = opts[:ip] || "0.0.0.0" + path = opts[:path] || "" + protocol_factory = opts[:protocol_factory] || BinaryProtocolFactory.new + @server = Mongrel::HttpServer.new ip, port + @server.register "/#{path}", Handler.new(processor, protocol_factory) + end + + def serve + @server.run.join + end + end +end diff --git a/lib/rb/lib/thrift/server/nonblocking_server.rb b/lib/rb/lib/thrift/server/nonblocking_server.rb new file mode 100644 index 00000000..5425f6de --- /dev/null +++ b/lib/rb/lib/thrift/server/nonblocking_server.rb @@ -0,0 +1,296 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'logger' +require 'thread' + +module Thrift + # this class expects to always use a FramedTransport for reading messages + class NonblockingServer < BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil, num=20, logger=nil) + super(processor, server_transport, transport_factory, protocol_factory) + @num_threads = num + if logger.nil? + @logger = Logger.new(STDERR) + @logger.level = Logger::WARN + else + @logger = logger + end + @shutdown_semaphore = Mutex.new + @transport_semaphore = Mutex.new + end + + def serve + @logger.info "Starting #{self}" + @server_transport.listen + @io_manager = start_io_manager + + begin + loop do + break if @server_transport.closed? + rd, = select([@server_transport], nil, nil, 0.1) + next if rd.nil? + socket = @server_transport.accept + @logger.debug "Accepted socket: #{socket.inspect}" + @io_manager.add_connection socket + end + rescue IOError => e + end + # we must be shutting down + @logger.info "#{self} is shutting down, goodbye" + ensure + @transport_semaphore.synchronize do + @server_transport.close + end + @io_manager.ensure_closed unless @io_manager.nil? + end + + def shutdown(timeout = 0, block = true) + @shutdown_semaphore.synchronize do + return if @is_shutdown + @is_shutdown = true + end + # nonblocking is intended for calling from within a Handler + # but we can't change the order of operations here, so lets thread + shutdown_proc = lambda do + @io_manager.shutdown(timeout) + @transport_semaphore.synchronize do + @server_transport.close # this will break the accept loop + end + end + if block + shutdown_proc.call + else + Thread.new &shutdown_proc + end + end + + private + + def start_io_manager + iom = IOManager.new(@processor, @server_transport, @transport_factory, @protocol_factory, @num_threads, @logger) + iom.spawn + iom + end + + class IOManager # :nodoc: + DEFAULT_BUFFER = 2**20 + + def initialize(processor, server_transport, transport_factory, protocol_factory, num, logger) + @processor = processor + @server_transport = server_transport + @transport_factory = transport_factory + @protocol_factory = protocol_factory + @num_threads = num + @logger = logger + @connections = [] + @buffers = Hash.new { |h,k| h[k] = '' } + @signal_queue = Queue.new + @signal_pipes = IO.pipe + @signal_pipes[1].sync = true + @worker_queue = Queue.new + @shutdown_queue = Queue.new + end + + def add_connection(socket) + signal [:connection, socket] + end + + def spawn + @iom_thread = Thread.new do + @logger.debug "Starting #{self}" + run + end + end + + def shutdown(timeout = 0) + @logger.debug "#{self} is shutting down workers" + @worker_queue.clear + @num_threads.times { @worker_queue.push [:shutdown] } + signal [:shutdown, timeout] + @shutdown_queue.pop + @signal_pipes[0].close + @signal_pipes[1].close + @logger.debug "#{self} is shutting down, goodbye" + end + + def ensure_closed + kill_worker_threads if @worker_threads + @iom_thread.kill + end + + private + + def run + spin_worker_threads + + loop do + rd, = select([@signal_pipes[0], *@connections]) + if rd.delete @signal_pipes[0] + break if read_signals == :shutdown + end + rd.each do |fd| + if fd.handle.eof? + remove_connection fd + else + read_connection fd + end + end + end + join_worker_threads(@shutdown_timeout) + ensure + @shutdown_queue.push :shutdown + end + + def read_connection(fd) + @buffers[fd] << fd.read(DEFAULT_BUFFER) + frame = slice_frame!(@buffers[fd]) + if frame + @logger.debug "#{self} is processing a frame" + @worker_queue.push [:frame, fd, frame] + end + end + + def spin_worker_threads + @logger.debug "#{self} is spinning up worker threads" + @worker_threads = [] + @num_threads.times do + @worker_threads << spin_thread + end + end + + def spin_thread + Worker.new(@processor, @transport_factory, @protocol_factory, @logger, @worker_queue).spawn + end + + def signal(msg) + @signal_queue << msg + @signal_pipes[1].write " " + end + + def read_signals + # clear the signal pipe + # note that since read_nonblock is broken in jruby, + # we can only read up to a set number of signals at once + sigstr = @signal_pipes[0].readpartial(1024) + # now read the signals + begin + sigstr.length.times do + signal, obj = @signal_queue.pop(true) + case signal + when :connection + @connections << obj + when :shutdown + @shutdown_timeout = obj + return :shutdown + end + end + rescue ThreadError + # out of signals + # note that in a perfect world this would never happen, since we're + # only reading the number of signals pushed on the pipe, but given the lack + # of locks, in theory we could clear the pipe/queue while a new signal is being + # placed on the pipe, at which point our next read_signals would hit this error + end + end + + def remove_connection(fd) + # don't explicitly close it, a thread may still be writing to it + @connections.delete fd + @buffers.delete fd + end + + def join_worker_threads(shutdown_timeout) + start = Time.now + @worker_threads.each do |t| + if shutdown_timeout > 0 + timeout = (start + shutdown_timeout) - Time.now + break if timeout <= 0 + t.join(timeout) + else + t.join + end + end + kill_worker_threads + end + + def kill_worker_threads + @worker_threads.each do |t| + t.kill if t.status + end + @worker_threads.clear + end + + def slice_frame!(buf) + if buf.length >= 4 + size = buf.unpack('N').first + if buf.length >= size + 4 + buf.slice!(0, size + 4) + else + nil + end + else + nil + end + end + + class Worker # :nodoc: + def initialize(processor, transport_factory, protocol_factory, logger, queue) + @processor = processor + @transport_factory = transport_factory + @protocol_factory = protocol_factory + @logger = logger + @queue = queue + end + + def spawn + Thread.new do + @logger.debug "#{self} is spawning" + run + end + end + + private + + def run + loop do + cmd, *args = @queue.pop + case cmd + when :shutdown + @logger.debug "#{self} is shutting down, goodbye" + break + when :frame + fd, frame = args + begin + otrans = @transport_factory.get_transport(fd) + oprot = @protocol_factory.get_protocol(otrans) + membuf = MemoryBufferTransport.new(frame) + itrans = @transport_factory.get_transport(membuf) + iprot = @protocol_factory.get_protocol(itrans) + @processor.process(iprot, oprot) + rescue => e + @logger.error "#{Thread.current.inspect} raised error: #{e.inspect}\n#{e.backtrace.join("\n")}" + end + end + end + end + end + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/server/simple_server.rb b/lib/rb/lib/thrift/server/simple_server.rb new file mode 100644 index 00000000..21e86592 --- /dev/null +++ b/lib/rb/lib/thrift/server/simple_server.rb @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class SimpleServer < BaseServer + def serve + begin + @server_transport.listen + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + begin + loop do + @processor.process(prot, prot) + end + rescue Thrift::TransportException, Thrift::ProtocolException + ensure + trans.close + end + end + ensure + @server_transport.close + end + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/server/thread_pool_server.rb b/lib/rb/lib/thrift/server/thread_pool_server.rb new file mode 100644 index 00000000..8cec805a --- /dev/null +++ b/lib/rb/lib/thrift/server/thread_pool_server.rb @@ -0,0 +1,75 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'thread' + +module Thrift + class ThreadPoolServer < BaseServer + def initialize(processor, server_transport, transport_factory=nil, protocol_factory=nil, num=20) + super(processor, server_transport, transport_factory, protocol_factory) + @thread_q = SizedQueue.new(num) + @exception_q = Queue.new + @running = false + end + + ## exceptions that happen in worker threads will be relayed here and + ## must be caught. 'retry' can be used to continue. (threads will + ## continue to run while the exception is being handled.) + def rescuable_serve + Thread.new { serve } unless @running + @running = true + raise @exception_q.pop + end + + ## exceptions that happen in worker threads simply cause that thread + ## to die and another to be spawned in its place. + def serve + @server_transport.listen + + begin + loop do + @thread_q.push(:token) + Thread.new do + begin + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + begin + loop do + @processor.process(prot, prot) + end + rescue Thrift::TransportException, Thrift::ProtocolException => e + ensure + trans.close + end + end + rescue => e + @exception_q.push(e) + ensure + @thread_q.pop # thread died! + end + end + end + ensure + @server_transport.close + end + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/server/threaded_server.rb b/lib/rb/lib/thrift/server/threaded_server.rb new file mode 100644 index 00000000..a2c917cb --- /dev/null +++ b/lib/rb/lib/thrift/server/threaded_server.rb @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'thread' + +module Thrift + class ThreadedServer < BaseServer + def serve + begin + @server_transport.listen + loop do + client = @server_transport.accept + trans = @transport_factory.get_transport(client) + prot = @protocol_factory.get_protocol(trans) + Thread.new(prot, trans) do |p, t| + begin + loop do + @processor.process(p, p) + end + rescue Thrift::TransportException, Thrift::ProtocolException + ensure + t.close + end + end + end + ensure + @server_transport.close + end + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/struct.rb b/lib/rb/lib/thrift/struct.rb new file mode 100644 index 00000000..01aae56b --- /dev/null +++ b/lib/rb/lib/thrift/struct.rb @@ -0,0 +1,294 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'set' + +module Thrift + module Struct + def initialize(d={}) + # get a copy of the default values to work on, removing defaults in favor of arguments + fields_with_defaults = fields_with_default_values.dup + + # check if the defaults is empty, or if there are no parameters for this + # instantiation, and if so, don't bother overriding defaults. + unless fields_with_defaults.empty? || d.empty? + d.each_key do |name| + fields_with_defaults.delete(name.to_s) + end + end + + # assign all the user-specified arguments + unless d.empty? + d.each do |name, value| + unless name_to_id(name.to_s) + raise Exception, "Unknown key given to #{self.class}.new: #{name}" + end + Thrift.check_type(value, struct_fields[name_to_id(name.to_s)], name) if Thrift.type_checking + instance_variable_set("@#{name}", value) + end + end + + # assign all the default values + unless fields_with_defaults.empty? + fields_with_defaults.each do |name, default_value| + instance_variable_set("@#{name}", (default_value.dup rescue default_value)) + end + end + end + + def fields_with_default_values + fields_with_default_values = self.class.instance_variable_get("@fields_with_default_values") + unless fields_with_default_values + fields_with_default_values = {} + struct_fields.each do |fid, field_def| + unless field_def[:default].nil? + fields_with_default_values[field_def[:name]] = field_def[:default] + end + end + self.class.instance_variable_set("@fields_with_default_values", fields_with_default_values) + end + fields_with_default_values + end + + def name_to_id(name) + names_to_ids = self.class.instance_variable_get("@names_to_ids") + unless names_to_ids + names_to_ids = {} + struct_fields.each do |fid, field_def| + names_to_ids[field_def[:name]] = fid + end + self.class.instance_variable_set("@names_to_ids", names_to_ids) + end + names_to_ids[name] + end + + def each_field + struct_fields.keys.sort.each do |fid| + data = struct_fields[fid] + yield fid, data + end + end + + def inspect(skip_optional_nulls = true) + fields = [] + each_field do |fid, field_info| + name = field_info[:name] + value = instance_variable_get("@#{name}") + unless skip_optional_nulls && field_info[:optional] && value.nil? + fields << "#{name}:#{value.inspect}" + end + end + "<#{self.class} #{fields.join(", ")}>" + end + + def read(iprot) + iprot.read_struct_begin + loop do + fname, ftype, fid = iprot.read_field_begin + break if (ftype == Types::STOP) + handle_message(iprot, fid, ftype) + iprot.read_field_end + end + iprot.read_struct_end + validate + end + + def write(oprot) + validate + oprot.write_struct_begin(self.class.name) + each_field do |fid, field_info| + name = field_info[:name] + type = field_info[:type] + if (value = instance_variable_get("@#{name}")) + if is_container? type + oprot.write_field_begin(name, type, fid) + write_container(oprot, value, field_info) + oprot.write_field_end + else + oprot.write_field(name, type, fid, value) + end + end + end + oprot.write_field_stop + oprot.write_struct_end + end + + def ==(other) + each_field do |fid, field_info| + name = field_info[:name] + return false unless self.instance_variable_get("@#{name}") == other.instance_variable_get("@#{name}") + end + true + end + + def eql?(other) + self.class == other.class && self == other + end + + # for the time being, we're ok with a naive hash. this could definitely be improved upon. + def hash + 0 + end + + def differences(other) + diffs = [] + unless other.is_a?(self.class) + diffs << "Different class!" + else + each_field do |fid, field_info| + name = field_info[:name] + diffs << "#{name} differs!" unless self.instance_variable_get("@#{name}") == other.instance_variable_get("@#{name}") + end + end + diffs + end + + def self.field_accessor(klass, *fields) + fields.each do |field| + klass.send :attr_reader, field + klass.send :define_method, "#{field}=" do |value| + Thrift.check_type(value, klass::FIELDS.values.find { |f| f[:name].to_s == field.to_s }, field) if Thrift.type_checking + instance_variable_set("@#{field}", value) + end + end + end + + protected + + def self.append_features(mod) + if mod.ancestors.include? ::Exception + mod.send :class_variable_set, :'@@__thrift_struct_real_initialize', mod.instance_method(:initialize) + super + # set up our custom initializer so `raise Xception, 'message'` works + mod.send :define_method, :struct_initialize, mod.instance_method(:initialize) + mod.send :define_method, :initialize, mod.instance_method(:exception_initialize) + else + super + end + end + + def exception_initialize(*args, &block) + if args.size == 1 and args.first.is_a? Hash + # looks like it's a regular Struct initialize + method(:struct_initialize).call(args.first) + else + # call the Struct initializer first with no args + # this will set our field default values + method(:struct_initialize).call() + # now give it to the exception + self.class.send(:class_variable_get, :'@@__thrift_struct_real_initialize').bind(self).call(*args, &block) if args.size > 0 + # self.class.instance_method(:initialize).bind(self).call(*args, &block) + end + end + + def handle_message(iprot, fid, ftype) + field = struct_fields[fid] + if field and field[:type] == ftype + value = read_field(iprot, field) + instance_variable_set("@#{field[:name]}", value) + else + iprot.skip(ftype) + end + end + + def read_field(iprot, field = {}) + case field[:type] + when Types::STRUCT + value = field[:class].new + value.read(iprot) + when Types::MAP + key_type, val_type, size = iprot.read_map_begin + value = {} + size.times do + k = read_field(iprot, field_info(field[:key])) + v = read_field(iprot, field_info(field[:value])) + value[k] = v + end + iprot.read_map_end + when Types::LIST + e_type, size = iprot.read_list_begin + value = Array.new(size) do |n| + read_field(iprot, field_info(field[:element])) + end + iprot.read_list_end + when Types::SET + e_type, size = iprot.read_set_begin + value = Set.new + size.times do + element = read_field(iprot, field_info(field[:element])) + value << element + end + iprot.read_set_end + else + value = iprot.read_type(field[:type]) + end + value + end + + def write_data(oprot, value, field) + if is_container? field[:type] + write_container(oprot, value, field) + else + oprot.write_type(field[:type], value) + end + end + + def write_container(oprot, value, field = {}) + case field[:type] + when Types::MAP + oprot.write_map_begin(field[:key][:type], field[:value][:type], value.size) + value.each do |k, v| + write_data(oprot, k, field[:key]) + write_data(oprot, v, field[:value]) + end + oprot.write_map_end + when Types::LIST + oprot.write_list_begin(field[:element][:type], value.size) + value.each do |elem| + write_data(oprot, elem, field[:element]) + end + oprot.write_list_end + when Types::SET + oprot.write_set_begin(field[:element][:type], value.size) + value.each do |v,| # the , is to preserve compatibility with the old Hash-style sets + write_data(oprot, v, field[:element]) + end + oprot.write_set_end + else + raise "Not a container type: #{field[:type]}" + end + end + + CONTAINER_TYPES = [] + CONTAINER_TYPES[Types::LIST] = true + CONTAINER_TYPES[Types::MAP] = true + CONTAINER_TYPES[Types::SET] = true + def is_container?(type) + CONTAINER_TYPES[type] + end + + def field_info(field) + { :type => field[:type], + :class => field[:class], + :key => field[:key], + :value => field[:value], + :element => field[:element] } + end + end +end diff --git a/lib/rb/lib/thrift/thrift_native.rb b/lib/rb/lib/thrift/thrift_native.rb new file mode 100644 index 00000000..4d8df61f --- /dev/null +++ b/lib/rb/lib/thrift/thrift_native.rb @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +begin + require "thrift_native" +rescue LoadError + puts "Unable to load thrift_native extension. Defaulting to pure Ruby libraries." +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/base_server_transport.rb b/lib/rb/lib/thrift/transport/base_server_transport.rb new file mode 100644 index 00000000..68c5af07 --- /dev/null +++ b/lib/rb/lib/thrift/transport/base_server_transport.rb @@ -0,0 +1,37 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BaseServerTransport + def listen + raise NotImplementedError + end + + def accept + raise NotImplementedError + end + + def close; nil; end + + def closed? + raise NotImplementedError + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/base_transport.rb b/lib/rb/lib/thrift/transport/base_transport.rb new file mode 100644 index 00000000..08a71dab --- /dev/null +++ b/lib/rb/lib/thrift/transport/base_transport.rb @@ -0,0 +1,70 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class TransportException < Exception + UNKNOWN = 0 + NOT_OPEN = 1 + ALREADY_OPEN = 2 + TIMED_OUT = 3 + END_OF_FILE = 4 + + attr_reader :type + + def initialize(type=UNKNOWN, message=nil) + super(message) + @type = type + end + end + + class BaseTransport + def open?; end + + def open; end + + def close; end + + def read(sz) + raise NotImplementedError + end + + def read_all(size) + buf = '' + + while (buf.length < size) + chunk = read(size - buf.length) + buf << chunk + end + + buf + end + + def write(buf); end + alias_method :<<, :write + + def flush; end + end + + class BaseTransportFactory + def get_transport(trans) + return trans + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/buffered_transport.rb b/lib/rb/lib/thrift/transport/buffered_transport.rb new file mode 100644 index 00000000..8dead4e0 --- /dev/null +++ b/lib/rb/lib/thrift/transport/buffered_transport.rb @@ -0,0 +1,77 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class BufferedTransport < BaseTransport + DEFAULT_BUFFER = 4096 + + def initialize(transport) + @transport = transport + @wbuf = '' + @rbuf = '' + @index = 0 + end + + def open? + return @transport.open? + end + + def open + @transport.open + end + + def close + flush + @transport.close + end + + def read(sz) + @index += sz + ret = @rbuf.slice(@index - sz, sz) || '' + + if ret.length == 0 + @rbuf = @transport.read([sz, DEFAULT_BUFFER].max) + @index = sz + ret = @rbuf.slice(0, sz) || '' + end + + ret + end + + def write(buf) + @wbuf << buf + end + + def flush + if @wbuf != '' + @transport.write(@wbuf) + @wbuf = '' + end + + @transport.flush + end + end + + class BufferedTransportFactory < BaseTransportFactory + def get_transport(transport) + return BufferedTransport.new(transport) + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/framed_transport.rb b/lib/rb/lib/thrift/transport/framed_transport.rb new file mode 100644 index 00000000..558af744 --- /dev/null +++ b/lib/rb/lib/thrift/transport/framed_transport.rb @@ -0,0 +1,90 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class FramedTransport < BaseTransport + def initialize(transport, read=true, write=true) + @transport = transport + @rbuf = '' + @wbuf = '' + @read = read + @write = write + @index = 0 + end + + def open? + @transport.open? + end + + def open + @transport.open + end + + def close + @transport.close + end + + def read(sz) + return @transport.read(sz) unless @read + + return '' if sz <= 0 + + read_frame if @index >= @rbuf.length + + @index += sz + @rbuf.slice(@index - sz, sz) || '' + end + + def write(buf,sz=nil) + return @transport.write(buf) unless @write + + @wbuf << (sz ? buf[0...sz] : buf) + end + + # + # Writes the output buffer to the stream in the format of a 4-byte length + # followed by the actual data. + # + def flush + return @transport.flush unless @write + + out = [@wbuf.length].pack('N') + out << @wbuf + @transport.write(out) + @transport.flush + @wbuf = '' + end + + private + + def read_frame + sz = @transport.read_all(4).unpack('N').first + + @index = 0 + @rbuf = @transport.read_all(sz) + end + end + + class FramedTransportFactory < BaseTransportFactory + def get_transport(transport) + return FramedTransport.new(transport) + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/http_client_transport.rb b/lib/rb/lib/thrift/transport/http_client_transport.rb new file mode 100644 index 00000000..a190a983 --- /dev/null +++ b/lib/rb/lib/thrift/transport/http_client_transport.rb @@ -0,0 +1,45 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'net/http' +require 'net/https' +require 'uri' +require 'stringio' + +module Thrift + class HTTPClientTransport < BaseTransport + def initialize(url) + @url = URI url + @outbuf = "" + end + + def open?; true end + def read(sz); @inbuf.read sz end + def write(buf); @outbuf << buf end + def flush + http = Net::HTTP.new @url.host, @url.port + http.use_ssl = @url.scheme == "https" + headers = { 'Content-Type' => 'application/x-thrift' } + resp, data = http.post(@url.path, @outbuf, headers) + @inbuf = StringIO.new data + @outbuf = "" + end + end +end diff --git a/lib/rb/lib/thrift/transport/io_stream_transport.rb b/lib/rb/lib/thrift/transport/io_stream_transport.rb new file mode 100644 index 00000000..be348aa0 --- /dev/null +++ b/lib/rb/lib/thrift/transport/io_stream_transport.rb @@ -0,0 +1,39 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Very very simple implementation of wrapping two objects, one with a #read +# method and one with a #write method, into a transport for thrift. +# +# Assumes both objects are open, remain open, don't require flushing, etc. +# +module Thrift + class IOStreamTransport < BaseTransport + def initialize(input, output) + @input = input + @output = output + end + + def open?; not @input.closed? or not @output.closed? end + def read(sz); @input.read(sz) end + def write(buf); @output.write(buf) end + def close; @input.close; @output.close end + def to_io; @input end # we're assuming this is used in a IO.select for reading + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/memory_buffer_transport.rb b/lib/rb/lib/thrift/transport/memory_buffer_transport.rb new file mode 100644 index 00000000..33d732d1 --- /dev/null +++ b/lib/rb/lib/thrift/transport/memory_buffer_transport.rb @@ -0,0 +1,93 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +module Thrift + class MemoryBufferTransport < BaseTransport + GARBAGE_BUFFER_SIZE = 4*(2**10) # 4kB + + # If you pass a string to this, you should #dup that string + # unless you want it to be modified by #read and #write + #-- + # this behavior is no longer required. If you wish to change it + # go ahead, just make sure the specs pass + def initialize(buffer = nil) + @buf = buffer || '' + @index = 0 + end + + def open? + return true + end + + def open + end + + def close + end + + def peek + @index < @buf.size + end + + # this method does not use the passed object directly but copies it + def reset_buffer(new_buf = '') + @buf.replace new_buf + @index = 0 + end + + def available + @buf.length - @index + end + + def read(len) + data = @buf.slice(@index, len) + @index += len + @index = @buf.size if @index > @buf.size + if @index >= GARBAGE_BUFFER_SIZE + @buf = @buf.slice(@index..-1) + @index = 0 + end + data + end + + def write(wbuf) + @buf << wbuf + end + + def flush + end + + def inspect_buffer + out = [] + for idx in 0...(@buf.size) + # if idx != 0 + # out << " " + # end + + if idx == @index + out << ">" + end + + out << @buf[idx].to_s(16) + end + out.join(" ") + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/server_socket.rb b/lib/rb/lib/thrift/transport/server_socket.rb new file mode 100644 index 00000000..7feb9ab0 --- /dev/null +++ b/lib/rb/lib/thrift/transport/server_socket.rb @@ -0,0 +1,63 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class ServerSocket < BaseServerTransport + # call-seq: initialize(host = nil, port) + def initialize(host_or_port, port = nil) + if port + @host = host_or_port + @port = port + else + @host = nil + @port = host_or_port + end + @handle = nil + end + + attr_reader :handle + + def listen + @handle = TCPServer.new(@host, @port) + end + + def accept + unless @handle.nil? + sock = @handle.accept + trans = Socket.new + trans.handle = sock + trans + end + end + + def close + @handle.close unless @handle.nil? or @handle.closed? + @handle = nil + end + + def closed? + @handle.nil? or @handle.closed? + end + + alias to_io handle + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/socket.rb b/lib/rb/lib/thrift/transport/socket.rb new file mode 100644 index 00000000..06c937e5 --- /dev/null +++ b/lib/rb/lib/thrift/transport/socket.rb @@ -0,0 +1,136 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class Socket < BaseTransport + def initialize(host='localhost', port=9090, timeout=nil) + @host = host + @port = port + @timeout = timeout + @desc = "#{host}:#{port}" + @handle = nil + end + + attr_accessor :handle, :timeout + + def open + begin + addrinfo = ::Socket::getaddrinfo(@host, @port).first + @handle = ::Socket.new(addrinfo[4], ::Socket::SOCK_STREAM, 0) + sockaddr = ::Socket.sockaddr_in(addrinfo[1], addrinfo[3]) + begin + @handle.connect_nonblock(sockaddr) + rescue Errno::EINPROGRESS + unless IO.select(nil, [ @handle ], nil, @timeout) + raise TransportException.new(TransportException::NOT_OPEN, "Connection timeout to #{@desc}") + end + begin + @handle.connect_nonblock(sockaddr) + rescue Errno::EISCONN + end + end + @handle + rescue StandardError => e + raise TransportException.new(TransportException::NOT_OPEN, "Could not connect to #{@desc}: #{e}") + end + end + + def open? + !@handle.nil? and !@handle.closed? + end + + def write(str) + raise IOError, "closed stream" unless open? + begin + if @timeout.nil? or @timeout == 0 + @handle.write(str) + else + len = 0 + start = Time.now + while Time.now - start < @timeout + rd, wr, = IO.select(nil, [@handle], nil, @timeout) + if wr and not wr.empty? + len += @handle.write_nonblock(str[len..-1]) + break if len >= str.length + end + end + if len < str.length + raise TransportException.new(TransportException::TIMED_OUT, "Socket: Timed out writing #{str.length} bytes to #{@desc}") + else + len + end + end + rescue TransportException => e + # pass this on + raise e + rescue StandardError => e + @handle.close + @handle = nil + raise TransportException.new(TransportException::NOT_OPEN, e.message) + end + end + + def read(sz) + raise IOError, "closed stream" unless open? + + begin + if @timeout.nil? or @timeout == 0 + data = @handle.readpartial(sz) + else + # it's possible to interrupt select for something other than the timeout + # so we need to ensure we've waited long enough + start = Time.now + rd = nil # scoping + loop do + rd, = IO.select([@handle], nil, nil, @timeout) + break if (rd and not rd.empty?) or Time.now - start >= @timeout + end + if rd.nil? or rd.empty? + raise TransportException.new(TransportException::TIMED_OUT, "Socket: Timed out reading #{sz} bytes from #{@desc}") + else + data = @handle.readpartial(sz) + end + end + rescue TransportException => e + # don't let this get caught by the StandardError handler + raise e + rescue StandardError => e + @handle.close unless @handle.closed? + @handle = nil + raise TransportException.new(TransportException::NOT_OPEN, e.message) + end + if (data.nil? or data.length == 0) + raise TransportException.new(TransportException::UNKNOWN, "Socket: Could not read #{sz} bytes from #{@desc}") + end + data + end + + def close + @handle.close unless @handle.nil? or @handle.closed? + @handle = nil + end + + def to_io + @handle + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/unix_server_socket.rb b/lib/rb/lib/thrift/transport/unix_server_socket.rb new file mode 100644 index 00000000..a135d25f --- /dev/null +++ b/lib/rb/lib/thrift/transport/unix_server_socket.rb @@ -0,0 +1,60 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class UNIXServerSocket < BaseServerTransport + def initialize(path) + @path = path + @handle = nil + end + + attr_accessor :handle + + def listen + @handle = ::UNIXServer.new(@path) + end + + def accept + unless @handle.nil? + sock = @handle.accept + trans = UNIXSocket.new(nil) + trans.handle = sock + trans + end + end + + def close + if @handle + @handle.close unless @handle.closed? + @handle = nil + # UNIXServer doesn't delete the socket file, so we have to do it ourselves + File.delete(@path) + end + end + + def closed? + @handle.nil? or @handle.closed? + end + + alias to_io handle + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/transport/unix_socket.rb b/lib/rb/lib/thrift/transport/unix_socket.rb new file mode 100644 index 00000000..8f692e4c --- /dev/null +++ b/lib/rb/lib/thrift/transport/unix_socket.rb @@ -0,0 +1,40 @@ +# encoding: ascii-8bit +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'socket' + +module Thrift + class UNIXSocket < Socket + def initialize(path, timeout=nil) + @path = path + @timeout = timeout + @desc = @path # for read()'s error + @handle = nil + end + + def open + begin + @handle = ::UNIXSocket.new(@path) + rescue StandardError + raise TransportException.new(TransportException::NOT_OPEN, "Could not open UNIX socket at #{@path}") + end + end + end +end \ No newline at end of file diff --git a/lib/rb/lib/thrift/types.rb b/lib/rb/lib/thrift/types.rb new file mode 100644 index 00000000..20e4ca2c --- /dev/null +++ b/lib/rb/lib/thrift/types.rb @@ -0,0 +1,101 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'set' + +module Thrift + module Types + STOP = 0 + VOID = 1 + BOOL = 2 + BYTE = 3 + DOUBLE = 4 + I16 = 6 + I32 = 8 + I64 = 10 + STRING = 11 + STRUCT = 12 + MAP = 13 + SET = 14 + LIST = 15 + end + + class << self + attr_accessor :type_checking + end + + class TypeError < Exception + end + + def self.check_type(value, field, name, skip_nil=true) + return if value.nil? and skip_nil + klasses = case field[:type] + when Types::VOID + NilClass + when Types::BOOL + [TrueClass, FalseClass] + when Types::BYTE, Types::I16, Types::I32, Types::I64 + Integer + when Types::DOUBLE + Float + when Types::STRING + String + when Types::STRUCT + Struct + when Types::MAP + Hash + when Types::SET + Set + when Types::LIST + Array + end + valid = klasses && [*klasses].any? { |klass| klass === value } + raise TypeError, "Expected #{type_name(field[:type])}, received #{value.class} for field #{name}" unless valid + # check elements now + case field[:type] + when Types::MAP + value.each_pair do |k,v| + check_type(k, field[:key], "#{name}.key", false) + check_type(v, field[:value], "#{name}.value", false) + end + when Types::SET, Types::LIST + value.each do |el| + check_type(el, field[:element], "#{name}.element", false) + end + when Types::STRUCT + raise TypeError, "Expected #{field[:class]}, received #{value.class} for field #{name}" unless field[:class] == value.class + end + end + + def self.type_name(type) + Types.constants.each do |const| + return "Types::#{const}" if Types.const_get(const) == type + end + nil + end + + module MessageTypes + CALL = 1 + REPLY = 2 + EXCEPTION = 3 + ONEWAY = 4 + end +end + +Thrift.type_checking = false if Thrift.type_checking.nil? diff --git a/lib/rb/script/proto_benchmark.rb b/lib/rb/script/proto_benchmark.rb new file mode 100644 index 00000000..bb49e2e4 --- /dev/null +++ b/lib/rb/script/proto_benchmark.rb @@ -0,0 +1,121 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + "/../spec/spec_helper.rb" + +require "benchmark" +# require "ruby-prof" + +obj = Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + +HOW_MANY = 1_000 + +binser = Thrift::Serializer.new +bin_data = binser.serialize(obj) +bindeser = Thrift::Deserializer.new +accel_bin_ser = Thrift::Serializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) +accel_bin_deser = Thrift::Deserializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) + +compact_ser = Thrift::Serializer.new(Thrift::CompactProtocolFactory.new) +compact_data = compact_ser.serialize(obj) +compact_deser = Thrift::Deserializer.new(Thrift::CompactProtocolFactory.new) + +Benchmark.bm(60) do |reporter| + reporter.report("binary protocol, write") do + HOW_MANY.times do + binser.serialize(obj) + end + end + + reporter.report("accelerated binary protocol, write") do + HOW_MANY.times do + accel_bin_ser.serialize(obj) + end + end + + reporter.report("compact protocol, write") do + # RubyProf.start + HOW_MANY.times do + compact_ser.serialize(obj) + end + # result = RubyProf.stop + # printer = RubyProf::GraphHtmlPrinter.new(result) + # file = File.open("profile.html", "w+") + # printer.print(file, 0) + # file.close + end + + reporter.report("binary protocol, read") do + HOW_MANY.times do + bindeser.deserialize(obj, bin_data) + end + end + + reporter.report("accelerated binary protocol, read") do + HOW_MANY.times do + accel_bin_deser.deserialize(obj, bin_data) + end + end + + reporter.report("compact protocol, read") do + HOW_MANY.times do + compact_deser.deserialize(obj, compact_data) + end + end + + + # f = File.new("/tmp/testfile", "w") + # proto = Thrift::BinaryProtocolAccelerated.new(Thrift::IOStreamTransport.new(Thrift::MemoryBufferTransport.new, f)) + # reporter.report("accelerated binary protocol, write (to disk)") do + # HOW_MANY.times do + # obj.write(proto) + # end + # f.flush + # end + # f.close + # + # f = File.new("/tmp/testfile", "r") + # proto = Thrift::BinaryProtocolAccelerated.new(Thrift::IOStreamTransport.new(f, Thrift::MemoryBufferTransport.new)) + # reporter.report("accelerated binary protocol, read (from disk)") do + # HOW_MANY.times do + # obj.read(proto) + # end + # end + # f.close + # + # f = File.new("/tmp/testfile", "w") + # reporter.report("compact protocol, write (to disk)") do + # proto = Thrift::CompactProtocol.new(Thrift::IOStreamTransport.new(Thrift::MemoryBufferTransport.new, f)) + # HOW_MANY.times do + # obj.write(proto) + # end + # f.flush + # end + # f.close + # + # f = File.new("/tmp/testfile", "r") + # reporter.report("compact protocol, read (from disk)") do + # proto = Thrift::CompactProtocol.new(Thrift::IOStreamTransport.new(f, Thrift::MemoryBufferTransport.new)) + # HOW_MANY.times do + # obj.read(proto) + # end + # end + # f.close + +end diff --git a/lib/rb/script/read_struct.rb b/lib/rb/script/read_struct.rb new file mode 100644 index 00000000..831fcec9 --- /dev/null +++ b/lib/rb/script/read_struct.rb @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require "spec/spec_helper" + +path, factory_class = ARGV + +factory = eval(factory_class).new + +deser = Thrift::Deserializer.new(factory) + +cpts = CompactProtoTestStruct.new +CompactProtoTestStruct.constants.each do |const| + cpts.instance_variable_set("@#{const}", nil) +end + +data = File.read(path) + +deser.deserialize(cpts, data) + +if cpts == Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + puts "Object verified successfully!" +else + puts "Object failed verification! Expected #{Fixtures::COMPACT_PROTOCOL_TEST_STRUCT.inspect} but got #{cpts.inspect}" + + puts cpts.differences(Fixtures::COMPACT_PROTOCOL_TEST_STRUCT) +end diff --git a/lib/rb/script/write_struct.rb b/lib/rb/script/write_struct.rb new file mode 100644 index 00000000..da142197 --- /dev/null +++ b/lib/rb/script/write_struct.rb @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require "spec/spec_helper" + +path, factory_class = ARGV + +factory = eval(factory_class).new + +ser = Thrift::Serializer.new(factory) + +File.open(path, "w") do |file| + file.write(ser.serialize(Fixtures::COMPACT_PROTOCOL_TEST_STRUCT)) +end \ No newline at end of file diff --git a/lib/rb/setup.rb b/lib/rb/setup.rb new file mode 100644 index 00000000..9f0c8267 --- /dev/null +++ b/lib/rb/setup.rb @@ -0,0 +1,1585 @@ +# +# setup.rb +# +# Copyright (c) 2000-2005 Minero Aoki +# +# This program is free software. +# You can distribute/modify this program under the terms of +# the GNU LGPL, Lesser General Public License version 2.1. +# + +unless Enumerable.method_defined?(:map) # Ruby 1.4.6 + module Enumerable + alias map collect + end +end + +unless File.respond_to?(:read) # Ruby 1.6 + def File.read(fname) + open(fname) {|f| + return f.read + } + end +end + +unless Errno.const_defined?(:ENOTEMPTY) # Windows? + module Errno + class ENOTEMPTY + # We do not raise this exception, implementation is not needed. + end + end +end + +def File.binread(fname) + open(fname, 'rb') {|f| + return f.read + } +end + +# for corrupted Windows' stat(2) +def File.dir?(path) + File.directory?((path[-1,1] == '/') ? path : path + '/') +end + + +class ConfigTable + + include Enumerable + + def initialize(rbconfig) + @rbconfig = rbconfig + @items = [] + @table = {} + # options + @install_prefix = nil + @config_opt = nil + @verbose = true + @no_harm = false + end + + attr_accessor :install_prefix + attr_accessor :config_opt + + attr_writer :verbose + + def verbose? + @verbose + end + + attr_writer :no_harm + + def no_harm? + @no_harm + end + + def [](key) + lookup(key).resolve(self) + end + + def []=(key, val) + lookup(key).set val + end + + def names + @items.map {|i| i.name } + end + + def each(&block) + @items.each(&block) + end + + def key?(name) + @table.key?(name) + end + + def lookup(name) + @table[name] or setup_rb_error "no such config item: #{name}" + end + + def add(item) + @items.push item + @table[item.name] = item + end + + def remove(name) + item = lookup(name) + @items.delete_if {|i| i.name == name } + @table.delete_if {|name, i| i.name == name } + item + end + + def load_script(path, inst = nil) + if File.file?(path) + MetaConfigEnvironment.new(self, inst).instance_eval File.read(path), path + end + end + + def savefile + '.config' + end + + def load_savefile + begin + File.foreach(savefile()) do |line| + k, v = *line.split(/=/, 2) + self[k] = v.strip + end + rescue Errno::ENOENT + setup_rb_error $!.message + "\n#{File.basename($0)} config first" + end + end + + def save + @items.each {|i| i.value } + File.open(savefile(), 'w') {|f| + @items.each do |i| + f.printf "%s=%s\n", i.name, i.value if i.value? and i.value + end + } + end + + def load_standard_entries + standard_entries(@rbconfig).each do |ent| + add ent + end + end + + def standard_entries(rbconfig) + c = rbconfig + + rubypath = File.join(c['bindir'], c['ruby_install_name'] + c['EXEEXT']) + + major = c['MAJOR'].to_i + minor = c['MINOR'].to_i + teeny = c['TEENY'].to_i + version = "#{major}.#{minor}" + + # ruby ver. >= 1.4.4? + newpath_p = ((major >= 2) or + ((major == 1) and + ((minor >= 5) or + ((minor == 4) and (teeny >= 4))))) + + if c['rubylibdir'] + # V > 1.6.3 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = c['rubylibdir'] + librubyverarch = c['archdir'] + siteruby = c['sitedir'] + siterubyver = c['sitelibdir'] + siterubyverarch = c['sitearchdir'] + elsif newpath_p + # 1.4.4 <= V <= 1.6.3 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = "#{c['prefix']}/lib/ruby/#{version}" + librubyverarch = "#{c['prefix']}/lib/ruby/#{version}/#{c['arch']}" + siteruby = c['sitedir'] + siterubyver = "$siteruby/#{version}" + siterubyverarch = "$siterubyver/#{c['arch']}" + else + # V < 1.4.4 + libruby = "#{c['prefix']}/lib/ruby" + librubyver = "#{c['prefix']}/lib/ruby/#{version}" + librubyverarch = "#{c['prefix']}/lib/ruby/#{version}/#{c['arch']}" + siteruby = "#{c['prefix']}/lib/ruby/#{version}/site_ruby" + siterubyver = siteruby + siterubyverarch = "$siterubyver/#{c['arch']}" + end + parameterize = lambda {|path| + path.sub(/\A#{Regexp.quote(c['prefix'])}/, '$prefix') + } + + if arg = c['configure_args'].split.detect {|arg| /--with-make-prog=/ =~ arg } + makeprog = arg.sub(/'/, '').split(/=/, 2)[1] + else + makeprog = 'make' + end + + [ + ExecItem.new('installdirs', 'std/site/home', + 'std: install under libruby; site: install under site_ruby; home: install under $HOME')\ + {|val, table| + case val + when 'std' + table['rbdir'] = '$librubyver' + table['sodir'] = '$librubyverarch' + when 'site' + table['rbdir'] = '$siterubyver' + table['sodir'] = '$siterubyverarch' + when 'home' + setup_rb_error '$HOME was not set' unless ENV['HOME'] + table['prefix'] = ENV['HOME'] + table['rbdir'] = '$libdir/ruby' + table['sodir'] = '$libdir/ruby' + end + }, + PathItem.new('prefix', 'path', c['prefix'], + 'path prefix of target environment'), + PathItem.new('bindir', 'path', parameterize.call(c['bindir']), + 'the directory for commands'), + PathItem.new('libdir', 'path', parameterize.call(c['libdir']), + 'the directory for libraries'), + PathItem.new('datadir', 'path', parameterize.call(c['datadir']), + 'the directory for shared data'), + PathItem.new('mandir', 'path', parameterize.call(c['mandir']), + 'the directory for man pages'), + PathItem.new('sysconfdir', 'path', parameterize.call(c['sysconfdir']), + 'the directory for system configuration files'), + PathItem.new('localstatedir', 'path', parameterize.call(c['localstatedir']), + 'the directory for local state data'), + PathItem.new('libruby', 'path', libruby, + 'the directory for ruby libraries'), + PathItem.new('librubyver', 'path', librubyver, + 'the directory for standard ruby libraries'), + PathItem.new('librubyverarch', 'path', librubyverarch, + 'the directory for standard ruby extensions'), + PathItem.new('siteruby', 'path', siteruby, + 'the directory for version-independent aux ruby libraries'), + PathItem.new('siterubyver', 'path', siterubyver, + 'the directory for aux ruby libraries'), + PathItem.new('siterubyverarch', 'path', siterubyverarch, + 'the directory for aux ruby binaries'), + PathItem.new('rbdir', 'path', '$siterubyver', + 'the directory for ruby scripts'), + PathItem.new('sodir', 'path', '$siterubyverarch', + 'the directory for ruby extentions'), + PathItem.new('rubypath', 'path', rubypath, + 'the path to set to #! line'), + ProgramItem.new('rubyprog', 'name', rubypath, + 'the ruby program using for installation'), + ProgramItem.new('makeprog', 'name', makeprog, + 'the make program to compile ruby extentions'), + SelectItem.new('shebang', 'all/ruby/never', 'ruby', + 'shebang line (#!) editing mode'), + BoolItem.new('without-ext', 'yes/no', 'no', + 'does not compile/install ruby extentions') + ] + end + private :standard_entries + + def load_multipackage_entries + multipackage_entries().each do |ent| + add ent + end + end + + def multipackage_entries + [ + PackageSelectionItem.new('with', 'name,name...', '', 'ALL', + 'package names that you want to install'), + PackageSelectionItem.new('without', 'name,name...', '', 'NONE', + 'package names that you do not want to install') + ] + end + private :multipackage_entries + + ALIASES = { + 'std-ruby' => 'librubyver', + 'stdruby' => 'librubyver', + 'rubylibdir' => 'librubyver', + 'archdir' => 'librubyverarch', + 'site-ruby-common' => 'siteruby', # For backward compatibility + 'site-ruby' => 'siterubyver', # For backward compatibility + 'bin-dir' => 'bindir', + 'bin-dir' => 'bindir', + 'rb-dir' => 'rbdir', + 'so-dir' => 'sodir', + 'data-dir' => 'datadir', + 'ruby-path' => 'rubypath', + 'ruby-prog' => 'rubyprog', + 'ruby' => 'rubyprog', + 'make-prog' => 'makeprog', + 'make' => 'makeprog' + } + + def fixup + ALIASES.each do |ali, name| + @table[ali] = @table[name] + end + @items.freeze + @table.freeze + @options_re = /\A--(#{@table.keys.join('|')})(?:=(.*))?\z/ + end + + def parse_opt(opt) + m = @options_re.match(opt) or setup_rb_error "config: unknown option #{opt}" + m.to_a[1,2] + end + + def dllext + @rbconfig['DLEXT'] + end + + def value_config?(name) + lookup(name).value? + end + + class Item + def initialize(name, template, default, desc) + @name = name.freeze + @template = template + @value = default + @default = default + @description = desc + end + + attr_reader :name + attr_reader :description + + attr_accessor :default + alias help_default default + + def help_opt + "--#{@name}=#{@template}" + end + + def value? + true + end + + def value + @value + end + + def resolve(table) + @value.gsub(%r<\$([^/]+)>) { table[$1] } + end + + def set(val) + @value = check(val) + end + + private + + def check(val) + setup_rb_error "config: --#{name} requires argument" unless val + val + end + end + + class BoolItem < Item + def config_type + 'bool' + end + + def help_opt + "--#{@name}" + end + + private + + def check(val) + return 'yes' unless val + case val + when /\Ay(es)?\z/i, /\At(rue)?\z/i then 'yes' + when /\An(o)?\z/i, /\Af(alse)\z/i then 'no' + else + setup_rb_error "config: --#{@name} accepts only yes/no for argument" + end + end + end + + class PathItem < Item + def config_type + 'path' + end + + private + + def check(path) + setup_rb_error "config: --#{@name} requires argument" unless path + path[0,1] == '$' ? path : File.expand_path(path) + end + end + + class ProgramItem < Item + def config_type + 'program' + end + end + + class SelectItem < Item + def initialize(name, selection, default, desc) + super + @ok = selection.split('/') + end + + def config_type + 'select' + end + + private + + def check(val) + unless @ok.include?(val.strip) + setup_rb_error "config: use --#{@name}=#{@template} (#{val})" + end + val.strip + end + end + + class ExecItem < Item + def initialize(name, selection, desc, &block) + super name, selection, nil, desc + @ok = selection.split('/') + @action = block + end + + def config_type + 'exec' + end + + def value? + false + end + + def resolve(table) + setup_rb_error "$#{name()} wrongly used as option value" + end + + undef set + + def evaluate(val, table) + v = val.strip.downcase + unless @ok.include?(v) + setup_rb_error "invalid option --#{@name}=#{val} (use #{@template})" + end + @action.call v, table + end + end + + class PackageSelectionItem < Item + def initialize(name, template, default, help_default, desc) + super name, template, default, desc + @help_default = help_default + end + + attr_reader :help_default + + def config_type + 'package' + end + + private + + def check(val) + unless File.dir?("packages/#{val}") + setup_rb_error "config: no such package: #{val}" + end + val + end + end + + class MetaConfigEnvironment + def initialize(config, installer) + @config = config + @installer = installer + end + + def config_names + @config.names + end + + def config?(name) + @config.key?(name) + end + + def bool_config?(name) + @config.lookup(name).config_type == 'bool' + end + + def path_config?(name) + @config.lookup(name).config_type == 'path' + end + + def value_config?(name) + @config.lookup(name).config_type != 'exec' + end + + def add_config(item) + @config.add item + end + + def add_bool_config(name, default, desc) + @config.add BoolItem.new(name, 'yes/no', default ? 'yes' : 'no', desc) + end + + def add_path_config(name, default, desc) + @config.add PathItem.new(name, 'path', default, desc) + end + + def set_config_default(name, default) + @config.lookup(name).default = default + end + + def remove_config(name) + @config.remove(name) + end + + # For only multipackage + def packages + raise '[setup.rb fatal] multi-package metaconfig API packages() called for single-package; contact application package vendor' unless @installer + @installer.packages + end + + # For only multipackage + def declare_packages(list) + raise '[setup.rb fatal] multi-package metaconfig API declare_packages() called for single-package; contact application package vendor' unless @installer + @installer.packages = list + end + end + +end # class ConfigTable + + +# This module requires: #verbose?, #no_harm? +module FileOperations + + def mkdir_p(dirname, prefix = nil) + dirname = prefix + File.expand_path(dirname) if prefix + $stderr.puts "mkdir -p #{dirname}" if verbose? + return if no_harm? + + # Does not check '/', it's too abnormal. + dirs = File.expand_path(dirname).split(%r<(?=/)>) + if /\A[a-z]:\z/i =~ dirs[0] + disk = dirs.shift + dirs[0] = disk + dirs[0] + end + dirs.each_index do |idx| + path = dirs[0..idx].join('') + Dir.mkdir path unless File.dir?(path) + end + end + + def rm_f(path) + $stderr.puts "rm -f #{path}" if verbose? + return if no_harm? + force_remove_file path + end + + def rm_rf(path) + $stderr.puts "rm -rf #{path}" if verbose? + return if no_harm? + remove_tree path + end + + def remove_tree(path) + if File.symlink?(path) + remove_file path + elsif File.dir?(path) + remove_tree0 path + else + force_remove_file path + end + end + + def remove_tree0(path) + Dir.foreach(path) do |ent| + next if ent == '.' + next if ent == '..' + entpath = "#{path}/#{ent}" + if File.symlink?(entpath) + remove_file entpath + elsif File.dir?(entpath) + remove_tree0 entpath + else + force_remove_file entpath + end + end + begin + Dir.rmdir path + rescue Errno::ENOTEMPTY + # directory may not be empty + end + end + + def move_file(src, dest) + force_remove_file dest + begin + File.rename src, dest + rescue + File.open(dest, 'wb') {|f| + f.write File.binread(src) + } + File.chmod File.stat(src).mode, dest + File.unlink src + end + end + + def force_remove_file(path) + begin + remove_file path + rescue + end + end + + def remove_file(path) + File.chmod 0777, path + File.unlink path + end + + def install(from, dest, mode, prefix = nil) + $stderr.puts "install #{from} #{dest}" if verbose? + return if no_harm? + + realdest = prefix ? prefix + File.expand_path(dest) : dest + realdest = File.join(realdest, File.basename(from)) if File.dir?(realdest) + str = File.binread(from) + if diff?(str, realdest) + verbose_off { + rm_f realdest if File.exist?(realdest) + } + File.open(realdest, 'wb') {|f| + f.write str + } + File.chmod mode, realdest + + File.open("#{objdir_root()}/InstalledFiles", 'a') {|f| + if prefix + f.puts realdest.sub(prefix, '') + else + f.puts realdest + end + } + end + end + + def diff?(new_content, path) + return true unless File.exist?(path) + new_content != File.binread(path) + end + + def command(*args) + $stderr.puts args.join(' ') if verbose? + system(*args) or raise RuntimeError, + "system(#{args.map{|a| a.inspect }.join(' ')}) failed" + end + + def ruby(*args) + command config('rubyprog'), *args + end + + def make(task = nil) + command(*[config('makeprog'), task].compact) + end + + def extdir?(dir) + File.exist?("#{dir}/MANIFEST") or File.exist?("#{dir}/extconf.rb") + end + + def files_of(dir) + Dir.open(dir) {|d| + return d.select {|ent| File.file?("#{dir}/#{ent}") } + } + end + + DIR_REJECT = %w( . .. CVS SCCS RCS CVS.adm .svn ) + + def directories_of(dir) + Dir.open(dir) {|d| + return d.select {|ent| File.dir?("#{dir}/#{ent}") } - DIR_REJECT + } + end + +end + + +# This module requires: #srcdir_root, #objdir_root, #relpath +module HookScriptAPI + + def get_config(key) + @config[key] + end + + alias config get_config + + # obsolete: use metaconfig to change configuration + def set_config(key, val) + @config[key] = val + end + + # + # srcdir/objdir (works only in the package directory) + # + + def curr_srcdir + "#{srcdir_root()}/#{relpath()}" + end + + def curr_objdir + "#{objdir_root()}/#{relpath()}" + end + + def srcfile(path) + "#{curr_srcdir()}/#{path}" + end + + def srcexist?(path) + File.exist?(srcfile(path)) + end + + def srcdirectory?(path) + File.dir?(srcfile(path)) + end + + def srcfile?(path) + File.file?(srcfile(path)) + end + + def srcentries(path = '.') + Dir.open("#{curr_srcdir()}/#{path}") {|d| + return d.to_a - %w(. ..) + } + end + + def srcfiles(path = '.') + srcentries(path).select {|fname| + File.file?(File.join(curr_srcdir(), path, fname)) + } + end + + def srcdirectories(path = '.') + srcentries(path).select {|fname| + File.dir?(File.join(curr_srcdir(), path, fname)) + } + end + +end + + +class ToplevelInstaller + + Version = '3.4.1' + Copyright = 'Copyright (c) 2000-2005 Minero Aoki' + + TASKS = [ + [ 'all', 'do config, setup, then install' ], + [ 'config', 'saves your configurations' ], + [ 'show', 'shows current configuration' ], + [ 'setup', 'compiles ruby extentions and others' ], + [ 'install', 'installs files' ], + [ 'test', 'run all tests in test/' ], + [ 'clean', "does `make clean' for each extention" ], + [ 'distclean',"does `make distclean' for each extention" ] + ] + + def ToplevelInstaller.invoke + config = ConfigTable.new(load_rbconfig()) + config.load_standard_entries + config.load_multipackage_entries if multipackage? + config.fixup + klass = (multipackage?() ? ToplevelInstallerMulti : ToplevelInstaller) + klass.new(File.dirname($0), config).invoke + end + + def ToplevelInstaller.multipackage? + File.dir?(File.dirname($0) + '/packages') + end + + def ToplevelInstaller.load_rbconfig + if arg = ARGV.detect {|arg| /\A--rbconfig=/ =~ arg } + ARGV.delete(arg) + load File.expand_path(arg.split(/=/, 2)[1]) + $".push 'rbconfig.rb' + else + require 'rbconfig' + end + ::Config::CONFIG + end + + def initialize(ardir_root, config) + @ardir = File.expand_path(ardir_root) + @config = config + # cache + @valid_task_re = nil + end + + def config(key) + @config[key] + end + + def inspect + "#<#{self.class} #{__id__()}>" + end + + def invoke + run_metaconfigs + case task = parsearg_global() + when nil, 'all' + parsearg_config + init_installers + exec_config + exec_setup + exec_install + else + case task + when 'config', 'test' + ; + when 'clean', 'distclean' + @config.load_savefile if File.exist?(@config.savefile) + else + @config.load_savefile + end + __send__ "parsearg_#{task}" + init_installers + __send__ "exec_#{task}" + end + end + + def run_metaconfigs + @config.load_script "#{@ardir}/metaconfig" + end + + def init_installers + @installer = Installer.new(@config, @ardir, File.expand_path('.')) + end + + # + # Hook Script API bases + # + + def srcdir_root + @ardir + end + + def objdir_root + '.' + end + + def relpath + '.' + end + + # + # Option Parsing + # + + def parsearg_global + while arg = ARGV.shift + case arg + when /\A\w+\z/ + setup_rb_error "invalid task: #{arg}" unless valid_task?(arg) + return arg + when '-q', '--quiet' + @config.verbose = false + when '--verbose' + @config.verbose = true + when '--help' + print_usage $stdout + exit 0 + when '--version' + puts "#{File.basename($0)} version #{Version}" + exit 0 + when '--copyright' + puts Copyright + exit 0 + else + setup_rb_error "unknown global option '#{arg}'" + end + end + nil + end + + def valid_task?(t) + valid_task_re() =~ t + end + + def valid_task_re + @valid_task_re ||= /\A(?:#{TASKS.map {|task,desc| task }.join('|')})\z/ + end + + def parsearg_no_options + unless ARGV.empty? + task = caller(0).first.slice(%r<`parsearg_(\w+)'>, 1) + setup_rb_error "#{task}: unknown options: #{ARGV.join(' ')}" + end + end + + alias parsearg_show parsearg_no_options + alias parsearg_setup parsearg_no_options + alias parsearg_test parsearg_no_options + alias parsearg_clean parsearg_no_options + alias parsearg_distclean parsearg_no_options + + def parsearg_config + evalopt = [] + set = [] + @config.config_opt = [] + while i = ARGV.shift + if /\A--?\z/ =~ i + @config.config_opt = ARGV.dup + break + end + name, value = *@config.parse_opt(i) + if @config.value_config?(name) + @config[name] = value + else + evalopt.push [name, value] + end + set.push name + end + evalopt.each do |name, value| + @config.lookup(name).evaluate value, @config + end + # Check if configuration is valid + set.each do |n| + @config[n] if @config.value_config?(n) + end + end + + def parsearg_install + @config.no_harm = false + @config.install_prefix = '' + while a = ARGV.shift + case a + when '--no-harm' + @config.no_harm = true + when /\A--prefix=/ + path = a.split(/=/, 2)[1] + path = File.expand_path(path) unless path[0,1] == '/' + @config.install_prefix = path + else + setup_rb_error "install: unknown option #{a}" + end + end + end + + def print_usage(out) + out.puts 'Typical Installation Procedure:' + out.puts " $ ruby #{File.basename $0} config" + out.puts " $ ruby #{File.basename $0} setup" + out.puts " # ruby #{File.basename $0} install (may require root privilege)" + out.puts + out.puts 'Detailed Usage:' + out.puts " ruby #{File.basename $0} " + out.puts " ruby #{File.basename $0} [] []" + + fmt = " %-24s %s\n" + out.puts + out.puts 'Global options:' + out.printf fmt, '-q,--quiet', 'suppress message outputs' + out.printf fmt, ' --verbose', 'output messages verbosely' + out.printf fmt, ' --help', 'print this message' + out.printf fmt, ' --version', 'print version and quit' + out.printf fmt, ' --copyright', 'print copyright and quit' + out.puts + out.puts 'Tasks:' + TASKS.each do |name, desc| + out.printf fmt, name, desc + end + + fmt = " %-24s %s [%s]\n" + out.puts + out.puts 'Options for CONFIG or ALL:' + @config.each do |item| + out.printf fmt, item.help_opt, item.description, item.help_default + end + out.printf fmt, '--rbconfig=path', 'rbconfig.rb to load',"running ruby's" + out.puts + out.puts 'Options for INSTALL:' + out.printf fmt, '--no-harm', 'only display what to do if given', 'off' + out.printf fmt, '--prefix=path', 'install path prefix', '' + out.puts + end + + # + # Task Handlers + # + + def exec_config + @installer.exec_config + @config.save # must be final + end + + def exec_setup + @installer.exec_setup + end + + def exec_install + @installer.exec_install + end + + def exec_test + @installer.exec_test + end + + def exec_show + @config.each do |i| + printf "%-20s %s\n", i.name, i.value if i.value? + end + end + + def exec_clean + @installer.exec_clean + end + + def exec_distclean + @installer.exec_distclean + end + +end # class ToplevelInstaller + + +class ToplevelInstallerMulti < ToplevelInstaller + + include FileOperations + + def initialize(ardir_root, config) + super + @packages = directories_of("#{@ardir}/packages") + raise 'no package exists' if @packages.empty? + @root_installer = Installer.new(@config, @ardir, File.expand_path('.')) + end + + def run_metaconfigs + @config.load_script "#{@ardir}/metaconfig", self + @packages.each do |name| + @config.load_script "#{@ardir}/packages/#{name}/metaconfig" + end + end + + attr_reader :packages + + def packages=(list) + raise 'package list is empty' if list.empty? + list.each do |name| + raise "directory packages/#{name} does not exist"\ + unless File.dir?("#{@ardir}/packages/#{name}") + end + @packages = list + end + + def init_installers + @installers = {} + @packages.each do |pack| + @installers[pack] = Installer.new(@config, + "#{@ardir}/packages/#{pack}", + "packages/#{pack}") + end + with = extract_selection(config('with')) + without = extract_selection(config('without')) + @selected = @installers.keys.select {|name| + (with.empty? or with.include?(name)) \ + and not without.include?(name) + } + end + + def extract_selection(list) + a = list.split(/,/) + a.each do |name| + setup_rb_error "no such package: #{name}" unless @installers.key?(name) + end + a + end + + def print_usage(f) + super + f.puts 'Inluded packages:' + f.puts ' ' + @packages.sort.join(' ') + f.puts + end + + # + # Task Handlers + # + + def exec_config + run_hook 'pre-config' + each_selected_installers {|inst| inst.exec_config } + run_hook 'post-config' + @config.save # must be final + end + + def exec_setup + run_hook 'pre-setup' + each_selected_installers {|inst| inst.exec_setup } + run_hook 'post-setup' + end + + def exec_install + run_hook 'pre-install' + each_selected_installers {|inst| inst.exec_install } + run_hook 'post-install' + end + + def exec_test + run_hook 'pre-test' + each_selected_installers {|inst| inst.exec_test } + run_hook 'post-test' + end + + def exec_clean + rm_f @config.savefile + run_hook 'pre-clean' + each_selected_installers {|inst| inst.exec_clean } + run_hook 'post-clean' + end + + def exec_distclean + rm_f @config.savefile + run_hook 'pre-distclean' + each_selected_installers {|inst| inst.exec_distclean } + run_hook 'post-distclean' + end + + # + # lib + # + + def each_selected_installers + Dir.mkdir 'packages' unless File.dir?('packages') + @selected.each do |pack| + $stderr.puts "Processing the package `#{pack}' ..." if verbose? + Dir.mkdir "packages/#{pack}" unless File.dir?("packages/#{pack}") + Dir.chdir "packages/#{pack}" + yield @installers[pack] + Dir.chdir '../..' + end + end + + def run_hook(id) + @root_installer.run_hook id + end + + # module FileOperations requires this + def verbose? + @config.verbose? + end + + # module FileOperations requires this + def no_harm? + @config.no_harm? + end + +end # class ToplevelInstallerMulti + + +class Installer + + FILETYPES = %w( bin lib ext data conf man ) + + include FileOperations + include HookScriptAPI + + def initialize(config, srcroot, objroot) + @config = config + @srcdir = File.expand_path(srcroot) + @objdir = File.expand_path(objroot) + @currdir = '.' + end + + def inspect + "#<#{self.class} #{File.basename(@srcdir)}>" + end + + def noop(rel) + end + + # + # Hook Script API base methods + # + + def srcdir_root + @srcdir + end + + def objdir_root + @objdir + end + + def relpath + @currdir + end + + # + # Config Access + # + + # module FileOperations requires this + def verbose? + @config.verbose? + end + + # module FileOperations requires this + def no_harm? + @config.no_harm? + end + + def verbose_off + begin + save, @config.verbose = @config.verbose?, false + yield + ensure + @config.verbose = save + end + end + + # + # TASK config + # + + def exec_config + exec_task_traverse 'config' + end + + alias config_dir_bin noop + alias config_dir_lib noop + + def config_dir_ext(rel) + extconf if extdir?(curr_srcdir()) + end + + alias config_dir_data noop + alias config_dir_conf noop + alias config_dir_man noop + + def extconf + ruby "#{curr_srcdir()}/extconf.rb", *@config.config_opt + end + + # + # TASK setup + # + + def exec_setup + exec_task_traverse 'setup' + end + + def setup_dir_bin(rel) + files_of(curr_srcdir()).each do |fname| + update_shebang_line "#{curr_srcdir()}/#{fname}" + end + end + + alias setup_dir_lib noop + + def setup_dir_ext(rel) + make if extdir?(curr_srcdir()) + end + + alias setup_dir_data noop + alias setup_dir_conf noop + alias setup_dir_man noop + + def update_shebang_line(path) + return if no_harm? + return if config('shebang') == 'never' + old = Shebang.load(path) + if old + $stderr.puts "warning: #{path}: Shebang line includes too many args. It is not portable and your program may not work." if old.args.size > 1 + new = new_shebang(old) + return if new.to_s == old.to_s + else + return unless config('shebang') == 'all' + new = Shebang.new(config('rubypath')) + end + $stderr.puts "updating shebang: #{File.basename(path)}" if verbose? + open_atomic_writer(path) {|output| + File.open(path, 'rb') {|f| + f.gets if old # discard + output.puts new.to_s + output.print f.read + } + } + end + + def new_shebang(old) + if /\Aruby/ =~ File.basename(old.cmd) + Shebang.new(config('rubypath'), old.args) + elsif File.basename(old.cmd) == 'env' and old.args.first == 'ruby' + Shebang.new(config('rubypath'), old.args[1..-1]) + else + return old unless config('shebang') == 'all' + Shebang.new(config('rubypath')) + end + end + + def open_atomic_writer(path, &block) + tmpfile = File.basename(path) + '.tmp' + begin + File.open(tmpfile, 'wb', &block) + File.rename tmpfile, File.basename(path) + ensure + File.unlink tmpfile if File.exist?(tmpfile) + end + end + + class Shebang + def Shebang.load(path) + line = nil + File.open(path) {|f| + line = f.gets + } + return nil unless /\A#!/ =~ line + parse(line) + end + + def Shebang.parse(line) + cmd, *args = *line.strip.sub(/\A\#!/, '').split(' ') + new(cmd, args) + end + + def initialize(cmd, args = []) + @cmd = cmd + @args = args + end + + attr_reader :cmd + attr_reader :args + + def to_s + "#! #{@cmd}" + (@args.empty? ? '' : " #{@args.join(' ')}") + end + end + + # + # TASK install + # + + def exec_install + rm_f 'InstalledFiles' + exec_task_traverse 'install' + end + + def install_dir_bin(rel) + install_files targetfiles(), "#{config('bindir')}/#{rel}", 0755 + end + + def install_dir_lib(rel) + install_files libfiles(), "#{config('rbdir')}/#{rel}", 0644 + end + + def install_dir_ext(rel) + return unless extdir?(curr_srcdir()) + install_files rubyextentions('.'), + "#{config('sodir')}/#{File.dirname(rel)}", + 0555 + end + + def install_dir_data(rel) + install_files targetfiles(), "#{config('datadir')}/#{rel}", 0644 + end + + def install_dir_conf(rel) + # FIXME: should not remove current config files + # (rename previous file to .old/.org) + install_files targetfiles(), "#{config('sysconfdir')}/#{rel}", 0644 + end + + def install_dir_man(rel) + install_files targetfiles(), "#{config('mandir')}/#{rel}", 0644 + end + + def install_files(list, dest, mode) + mkdir_p dest, @config.install_prefix + list.each do |fname| + install fname, dest, mode, @config.install_prefix + end + end + + def libfiles + glob_reject(%w(*.y *.output), targetfiles()) + end + + def rubyextentions(dir) + ents = glob_select("*.#{@config.dllext}", targetfiles()) + if ents.empty? + setup_rb_error "no ruby extention exists: 'ruby #{$0} setup' first" + end + ents + end + + def targetfiles + mapdir(existfiles() - hookfiles()) + end + + def mapdir(ents) + ents.map {|ent| + if File.exist?(ent) + then ent # objdir + else "#{curr_srcdir()}/#{ent}" # srcdir + end + } + end + + # picked up many entries from cvs-1.11.1/src/ignore.c + JUNK_FILES = %w( + core RCSLOG tags TAGS .make.state + .nse_depinfo #* .#* cvslog.* ,* .del-* *.olb + *~ *.old *.bak *.BAK *.orig *.rej _$* *$ + + *.org *.in .* + ) + + def existfiles + glob_reject(JUNK_FILES, (files_of(curr_srcdir()) | files_of('.'))) + end + + def hookfiles + %w( pre-%s post-%s pre-%s.rb post-%s.rb ).map {|fmt| + %w( config setup install clean ).map {|t| sprintf(fmt, t) } + }.flatten + end + + def glob_select(pat, ents) + re = globs2re([pat]) + ents.select {|ent| re =~ ent } + end + + def glob_reject(pats, ents) + re = globs2re(pats) + ents.reject {|ent| re =~ ent } + end + + GLOB2REGEX = { + '.' => '\.', + '$' => '\$', + '#' => '\#', + '*' => '.*' + } + + def globs2re(pats) + /\A(?:#{ + pats.map {|pat| pat.gsub(/[\.\$\#\*]/) {|ch| GLOB2REGEX[ch] } }.join('|') + })\z/ + end + + # + # TASK test + # + + TESTDIR = 'test' + + def exec_test + unless File.directory?('test') + $stderr.puts 'no test in this package' if verbose? + return + end + $stderr.puts 'Running tests...' if verbose? + begin + require 'test/unit' + rescue LoadError + setup_rb_error 'test/unit cannot loaded. You need Ruby 1.8 or later to invoke this task.' + end + runner = Test::Unit::AutoRunner.new(true) + runner.to_run << TESTDIR + runner.run + end + + # + # TASK clean + # + + def exec_clean + exec_task_traverse 'clean' + rm_f @config.savefile + rm_f 'InstalledFiles' + end + + alias clean_dir_bin noop + alias clean_dir_lib noop + alias clean_dir_data noop + alias clean_dir_conf noop + alias clean_dir_man noop + + def clean_dir_ext(rel) + return unless extdir?(curr_srcdir()) + make 'clean' if File.file?('Makefile') + end + + # + # TASK distclean + # + + def exec_distclean + exec_task_traverse 'distclean' + rm_f @config.savefile + rm_f 'InstalledFiles' + end + + alias distclean_dir_bin noop + alias distclean_dir_lib noop + + def distclean_dir_ext(rel) + return unless extdir?(curr_srcdir()) + make 'distclean' if File.file?('Makefile') + end + + alias distclean_dir_data noop + alias distclean_dir_conf noop + alias distclean_dir_man noop + + # + # Traversing + # + + def exec_task_traverse(task) + run_hook "pre-#{task}" + FILETYPES.each do |type| + if type == 'ext' and config('without-ext') == 'yes' + $stderr.puts 'skipping ext/* by user option' if verbose? + next + end + traverse task, type, "#{task}_dir_#{type}" + end + run_hook "post-#{task}" + end + + def traverse(task, rel, mid) + dive_into(rel) { + run_hook "pre-#{task}" + __send__ mid, rel.sub(%r[\A.*?(?:/|\z)], '') + directories_of(curr_srcdir()).each do |d| + traverse task, "#{rel}/#{d}", mid + end + run_hook "post-#{task}" + } + end + + def dive_into(rel) + return unless File.dir?("#{@srcdir}/#{rel}") + + dir = File.basename(rel) + Dir.mkdir dir unless File.dir?(dir) + prevdir = Dir.pwd + Dir.chdir dir + $stderr.puts '---> ' + rel if verbose? + @currdir = rel + yield + Dir.chdir prevdir + $stderr.puts '<--- ' + rel if verbose? + @currdir = File.dirname(rel) + end + + def run_hook(id) + path = [ "#{curr_srcdir()}/#{id}", + "#{curr_srcdir()}/#{id}.rb" ].detect {|cand| File.file?(cand) } + return unless path + begin + instance_eval File.read(path), path, 1 + rescue + raise if $DEBUG + setup_rb_error "hook #{path} failed:\n" + $!.message + end + end + +end # class Installer + + +class SetupError < StandardError; end + +def setup_rb_error(msg) + raise SetupError, msg +end + +if $0 == __FILE__ + begin + ToplevelInstaller.invoke + rescue SetupError + raise if $DEBUG + $stderr.puts $!.message + $stderr.puts "Try 'ruby #{$0} --help' for detailed usage." + exit 1 + end +end diff --git a/lib/rb/spec/ThriftSpec.thrift b/lib/rb/spec/ThriftSpec.thrift new file mode 100644 index 00000000..fe5a8aae --- /dev/null +++ b/lib/rb/spec/ThriftSpec.thrift @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +namespace rb SpecNamespace + +struct Hello { + 1: string greeting = "hello world" +} + +struct Foo { + 1: i32 simple = 53, + 2: string words = "words", + 3: Hello hello = {'greeting' : "hello, world!"}, + 4: list ints = [1, 2, 2, 3], + 5: map> complex, + 6: set shorts = [5, 17, 239], + 7: optional string opt_string +} + +struct BoolStruct { + 1: bool yesno = 1 +} + +struct SimpleList { + 1: list bools, + 2: list bytes, + 3: list i16s, + 4: list i32s, + 5: list i64s, + 6: list doubles, + 7: list strings, + 8: list> maps, + 9: list> lists, + 10: list> sets, + 11: list hellos +} + +exception Xception { + 1: string message, + 2: i32 code = 1 +} + +service NonblockingService { + Hello greeting(1:bool english) + bool block() + oneway void unblock(1:i32 n) + oneway void shutdown() + void sleep(1:double seconds) +} diff --git a/lib/rb/spec/base_protocol_spec.rb b/lib/rb/spec/base_protocol_spec.rb new file mode 100644 index 00000000..efb16d8c --- /dev/null +++ b/lib/rb/spec/base_protocol_spec.rb @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftBaseProtocolSpec < Spec::ExampleGroup + include Thrift + + before(:each) do + @trans = mock("MockTransport") + @prot = BaseProtocol.new(@trans) + end + + describe BaseProtocol do + # most of the methods are stubs, so we can ignore them + + it "should make trans accessible" do + @prot.trans.should eql(@trans) + end + + it "should write out a field nicely" do + @prot.should_receive(:write_field_begin).with('field', 'type', 'fid').ordered + @prot.should_receive(:write_type).with('type', 'value').ordered + @prot.should_receive(:write_field_end).ordered + @prot.write_field('field', 'type', 'fid', 'value') + end + + it "should write out the different types" do + @prot.should_receive(:write_bool).with('bool').ordered + @prot.should_receive(:write_byte).with('byte').ordered + @prot.should_receive(:write_double).with('double').ordered + @prot.should_receive(:write_i16).with('i16').ordered + @prot.should_receive(:write_i32).with('i32').ordered + @prot.should_receive(:write_i64).with('i64').ordered + @prot.should_receive(:write_string).with('string').ordered + struct = mock('Struct') + struct.should_receive(:write).with(@prot).ordered + @prot.write_type(Types::BOOL, 'bool') + @prot.write_type(Types::BYTE, 'byte') + @prot.write_type(Types::DOUBLE, 'double') + @prot.write_type(Types::I16, 'i16') + @prot.write_type(Types::I32, 'i32') + @prot.write_type(Types::I64, 'i64') + @prot.write_type(Types::STRING, 'string') + @prot.write_type(Types::STRUCT, struct) + # all other types are not implemented + [Types::STOP, Types::VOID, Types::MAP, Types::SET, Types::LIST].each do |type| + lambda { @prot.write_type(type, type.to_s) }.should raise_error(NotImplementedError) + end + end + + it "should read the different types" do + @prot.should_receive(:read_bool).ordered + @prot.should_receive(:read_byte).ordered + @prot.should_receive(:read_i16).ordered + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_i64).ordered + @prot.should_receive(:read_double).ordered + @prot.should_receive(:read_string).ordered + @prot.read_type(Types::BOOL) + @prot.read_type(Types::BYTE) + @prot.read_type(Types::I16) + @prot.read_type(Types::I32) + @prot.read_type(Types::I64) + @prot.read_type(Types::DOUBLE) + @prot.read_type(Types::STRING) + # all other types are not implemented + [Types::STOP, Types::VOID, Types::MAP, Types::SET, Types::LIST].each do |type| + lambda { @prot.read_type(type) }.should raise_error(NotImplementedError) + end + end + + it "should skip the basic types" do + @prot.should_receive(:read_bool).ordered + @prot.should_receive(:read_byte).ordered + @prot.should_receive(:read_i16).ordered + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_i64).ordered + @prot.should_receive(:read_double).ordered + @prot.should_receive(:read_string).ordered + @prot.skip(Types::BOOL) + @prot.skip(Types::BYTE) + @prot.skip(Types::I16) + @prot.skip(Types::I32) + @prot.skip(Types::I64) + @prot.skip(Types::DOUBLE) + @prot.skip(Types::STRING) + @prot.skip(Types::STOP) # should do absolutely nothing + end + + it "should skip structs" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_struct_begin).ordered + @prot.should_receive(:read_field_begin).exactly(4).times.and_return( + ['field 1', Types::STRING, 1], + ['field 2', Types::I32, 2], + ['field 3', Types::MAP, 3], + [nil, Types::STOP, 0] + ) + @prot.should_receive(:read_field_end).exactly(3).times + @prot.should_receive(:read_string).exactly(3).times + @prot.should_receive(:read_i32).ordered + @prot.should_receive(:read_map_begin).ordered.and_return([Types::STRING, Types::STRING, 1]) + # @prot.should_receive(:read_string).exactly(2).times + @prot.should_receive(:read_map_end).ordered + @prot.should_receive(:read_struct_end).ordered + real_skip.call(Types::STRUCT) + end + + it "should skip maps" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_map_begin).ordered.and_return([Types::STRING, Types::STRUCT, 1]) + @prot.should_receive(:read_string).ordered + @prot.should_receive(:read_struct_begin).ordered.and_return(["some_struct"]) + @prot.should_receive(:read_field_begin).ordered.and_return([nil, Types::STOP, nil]); + @prot.should_receive(:read_struct_end).ordered + @prot.should_receive(:read_map_end).ordered + real_skip.call(Types::MAP) + end + + it "should skip sets" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_set_begin).ordered.and_return([Types::I64, 9]) + @prot.should_receive(:read_i64).ordered.exactly(9).times + @prot.should_receive(:read_set_end) + real_skip.call(Types::SET) + end + + it "should skip lists" do + real_skip = @prot.method(:skip) + @prot.should_receive(:read_list_begin).ordered.and_return([Types::DOUBLE, 11]) + @prot.should_receive(:read_double).ordered.exactly(11).times + @prot.should_receive(:read_list_end) + real_skip.call(Types::LIST) + end + end + + describe BaseProtocolFactory do + it "should raise NotImplementedError" do + # returning nil since Protocol is just an abstract class + lambda {BaseProtocolFactory.new.get_protocol(mock("MockTransport"))}.should raise_error(NotImplementedError) + end + end +end diff --git a/lib/rb/spec/base_transport_spec.rb b/lib/rb/spec/base_transport_spec.rb new file mode 100644 index 00000000..71897759 --- /dev/null +++ b/lib/rb/spec/base_transport_spec.rb @@ -0,0 +1,344 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftBaseTransportSpec < Spec::ExampleGroup + include Thrift + + describe TransportException do + it "should make type accessible" do + exc = TransportException.new(TransportException::ALREADY_OPEN, "msg") + exc.type.should == TransportException::ALREADY_OPEN + exc.message.should == "msg" + end + end + + describe BaseTransport do + it "should read the specified size" do + transport = BaseTransport.new + transport.should_receive(:read).with(40).ordered.and_return("10 letters") + transport.should_receive(:read).with(30).ordered.and_return("fifteen letters") + transport.should_receive(:read).with(15).ordered.and_return("more characters") + transport.read_all(40).should == "10 lettersfifteen lettersmore characters" + end + + it "should stub out the rest of the methods" do + # can't test for stubbiness, so just make sure they're defined + [:open?, :open, :close, :read, :write, :flush].each do |sym| + BaseTransport.method_defined?(sym).should be_true + end + end + + it "should alias << to write" do + BaseTransport.instance_method(:<<).should == BaseTransport.instance_method(:write) + end + end + + describe BaseServerTransport do + it "should stub out its methods" do + [:listen, :accept, :close].each do |sym| + BaseServerTransport.method_defined?(sym).should be_true + end + end + end + + describe BaseTransportFactory do + it "should return the transport it's given" do + transport = mock("Transport") + BaseTransportFactory.new.get_transport(transport).should eql(transport) + end + end + + describe BufferedTransport do + it "should pass through everything but write/flush/read" do + trans = mock("Transport") + trans.should_receive(:open?).ordered.and_return("+ open?") + trans.should_receive(:open).ordered.and_return("+ open") + trans.should_receive(:flush).ordered # from the close + trans.should_receive(:close).ordered.and_return("+ close") + btrans = BufferedTransport.new(trans) + btrans.open?.should == "+ open?" + btrans.open.should == "+ open" + btrans.close.should == "+ close" + end + + it "should buffer reads in chunks of #{BufferedTransport::DEFAULT_BUFFER}" do + trans = mock("Transport") + trans.should_receive(:read).with(BufferedTransport::DEFAULT_BUFFER).and_return("lorum ipsum dolor emet") + btrans = BufferedTransport.new(trans) + btrans.read(6).should == "lorum " + btrans.read(6).should == "ipsum " + btrans.read(6).should == "dolor " + btrans.read(6).should == "emet" + end + + it "should buffer writes and send them on flush" do + trans = mock("Transport") + btrans = BufferedTransport.new(trans) + btrans.write("one/") + btrans.write("two/") + btrans.write("three/") + trans.should_receive(:write).with("one/two/three/").ordered + trans.should_receive(:flush).ordered + btrans.flush + end + + it "should only send buffered data once" do + trans = mock("Transport") + btrans = BufferedTransport.new(trans) + btrans.write("one/") + btrans.write("two/") + btrans.write("three/") + trans.should_receive(:write).with("one/two/three/") + trans.stub!(:flush) + btrans.flush + # Nothing to flush with no data + btrans.flush + end + + it "should flush on close" do + trans = mock("Transport") + trans.should_receive(:close) + btrans = BufferedTransport.new(trans) + btrans.should_receive(:flush) + btrans.close + end + + it "should not write to socket if there's no data" do + trans = mock("Transport") + trans.should_receive(:flush) + btrans = BufferedTransport.new(trans) + btrans.flush + end + end + + describe BufferedTransportFactory do + it "should wrap the given transport in a BufferedTransport" do + trans = mock("Transport") + btrans = mock("BufferedTransport") + BufferedTransport.should_receive(:new).with(trans).and_return(btrans) + BufferedTransportFactory.new.get_transport(trans).should == btrans + end + end + + describe FramedTransport do + before(:each) do + @trans = mock("Transport") + end + + it "should pass through open?/open/close" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:open?).ordered.and_return("+ open?") + @trans.should_receive(:open).ordered.and_return("+ open") + @trans.should_receive(:close).ordered.and_return("+ close") + ftrans.open?.should == "+ open?" + ftrans.open.should == "+ open" + ftrans.close.should == "+ close" + end + + it "should pass through read when read is turned off" do + ftrans = FramedTransport.new(@trans, false, true) + @trans.should_receive(:read).with(17).ordered.and_return("+ read") + ftrans.read(17).should == "+ read" + end + + it "should pass through write/flush when write is turned off" do + ftrans = FramedTransport.new(@trans, true, false) + @trans.should_receive(:write).with("foo").ordered.and_return("+ write") + @trans.should_receive(:flush).ordered.and_return("+ flush") + ftrans.write("foo").should == "+ write" + ftrans.flush.should == "+ flush" + end + + it "should return a full frame if asked for >= the frame's length" do + frame = "this is a frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + FramedTransport.new(@trans).read(frame.length + 10).should == frame + end + + it "should return slices of the frame when asked for < the frame's length" do + frame = "this is a frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + ftrans = FramedTransport.new(@trans) + ftrans.read(4).should == "this" + ftrans.read(4).should == " is " + ftrans.read(16).should == "a frame" + end + + it "should return nothing if asked for <= 0" do + FramedTransport.new(@trans).read(-2).should == "" + end + + it "should pull a new frame when the first is exhausted" do + frame = "this is a frame" + frame2 = "yet another frame" + @trans.should_receive(:read_all).with(4).and_return("\000\000\000\017", "\000\000\000\021") + @trans.should_receive(:read_all).with(frame.length).and_return(frame) + @trans.should_receive(:read_all).with(frame2.length).and_return(frame2) + ftrans = FramedTransport.new(@trans) + ftrans.read(4).should == "this" + ftrans.read(8).should == " is a fr" + ftrans.read(6).should == "ame" + ftrans.read(4).should == "yet " + ftrans.read(16).should == "another frame" + end + + it "should buffer writes" do + ftrans = FramedTransport.new(@trans) + @trans.should_not_receive(:write) + ftrans.write("foo") + ftrans.write("bar") + ftrans.write("this is a frame") + end + + it "should write slices of the buffer" do + ftrans = FramedTransport.new(@trans) + ftrans.write("foobar", 3) + ftrans.write("barfoo", 1) + @trans.stub!(:flush) + @trans.should_receive(:write).with("\000\000\000\004foob") + ftrans.flush + end + + it "should flush frames with a 4-byte header" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:write).with("\000\000\000\035one/two/three/this is a frame").ordered + @trans.should_receive(:flush).ordered + ftrans.write("one/") + ftrans.write("two/") + ftrans.write("three/") + ftrans.write("this is a frame") + ftrans.flush + end + + it "should not flush the same buffered data twice" do + ftrans = FramedTransport.new(@trans) + @trans.should_receive(:write).with("\000\000\000\007foo/bar") + @trans.stub!(:flush) + ftrans.write("foo") + ftrans.write("/bar") + ftrans.flush + @trans.should_receive(:write).with("\000\000\000\000") + ftrans.flush + end + end + + describe FramedTransportFactory do + it "should wrap the given transport in a FramedTransport" do + trans = mock("Transport") + FramedTransport.should_receive(:new).with(trans) + FramedTransportFactory.new.get_transport(trans) + end + end + + describe MemoryBufferTransport do + before(:each) do + @buffer = MemoryBufferTransport.new + end + + it "should accept a buffer on input and use it directly" do + s = "this is a test" + @buffer = MemoryBufferTransport.new(s) + @buffer.read(4).should == "this" + s.slice!(-4..-1) + @buffer.read(@buffer.available).should == " is a " + end + + it "should always remain open" do + @buffer.should be_open + @buffer.close + @buffer.should be_open + end + + it "should respond to peek and available" do + @buffer.write "some data" + @buffer.peek.should be_true + @buffer.available.should == 9 + @buffer.read(4) + @buffer.peek.should be_true + @buffer.available.should == 5 + @buffer.read(16) + @buffer.peek.should be_false + @buffer.available.should == 0 + end + + it "should be able to reset the buffer" do + @buffer.write "test data" + @buffer.reset_buffer("foobar") + @buffer.available.should == 6 + @buffer.read(10).should == "foobar" + @buffer.reset_buffer + @buffer.available.should == 0 + end + + it "should copy the given string whne resetting the buffer" do + s = "this is a test" + @buffer.reset_buffer(s) + @buffer.available.should == 14 + @buffer.read(10) + @buffer.available.should == 4 + s.should == "this is a test" + end + + it "should return from read what was given in write" do + @buffer.write "test data" + @buffer.read(4).should == "test" + @buffer.read(10).should == " data" + @buffer.read(10).should == "" + @buffer.write "foo" + @buffer.write " bar" + @buffer.read(10).should == "foo bar" + end + end + + describe IOStreamTransport do + before(:each) do + @input = mock("Input", :closed? => false) + @output = mock("Output", :closed? => false) + @trans = IOStreamTransport.new(@input, @output) + end + + it "should be open as long as both input or output are open" do + @trans.should be_open + @input.stub!(:closed?).and_return(true) + @trans.should be_open + @input.stub!(:closed?).and_return(false) + @output.stub!(:closed?).and_return(true) + @trans.should be_open + @input.stub!(:closed?).and_return(true) + @trans.should_not be_open + end + + it "should pass through read/write to input/output" do + @input.should_receive(:read).with(17).and_return("+ read") + @output.should_receive(:write).with("foobar").and_return("+ write") + @trans.read(17).should == "+ read" + @trans.write("foobar").should == "+ write" + end + + it "should close both input and output when closed" do + @input.should_receive(:close) + @output.should_receive(:close) + @trans.close + end + end +end diff --git a/lib/rb/spec/binary_protocol_accelerated_spec.rb b/lib/rb/spec/binary_protocol_accelerated_spec.rb new file mode 100644 index 00000000..0306cf5f --- /dev/null +++ b/lib/rb/spec/binary_protocol_accelerated_spec.rb @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/binary_protocol_spec_shared' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftBinaryProtocolAcceleratedSpec < Spec::ExampleGroup + include Thrift + + describe Thrift::BinaryProtocolAccelerated do + # since BinaryProtocolAccelerated should be directly equivalent to + # BinaryProtocol, we don't need any custom specs! + it_should_behave_like 'a binary protocol' + + def protocol_class + BinaryProtocolAccelerated + end + end + + describe BinaryProtocolAcceleratedFactory do + it "should create a BinaryProtocolAccelerated" do + BinaryProtocolAcceleratedFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocolAccelerated) + end + end +end diff --git a/lib/rb/spec/binary_protocol_spec.rb b/lib/rb/spec/binary_protocol_spec.rb new file mode 100644 index 00000000..0abccb89 --- /dev/null +++ b/lib/rb/spec/binary_protocol_spec.rb @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/binary_protocol_spec_shared' + +class ThriftBinaryProtocolSpec < Spec::ExampleGroup + include Thrift + + describe BinaryProtocol do + it_should_behave_like 'a binary protocol' + + def protocol_class + BinaryProtocol + end + + it "should read a message header" do + @trans.should_receive(:read_all).exactly(2).times.and_return( + [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY].pack('N'), + [42].pack('N') + ) + @prot.should_receive(:read_string).and_return('testMessage') + @prot.read_message_begin.should == ['testMessage', Thrift::MessageTypes::REPLY, 42] + end + + it "should raise an exception if the message header has the wrong version" do + @prot.should_receive(:read_i32).and_return(-1) + lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'Missing version identifier') do |e| + e.type == Thrift::ProtocolException::BAD_VERSION + end + end + + it "should raise an exception if the message header does not exist and strict_read is enabled" do + @prot.should_receive(:read_i32).and_return(42) + @prot.should_receive(:strict_read).and_return(true) + lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'No version identifier, old protocol client?') do |e| + e.type == Thrift::ProtocolException::BAD_VERSION + end + end + end + + describe BinaryProtocolFactory do + it "should create a BinaryProtocol" do + BinaryProtocolFactory.new.get_protocol(mock("MockTransport")).should be_instance_of(BinaryProtocol) + end + end +end diff --git a/lib/rb/spec/binary_protocol_spec_shared.rb b/lib/rb/spec/binary_protocol_spec_shared.rb new file mode 100644 index 00000000..c6608e01 --- /dev/null +++ b/lib/rb/spec/binary_protocol_spec_shared.rb @@ -0,0 +1,375 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +shared_examples_for 'a binary protocol' do + before(:each) do + @trans = Thrift::MemoryBufferTransport.new + @prot = protocol_class.new(@trans) + end + + it "should define the proper VERSION_1, VERSION_MASK AND TYPE_MASK" do + protocol_class.const_get(:VERSION_MASK).should == 0xffff0000 + protocol_class.const_get(:VERSION_1).should == 0x80010000 + protocol_class.const_get(:TYPE_MASK).should == 0x000000ff + end + + it "should make strict_read readable" do + @prot.strict_read.should eql(true) + end + + it "should make strict_write readable" do + @prot.strict_write.should eql(true) + end + + it "should write the message header" do + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::CALL, "testMessage".size, "testMessage", 17].pack("NNa11N") + end + + it "should write the message header without version when writes are not strict" do + @prot = protocol_class.new(@trans, true, false) # no strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021" + end + + it "should write the message header with a version when writes are strict" do + @prot = protocol_class.new(@trans) # strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021" + end + + + # message footer is a noop + + it "should write the field header" do + @prot.write_field_begin('foo', Thrift::Types::DOUBLE, 3) + @trans.read(1000).should == [Thrift::Types::DOUBLE, 3].pack("cn") + end + + # field footer is a noop + + it "should write the STOP field" do + @prot.write_field_stop + @trans.read(1).should == "\000" + end + + it "should write the map header" do + @prot.write_map_begin(Thrift::Types::STRING, Thrift::Types::LIST, 17) + @trans.read(1000).should == [Thrift::Types::STRING, Thrift::Types::LIST, 17].pack("ccN"); + end + + # map footer is a noop + + it "should write the list header" do + @prot.write_list_begin(Thrift::Types::I16, 42) + @trans.read(1000).should == [Thrift::Types::I16, 42].pack("cN") + end + + # list footer is a noop + + it "should write the set header" do + @prot.write_set_begin(Thrift::Types::I16, 42) + @trans.read(1000).should == [Thrift::Types::I16, 42].pack("cN") + end + + it "should write a bool" do + @prot.write_bool(true) + @prot.write_bool(false) + @trans.read(1000).should == "\001\000" + end + + it "should treat a nil bool as false" do + @prot.write_bool(nil) + @trans.read(1).should == "\000" + end + + it "should write a byte" do + # byte is small enough, let's check -128..127 + (-128..127).each do |i| + @prot.write_byte(i) + @trans.read(1).should == [i].pack('c') + end + # handing it numbers out of signed range should clip + @trans.rspec_verify + (128..255).each do |i| + @prot.write_byte(i) + @trans.read(1).should == [i].pack('c') + end + # and lastly, a Bignum is going to error out + lambda { @prot.write_byte(2**65) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil byte" do + lambda { @prot.write_byte(nil) }.should raise_error + end + + it "should write an i16" do + # try a random scattering of values + # include the signed i16 minimum/maximum + [-2**15, -1024, 17, 0, -10000, 1723, 2**15-1].each do |i| + @prot.write_i16(i) + end + # and try something out of signed range, it should clip + @prot.write_i16(2**15 + 5) + + @trans.read(1000).should == "\200\000\374\000\000\021\000\000\330\360\006\273\177\377\200\005" + + # a Bignum should error + # lambda { @prot.write_i16(2**65) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil i16" do + lambda { @prot.write_i16(nil) }.should raise_error + end + + it "should write an i32" do + # try a random scattering of values + # include the signed i32 minimum/maximum + [-2**31, -123123, -2532, -3, 0, 2351235, 12331, 2**31-1].each do |i| + @prot.write_i32(i) + end + # try something out of signed range, it should clip + @trans.read(1000).should == "\200\000\000\000" + "\377\376\037\r" + "\377\377\366\034" + "\377\377\377\375" + "\000\000\000\000" + "\000#\340\203" + "\000\0000+" + "\177\377\377\377" + [2 ** 31 + 5, 2 ** 65 + 5].each do |i| + lambda { @prot.write_i32(i) }.should raise_error(RangeError) + end + end + + it "should error gracefully when trying to write a nil i32" do + lambda { @prot.write_i32(nil) }.should raise_error + end + + it "should write an i64" do + # try a random scattering of values + # try the signed i64 minimum/maximum + [-2**63, -12356123612323, -23512351, -234, 0, 1231, 2351236, 12361236213, 2**63-1].each do |i| + @prot.write_i64(i) + end + # try something out of signed range, it should clip + @trans.read(1000).should == ["\200\000\000\000\000\000\000\000", + "\377\377\364\303\035\244+]", + "\377\377\377\377\376\231:\341", + "\377\377\377\377\377\377\377\026", + "\000\000\000\000\000\000\000\000", + "\000\000\000\000\000\000\004\317", + "\000\000\000\000\000#\340\204", + "\000\000\000\002\340\311~\365", + "\177\377\377\377\377\377\377\377"].join("") + lambda { @prot.write_i64(2 ** 65 + 5) }.should raise_error(RangeError) + end + + it "should error gracefully when trying to write a nil i64" do + lambda { @prot.write_i64(nil) }.should raise_error + end + + it "should write a double" do + # try a random scattering of values, including min/max + values = [Float::MIN,-1231.15325, -123123.23, -23.23515123, 0, 12351.1325, 523.23, Float::MAX] + values.each do |f| + @prot.write_double(f) + @trans.read(1000).should == [f].pack("G") + end + end + + it "should error gracefully when trying to write a nil double" do + lambda { @prot.write_double(nil) }.should raise_error + end + + it "should write a string" do + str = "hello world" + @prot.write_string(str) + @trans.read(1000).should == [str.size].pack("N") + str + end + + it "should error gracefully when trying to write a nil string" do + lambda { @prot.write_string(nil) }.should raise_error + end + + it "should write the message header without version when writes are not strict" do + @prot = protocol_class.new(@trans, true, false) # no strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021" + end + + it "should write the message header with a version when writes are strict" do + @prot = protocol_class.new(@trans) # strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021" + end + + # message footer is a noop + + it "should read a field header" do + @trans.write([Thrift::Types::STRING, 3].pack("cn")) + @prot.read_field_begin.should == [nil, Thrift::Types::STRING, 3] + end + + # field footer is a noop + + it "should read a stop field" do + @trans.write([Thrift::Types::STOP].pack("c")); + @prot.read_field_begin.should == [nil, Thrift::Types::STOP, 0] + end + + it "should read a map header" do + @trans.write([Thrift::Types::DOUBLE, Thrift::Types::I64, 42].pack("ccN")) + @prot.read_map_begin.should == [Thrift::Types::DOUBLE, Thrift::Types::I64, 42] + end + + # map footer is a noop + + it "should read a list header" do + @trans.write([Thrift::Types::STRING, 17].pack("cN")) + @prot.read_list_begin.should == [Thrift::Types::STRING, 17] + end + + # list footer is a noop + + it "should read a set header" do + @trans.write([Thrift::Types::STRING, 17].pack("cN")) + @prot.read_set_begin.should == [Thrift::Types::STRING, 17] + end + + # set footer is a noop + + it "should read a bool" do + @trans.write("\001\000"); + @prot.read_bool.should == true + @prot.read_bool.should == false + end + + it "should read a byte" do + [-128, -57, -3, 0, 17, 24, 127].each do |i| + @trans.write([i].pack("c")) + @prot.read_byte.should == i + end + end + + it "should read an i16" do + # try a scattering of values, including min/max + [-2**15, -5237, -353, 0, 1527, 2234, 2**15-1].each do |i| + @trans.write([i].pack("n")); + @prot.read_i16.should == i + end + end + + it "should read an i32" do + # try a scattering of values, including min/max + [-2**31, -235125, -6236, 0, 2351, 123123, 2**31-1].each do |i| + @trans.write([i].pack("N")) + @prot.read_i32.should == i + end + end + + it "should read an i64" do + # try a scattering of values, including min/max + [-2**63, -123512312, -6346, 0, 32, 2346322323, 2**63-1].each do |i| + @trans.write([i >> 32, i & 0xFFFFFFFF].pack("NN")) + @prot.read_i64.should == i + end + end + + it "should read a double" do + # try a random scattering of values, including min/max + [Float::MIN, -231231.12351, -323.233513, 0, 123.2351235, 2351235.12351235, Float::MAX].each do |f| + @trans.write([f].pack("G")); + @prot.read_double.should == f + end + end + + it "should read a string" do + str = "hello world" + @trans.write([str.size].pack("N") + str) + @prot.read_string.should == str + end + + it "should perform a complete rpc with no args or return" do + srv_test( + proc {|client| client.send_voidMethod()}, + proc {|client| client.recv_voidMethod.should == nil} + ) + end + + it "should perform a complete rpc with a primitive return type" do + srv_test( + proc {|client| client.send_primitiveMethod()}, + proc {|client| client.recv_primitiveMethod.should == 1} + ) + end + + it "should perform a complete rpc with a struct return type" do + srv_test( + proc {|client| client.send_structMethod()}, + proc {|client| + result = client.recv_structMethod + result.set_byte_map = nil + result.map_byte_map = nil + result.should == Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + } + ) + end + + def get_socket_connection + server = Thrift::ServerSocket.new("localhost", 9090) + server.listen + + clientside = Thrift::Socket.new("localhost", 9090) + clientside.open + serverside = server.accept + [clientside, serverside, server] + end + + def srv_test(firstblock, secondblock) + clientside, serverside, server = get_socket_connection + + clientproto = protocol_class.new(clientside) + serverproto = protocol_class.new(serverside) + + processor = Srv::Processor.new(SrvHandler.new) + + client = Srv::Client.new(clientproto, clientproto) + + # first block + firstblock.call(client) + + processor.process(serverproto, serverproto) + + # second block + secondblock.call(client) + ensure + clientside.close + serverside.close + server.close + end + + class SrvHandler + def voidMethod() + end + + def primitiveMethod + 1 + end + + def structMethod + Fixtures::COMPACT_PROTOCOL_TEST_STRUCT + end + end +end diff --git a/lib/rb/spec/client_spec.rb b/lib/rb/spec/client_spec.rb new file mode 100644 index 00000000..e707d816 --- /dev/null +++ b/lib/rb/spec/client_spec.rb @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftClientSpec < Spec::ExampleGroup + include Thrift + + class ClientSpec + include Thrift::Client + end + + before(:each) do + @prot = mock("MockProtocol") + @client = ClientSpec.new(@prot) + end + + describe Client do + it "should re-use iprot for oprot if not otherwise specified" do + @client.instance_variable_get(:'@iprot').should eql(@prot) + @client.instance_variable_get(:'@oprot').should eql(@prot) + end + + it "should send a test message" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::CALL, 0) + mock_args = mock('#') + mock_args.should_receive(:foo=).with('foo') + mock_args.should_receive(:bar=).with(42) + mock_args.should_receive(:write).with(@prot) + @prot.should_receive(:write_message_end) + @prot.should_receive(:trans) do + mock('trans').tee do |trans| + trans.should_receive(:flush) + end + end + klass = stub("TestMessage_args", :new => mock_args) + @client.send_message('testMessage', klass, :foo => 'foo', :bar => 42) + end + + it "should increment the sequence id when sending messages" do + pending "it seems sequence ids are completely ignored right now" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::CALL, 0).ordered + @prot.should_receive(:write_message_begin).with('testMessage2', MessageTypes::CALL, 1).ordered + @prot.should_receive(:write_message_begin).with('testMessage3', MessageTypes::CALL, 2).ordered + @prot.stub!(:write_message_end) + @prot.stub!(:trans).and_return mock("trans").as_null_object + @client.send_message('testMessage', mock("args class").as_null_object) + @client.send_message('testMessage2', mock("args class").as_null_object) + @client.send_message('testMessage3', mock("args class").as_null_object) + end + end + + it "should receive a test message" do + @prot.should_receive(:read_message_begin).and_return [nil, MessageTypes::CALL, 0] + @prot.should_receive(:read_message_end) + mock_klass = mock("#") + mock_klass.should_receive(:read).with(@prot) + @client.receive_message(stub("MockClass", :new => mock_klass)) + end + + it "should handle received exceptions" do + @prot.should_receive(:read_message_begin).and_return [nil, MessageTypes::EXCEPTION, 0] + @prot.should_receive(:read_message_end) + ApplicationException.should_receive(:new).and_return do + StandardError.new.tee do |mock_exc| + mock_exc.should_receive(:read).with(@prot) + end + end + lambda { @client.receive_message(nil) }.should raise_error(StandardError) + end + + it "should close the transport if an error occurs while sending a message" do + @prot.stub!(:write_message_begin) + @prot.should_not_receive(:write_message_end) + mock_args = mock("#") + mock_args.should_receive(:write).with(@prot).and_raise(StandardError) + trans = mock("MockTransport") + @prot.stub!(:trans).and_return(trans) + trans.should_receive(:close) + klass = mock("TestMessage_args", :new => mock_args) + lambda { @client.send_message("testMessage", klass) }.should raise_error(StandardError) + end + end +end diff --git a/lib/rb/spec/compact_protocol_spec.rb b/lib/rb/spec/compact_protocol_spec.rb new file mode 100644 index 00000000..b9a79810 --- /dev/null +++ b/lib/rb/spec/compact_protocol_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +describe Thrift::CompactProtocol do + TESTS = { + :byte => (-127..127).to_a, + :i16 => (0..14).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :i32 => (0..30).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :i64 => (0..62).map {|shift| [1 << shift, -(1 << shift)]}.flatten.sort, + :string => ["", "1", "short", "fourteen123456", "fifteen12345678", "1" * 127, "1" * 3000], + :binary => ["", "\001", "\001" * 5, "\001" * 14, "\001" * 15, "\001" * 127, "\001" * 3000], + :double => [0.0, 1.0, -1.0, 1.1, -1.1, 10000000.1, 1.0/0.0, -1.0/0.0], + :bool => [true, false] + } + + it "should encode and decode naked primitives correctly" do + TESTS.each_pair do |primitive_type, test_values| + test_values.each do |value| + # puts "testing #{value}" if primitive_type == :i64 + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + proto.send(writer(primitive_type), value) + # puts "buf: #{trans.inspect_buffer}" if primitive_type == :i64 + read_back = proto.send(reader(primitive_type)) + read_back.should == value + end + end + end + + it "should encode and decode primitives in fields correctly" do + TESTS.each_pair do |primitive_type, test_values| + final_primitive_type = primitive_type == :binary ? :string : primitive_type + thrift_type = Thrift::Types.const_get(final_primitive_type.to_s.upcase) + # puts primitive_type + test_values.each do |value| + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + proto.write_field_begin(nil, thrift_type, 15) + proto.send(writer(primitive_type), value) + proto.write_field_end + + proto = Thrift::CompactProtocol.new(trans) + name, type, id = proto.read_field_begin + type.should == thrift_type + id.should == 15 + read_back = proto.send(reader(primitive_type)) + read_back.should == value + proto.read_field_end + end + end + end + + it "should encode and decode a monster struct correctly" do + trans = Thrift::MemoryBufferTransport.new + proto = Thrift::CompactProtocol.new(trans) + + struct = CompactProtoTestStruct.new + # sets and maps don't hash well... not sure what to do here. + struct.write(proto) + + struct2 = CompactProtoTestStruct.new + struct2.read(proto) + struct2.should == struct + end + + it "should make method calls correctly" do + client_out_trans = Thrift::MemoryBufferTransport.new + client_out_proto = Thrift::CompactProtocol.new(client_out_trans) + + client_in_trans = Thrift::MemoryBufferTransport.new + client_in_proto = Thrift::CompactProtocol.new(client_in_trans) + + processor = Srv::Processor.new(JankyHandler.new) + + client = Srv::Client.new(client_in_proto, client_out_proto) + client.send_Janky(1) + # puts client_out_trans.inspect_buffer + processor.process(client_out_proto, client_in_proto) + client.recv_Janky.should == 2 + end + + class JankyHandler + def Janky(i32arg) + i32arg * 2 + end + end + + def writer(sym) + sym = sym == :binary ? :string : sym + "write_#{sym.to_s}" + end + + def reader(sym) + sym = sym == :binary ? :string : sym + "read_#{sym.to_s}" + end +end \ No newline at end of file diff --git a/lib/rb/spec/exception_spec.rb b/lib/rb/spec/exception_spec.rb new file mode 100644 index 00000000..fc321378 --- /dev/null +++ b/lib/rb/spec/exception_spec.rb @@ -0,0 +1,142 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftExceptionSpec < Spec::ExampleGroup + include Thrift + + describe Exception do + it "should have an accessible message" do + e = Exception.new("test message") + e.message.should == "test message" + end + end + + describe ApplicationException do + it "should inherit from Thrift::Exception" do + ApplicationException.superclass.should == Exception + end + + it "should have an accessible type and message" do + e = ApplicationException.new + e.type.should == ApplicationException::UNKNOWN + e.message.should be_nil + e = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, "test message") + e.type.should == ApplicationException::UNKNOWN_METHOD + e.message.should == "test message" + end + + it "should read a struct off of a protocol" do + prot = mock("MockProtocol") + prot.should_receive(:read_struct_begin).ordered + prot.should_receive(:read_field_begin).exactly(3).times.and_return( + ["message", Types::STRING, 1], + ["type", Types::I32, 2], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_string).ordered.and_return "test message" + prot.should_receive(:read_i32).ordered.and_return ApplicationException::BAD_SEQUENCE_ID + prot.should_receive(:read_field_end).exactly(2).times + prot.should_receive(:read_struct_end).ordered + + e = ApplicationException.new + e.read(prot) + e.message.should == "test message" + e.type.should == ApplicationException::BAD_SEQUENCE_ID + end + + it "should skip bad fields when reading a struct" do + prot = mock("MockProtocol") + prot.should_receive(:read_struct_begin).ordered + prot.should_receive(:read_field_begin).exactly(5).times.and_return( + ["type", Types::I32, 2], + ["type", Types::STRING, 2], + ["message", Types::MAP, 1], + ["message", Types::STRING, 3], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_i32).and_return ApplicationException::INVALID_MESSAGE_TYPE + prot.should_receive(:skip).with(Types::STRING).twice + prot.should_receive(:skip).with(Types::MAP) + prot.should_receive(:read_field_end).exactly(4).times + prot.should_receive(:read_struct_end).ordered + + e = ApplicationException.new + e.read(prot) + e.message.should be_nil + e.type.should == ApplicationException::INVALID_MESSAGE_TYPE + end + + it "should write a Thrift::ApplicationException struct to the oprot" do + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("message", Types::STRING, 1).ordered + prot.should_receive(:write_string).with("test message").ordered + prot.should_receive(:write_field_begin).with("type", Types::I32, 2).ordered + prot.should_receive(:write_i32).with(ApplicationException::UNKNOWN_METHOD).ordered + prot.should_receive(:write_field_end).twice + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(ApplicationException::UNKNOWN_METHOD, "test message") + e.write(prot) + end + + it "should skip nil fields when writing to the oprot" do + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("message", Types::STRING, 1).ordered + prot.should_receive(:write_string).with("test message").ordered + prot.should_receive(:write_field_end).ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(nil, "test message") + e.write(prot) + + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_begin).with("type", Types::I32, 2).ordered + prot.should_receive(:write_i32).with(ApplicationException::BAD_SEQUENCE_ID).ordered + prot.should_receive(:write_field_end).ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(ApplicationException::BAD_SEQUENCE_ID) + e.write(prot) + + prot = mock("MockProtocol") + prot.should_receive(:write_struct_begin).with("Thrift::ApplicationException").ordered + prot.should_receive(:write_field_stop).ordered + prot.should_receive(:write_struct_end).ordered + + e = ApplicationException.new(nil) + e.write(prot) + end + end + + describe ProtocolException do + it "should have an accessible type" do + prot = ProtocolException.new(ProtocolException::SIZE_LIMIT, "message") + prot.type.should == ProtocolException::SIZE_LIMIT + prot.message.should == "message" + end + end +end diff --git a/lib/rb/spec/http_client_spec.rb b/lib/rb/spec/http_client_spec.rb new file mode 100644 index 00000000..94526deb --- /dev/null +++ b/lib/rb/spec/http_client_spec.rb @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftHTTPClientTransportSpec < Spec::ExampleGroup + include Thrift + + describe HTTPClientTransport do + before(:each) do + @client = HTTPClientTransport.new("http://my.domain.com/path/to/service") + end + + it "should always be open" do + @client.should be_open + @client.close + @client.should be_open + end + + it "should post via HTTP and return the results" do + @client.write "a test" + @client.write " frame" + Net::HTTP.should_receive(:new).with("my.domain.com", 80).and_return do + mock("Net::HTTP").tee do |http| + http.should_receive(:use_ssl=).with(false) + http.should_receive(:post).with("/path/to/service", "a test frame", {"Content-Type"=>"application/x-thrift"}).and_return([nil, "data"]) + end + end + @client.flush + @client.read(10).should == "data" + end + end +end diff --git a/lib/rb/spec/mongrel_http_server_spec.rb b/lib/rb/spec/mongrel_http_server_spec.rb new file mode 100644 index 00000000..c994491c --- /dev/null +++ b/lib/rb/spec/mongrel_http_server_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require 'thrift/server/mongrel_http_server' + +class ThriftHTTPServerSpec < Spec::ExampleGroup + include Thrift + + Handler = MongrelHTTPServer::Handler + + describe MongrelHTTPServer do + it "should have appropriate defaults" do + mock_factory = mock("BinaryProtocolFactory") + mock_proc = mock("Processor") + BinaryProtocolFactory.should_receive(:new).and_return(mock_factory) + Mongrel::HttpServer.should_receive(:new).with("0.0.0.0", 80).and_return do + mock("Mongrel::HttpServer").tee do |mock| + handler = mock("Handler") + Handler.should_receive(:new).with(mock_proc, mock_factory).and_return(handler) + mock.should_receive(:register).with("/", handler) + end + end + MongrelHTTPServer.new(mock_proc) + end + + it "should understand :ip, :port, :path, and :protocol_factory" do + mock_proc = mock("Processor") + mock_factory = mock("ProtocolFactory") + Mongrel::HttpServer.should_receive(:new).with("1.2.3.4", 1234).and_return do + mock("Mongrel::HttpServer").tee do |mock| + handler = mock("Handler") + Handler.should_receive(:new).with(mock_proc, mock_factory).and_return(handler) + mock.should_receive(:register).with("/foo", handler) + end + end + MongrelHTTPServer.new(mock_proc, :ip => "1.2.3.4", :port => 1234, :path => "foo", + :protocol_factory => mock_factory) + end + + it "should serve using Mongrel::HttpServer" do + BinaryProtocolFactory.stub!(:new) + Mongrel::HttpServer.should_receive(:new).and_return do + mock("Mongrel::HttpServer").tee do |mock| + Handler.stub!(:new) + mock.stub!(:register) + mock.should_receive(:run).and_return do + mock("Mongrel::HttpServer.run").tee do |runner| + runner.should_receive(:join) + end + end + end + end + MongrelHTTPServer.new(nil).serve + end + end + + describe MongrelHTTPServer::Handler do + before(:each) do + @processor = mock("Processor") + @factory = mock("ProtocolFactory") + @handler = Handler.new(@processor, @factory) + end + + it "should return 404 for non-POST requests" do + request = mock("request", :params => {"REQUEST_METHOD" => "GET"}) + response = mock("response") + response.should_receive(:start).with(404) + response.should_not_receive(:start).with(200) + @handler.process(request, response) + end + + it "should serve using application/x-thrift" do + request = mock("request", :params => {"REQUEST_METHOD" => "POST"}, :body => nil) + response = mock("response") + head = mock("head") + head.should_receive(:[]=).with("Content-Type", "application/x-thrift") + IOStreamTransport.stub!(:new) + @factory.stub!(:get_protocol) + @processor.stub!(:process) + response.should_receive(:start).with(200).and_yield(head, nil) + @handler.process(request, response) + end + + it "should use the IOStreamTransport" do + body = mock("body") + request = mock("request", :params => {"REQUEST_METHOD" => "POST"}, :body => body) + response = mock("response") + head = mock("head") + head.stub!(:[]=) + out = mock("out") + protocol = mock("protocol") + transport = mock("transport") + IOStreamTransport.should_receive(:new).with(body, out).and_return(transport) + @factory.should_receive(:get_protocol).with(transport).and_return(protocol) + @processor.should_receive(:process).with(protocol, protocol) + response.should_receive(:start).with(200).and_yield(head, out) + @handler.process(request, response) + end + end +end diff --git a/lib/rb/spec/nonblocking_server_spec.rb b/lib/rb/spec/nonblocking_server_spec.rb new file mode 100644 index 00000000..a0e86cf2 --- /dev/null +++ b/lib/rb/spec/nonblocking_server_spec.rb @@ -0,0 +1,266 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/nonblocking_service' + +class ThriftNonblockingServerSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + class Handler + def initialize + @queue = Queue.new + end + + attr_accessor :server + + def greeting(english) + if english + SpecNamespace::Hello.new + else + SpecNamespace::Hello.new(:greeting => "Aloha!") + end + end + + def block + @queue.pop + end + + def unblock(n) + n.times { @queue.push true } + end + + def sleep(time) + Kernel.sleep time + end + + def shutdown + @server.shutdown(0, false) + end + end + + class SpecTransport < BaseTransport + def initialize(transport, queue) + @transport = transport + @queue = queue + @flushed = false + end + + def open? + @transport.open? + end + + def open + @transport.open + end + + def close + @transport.close + end + + def read(sz) + @transport.read(sz) + end + + def write(buf,sz=nil) + @transport.write(buf, sz) + end + + def flush + @queue.push :flushed unless @flushed or @queue.nil? + @flushed = true + @transport.flush + end + end + + class SpecServerSocket < ServerSocket + def initialize(host, port, queue) + super(host, port) + @queue = queue + end + + def listen + super + @queue.push :listen + end + end + + describe Thrift::NonblockingServer do + before(:each) do + @port = 43251 + handler = Handler.new + processor = NonblockingService::Processor.new(handler) + queue = Queue.new + @transport = SpecServerSocket.new('localhost', @port, queue) + transport_factory = FramedTransportFactory.new + logger = Logger.new(STDERR) + logger.level = Logger::WARN + @server = NonblockingServer.new(processor, @transport, transport_factory, nil, 5, logger) + handler.server = @server + @server_thread = Thread.new(Thread.current) do |master_thread| + begin + @server.serve + rescue => e + p e + puts e.backtrace * "\n" + master_thread.raise e + end + end + queue.pop + + @clients = [] + @catch_exceptions = false + end + + after(:each) do + @clients.each { |client, trans| trans.close } + # @server.shutdown(1) + @server_thread.kill + @transport.close + end + + def setup_client(queue = nil) + transport = SpecTransport.new(FramedTransport.new(Socket.new('localhost', @port)), queue) + protocol = BinaryProtocol.new(transport) + client = NonblockingService::Client.new(protocol) + transport.open + @clients << [client, transport] + client + end + + def setup_client_thread(result) + queue = Queue.new + Thread.new do + begin + client = setup_client + while (cmd = queue.pop) + msg, *args = cmd + case msg + when :block + result << client.block + when :unblock + client.unblock(args.first) + when :hello + result << client.greeting(true) # ignore result + when :sleep + client.sleep(args[0] || 0.5) + result << :slept + when :shutdown + client.shutdown + when :exit + result << :done + break + end + end + @clients.each { |c,t| t.close and break if c == client } #close the transport + rescue => e + raise e unless @catch_exceptions + end + end + queue + end + + it "should handle basic message passing" do + client = setup_client + client.greeting(true).should == Hello.new + client.greeting(false).should == Hello.new(:greeting => 'Aloha!') + @server.shutdown + end + + it "should handle concurrent clients" do + queue = Queue.new + trans_queue = Queue.new + 4.times do + Thread.new(Thread.current) do |main_thread| + begin + queue.push setup_client(trans_queue).block + rescue => e + main_thread.raise e + end + end + end + 4.times { trans_queue.pop } + setup_client.unblock(4) + 4.times { queue.pop.should be_true } + @server.shutdown + end + + it "should handle messages from more than 5 long-lived connections" do + queues = [] + result = Queue.new + 7.times do |i| + queues << setup_client_thread(result) + Thread.pass if i == 4 # give the server time to accept connections + end + client = setup_client + # block 4 connections + 4.times { |i| queues[i] << :block } + queues[4] << :hello + queues[5] << :hello + queues[6] << :hello + 3.times { result.pop.should == Hello.new } + client.greeting(true).should == Hello.new + queues[5] << [:unblock, 4] + 4.times { result.pop.should be_true } + queues[2] << :hello + result.pop.should == Hello.new + client.greeting(false).should == Hello.new(:greeting => 'Aloha!') + 7.times { queues.shift << :exit } + client.greeting(true).should == Hello.new + @server.shutdown + end + + it "should shut down when asked" do + # connect first to ensure it's running + client = setup_client + client.greeting(false) # force a message pass + @server.shutdown + @server_thread.join(2).should be_an_instance_of(Thread) + end + + it "should continue processing active messages when shutting down" do + result = Queue.new + client = setup_client_thread(result) + client << :sleep + sleep 0.1 # give the server time to start processing the client's message + @server.shutdown + @server_thread.join(2).should be_an_instance_of(Thread) + result.pop.should == :slept + end + + it "should kill active messages when they don't expire while shutting down" do + result = Queue.new + client = setup_client_thread(result) + client << [:sleep, 10] + sleep 0.1 # start processing the client's message + @server.shutdown(1) + @catch_exceptions = true + @server_thread.join(3).should_not be_nil + result.should be_empty + end + + it "should allow shutting down in response to a message" do + client = setup_client + client.greeting(true).should == Hello.new + client.shutdown + @server_thread.join(2).should_not be_nil + end + end +end diff --git a/lib/rb/spec/processor_spec.rb b/lib/rb/spec/processor_spec.rb new file mode 100644 index 00000000..d35f6528 --- /dev/null +++ b/lib/rb/spec/processor_spec.rb @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftProcessorSpec < Spec::ExampleGroup + include Thrift + + class ProcessorSpec + include Thrift::Processor + end + + describe "Processor" do + before(:each) do + @processor = ProcessorSpec.new(mock("MockHandler")) + @prot = mock("MockProtocol") + end + + def mock_trans(obj) + obj.should_receive(:trans).ordered.and_return do + mock("trans").tee do |trans| + trans.should_receive(:flush).ordered + end + end + end + + it "should call process_ when it receives that message" do + @prot.should_receive(:read_message_begin).ordered.and_return ['testMessage', MessageTypes::CALL, 17] + @processor.should_receive(:process_testMessage).with(17, @prot, @prot).ordered + @processor.process(@prot, @prot).should == true + end + + it "should raise an ApplicationException when the received message cannot be processed" do + @prot.should_receive(:read_message_begin).ordered.and_return ['testMessage', MessageTypes::CALL, 4] + @prot.should_receive(:skip).with(Types::STRUCT).ordered + @prot.should_receive(:read_message_end).ordered + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::EXCEPTION, 4).ordered + ApplicationException.should_receive(:new).with(ApplicationException::UNKNOWN_METHOD, "Unknown function testMessage").and_return do + mock(ApplicationException).tee do |e| + e.should_receive(:write).with(@prot).ordered + end + end + @prot.should_receive(:write_message_end).ordered + mock_trans(@prot) + @processor.process(@prot, @prot) + end + + it "should pass args off to the args class" do + args_class = mock("MockArgsClass") + args = mock("#").tee do |args| + args.should_receive(:read).with(@prot).ordered + end + args_class.should_receive(:new).and_return args + @prot.should_receive(:read_message_end).ordered + @processor.read_args(@prot, args_class).should eql(args) + end + + it "should write out a reply when asked" do + @prot.should_receive(:write_message_begin).with('testMessage', MessageTypes::REPLY, 23).ordered + result = mock("MockResult") + result.should_receive(:write).with(@prot).ordered + @prot.should_receive(:write_message_end).ordered + mock_trans(@prot) + @processor.write_result(result, @prot, 'testMessage', 23) + end + end +end diff --git a/lib/rb/spec/serializer_spec.rb b/lib/rb/spec/serializer_spec.rb new file mode 100644 index 00000000..82f374b1 --- /dev/null +++ b/lib/rb/spec/serializer_spec.rb @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftSerializerSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + describe Serializer do + it "should serialize structs to binary by default" do + serializer = Serializer.new(Thrift::BinaryProtocolAcceleratedFactory.new) + data = serializer.serialize(Hello.new(:greeting => "'Ello guv'nor!")) + data.should == "\x0B\x00\x01\x00\x00\x00\x0E'Ello guv'nor!\x00" + end + + it "should serialize structs to the given protocol" do + protocol = BaseProtocol.new(mock("transport")) + protocol.should_receive(:write_struct_begin).with("SpecNamespace::Hello") + protocol.should_receive(:write_field_begin).with("greeting", Types::STRING, 1) + protocol.should_receive(:write_string).with("Good day") + protocol.should_receive(:write_field_end) + protocol.should_receive(:write_field_stop) + protocol.should_receive(:write_struct_end) + protocol_factory = mock("ProtocolFactory") + protocol_factory.stub!(:get_protocol).and_return(protocol) + serializer = Serializer.new(protocol_factory) + serializer.serialize(Hello.new(:greeting => "Good day")) + end + end + + describe Deserializer do + it "should deserialize structs from binary by default" do + deserializer = Deserializer.new + data = "\x0B\x00\x01\x00\x00\x00\x0E'Ello guv'nor!\x00" + deserializer.deserialize(Hello.new, data).should == Hello.new(:greeting => "'Ello guv'nor!") + end + + it "should deserialize structs from the given protocol" do + protocol = BaseProtocol.new(mock("transport")) + protocol.should_receive(:read_struct_begin).and_return("SpecNamespace::Hello") + protocol.should_receive(:read_field_begin).and_return(["greeting", Types::STRING, 1], + [nil, Types::STOP, 0]) + protocol.should_receive(:read_string).and_return("Good day") + protocol.should_receive(:read_field_end) + protocol.should_receive(:read_struct_end) + protocol_factory = mock("ProtocolFactory") + protocol_factory.stub!(:get_protocol).and_return(protocol) + deserializer = Deserializer.new(protocol_factory) + deserializer.deserialize(Hello.new, "").should == Hello.new(:greeting => "Good day") + end + end +end diff --git a/lib/rb/spec/server_socket_spec.rb b/lib/rb/spec/server_socket_spec.rb new file mode 100644 index 00000000..fce50134 --- /dev/null +++ b/lib/rb/spec/server_socket_spec.rb @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftServerSocketSpec < Spec::ExampleGroup + include Thrift + + describe ServerSocket do + before(:each) do + @socket = ServerSocket.new(1234) + end + + it "should create a handle when calling listen" do + TCPServer.should_receive(:new).with(nil, 1234) + @socket.listen + end + + it "should accept an optional host argument" do + @socket = ServerSocket.new('localhost', 1234) + TCPServer.should_receive(:new).with('localhost', 1234) + @socket.listen + end + + it "should create a Thrift::Socket to wrap accepted sockets" do + handle = mock("TCPServer") + TCPServer.should_receive(:new).with(nil, 1234).and_return(handle) + @socket.listen + sock = mock("sock") + handle.should_receive(:accept).and_return(sock) + trans = mock("Socket") + Socket.should_receive(:new).and_return(trans) + trans.should_receive(:handle=).with(sock) + @socket.accept.should == trans + end + + it "should close the handle when closed" do + handle = mock("TCPServer", :closed? => false) + TCPServer.should_receive(:new).with(nil, 1234).and_return(handle) + @socket.listen + handle.should_receive(:close) + @socket.close + end + + it "should return nil when accepting if there is no handle" do + @socket.accept.should be_nil + end + + it "should return true for closed? when appropriate" do + handle = mock("TCPServer", :closed? => false) + TCPServer.stub!(:new).and_return(handle) + @socket.listen + @socket.should_not be_closed + handle.stub!(:close) + @socket.close + @socket.should be_closed + @socket.listen + @socket.should_not be_closed + handle.stub!(:closed?).and_return(true) + @socket.should be_closed + end + end +end diff --git a/lib/rb/spec/server_spec.rb b/lib/rb/spec/server_spec.rb new file mode 100644 index 00000000..ffe9bffa --- /dev/null +++ b/lib/rb/spec/server_spec.rb @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +class ThriftServerSpec < Spec::ExampleGroup + include Thrift + + describe BaseServer do + it "should default to BaseTransportFactory and BinaryProtocolFactory when not specified" do + server = BaseServer.new(mock("Processor"), mock("BaseServerTransport")) + server.instance_variable_get(:'@transport_factory').should be_an_instance_of(BaseTransportFactory) + server.instance_variable_get(:'@protocol_factory').should be_an_instance_of(BinaryProtocolFactory) + end + + # serve is a noop, so can't test that + end + + shared_examples_for "servers" do + before(:each) do + @processor = mock("Processor") + @serverTrans = mock("ServerTransport") + @trans = mock("BaseTransport") + @prot = mock("BaseProtocol") + @client = mock("Client") + @server = server_type.new(@processor, @serverTrans, @trans, @prot) + end + end + + describe SimpleServer do + it_should_behave_like "servers" + + def server_type + SimpleServer + end + + it "should serve in the main thread" do + @serverTrans.should_receive(:listen).ordered + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.with(@client).and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.with(@trans).and_return(@prot) + x = 0 + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then throw :stop + end + end + @trans.should_receive(:close).exactly(3).times + @serverTrans.should_receive(:close).ordered + lambda { @server.serve }.should throw_symbol(:stop) + end + end + + describe ThreadedServer do + it_should_behave_like "servers" + + def server_type + ThreadedServer + end + + it "should serve using threads" do + @serverTrans.should_receive(:listen).ordered + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.with(@client).and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.with(@trans).and_return(@prot) + Thread.should_receive(:new).with(@prot, @trans).exactly(3).times.and_yield(@prot, @trans) + x = 0 + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then throw :stop + end + end + @trans.should_receive(:close).exactly(3).times + @serverTrans.should_receive(:close).ordered + lambda { @server.serve }.should throw_symbol(:stop) + end + end + + describe ThreadPoolServer do + it_should_behave_like "servers" + + def server_type + # put this stuff here so it runs before the server is created + @threadQ = mock("SizedQueue") + SizedQueue.should_receive(:new).with(20).and_return(@threadQ) + @excQ = mock("Queue") + Queue.should_receive(:new).and_return(@excQ) + ThreadPoolServer + end + + it "should set up the queues" do + @server.instance_variable_get(:'@thread_q').should be(@threadQ) + @server.instance_variable_get(:'@exception_q').should be(@excQ) + end + + it "should serve inside a thread" do + Thread.should_receive(:new).and_return do |block| + @server.should_receive(:serve) + block.call + @server.rspec_verify + end + @excQ.should_receive(:pop).and_throw(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + end + + it "should avoid running the server twice when retrying rescuable_serve" do + Thread.should_receive(:new).and_return do |block| + @server.should_receive(:serve) + block.call + @server.rspec_verify + end + @excQ.should_receive(:pop).twice.and_throw(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + lambda { @server.rescuable_serve }.should throw_symbol(:popped) + end + + it "should serve using a thread pool" do + @serverTrans.should_receive(:listen).ordered + @threadQ.should_receive(:push).with(:token) + @threadQ.should_receive(:pop) + Thread.should_receive(:new).and_yield + @serverTrans.should_receive(:accept).exactly(3).times.and_return(@client) + @trans.should_receive(:get_transport).exactly(3).times.and_return(@trans) + @prot.should_receive(:get_protocol).exactly(3).times.and_return(@prot) + x = 0 + error = RuntimeError.new("Stopped") + @processor.should_receive(:process).exactly(3).times.with(@prot, @prot).and_return do + case (x += 1) + when 1 then raise Thrift::TransportException + when 2 then raise Thrift::ProtocolException + when 3 then raise error + end + end + @trans.should_receive(:close).exactly(3).times + @excQ.should_receive(:push).with(error).and_throw(:stop) + @serverTrans.should_receive(:close) + lambda { @server.serve }.should throw_symbol(:stop) + end + end +end diff --git a/lib/rb/spec/socket_spec.rb b/lib/rb/spec/socket_spec.rb new file mode 100644 index 00000000..dd8b0f92 --- /dev/null +++ b/lib/rb/spec/socket_spec.rb @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftSocketSpec < Spec::ExampleGroup + include Thrift + + describe Socket do + before(:each) do + @socket = Socket.new + @handle = mock("Handle", :closed? => false) + @handle.stub!(:close) + @handle.stub!(:connect_nonblock) + ::Socket.stub!(:new).and_return(@handle) + end + + it_should_behave_like "a socket" + + it "should raise a TransportException when it cannot open a socket" do + ::Socket.should_receive(:new).and_raise(StandardError) + lambda { @socket.open }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should open a ::Socket with default args" do + ::Socket.should_receive(:new).and_return(mock("Handle", :connect_nonblock => true)) + ::Socket.should_receive(:getaddrinfo).with("localhost", 9090).and_return([[]]) + ::Socket.should_receive(:sockaddr_in) + @socket.open + end + + it "should accept host/port options" do + ::Socket.should_receive(:new).and_return(mock("Handle", :connect_nonblock => true)) + ::Socket.should_receive(:getaddrinfo).with("my.domain", 1234).and_return([[]]) + ::Socket.should_receive(:sockaddr_in) + Socket.new('my.domain', 1234).open + end + + it "should accept an optional timeout" do + ::Socket.stub!(:new) + Socket.new('localhost', 8080, 5).timeout.should == 5 + end + end +end diff --git a/lib/rb/spec/socket_spec_shared.rb b/lib/rb/spec/socket_spec_shared.rb new file mode 100644 index 00000000..96b433b8 --- /dev/null +++ b/lib/rb/spec/socket_spec_shared.rb @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' + +shared_examples_for "a socket" do + it "should open a socket" do + @socket.open.should == @handle + end + + it "should be open whenever it has a handle" do + @socket.should_not be_open + @socket.open + @socket.should be_open + @socket.handle = nil + @socket.should_not be_open + @socket.handle = @handle + @socket.close + @socket.should_not be_open + end + + it "should write data to the handle" do + @socket.open + @handle.should_receive(:write).with("foobar") + @socket.write("foobar") + @handle.should_receive(:write).with("fail").and_raise(StandardError) + lambda { @socket.write("fail") }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should raise an error when it cannot read from the handle" do + @socket.open + @handle.should_receive(:readpartial).with(17).and_raise(StandardError) + lambda { @socket.read(17) }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should return the data read when reading from the handle works" do + @socket.open + @handle.should_receive(:readpartial).with(17).and_return("test data") + @socket.read(17).should == "test data" + end + + it "should declare itself as closed when it has an error" do + @socket.open + @handle.should_receive(:write).with("fail").and_raise(StandardError) + @socket.should be_open + lambda { @socket.write("fail") }.should raise_error + @socket.should_not be_open + end + + it "should raise an error when the stream is closed" do + @socket.open + @handle.stub!(:closed?).and_return(true) + @socket.should_not be_open + lambda { @socket.write("fail") }.should raise_error(IOError, "closed stream") + lambda { @socket.read(10) }.should raise_error(IOError, "closed stream") + end + + it "should support the timeout accessor for read" do + @socket.timeout = 3 + @socket.open + IO.should_receive(:select).with([@handle], nil, nil, 3).and_return([[@handle], [], []]) + @handle.should_receive(:readpartial).with(17).and_return("test data") + @socket.read(17).should == "test data" + end + + it "should support the timeout accessor for write" do + @socket.timeout = 3 + @socket.open + IO.should_receive(:select).with(nil, [@handle], nil, 3).twice.and_return([[], [@handle], []]) + @handle.should_receive(:write_nonblock).with("test data").and_return(4) + @handle.should_receive(:write_nonblock).with(" data").and_return(5) + @socket.write("test data").should == 9 + end + + it "should raise an error when read times out" do + @socket.timeout = 0.5 + @socket.open + IO.should_receive(:select).with([@handle], nil, nil, 0.5).at_least(1).times.and_return(nil) + lambda { @socket.read(17) }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::TIMED_OUT } + end + + it "should raise an error when write times out" do + @socket.timeout = 0.5 + @socket.open + IO.should_receive(:select).with(nil, [@handle], nil, 0.5).any_number_of_times.and_return(nil) + lambda { @socket.write("test data") }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::TIMED_OUT } + end +end diff --git a/lib/rb/spec/spec_helper.rb b/lib/rb/spec/spec_helper.rb new file mode 100644 index 00000000..3c20bf99 --- /dev/null +++ b/lib/rb/spec/spec_helper.rb @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'rubygems' +# require at least 1.1.4 to fix a bug with describing Modules +gem 'rspec', '>= 1.1.4' +require 'spec' + +$:.unshift File.join(File.dirname(__FILE__), *%w[.. ext]) + +# pretend we already loaded fastthread, otherwise the nonblocking_server_spec +# will get screwed up +# $" << 'fastthread.bundle' + +require File.dirname(__FILE__) + '/../lib/thrift' + +class Object + # tee is a useful method, so let's let our tests have it + def tee(&block) + block.call(self) + self + end +end + +Spec::Runner.configure do |configuration| + configuration.before(:each) do + Thrift.type_checking = true + end +end + +require File.dirname(__FILE__) + "/../debug_proto_test/gen-rb/Srv" +require File.dirname(__FILE__) + "/../debug_proto_test/gen-rb/debug_proto_test_constants" + +module Fixtures + COMPACT_PROTOCOL_TEST_STRUCT = COMPACT_TEST.dup + COMPACT_PROTOCOL_TEST_STRUCT.a_binary = [0,1,2,3,4,5,6,7,8].pack('c*') + COMPACT_PROTOCOL_TEST_STRUCT.set_byte_map = nil + COMPACT_PROTOCOL_TEST_STRUCT.map_byte_map = nil +end \ No newline at end of file diff --git a/lib/rb/spec/struct_spec.rb b/lib/rb/spec/struct_spec.rb new file mode 100644 index 00000000..23a701ec --- /dev/null +++ b/lib/rb/spec/struct_spec.rb @@ -0,0 +1,253 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftStructSpec < Spec::ExampleGroup + include Thrift + include SpecNamespace + + describe Struct do + it "should iterate over all fields properly" do + fields = {} + Foo.new.each_field { |fid,field_info| fields[fid] = field_info } + fields.should == Foo::FIELDS + end + + it "should initialize all fields to defaults" do + struct = Foo.new + struct.simple.should == 53 + struct.words.should == "words" + struct.hello.should == Hello.new(:greeting => 'hello, world!') + struct.ints.should == [1, 2, 2, 3] + struct.complex.should be_nil + struct.shorts.should == Set.new([5, 17, 239]) + end + + it "should not share default values between instances" do + begin + struct = Foo.new + struct.ints << 17 + Foo.new.ints.should == [1,2,2,3] + ensure + # ensure no leakage to other tests + Foo::FIELDS[4][:default] = [1,2,2,3] + end + end + + it "should properly initialize boolean values" do + struct = BoolStruct.new(:yesno => false) + struct.yesno.should be_false + end + + it "should have proper == semantics" do + Foo.new.should_not == Hello.new + Foo.new.should == Foo.new + Foo.new(:simple => 52).should_not == Foo.new + end + + it "should read itself off the wire" do + struct = Foo.new + prot = BaseProtocol.new(mock("transport")) + prot.should_receive(:read_struct_begin).twice + prot.should_receive(:read_struct_end).twice + prot.should_receive(:read_field_begin).and_return( + ['complex', Types::MAP, 5], # Foo + ['words', Types::STRING, 2], # Foo + ['hello', Types::STRUCT, 3], # Foo + ['greeting', Types::STRING, 1], # Hello + [nil, Types::STOP, 0], # Hello + ['simple', Types::I32, 1], # Foo + ['ints', Types::LIST, 4], # Foo + ['shorts', Types::SET, 6], # Foo + [nil, Types::STOP, 0] # Hello + ) + prot.should_receive(:read_field_end).exactly(7).times + prot.should_receive(:read_map_begin).and_return( + [Types::I32, Types::MAP, 2], # complex + [Types::STRING, Types::DOUBLE, 2], # complex/1/value + [Types::STRING, Types::DOUBLE, 1] # complex/2/value + ) + prot.should_receive(:read_map_end).exactly(3).times + prot.should_receive(:read_list_begin).and_return([Types::I32, 4]) + prot.should_receive(:read_list_end) + prot.should_receive(:read_set_begin).and_return([Types::I16, 2]) + prot.should_receive(:read_set_end) + prot.should_receive(:read_i32).and_return( + 1, 14, # complex keys + 42, # simple + 4, 23, 4, 29 # ints + ) + prot.should_receive(:read_string).and_return("pi", "e", "feigenbaum", "apple banana", "what's up?") + prot.should_receive(:read_double).and_return(Math::PI, Math::E, 4.669201609) + prot.should_receive(:read_i16).and_return(2, 3) + prot.should_not_receive(:skip) + struct.read(prot) + + struct.simple.should == 42 + struct.complex.should == {1 => {"pi" => Math::PI, "e" => Math::E}, 14 => {"feigenbaum" => 4.669201609}} + struct.hello.should == Hello.new(:greeting => "what's up?") + struct.words.should == "apple banana" + struct.ints.should == [4, 23, 4, 29] + struct.shorts.should == Set.new([3, 2]) + end + + it "should skip unexpected fields in structs and use default values" do + struct = Foo.new + prot = BaseProtocol.new(mock("transport")) + prot.should_receive(:read_struct_begin) + prot.should_receive(:read_struct_end) + prot.should_receive(:read_field_begin).and_return( + ['simple', Types::I32, 1], + ['complex', Types::STRUCT, 5], + ['thinz', Types::MAP, 7], + ['foobar', Types::I32, 3], + ['words', Types::STRING, 2], + [nil, Types::STOP, 0] + ) + prot.should_receive(:read_field_end).exactly(5).times + prot.should_receive(:read_i32).and_return(42) + prot.should_receive(:read_string).and_return("foobar") + prot.should_receive(:skip).with(Types::STRUCT) + prot.should_receive(:skip).with(Types::MAP) + # prot.should_receive(:read_map_begin).and_return([Types::I32, Types::I32, 0]) + # prot.should_receive(:read_map_end) + prot.should_receive(:skip).with(Types::I32) + struct.read(prot) + + struct.simple.should == 42 + struct.complex.should be_nil + struct.words.should == "foobar" + struct.hello.should == Hello.new(:greeting => 'hello, world!') + struct.ints.should == [1, 2, 2, 3] + struct.shorts.should == Set.new([5, 17, 239]) + end + + it "should write itself to the wire" do + prot = BaseProtocol.new(mock("transport")) #mock("Protocol") + prot.should_receive(:write_struct_begin).with("SpecNamespace::Foo") + prot.should_receive(:write_struct_begin).with("SpecNamespace::Hello") + prot.should_receive(:write_struct_end).twice + prot.should_receive(:write_field_begin).with('ints', Types::LIST, 4) + prot.should_receive(:write_i32).with(1) + prot.should_receive(:write_i32).with(2).twice + prot.should_receive(:write_i32).with(3) + prot.should_receive(:write_field_begin).with('complex', Types::MAP, 5) + prot.should_receive(:write_i32).with(5) + prot.should_receive(:write_string).with('foo') + prot.should_receive(:write_double).with(1.23) + prot.should_receive(:write_field_begin).with('shorts', Types::SET, 6) + prot.should_receive(:write_i16).with(5) + prot.should_receive(:write_i16).with(17) + prot.should_receive(:write_i16).with(239) + prot.should_receive(:write_field_stop).twice + prot.should_receive(:write_field_end).exactly(6).times + prot.should_receive(:write_field_begin).with('simple', Types::I32, 1) + prot.should_receive(:write_i32).with(53) + prot.should_receive(:write_field_begin).with('hello', Types::STRUCT, 3) + prot.should_receive(:write_field_begin).with('greeting', Types::STRING, 1) + prot.should_receive(:write_string).with('hello, world!') + prot.should_receive(:write_map_begin).with(Types::I32, Types::MAP, 1) + prot.should_receive(:write_map_begin).with(Types::STRING, Types::DOUBLE, 1) + prot.should_receive(:write_map_end).twice + prot.should_receive(:write_list_begin).with(Types::I32, 4) + prot.should_receive(:write_list_end) + prot.should_receive(:write_set_begin).with(Types::I16, 3) + prot.should_receive(:write_set_end) + + struct = Foo.new + struct.words = nil + struct.complex = {5 => {"foo" => 1.23}} + struct.write(prot) + end + + it "should raise an exception if presented with an unknown container" do + # yeah this is silly, but I'm going for code coverage here + struct = Foo.new + lambda { struct.send :write_container, nil, nil, {:type => "foo"} }.should raise_error(StandardError, "Not a container type: foo") + end + + it "should support optional type-checking in Thrift::Struct.new" do + Thrift.type_checking = true + begin + lambda { Hello.new(:greeting => 3) }.should raise_error(TypeError, "Expected Types::STRING, received Fixnum for field greeting") + ensure + Thrift.type_checking = false + end + lambda { Hello.new(:greeting => 3) }.should_not raise_error(TypeError) + end + + it "should support optional type-checking in field accessors" do + Thrift.type_checking = true + begin + hello = Hello.new + lambda { hello.greeting = 3 }.should raise_error(TypeError, "Expected Types::STRING, received Fixnum for field greeting") + ensure + Thrift.type_checking = false + end + lambda { hello.greeting = 3 }.should_not raise_error(TypeError) + end + + it "should raise an exception when unknown types are given to Thrift::Struct.new" do + lambda { Hello.new(:fish => 'salmon') }.should raise_error(Exception, "Unknown key given to SpecNamespace::Hello.new: fish") + end + + it "should support `raise Xception, 'message'` for Exception structs" do + begin + raise Xception, "something happened" + rescue Thrift::Exception => e + e.message.should == "something happened" + e.code.should == 1 + # ensure it gets serialized properly, this is the really important part + prot = BaseProtocol.new(mock("trans")) + prot.should_receive(:write_struct_begin).with("SpecNamespace::Xception") + prot.should_receive(:write_struct_end) + prot.should_receive(:write_field_begin).with('message', Types::STRING, 1)#, "something happened") + prot.should_receive(:write_string).with("something happened") + prot.should_receive(:write_field_begin).with('code', Types::I32, 2)#, 1) + prot.should_receive(:write_i32).with(1) + prot.should_receive(:write_field_stop) + prot.should_receive(:write_field_end).twice + + e.write(prot) + end + end + + it "should support the regular initializer for exception structs" do + begin + raise Xception, :message => "something happened", :code => 5 + rescue Thrift::Exception => e + e.message.should == "something happened" + e.code.should == 5 + prot = BaseProtocol.new(mock("trans")) + prot.should_receive(:write_struct_begin).with("SpecNamespace::Xception") + prot.should_receive(:write_struct_end) + prot.should_receive(:write_field_begin).with('message', Types::STRING, 1) + prot.should_receive(:write_string).with("something happened") + prot.should_receive(:write_field_begin).with('code', Types::I32, 2) + prot.should_receive(:write_i32).with(5) + prot.should_receive(:write_field_stop) + prot.should_receive(:write_field_end).twice + + e.write(prot) + end + end + end +end diff --git a/lib/rb/spec/types_spec.rb b/lib/rb/spec/types_spec.rb new file mode 100644 index 00000000..d979cfb1 --- /dev/null +++ b/lib/rb/spec/types_spec.rb @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + '/gen-rb/thrift_spec_types' + +class ThriftTypesSpec < Spec::ExampleGroup + include Thrift + + before(:each) do + Thrift.type_checking = true + end + + after(:each) do + Thrift.type_checking = false + end + + describe "Type checking" do + it "should return the proper name for each type" do + Thrift.type_name(Types::I16).should == "Types::I16" + Thrift.type_name(Types::VOID).should == "Types::VOID" + Thrift.type_name(Types::LIST).should == "Types::LIST" + Thrift.type_name(42).should be_nil + end + + it "should check types properly" do + # lambda { Thrift.check_type(nil, Types::STOP) }.should raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::STOP}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::VOID}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::VOID}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(true, {:type => Types::BOOL}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::BOOL}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::BYTE}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I16}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I32}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(42, {:type => Types::I64}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3.14, {:type => Types::I32}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(3.14, {:type => Types::DOUBLE}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::DOUBLE}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type("3", {:type => Types::STRING}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(3, {:type => Types::STRING}, :foo) }.should raise_error(TypeError) + hello = SpecNamespace::Hello.new + lambda { Thrift.check_type(hello, {:type => Types::STRUCT, :class => SpecNamespace::Hello}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type("foo", {:type => Types::STRUCT}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({:foo => 1}, {:type => Types::MAP}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1], {:type => Types::MAP}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type([1], {:type => Types::LIST}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type({:foo => 1}, {:type => Types::LIST}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1,2]), {:type => Types::SET}, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1,2], {:type => Types::SET}, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({:foo => true}, {:type => Types::SET}, :foo) }.should raise_error(TypeError) + end + + it "should error out if nil is passed and skip_types is false" do + lambda { Thrift.check_type(nil, {:type => Types::BOOL}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::BYTE}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I16}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I32}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::I64}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::DOUBLE}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::STRING}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::STRUCT}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::LIST}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::SET}, :foo, false) }.should raise_error(TypeError) + lambda { Thrift.check_type(nil, {:type => Types::MAP}, :foo, false) }.should raise_error(TypeError) + end + + it "should check element types on containers" do + field = {:type => Types::LIST, :element => {:type => Types::I32}} + lambda { Thrift.check_type([1, 2], field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type([1, nil, 2], field, :foo) }.should raise_error(TypeError) + field = {:type => Types::MAP, :key => {:type => Types::I32}, :value => {:type => Types::STRING}} + lambda { Thrift.check_type({1 => "one", 2 => "two"}, field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type({1 => "one", nil => "nil"}, field, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type({1 => nil, 2 => "two"}, field, :foo) }.should raise_error(TypeError) + field = {:type => Types::SET, :element => {:type => Types::I32}} + lambda { Thrift.check_type(Set.new([1, 2]), field, :foo) }.should_not raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1, nil, 2]), field, :foo) }.should raise_error(TypeError) + lambda { Thrift.check_type(Set.new([1, 2.3, 2]), field, :foo) }.should raise_error(TypeError) + + field = {:type => Types::STRUCT, :class => SpecNamespace::Hello} + lambda { Thrift.check_type(SpecNamespace::BoolStruct, field, :foo) }.should raise_error(TypeError) + end + + it "should give the TypeError a readable message" do + msg = "Expected Types::STRING, received Fixnum for field foo" + lambda { Thrift.check_type(3, {:type => Types::STRING}, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::STRING, received Fixnum for field foo.element" + field = {:type => Types::LIST, :element => {:type => Types::STRING}} + lambda { Thrift.check_type([3], field, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::I32, received NilClass for field foo.element.key" + field = {:type => Types::LIST, + :element => {:type => Types::MAP, + :key => {:type => Types::I32}, + :value => {:type => Types::I32}}} + lambda { Thrift.check_type([{nil => 3}], field, :foo) }.should raise_error(TypeError, msg) + msg = "Expected Types::I32, received NilClass for field foo.element.value" + lambda { Thrift.check_type([{1 => nil}], field, :foo) }.should raise_error(TypeError, msg) + end + end +end diff --git a/lib/rb/spec/unix_socket_spec.rb b/lib/rb/spec/unix_socket_spec.rb new file mode 100644 index 00000000..df239d7a --- /dev/null +++ b/lib/rb/spec/unix_socket_spec.rb @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.dirname(__FILE__) + '/spec_helper' +require File.dirname(__FILE__) + "/socket_spec_shared" + +class ThriftUNIXSocketSpec < Spec::ExampleGroup + include Thrift + + describe UNIXSocket do + before(:each) do + @path = '/tmp/thrift_spec_socket' + @socket = UNIXSocket.new(@path) + @handle = mock("Handle", :closed? => false) + @handle.stub!(:close) + ::UNIXSocket.stub!(:new).and_return(@handle) + end + + it_should_behave_like "a socket" + + it "should raise a TransportException when it cannot open a socket" do + ::UNIXSocket.should_receive(:new).and_raise(StandardError) + lambda { @socket.open }.should raise_error(Thrift::TransportException) { |e| e.type.should == Thrift::TransportException::NOT_OPEN } + end + + it "should accept an optional timeout" do + ::UNIXSocket.stub!(:new) + UNIXSocket.new(@path, 5).timeout.should == 5 + end + end + + describe UNIXServerSocket do + before(:each) do + @path = '/tmp/thrift_spec_socket' + @socket = UNIXServerSocket.new(@path) + end + + it "should create a handle when calling listen" do + UNIXServer.should_receive(:new).with(@path) + @socket.listen + end + + it "should create a Thrift::UNIXSocket to wrap accepted sockets" do + handle = mock("UNIXServer") + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + sock = mock("sock") + handle.should_receive(:accept).and_return(sock) + trans = mock("UNIXSocket") + UNIXSocket.should_receive(:new).and_return(trans) + trans.should_receive(:handle=).with(sock) + @socket.accept.should == trans + end + + it "should close the handle when closed" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + handle.should_receive(:close) + File.stub!(:delete) + @socket.close + end + + it "should delete the socket when closed" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.should_receive(:new).with(@path).and_return(handle) + @socket.listen + handle.stub!(:close) + File.should_receive(:delete).with(@path) + @socket.close + end + + it "should return nil when accepting if there is no handle" do + @socket.accept.should be_nil + end + + it "should return true for closed? when appropriate" do + handle = mock("UNIXServer", :closed? => false) + UNIXServer.stub!(:new).and_return(handle) + File.stub!(:delete) + @socket.listen + @socket.should_not be_closed + handle.stub!(:close) + @socket.close + @socket.should be_closed + @socket.listen + @socket.should_not be_closed + handle.stub!(:closed?).and_return(true) + @socket.should be_closed + end + end +end diff --git a/lib/st/README b/lib/st/README new file mode 100644 index 00000000..be865b8f --- /dev/null +++ b/lib/st/README @@ -0,0 +1,35 @@ +Thrift SmallTalk Software Library + +Last updated Nov 2007 + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Library +======= + +To get started, just file in thrift.st with Squeak, run thrift -st +on the tutorial .thrift files (and file in the resulting code), and +then: + +calc := CalculatorClient binaryOnHost: 'localhost' port: '9090' +calc addNum1: 10 num2: 15 + +Tested in Squeak 3.7, but should work fine with anything later. diff --git a/lib/st/thrift.st b/lib/st/thrift.st new file mode 100644 index 00000000..6883539f --- /dev/null +++ b/lib/st/thrift.st @@ -0,0 +1,812 @@ +" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +License); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +AS IS BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +" + +SystemOrganization addCategory: #Thrift! +SystemOrganization addCategory: #'Thrift-Protocol'! +SystemOrganization addCategory: #'Thrift-Transport'! + +Error subclass: #TError + instanceVariableNames: 'code' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +signalWithCode: anInteger + self new code: anInteger; signal! ! + +!TError methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +code + ^ code! ! + +!TError methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +code: anInteger + code := anInteger! ! + +TError subclass: #TProtocolError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +badVersion + ^ 4! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +invalidData + ^ 1! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:39'! +negativeSize + ^ 2! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:40'! +sizeLimit + ^ 3! ! + +!TProtocolError class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:40'! +unknown + ^ 0! ! + +TError subclass: #TTransportError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +TTransportError subclass: #TTransportClosedError + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +Object subclass: #TClient + instanceVariableNames: 'iprot oprot seqid remoteSeqid' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TClient class methodsFor: 'as yet unclassified' stamp: 'pc 11/7/2007 06:00'! +binaryOnHost: aString port: anInteger + | sock | + sock := TSocket new host: aString; port: anInteger; open; yourself. + ^ self new + inProtocol: (TBinaryProtocol new transport: sock); + yourself! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 23:03'! +inProtocol: aProtocol + iprot := aProtocol. + oprot ifNil: [oprot := aProtocol]! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 04:28'! +nextSeqid + ^ seqid + ifNil: [seqid := 0] + ifNotNil: [seqid := seqid + 1]! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:51'! +outProtocol: aProtocol + oprot := aProtocol! ! + +!TClient methodsFor: 'as yet unclassified' stamp: 'pc 10/28/2007 15:32'! +validateRemoteMessage: aMsg + remoteSeqid + ifNil: [remoteSeqid := aMsg seqid] + ifNotNil: + [(remoteSeqid + 1) = aMsg seqid ifFalse: + [TProtocolError signal: 'Bad seqid: ', aMsg seqid asString, + '; wanted: ', remoteSeqid asString]. + remoteSeqid := aMsg seqid]! ! + +Object subclass: #TField + instanceVariableNames: 'name type id' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +id + ^ id ifNil: [0]! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +id: anInteger + id := anInteger! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +name + ^ name ifNil: ['']! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +name: anObject + name := anObject! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +type + ^ type ifNil: [TType stop]! ! + +!TField methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:44'! +type: anInteger + type := anInteger! ! + +Object subclass: #TMessage + instanceVariableNames: 'name seqid type' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TMessage subclass: #TCallMessage + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TCallMessage methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:53'! +type + ^ 1! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +name + ^ name ifNil: ['']! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +name: aString + name := aString! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:05'! +seqid + ^ seqid ifNil: [0]! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +seqid: anInteger + seqid := anInteger! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:06'! +type + ^ type ifNil: [0]! ! + +!TMessage methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:35'! +type: anInteger + type := anInteger! ! + +Object subclass: #TProtocol + instanceVariableNames: 'transport' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TProtocol subclass: #TBinaryProtocol + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:24'! +intFromByteArray: buf + | vals | + vals := Array new: buf size. + 1 to: buf size do: [:n | vals at: n put: ((buf at: n) bitShift: (buf size - n) * 8)]. + ^ vals sum! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 18:46'! +readBool + ^ self readByte isZero not! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/25/2007 00:02'! +readByte + ^ (self transport read: 1) first! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/28/2007 16:24'! +readDouble + | val | + val := Float new: 2. + ^ val basicAt: 1 put: (self readRawInt: 4); + basicAt: 2 put: (self readRawInt: 4); + yourself! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 20:02'! +readFieldBegin + | field | + field := TField new type: self readByte. + + ^ field type = TType stop + ifTrue: [field] + ifFalse: [field id: self readI16; yourself]! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:15'! +readI16 + ^ self readInt: 2! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:20'! +readI32 + ^ self readInt: 4! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:20'! +readI64 + ^ self readInt: 8! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 02:35'! +readInt: size + | buf val | + buf := transport read: size. + val := self intFromByteArray: buf. + ^ buf first > 16r7F + ifTrue: [self unsignedInt: val size: size] + ifFalse: [val]! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:57'! +readListBegin + ^ TList new + elemType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:58'! +readMapBegin + ^ TMap new + keyType: self readByte; + valueType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:22'! +readMessageBegin + | version | + version := self readI32. + + (version bitAnd: self versionMask) = self version1 + ifFalse: [TProtocolError signalWithCode: TProtocolError badVersion]. + + ^ TMessage new + type: (version bitAnd: 16r000000FF); + name: self readString; + seqid: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 10/28/2007 16:24'! +readRawInt: size + ^ self intFromByteArray: (transport read: size)! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 00:59'! +readSetBegin + "element type, size" + ^ TSet new + elemType: self readByte; + size: self readI32! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 02/07/2009 19:00'! +readString +readString + | sz | + sz := self readI32. + ^ sz > 0 ifTrue: [(transport read: sz) asString] ifFalse: ['']! ! + +!TBinaryProtocol methodsFor: 'reading' stamp: 'pc 11/1/2007 04:22'! +unsignedInt: val size: size + ^ 0 - ((val - 1) bitXor: ((2 raisedTo: (size * 8)) - 1))! ! + +!TBinaryProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:13'! +version1 + ^ 16r80010000 ! ! + +!TBinaryProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 18:01'! +versionMask + ^ 16rFFFF0000! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:35'! +write: aString + transport write: aString! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:23'! +writeBool: bool + bool ifTrue: [self writeByte: 1] + ifFalse: [self writeByte: 0]! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/26/2007 09:31'! +writeByte: aNumber + aNumber > 16rFF ifTrue: [TError signal: 'writeByte too big']. + transport write: (Array with: aNumber)! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/28/2007 16:16'! +writeDouble: aDouble + self writeI32: (aDouble basicAt: 1); + writeI32: (aDouble basicAt: 2)! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:56'! +writeField: aField + self writeByte: aField type; + writeI16: aField id! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/25/2007 00:01'! +writeFieldBegin: aField + self writeByte: aField type. + self writeI16: aField id! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:04'! +writeFieldStop + self writeByte: TType stop! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI16: i16 + self writeInt: i16 size: 2! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI32: i32 + self writeInt: i32 size: 4! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 02:06'! +writeI64: i64 + self writeInt: i64 size: 8! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 04:23'! +writeInt: val size: size + 1 to: size do: [:n | self writeByte: ((val bitShift: (size negated + n) * 8) bitAnd: 16rFF)]! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 00:48'! +writeListBegin: aList + self writeByte: aList elemType; writeI32: aList size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:55'! +writeMapBegin: aMap + self writeByte: aMap keyType; + writeByte: aMap valueType; + writeI32: aMap size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 20:36'! +writeMessageBegin: msg + self writeI32: (self version1 bitOr: msg type); + writeString: msg name; + writeI32: msg seqid! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 11/1/2007 00:56'! +writeSetBegin: aSet + self writeByte: aSet elemType; writeI32: aSet size! ! + +!TBinaryProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 18:35'! +writeString: aString + self writeI32: aString size; + write: aString! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readBool! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readByte! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readDouble! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readFieldBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readFieldEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI16! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI32! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readI64! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readListBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readListEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readMapBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readMapEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:39'! +readMessageBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:39'! +readMessageEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readSetBegin! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readSetEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/25/2007 16:10'! +readSimpleType: aType + aType = TType bool ifTrue: [^ self readBool]. + aType = TType byte ifTrue: [^ self readByte]. + aType = TType double ifTrue: [^ self readDouble]. + aType = TType i16 ifTrue: [^ self readI16]. + aType = TType i32 ifTrue: [^ self readI32]. + aType = TType i64 ifTrue: [^ self readI64]. + aType = TType list ifTrue: [^ self readBool].! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readString! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readStructBegin + ! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/24/2007 19:40'! +readStructEnd! ! + +!TProtocol methodsFor: 'reading' stamp: 'pc 10/26/2007 21:34'! +skip: aType + aType = TType stop ifTrue: [^ self]. + aType = TType bool ifTrue: [^ self readBool]. + aType = TType byte ifTrue: [^ self readByte]. + aType = TType i16 ifTrue: [^ self readI16]. + aType = TType i32 ifTrue: [^ self readI32]. + aType = TType i64 ifTrue: [^ self readI64]. + aType = TType string ifTrue: [^ self readString]. + aType = TType double ifTrue: [^ self readDouble]. + aType = TType struct ifTrue: + [| field | + self readStructBegin. + [(field := self readFieldBegin) type = TType stop] whileFalse: + [self skip: field type. self readFieldEnd]. + ^ self readStructEnd]. + aType = TType map ifTrue: + [| map | + map := self readMapBegin. + map size timesRepeat: [self skip: map keyType. self skip: map valueType]. + ^ self readMapEnd]. + aType = TType list ifTrue: + [| list | + list := self readListBegin. + list size timesRepeat: [self skip: list elemType]. + ^ self readListEnd]. + aType = TType set ifTrue: + [| set | + set := self readSetBegin. + set size timesRepeat: [self skip: set elemType]. + ^ self readSetEnd]. + + self error: 'Unknown type'! ! + +!TProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 23:02'! +transport + ^ transport! ! + +!TProtocol methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:28'! +transport: aTransport + transport := aTransport! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeBool: aBool! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeByte: aByte! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeDouble: aFloat! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeFieldBegin: aField! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeFieldEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeFieldStop! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI16: i16! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI32: i32! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeI64: i64! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeListBegin: aList! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeListEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeMapBegin: aMap! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeMapEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:36'! +writeMessageBegin! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:36'! +writeMessageEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:39'! +writeSetBegin: aSet! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeSetEnd! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeString: aString! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:38'! +writeStructBegin: aStruct! ! + +!TProtocol methodsFor: 'writing' stamp: 'pc 10/24/2007 19:37'! +writeStructEnd! ! + +Object subclass: #TResult + instanceVariableNames: 'success oprot iprot exception' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 21:35'! +exception + ^ exception! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 21:35'! +exception: anError + exception := anError! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 14:43'! +success + ^ success! ! + +!TResult methodsFor: 'as yet unclassified' stamp: 'pc 10/26/2007 14:43'! +success: anObject + success := anObject! ! + +Object subclass: #TSizedObject + instanceVariableNames: 'size' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TSizedObject subclass: #TList + instanceVariableNames: 'elemType' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TList methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +elemType + ^ elemType ifNil: [TType stop]! ! + +!TList methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:42'! +elemType: anInteger + elemType := anInteger! ! + +TList subclass: #TSet + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +TSizedObject subclass: #TMap + instanceVariableNames: 'keyType valueType' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +keyType + ^ keyType ifNil: [TType stop]! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:45'! +keyType: anInteger + keyType := anInteger! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 20:04'! +valueType + ^ valueType ifNil: [TType stop]! ! + +!TMap methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:45'! +valueType: anInteger + valueType := anInteger! ! + +!TSizedObject methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:03'! +size + ^ size ifNil: [0]! ! + +!TSizedObject methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:06'! +size: anInteger + size := anInteger! ! + +Object subclass: #TSocket + instanceVariableNames: 'host port stream' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:34'! +close + self isOpen ifTrue: [stream close]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:23'! +connect + ^ (self socketStream openConnectionToHost: + (NetNameResolver addressForName: host) port: port) + timeout: 180; + binary; + yourself! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:35'! +flush + stream flush! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:08'! +host: aString + host := aString! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 20:34'! +isOpen + ^ stream isNil not + and: [stream socket isConnected] + and: [stream socket isOtherEndClosed not]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:22'! +open + stream := self connect! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:09'! +port: anInteger + port := anInteger! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:17'! +read: size + | data | + [data := stream next: size. + data isEmpty ifTrue: [TTransportError signal: 'Could not read ', size asString, ' bytes']. + ^ data] + on: ConnectionClosed + do: [TTransportClosedError signal]! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:18'! +socketStream + ^ Smalltalk at: #FastSocketStream ifAbsent: [SocketStream] ! ! + +!TSocket methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 22:17'! +write: aCollection + [stream nextPutAll: aCollection] + on: ConnectionClosed + do: [TTransportClosedError signal]! ! + +Object subclass: #TStruct + instanceVariableNames: 'name' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Protocol'! + +!TStruct methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:47'! +name + ^ name! ! + +!TStruct methodsFor: 'accessing' stamp: 'pc 10/24/2007 19:47'! +name: aString + name := aString! ! + +Object subclass: #TTransport + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift-Transport'! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +close + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +flush + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +isOpen + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +open + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:18'! +read: anInteger + self subclassResponsibility! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +readAll: anInteger + ^ String streamContents: [:str | + [str size < anInteger] whileTrue: + [str nextPutAll: (self read: anInteger - str size)]]! ! + +!TTransport methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:22'! +write: aString + self subclassResponsibility! ! + +Object subclass: #TType + instanceVariableNames: '' + classVariableNames: '' + poolDictionaries: '' + category: 'Thrift'! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +bool + ^ 2! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +byte + ^ 3! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:55'! +codeOf: aTypeName + self typeMap do: [:each | each first = aTypeName ifTrue: [^ each second]]. + ^ nil! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +double + ^ 4! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i16 + ^ 6! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i32 + ^ 8! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +i64 + ^ 10! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +list + ^ 15! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +map + ^ 13! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:56'! +nameOf: aTypeCode + self typeMap do: [:each | each second = aTypeCode ifTrue: [^ each first]]. + ^ nil! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +set + ^ 14! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +stop + ^ 0! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +string + ^ 11! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:04'! +struct + ^ 12! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/25/2007 15:51'! +typeMap + ^ #((bool 2) (byte 3) (double 4) (i16 6) (i32 8) (i64 10) (list 15) + (map 13) (set 15) (stop 0) (string 11) (struct 12) (void 1))! ! + +!TType class methodsFor: 'as yet unclassified' stamp: 'pc 10/24/2007 17:03'! +void + ^ 1! ! diff --git a/print_version.sh b/print_version.sh new file mode 100755 index 00000000..f36fa5ea --- /dev/null +++ b/print_version.sh @@ -0,0 +1,47 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +cd "`dirname "$0"`" + +# Computing both the version and the revision on every invocation is wasteful, +# but it is cheap and avoids the use of nonportable shell functions. + +VERSION=`sed -ne 's/^AC_INIT(\[thrift\], \[\(.*\)\])$/\1/p' configure.ac` + +if test -d .svn ; then + REVISION="r`svnversion`" +elif test -d .git ; then + SHA1=`git rev-list --max-count=1 --grep='^git-svn-id:' HEAD` + REVISION=`git cat-file commit $SHA1 | sed -ne 's/^git-svn-id:[^@]*@\([0-9][0-9]*\).*/r\1/p'` + OFFSET=`git rev-list ^$SHA1 HEAD | wc -l | tr -d ' '` + if test $OFFSET != 0 ; then + REVISION="$REVISION-$OFFSET-`git rev-parse --verify HEAD | cut -c 1-7`" + fi +else + REVISION="exported" +fi + +case "$1" in + -v) echo $VERSION ;; + -r) echo $REVISION ;; + -a) echo "$VERSION-$REVISION" ;; + *) echo "Usage: $0 -v|-r|-a"; exit 1;; +esac diff --git a/test/AllProtocolTests.cpp b/test/AllProtocolTests.cpp new file mode 100644 index 00000000..db29cccf --- /dev/null +++ b/test/AllProtocolTests.cpp @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include +#include +#include +#include "AllProtocolTests.tcc" + +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; + +char errorMessage[ERR_LEN]; + +int main(int argc, char** argv) { + try { + testProtocol("TBinaryProtocol"); + testProtocol("TCompactProtocol"); + } catch (TException e) { + printf("%s\n", e.what()); + return 1; + } + return 0; +} diff --git a/test/AllProtocolTests.tcc b/test/AllProtocolTests.tcc new file mode 100644 index 00000000..a5a31156 --- /dev/null +++ b/test/AllProtocolTests.tcc @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TEST_GENERICPROTOCOLTEST_TCC_ +#define _THRIFT_TEST_GENERICPROTOCOLTEST_TCC_ 1 + +#include + +#include +#include +#include + +#include "GenericHelpers.h" + +using boost::shared_ptr; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; + +#define ERR_LEN 512 +extern char errorMessage[ERR_LEN]; + +template +void testNaked(Val val) { + shared_ptr transport(new TMemoryBuffer()); + shared_ptr protocol(new TProto(transport)); + + GenericIO::write(protocol, val); + Val out; + GenericIO::read(protocol, out); + if (out != val) { + snprintf(errorMessage, ERR_LEN, "Invalid naked test (type: %s)", ClassNames::getName()); + throw TException(errorMessage); + } +} + +template +void testField(const Val val) { + shared_ptr transport(new TMemoryBuffer()); + shared_ptr protocol(new TProto(transport)); + + protocol->writeStructBegin("test_struct"); + protocol->writeFieldBegin("test_field", type, (int16_t)15); + + GenericIO::write(protocol, val); + + protocol->writeFieldEnd(); + protocol->writeStructEnd(); + + std::string name; + TType fieldType; + int16_t fieldId; + + protocol->readStructBegin(name); + protocol->readFieldBegin(name, fieldType, fieldId); + + if (fieldId != 15) { + snprintf(errorMessage, ERR_LEN, "Invalid ID (type: %s)", typeid(val).name()); + throw TException(errorMessage); + } + if (fieldType != type) { + snprintf(errorMessage, ERR_LEN, "Invalid Field Type (type: %s)", typeid(val).name()); + throw TException(errorMessage); + } + + Val out; + GenericIO::read(protocol, out); + + if (out != val) { + snprintf(errorMessage, ERR_LEN, "Invalid value read (type: %s)", typeid(val).name()); + throw TException(errorMessage); + } + + protocol->readFieldEnd(); + protocol->readStructEnd(); +} + +template +void testMessage() { + struct TMessage { + const char* name; + TMessageType type; + int32_t seqid; + } messages[4] = { + {"short message name", T_CALL, 0}, + {"1", T_REPLY, 12345}, + {"loooooooooooooooooooooooooooooooooong", T_EXCEPTION, 1 << 16}, + {"Janky", T_CALL, 0} + }; + + for (int i = 0; i < 4; i++) { + shared_ptr transport(new TMemoryBuffer()); + shared_ptr protocol(new TProto(transport)); + + protocol->writeMessageBegin(messages[i].name, + messages[i].type, + messages[i].seqid); + protocol->writeMessageEnd(); + + std::string name; + TMessageType type; + int32_t seqid; + + protocol->readMessageBegin(name, type, seqid); + if (name != messages[i].name || + type != messages[i].type || + seqid != messages[i].seqid) { + throw TException("readMessageBegin failed."); + } + } +} + +template +void testProtocol(const char* protoname) { + try { + testNaked((int8_t)123); + + for (int32_t i = 0; i < 128; i++) { + testField((int8_t)i); + testField((int8_t)-i); + } + + testNaked((int16_t)0); + testNaked((int16_t)1); + testNaked((int16_t)15000); + testNaked((int16_t)0x7fff); + testNaked((int16_t)-1); + testNaked((int16_t)-15000); + testNaked((int16_t)-0x7fff); + testNaked(std::numeric_limits::min()); + testNaked(std::numeric_limits::max()); + + testField((int16_t)0); + testField((int16_t)1); + testField((int16_t)7); + testField((int16_t)150); + testField((int16_t)15000); + testField((int16_t)0x7fff); + testField((int16_t)-1); + testField((int16_t)-7); + testField((int16_t)-150); + testField((int16_t)-15000); + testField((int16_t)-0x7fff); + + testNaked(0); + testNaked(1); + testNaked(15000); + testNaked(0xffff); + testNaked(-1); + testNaked(-15000); + testNaked(-0xffff); + testNaked(std::numeric_limits::min()); + testNaked(std::numeric_limits::max()); + + testField(0); + testField(1); + testField(7); + testField(150); + testField(15000); + testField(31337); + testField(0xffff); + testField(0xffffff); + testField(-1); + testField(-7); + testField(-150); + testField(-15000); + testField(-0xffff); + testField(-0xffffff); + testNaked(std::numeric_limits::min()); + testNaked(std::numeric_limits::max()); + testNaked(std::numeric_limits::min() + 10); + testNaked(std::numeric_limits::max() - 16); + testNaked(std::numeric_limits::min()); + testNaked(std::numeric_limits::max()); + + + testNaked(0); + for (int64_t i = 0; i < 62; i++) { + testNaked(1L << i); + testNaked(-(1L << i)); + } + + testField(0); + for (int i = 0; i < 62; i++) { + testField(1L << i); + testField(-(1L << i)); + } + + testNaked(123.456); + + testNaked(""); + testNaked("short"); + testNaked("borderlinetiny"); + testNaked("a bit longer than the smallest possible"); + testNaked("\x1\x2\x3\x4\x5\x6\x7\x8\x9\xA"); //kinda binary test + + testField(""); + testField("short"); + testField("borderlinetiny"); + testField("a bit longer than the smallest possible"); + + testMessage(); + + printf("%s => OK\n", protoname); + } catch (TException e) { + snprintf(errorMessage, ERR_LEN, "%s => Test FAILED: %s", protoname, e.what()); + throw TException(errorMessage); + } +} + +#endif diff --git a/test/AnnotationTest.thrift b/test/AnnotationTest.thrift new file mode 100644 index 00000000..dcc41b0b --- /dev/null +++ b/test/AnnotationTest.thrift @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +typedef list ( cpp.template = "std::list" ) int_linked_list + +struct foo { + 1: i32 bar; + 2: i32 baz; + 3: i32 qux; + 4: i32 bop; +} ( + cpp.type = "DenseFoo", + python.type = "DenseFoo", + java.final = "", +) diff --git a/test/Benchmark.cpp b/test/Benchmark.cpp new file mode 100644 index 00000000..d315fca7 --- /dev/null +++ b/test/Benchmark.cpp @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include "gen-cpp/DebugProtoTest_types.h" +#include +#include "../lib/cpp/src/protocol/TDebugProtocol.h" +#include + +class Timer { +public: + timeval vStart; + + Timer() { + gettimeofday(&vStart, 0); + } + void start() { + gettimeofday(&vStart, 0); + } + + double frame() { + timeval vEnd; + gettimeofday(&vEnd, 0); + double dstart = vStart.tv_sec + ((double)vStart.tv_usec / 1000000.0); + double dend = vEnd.tv_sec + ((double)vEnd.tv_usec / 1000000.0); + return dend - dstart; + } + +}; + +int main() { + using namespace std; + using namespace thrift::test::debug; + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + using namespace boost; + + OneOfEach ooe; + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (uint64_t)6000 * 1000 * 1000; + ooe.double_precision = M_PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = "\xd7\n\a\t"; + ooe.base64 = "\1\2\3\255"; + + shared_ptr buf(new TMemoryBuffer()); + + int num = 1000000; + + { + Timer timer; + + for (int i = 0; i < num; i ++) { + buf->resetBuffer(); + TBinaryProtocol prot(buf); + ooe.write(&prot); + } + cout << "Write: " << num / (1000 * timer.frame()) << " kHz" << endl; + } + + uint8_t* data; + uint32_t datasize; + + buf->getBuffer(&data, &datasize); + + { + + Timer timer; + + for (int i = 0; i < num; i ++) { + OneOfEach ooe2; + shared_ptr buf2(new TMemoryBuffer(data, datasize)); + //buf2->resetBuffer(data, datasize); + TBinaryProtocol prot(buf2); + ooe2.read(&prot); + + //cout << apache::thrift::ThriftDebugString(ooe2) << endl << endl; + } + cout << " Read: " << num / (1000 * timer.frame()) << " kHz" << endl; + } + + + return 0; +} diff --git a/test/BrokenConstants.thrift b/test/BrokenConstants.thrift new file mode 100644 index 00000000..c5aab4ab --- /dev/null +++ b/test/BrokenConstants.thrift @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +const i64 myint = 68719476736 +const i64 broken = 9876543210987654321 // A little over 2^63 + +enum foo { + bar = 68719476736 +} diff --git a/test/ConstantsDemo.thrift b/test/ConstantsDemo.thrift new file mode 100644 index 00000000..7e97f02c --- /dev/null +++ b/test/ConstantsDemo.thrift @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace cpp yozone + +struct thing { + 1: i32 hello, + 2: i32 goodbye +} + +enum enumconstants { + ONE = 1, + TWO = 2 +} + +struct thing2 { + 1: enumconstants val = TWO +} + +typedef i32 myIntType +const myIntType myInt = 3 + +const map GEN_ENUM_NAMES = {ONE : "HOWDY", TWO: PARTNER} + +const i32 hex_const = 0x0001F + +const i32 GEN_ME = -3523553 +const double GEn_DUB = 325.532 +const double GEn_DU = 085.2355 +const string GEN_STRING = "asldkjasfd" + +const map GEN_MAP = { 35532 : 233, 43523 : 853 } +const list GEN_LIST = [ 235235, 23598352, 3253523 ] + +const map> GEN_MAPMAP = { 235 : { 532 : 53255, 235:235}} + +const map GEN_MAP2 = { "hello" : 233, "lkj98d" : 853, 'lkjsdf' : 098325 } + +const thing GEN_THING = { 'hello' : 325, 'goodbye' : 325352 } + +const map GEN_WHAT = { 35 : { 'hello' : 325, 'goodbye' : 325352 } } + +const set GEN_SET = [ 235, 235, 53235 ] + +exception Blah { + 1: i32 bing } + +exception Gak {} + +service yowza { + void blingity(), + i32 blangity() throws (1: Blah hoot ) +} diff --git a/test/DebugProtoTest.cpp b/test/DebugProtoTest.cpp new file mode 100644 index 00000000..8c9aabec --- /dev/null +++ b/test/DebugProtoTest.cpp @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include "gen-cpp/DebugProtoTest_types.h" +#include "../lib/cpp/src/protocol/TDebugProtocol.h" + +int main() { + using std::cout; + using std::endl; + using namespace thrift::test::debug; + + + OneOfEach ooe; + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (uint64_t)6000 * 1000 * 1000; + ooe.double_precision = M_PI; + ooe.some_characters = "Debug THIS!"; + ooe.zomg_unicode = "\xd7\n\a\t"; + + cout << apache::thrift::ThriftDebugString(ooe) << endl << endl; + + + Nesting n; + n.my_ooe = ooe; + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (std::sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20" + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e" + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba\xc7\x83\xe2\x80\xbc"; + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + cout << apache::thrift::ThriftDebugString(n) << endl << endl; + + + HolyMoley hm; + + hm.big.push_back(ooe); + hm.big.push_back(n.my_ooe); + hm.big[0].a_bite = 0x22; + hm.big[1].a_bite = 0x33; + + std::vector stage1; + stage1.push_back("and a one"); + stage1.push_back("and a two"); + hm.contain.insert(stage1); + stage1.clear(); + stage1.push_back("then a one, two"); + stage1.push_back("three!"); + stage1.push_back("FOUR!!"); + hm.contain.insert(stage1); + stage1.clear(); + hm.contain.insert(stage1); + + std::vector stage2; + hm.bonks["nothing"] = stage2; + stage2.resize(stage2.size()+1); + stage2.back().type = 1; + stage2.back().message = "Wait."; + stage2.resize(stage2.size()+1); + stage2.back().type = 2; + stage2.back().message = "What?"; + hm.bonks["something"] = stage2; + stage2.clear(); + stage2.resize(stage2.size()+1); + stage2.back().type = 3; + stage2.back().message = "quoth"; + stage2.resize(stage2.size()+1); + stage2.back().type = 4; + stage2.back().message = "the raven"; + stage2.resize(stage2.size()+1); + stage2.back().type = 5; + stage2.back().message = "nevermore"; + hm.bonks["poe"] = stage2; + + cout << apache::thrift::ThriftDebugString(hm) << endl << endl; + + + return 0; +} diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift new file mode 100644 index 00000000..d3d25802 --- /dev/null +++ b/test/DebugProtoTest.thrift @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace cpp thrift.test.debug +namespace java thrift.test + +struct Doubles { + 1: double nan, + 2: double inf, + 3: double neginf, + 4: double repeating, + 5: double big, + 6: double small, + 7: double zero, + 8: double negzero, +} + +struct OneOfEach { + 1: bool im_true, + 2: bool im_false, + 3: byte a_bite = 200, + 4: i16 integer16 = 33000, + 5: i32 integer32, + 6: i64 integer64 = 10000000000, + 7: double double_precision, + 8: string some_characters, + 9: string zomg_unicode, + 10: bool what_who, + 11: binary base64, + 12: list byte_list = [1, 2, 3], + 13: list i16_list = [1,2,3], + 14: list i64_list = [1,2,3] +} + +struct Bonk { + 1: i32 type, + 2: string message, +} + +struct Nesting { + 1: Bonk my_bonk, + 2: OneOfEach my_ooe, +} + +struct HolyMoley { + 1: list big, + 2: set> contain, + 3: map> bonks, +} + +struct Backwards { + 2: i32 first_tag2, + 1: i32 second_tag1, +} + +struct Empty { +} + +struct Wrapper { + 1: Empty foo +} + +struct RandomStuff { + 1: i32 a, + 2: i32 b, + 3: i32 c, + 4: i32 d, + 5: list myintlist, + 6: map maps, + 7: i64 bigint, + 8: double triple, +} + +struct Base64 { + 1: i32 a, + 2: binary b1, + 3: binary b2, + 4: binary b3, + 5: binary b4, + 6: binary b5, + 7: binary b6, +} + +struct CompactProtoTestStruct { + // primitive fields + 1: byte a_byte; + 2: i16 a_i16; + 3: i32 a_i32; + 4: i64 a_i64; + 5: double a_double; + 6: string a_string; + 7: binary a_binary; + 8: bool true_field; + 9: bool false_field; + 10: Empty empty_struct_field; + + // primitives in lists + 11: list byte_list; + 12: list i16_list; + 13: list i32_list; + 14: list i64_list; + 15: list double_list; + 16: list string_list; + 17: list binary_list; + 18: list boolean_list; + 19: list struct_list; + + // primitives in sets + 20: set byte_set; + 21: set i16_set; + 22: set i32_set; + 23: set i64_set; + 24: set double_set; + 25: set string_set; + 26: set binary_set; + 27: set boolean_set; + 28: set struct_set; + + // maps + // primitives as keys + 29: map byte_byte_map; + 30: map i16_byte_map; + 31: map i32_byte_map; + 32: map i64_byte_map; + 33: map double_byte_map; + 34: map string_byte_map; + 35: map binary_byte_map; + 36: map boolean_byte_map; + // primitives as values + 37: map byte_i16_map; + 38: map byte_i32_map; + 39: map byte_i64_map; + 40: map byte_double_map; + 41: map byte_string_map; + 42: map byte_binary_map; + 43: map byte_boolean_map; + // collections as keys + 44: map, byte> list_byte_map; + 45: map, byte> set_byte_map; + 46: map, byte> map_byte_map; + // collections as values + 47: map> byte_map_map; + 48: map> byte_set_map; + 49: map> byte_list_map; +} + + +const CompactProtoTestStruct COMPACT_TEST = { + 'a_byte' : 127, + 'a_i16' : 32000, + 'a_i32' : 1000000000, + 'a_i64' : 0xffffffffff, + 'a_double' : 5.6789, + 'a_string' : "my string", +//'a_binary,' + 'true_field' : 1, + 'false_field' : 0, + 'empty_struct_field' : {}, + 'byte_list' : [-127, -1, 0, 1, 127], + 'i16_list' : [-1, 0, 1, 0x7fff], + 'i32_list' : [-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff], + 'i64_list' : [-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff], + 'double_list' : [0.1, 0.2, 0.3], + 'string_list' : ["first", "second", "third"], +//'binary_list,' + 'boolean_list' : [1, 1, 1, 0, 0, 0], + 'struct_list' : [{}, {}], + 'byte_set' : [-127, -1, 0, 1, 127], + 'i16_set' : [-1, 0, 1, 0x7fff], + 'i32_set' : [1, 2, 3], + 'i64_set' : [-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff], + 'double_set' : [0.1, 0.2, 0.3], + 'string_set' : ["first", "second", "third"], +//'binary_set,' + 'boolean_set' : [1, 0], + 'struct_set' : [{}], + 'byte_byte_map' : {1 : 2}, + 'i16_byte_map' : {1 : 1, -1 : 1, 0x7fff : 1}, + 'i32_byte_map' : {1 : 1, -1 : 1, 0x7fffffff : 1}, + 'i64_byte_map' : {0 : 1, 1 : 1, -1 : 1, 0x7fffffffffffffff : 1}, + 'double_byte_map' : {-1.1 : 1, 1.1 : 1}, + 'string_byte_map' : {"first" : 1, "second" : 2, "third" : 3, "" : 0}, +//'binary_byte_map,' + 'boolean_byte_map' : {1 : 1, 0 : 0}, + 'byte_i16_map' : {1 : 1, 2 : -1, 3 : 0x7fff}, + 'byte_i32_map' : {1 : 1, 2 : -1, 3 : 0x7fffffff}, + 'byte_i64_map' : {1 : 1, 2 : -1, 3 : 0x7fffffffffffffff}, + 'byte_double_map' : {1 : 0.1, 2 : -0.1, 3 : 1000000.1}, + 'byte_string_map' : {1 : "", 2 : "blah", 3 : "loooooooooooooong string"}, +//'byte_binary_map,' + 'byte_boolean_map' : {1 : 1, 2 : 0}, + 'list_byte_map' : {[1, 2, 3] : 1, [0, 1] : 2, [] : 0}, + 'set_byte_map' : {[1, 2, 3] : 1, [0, 1] : 2, [] : 0}, + 'map_byte_map' : {{1 : 1} : 1, {2 : 2} : 2, {} : 0}, + 'byte_map_map' : {0 : {}, 1 : {1 : 1}, 2 : {1 : 1, 2 : 2}}, + 'byte_set_map' : {0 : [], 1 : [1], 2 : [1, 2]}, + 'byte_list_map' : {0 : [], 1 : [1], 2 : [1, 2]}, +} + + + +service Srv { + i32 Janky(1: i32 arg); + + // return type only methods + + void voidMethod(); + i32 primitiveMethod(); + CompactProtoTestStruct structMethod(); +} + +service Inherited extends Srv { + i32 identity(1: i32 arg) +} + +service EmptyService {} + +// The only purpose of this thing is to increase the size of the generated code +// so that ZlibTest has more highly compressible data to play with. +struct BlowUp { + 1: map,set>> b1; + 2: map,set>> b2; + 3: map,set>> b3; + 4: map,set>> b4; +} + + +struct ReverseOrderStruct { + 4: string first; + 3: i16 second; + 2: i32 third; + 1: i64 fourth; +} + +service ReverseOrderService { + void myMethod(4: string first, 3: i16 second, 2: i32 third, 1: i64 fourth); +} \ No newline at end of file diff --git a/test/DebugProtoTest_extras.cpp b/test/DebugProtoTest_extras.cpp new file mode 100644 index 00000000..e68c544b --- /dev/null +++ b/test/DebugProtoTest_extras.cpp @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Extra functions required for DebugProtoTest_types to work + +#include "gen-cpp/DebugProtoTest_types.h" + + +namespace thrift { namespace test { namespace debug { + +bool Empty::operator<(Empty const& other) const { + // It is empty, so all are equal. + return false; +} + +}}} diff --git a/test/DenseLinkingTest.thrift b/test/DenseLinkingTest.thrift new file mode 100644 index 00000000..cf61496a --- /dev/null +++ b/test/DenseLinkingTest.thrift @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +../compiler/cpp/thrift -gen cpp:dense DebugProtoTest.thrift +../compiler/cpp/thrift -gen cpp:dense DenseLinkingTest.thrift +g++ -Wall -g -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \ + DebugProtoTest.cpp gen-cpp/DebugProtoTest_types.cpp \ + gen-cpp/DenseLinkingTest_types.cpp \ + ../lib/cpp/.libs/libthrift.a -o DebugProtoTest +./DebugProtoTest +*/ + +/* +The idea of this test is that everything is structurally identical to DebugProtoTest. +If I messed up the naming of the reflection local typespecs, +then compiling this should give errors because of doubly defined symbols. +*/ + +namespace cpp thrift.test + +struct OneOfEachZZ { + 1: bool im_true, + 2: bool im_false, + 3: byte a_bite, + 4: i16 integer16, + 5: i32 integer32, + 6: i64 integer64, + 7: double double_precision, + 8: string some_characters, + 9: string zomg_unicode, + 10: bool what_who, +} + +struct BonkZZ { + 1: i32 type, + 2: string message, +} + +struct NestingZZ { + 1: BonkZZ my_bonk, + 2: OneOfEachZZ my_ooe, +} + +struct HolyMoleyZZ { + 1: list big, + 2: set> contain, + 3: map> bonks, +} + +struct BackwardsZZ { + 2: i32 first_tag2, + 1: i32 second_tag1, +} + +struct EmptyZZ { +} + +struct WrapperZZ { + 1: EmptyZZ foo +} + +struct RandomStuffZZ { + 1: i32 a, + 2: i32 b, + 3: i32 c, + 4: i32 d, + 5: list myintlist, + 6: map maps, + 7: i64 bigint, + 8: double triple, +} + +service Srv { + i32 Janky(1: i32 arg) +} diff --git a/test/DenseProtoTest.cpp b/test/DenseProtoTest.cpp new file mode 100644 index 00000000..99f78655 --- /dev/null +++ b/test/DenseProtoTest.cpp @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +../compiler/cpp/thrift --gen cpp:dense DebugProtoTest.thrift +../compiler/cpp/thrift --gen cpp:dense OptionalRequiredTest.thrift +g++ -Wall -g -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \ + gen-cpp/OptionalRequiredTest_types.cpp \ + gen-cpp/DebugProtoTest_types.cpp \ + DenseProtoTest.cpp ../lib/cpp/.libs/libthrift.a -o DenseProtoTest +./DenseProtoTest +*/ + +// I do this to reach into the guts of TDenseProtocol. Sorry. +#define private public +#define inline + +#undef NDEBUG +#include +#include +#include +#include +#include +#include "gen-cpp/DebugProtoTest_types.h" +#include "gen-cpp/OptionalRequiredTest_types.h" +#include +#include + + +// Can't use memcmp here. GCC is too smart. +bool my_memeq(const char* str1, const char* str2, int len) { + for (int i = 0; i < len; i++) { + if (str1[i] != str2[i]) { + return false; + } + } + return true; +} + + +int main() { + using std::string; + using std::cout; + using std::endl; + using boost::shared_ptr; + using namespace thrift::test::debug; + using namespace apache::thrift::transport; + using namespace apache::thrift::protocol; + + + OneOfEach ooe; + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (uint64_t)6000 * 1000 * 1000; + ooe.double_precision = M_PI; + ooe.some_characters = "Debug THIS!"; + ooe.zomg_unicode = "\xd7\n\a\t"; + + //cout << apache::thrift::ThriftDebugString(ooe) << endl << endl; + + + Nesting n; + n.my_ooe = ooe; + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (std::sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20" + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e" + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba\xc7\x83\xe2\x80\xbc"; + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + //cout << apache::thrift::ThriftDebugString(n) << endl << endl; + + + HolyMoley hm; + + hm.big.push_back(ooe); + hm.big.push_back(n.my_ooe); + hm.big[0].a_bite = 0x22; + hm.big[1].a_bite = 0x33; + + std::vector stage1; + stage1.push_back("and a one"); + stage1.push_back("and a two"); + hm.contain.insert(stage1); + stage1.clear(); + stage1.push_back("then a one, two"); + stage1.push_back("three!"); + stage1.push_back("FOUR!!"); + hm.contain.insert(stage1); + stage1.clear(); + hm.contain.insert(stage1); + + std::vector stage2; + hm.bonks["nothing"] = stage2; + stage2.resize(stage2.size()+1); + stage2.back().type = 1; + stage2.back().message = "Wait."; + stage2.resize(stage2.size()+1); + stage2.back().type = 2; + stage2.back().message = "What?"; + hm.bonks["something"] = stage2; + stage2.clear(); + stage2.resize(stage2.size()+1); + stage2.back().type = 3; + stage2.back().message = "quoth"; + stage2.resize(stage2.size()+1); + stage2.back().type = 4; + stage2.back().message = "the raven"; + stage2.resize(stage2.size()+1); + stage2.back().type = 5; + stage2.back().message = "nevermore"; + hm.bonks["poe"] = stage2; + + //cout << apache::thrift::ThriftDebugString(hm) << endl << endl; + + shared_ptr buffer(new TMemoryBuffer()); + shared_ptr proto(new TDenseProtocol(buffer)); + proto->setTypeSpec(HolyMoley::local_reflection); + + hm.write(proto.get()); + HolyMoley hm2; + hm2.read(proto.get()); + + assert(hm == hm2); + + + // Let's test out the variable-length ints, shall we? + uint64_t vlq; + #define checkout(i, c) { \ + buffer->resetBuffer(); \ + proto->vlqWrite(i); \ + proto->getTransport()->flush(); \ + assert(my_memeq(buffer->getBufferAsString().data(), c, sizeof(c)-1)); \ + proto->vlqRead(vlq); \ + assert(vlq == i); \ + } + + checkout(0x00000000, "\x00"); + checkout(0x00000040, "\x40"); + checkout(0x0000007F, "\x7F"); + checkout(0x00000080, "\x81\x00"); + checkout(0x00002000, "\xC0\x00"); + checkout(0x00003FFF, "\xFF\x7F"); + checkout(0x00004000, "\x81\x80\x00"); + checkout(0x00100000, "\xC0\x80\x00"); + checkout(0x001FFFFF, "\xFF\xFF\x7F"); + checkout(0x00200000, "\x81\x80\x80\x00"); + checkout(0x08000000, "\xC0\x80\x80\x00"); + checkout(0x0FFFFFFF, "\xFF\xFF\xFF\x7F"); + checkout(0x10000000, "\x81\x80\x80\x80\x00"); + checkout(0x20000000, "\x82\x80\x80\x80\x00"); + checkout(0x1FFFFFFF, "\x81\xFF\xFF\xFF\x7F"); + checkout(0xFFFFFFFF, "\x8F\xFF\xFF\xFF\x7F"); + + checkout(0x0000000100000000ull, "\x90\x80\x80\x80\x00"); + checkout(0x0000000200000000ull, "\xA0\x80\x80\x80\x00"); + checkout(0x0000000300000000ull, "\xB0\x80\x80\x80\x00"); + checkout(0x0000000700000000ull, "\xF0\x80\x80\x80\x00"); + checkout(0x00000007F0000000ull, "\xFF\x80\x80\x80\x00"); + checkout(0x00000007FFFFFFFFull, "\xFF\xFF\xFF\xFF\x7F"); + checkout(0x0000000800000000ull, "\x81\x80\x80\x80\x80\x00"); + checkout(0x1FFFFFFFFFFFFFFFull, "\x9F\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + checkout(0x7FFFFFFFFFFFFFFFull, "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + checkout(0xFFFFFFFFFFFFFFFFull, "\x81\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + + // Test out the slow path with a TBufferedTransport. + shared_ptr buff_trans(new TBufferedTransport(buffer, 3)); + proto.reset(new TDenseProtocol(buff_trans)); + checkout(0x0000000100000000ull, "\x90\x80\x80\x80\x00"); + checkout(0x0000000200000000ull, "\xA0\x80\x80\x80\x00"); + checkout(0x0000000300000000ull, "\xB0\x80\x80\x80\x00"); + checkout(0x0000000700000000ull, "\xF0\x80\x80\x80\x00"); + checkout(0x00000007F0000000ull, "\xFF\x80\x80\x80\x00"); + checkout(0x00000007FFFFFFFFull, "\xFF\xFF\xFF\xFF\x7F"); + checkout(0x0000000800000000ull, "\x81\x80\x80\x80\x80\x00"); + checkout(0x1FFFFFFFFFFFFFFFull, "\x9F\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + checkout(0x7FFFFFFFFFFFFFFFull, "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + checkout(0xFFFFFFFFFFFFFFFFull, "\x81\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F"); + + // Test optional stuff. + proto.reset(new TDenseProtocol(buffer)); + proto->setTypeSpec(ManyOpt::local_reflection); + ManyOpt mo1, mo2, mo3, mo4, mo5, mo6; + mo1.opt1 = 923759347; + mo1.opt2 = 392749274; + mo1.opt3 = 395739402; + mo1.def4 = 294730928; + mo1.opt5 = 394309218; + mo1.opt6 = 832194723; + mo1.__isset.opt1 = true; + mo1.__isset.opt2 = true; + mo1.__isset.opt3 = true; + mo1.__isset.def4 = true; + mo1.__isset.opt5 = true; + mo1.__isset.opt6 = true; + + mo1.write(proto.get()); + mo2.read(proto.get()); + + assert(mo2.__isset.opt1 == true); + assert(mo2.__isset.opt2 == true); + assert(mo2.__isset.opt3 == true); + assert(mo2.__isset.def4 == true); + assert(mo2.__isset.opt5 == true); + assert(mo2.__isset.opt6 == true); + + assert(mo1 == mo2); + + mo1.__isset.opt1 = false; + mo1.__isset.opt3 = false; + mo1.__isset.opt5 = false; + + mo1.write(proto.get()); + mo3.read(proto.get()); + + assert(mo3.__isset.opt1 == false); + assert(mo3.__isset.opt2 == true); + assert(mo3.__isset.opt3 == false); + assert(mo3.__isset.def4 == true); + assert(mo3.__isset.opt5 == false); + assert(mo3.__isset.opt6 == true); + + assert(mo1 == mo3); + + mo1.__isset.opt1 = true; + mo1.__isset.opt3 = true; + mo1.__isset.opt5 = true; + mo1.__isset.opt2 = false; + mo1.__isset.opt6 = false; + + mo1.write(proto.get()); + mo4.read(proto.get()); + + assert(mo4.__isset.opt1 == true); + assert(mo4.__isset.opt2 == false); + assert(mo4.__isset.opt3 == true); + assert(mo4.__isset.def4 == true); + assert(mo4.__isset.opt5 == true); + assert(mo4.__isset.opt6 == false); + + assert(mo1 == mo4); + + mo1.__isset.opt1 = false; + mo1.__isset.opt5 = false; + + mo1.write(proto.get()); + mo5.read(proto.get()); + + assert(mo5.__isset.opt1 == false); + assert(mo5.__isset.opt2 == false); + assert(mo5.__isset.opt3 == true); + assert(mo5.__isset.def4 == true); + assert(mo5.__isset.opt5 == false); + assert(mo5.__isset.opt6 == false); + + assert(mo1 == mo5); + + mo1.__isset.opt3 = false; + + mo1.write(proto.get()); + mo6.read(proto.get()); + + assert(mo6.__isset.opt1 == false); + assert(mo6.__isset.opt2 == false); + assert(mo6.__isset.opt3 == false); + assert(mo6.__isset.def4 == true); + assert(mo6.__isset.opt5 == false); + assert(mo6.__isset.opt6 == false); + + assert(mo1 == mo6); + + + // Test fingerprint checking stuff. + + { + // Default and required have the same fingerprint. + Tricky1 t1; + Tricky3 t3; + assert(string(Tricky1::ascii_fingerprint) == Tricky3::ascii_fingerprint); + proto->setTypeSpec(Tricky1::local_reflection); + t1.im_default = 227; + t1.write(proto.get()); + proto->setTypeSpec(Tricky3::local_reflection); + t3.read(proto.get()); + assert(t3.im_required == 227); + } + + { + // Optional changes things. + Tricky1 t1; + Tricky2 t2; + assert(string(Tricky1::ascii_fingerprint) != Tricky2::ascii_fingerprint); + proto->setTypeSpec(Tricky1::local_reflection); + t1.im_default = 227; + t1.write(proto.get()); + try { + proto->setTypeSpec(Tricky2::local_reflection); + t2.read(proto.get()); + assert(false); + } catch (TProtocolException& ex) { + buffer->resetBuffer(); + } + } + + { + // Holy cow. We can use the Tricky1 typespec with the Tricky2 structure. + Tricky1 t1; + Tricky2 t2; + proto->setTypeSpec(Tricky1::local_reflection); + t1.im_default = 227; + t1.write(proto.get()); + t2.read(proto.get()); + assert(t2.__isset.im_optional == true); + assert(t2.im_optional == 227); + } + + { + // And totally off the wall. + Tricky1 t1; + OneOfEach ooe2; + assert(string(Tricky1::ascii_fingerprint) != OneOfEach::ascii_fingerprint); + proto->setTypeSpec(Tricky1::local_reflection); + t1.im_default = 227; + t1.write(proto.get()); + try { + proto->setTypeSpec(OneOfEach::local_reflection); + ooe2.read(proto.get()); + assert(false); + } catch (TProtocolException& ex) { + buffer->resetBuffer(); + } + } + + // Okay, this is really off the wall. + // Just don't crash. + cout << "Starting fuzz test. This takes a while. (20 dots.)" << endl; + std::srand(12345); + for (int i = 0; i < 2000; i++) { + if (i % 100 == 0) { + cout << "."; + cout.flush(); + } + buffer->resetBuffer(); + // Make sure the fingerprint prefix is right. + buffer->write(Nesting::binary_fingerprint, 4); + for (int j = 0; j < 1024*1024; j++) { + uint8_t r = std::rand(); + buffer->write(&r, 1); + } + Nesting n; + proto->setTypeSpec(OneOfEach::local_reflection); + try { + n.read(proto.get()); + } catch (TProtocolException& ex) { + } catch (TTransportException& ex) { + } + } + cout << endl; + + return 0; +} diff --git a/test/DocTest.thrift b/test/DocTest.thrift new file mode 100644 index 00000000..cb355ae3 --- /dev/null +++ b/test/DocTest.thrift @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Program doctext. + * + * Seriously, this is the documentation for this whole program. + */ + +namespace java thrift.test +namespace cpp thrift.test + +// C++ comment +/* c style comment */ + +# the new unix comment + +/** Some doc text goes here. Wow I am [nesting these] (no more nesting.) */ +enum Numberz +{ + + /** This is how to document a parameter */ + ONE = 1, + + /** And this is a doc for a parameter that has no specific value assigned */ + TWO, + + THREE, + FIVE = 5, + SIX, + EIGHT = 8 +} + +/** This is how you would do a typedef doc */ +typedef i64 UserId + +/** And this is where you would document a struct */ +struct Xtruct +{ + + /** And the members of a struct */ + 1: string string_thing + + /** doct text goes before a comma */ + 4: byte byte_thing, + + 9: i32 i32_thing, + 11: i64 i64_thing +} + +/** + * You can document constants now too. Yeehaw! + */ +const i32 INT32CONSTANT = 9853 +const i16 INT16CONSTANT = 1616 +/** Everyone get in on the docu-action! */ +const map MAPCONSTANT = {'hello':'world', 'goodnight':'moon'} + +struct Xtruct2 +{ + 1: byte byte_thing, + 2: Xtruct struct_thing, + 3: i32 i32_thing +} + +/** Struct insanity */ +struct Insanity +{ + + /** This is doc for field 1 */ + 1: map userMap, + + /** And this is doc for field 2 */ + 2: list xtructs +} + +exception Xception { + 1: i32 errorCode, + 2: string message +} + +exception Xception2 { + 1: i32 errorCode, + 2: Xtruct struct_thing +} + +/* C1 */ +/** Doc */ +/* C2 */ +/* C3 */ +struct EmptyStruct {} + +struct OneField { + 1: EmptyStruct field +} + +/** This is where you would document a Service */ +service ThriftTest +{ + + /** And this is how you would document functions in a service */ + void testVoid(), + string testString(1: string thing), + byte testByte(1: byte thing), + i32 testI32(1: i32 thing), + + /** Like this one */ + i64 testI64(1: i64 thing), + double testDouble(1: double thing), + Xtruct testStruct(1: Xtruct thing), + Xtruct2 testNest(1: Xtruct2 thing), + map testMap(1: map thing), + set testSet(1: set thing), + list testList(1: list thing), + + /** This is an example of a function with params documented */ + Numberz testEnum( + + /** This param is a thing */ + 1: Numberz thing + + ), + + UserId testTypedef(1: UserId thing), + + map> testMapMap(1: i32 hello), + + /* So you think you've got this all worked, out eh? */ + map> testInsanity(1: Insanity argument), + +} + +/// This style of Doxy-comment doesn't work. +typedef i32 SorryNoGo + +/** + * This is a trivial example of a multiline docstring. + */ +typedef i32 TrivialMultiLine + +/** + * This is the cannonical example + * of a multiline docstring. + */ +typedef i32 StandardMultiLine + +/** + * The last line is non-blank. + * I said non-blank! */ +typedef i32 LastLine + +/** Both the first line + * are non blank. ;-) + * and the last line */ +typedef i32 FirstAndLastLine + +/** + * INDENTED TITLE + * The text is less indented. + */ +typedef i32 IndentedTitle + +/** First line indented. + * Unfortunately, this does not get indented. + */ +typedef i32 FirstLineIndent + + +/** + * void code_in_comment() { + * printf("hooray code!"); + * } + */ +typedef i32 CodeInComment + + /** + * Indented Docstring. + * This whole docstring is indented. + * This line is indented further. + */ +typedef i32 IndentedDocstring + +/** Irregular docstring. + * We will have to punt + * on this thing */ +typedef i32 Irregular1 + +/** + * note the space + * before these lines +* but not this + * one + */ +typedef i32 Irregular2 + +/** +* Flush against +* the left. +*/ +typedef i32 Flush + +/** + No stars in this one. + It should still work fine, though. + Including indenting. + */ +typedef i32 NoStars + +/** Trailing whitespace +Sloppy trailing whitespace +is truncated. */ +typedef i32 TrailingWhitespace + +/** + * This is a big one. + * + * We'll have some blank lines in it. + * + * void as_well_as(some code) { + * puts("YEEHAW!"); + * } + */ +typedef i32 BigDog + +/** +* +* +*/ +typedef i32 TotallyDegenerate + +/* THE END */ diff --git a/test/FastbinaryTest.py b/test/FastbinaryTest.py new file mode 100755 index 00000000..7f6efae6 --- /dev/null +++ b/test/FastbinaryTest.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +r""" +thrift --gen py DebugProtoTest.thrift +./FastbinaryTest.py +""" + +# TODO(dreiss): Test error cases. Check for memory leaks. + +import sys +sys.path.append('./gen-py') + +import math +from DebugProtoTest import Srv +from DebugProtoTest.ttypes import * +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +import timeit +from cStringIO import StringIO +from copy import deepcopy +from pprint import pprint + +class TDevNullTransport(TTransport.TTransportBase): + def __init__(self): + pass + def isOpen(self): + return True + +ooe1 = OneOfEach() +ooe1.im_true = True; +ooe1.im_false = False; +ooe1.a_bite = 0xd6; +ooe1.integer16 = 27000; +ooe1.integer32 = 1<<24; +ooe1.integer64 = 6000 * 1000 * 1000; +ooe1.double_precision = math.pi; +ooe1.some_characters = "Debug THIS!"; +ooe1.zomg_unicode = "\xd7\n\a\t"; + +ooe2 = OneOfEach(); +ooe2.integer16 = 16; +ooe2.integer32 = 32; +ooe2.integer64 = 64; +ooe2.double_precision = (math.sqrt(5)+1)/2; +ooe2.some_characters = ":R (me going \"rrrr\")"; +ooe2.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20"\ + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe"\ + "\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e"\ + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba"\ + "\xc7\x83\xe2\x80\xbc"; + +hm = HolyMoley({"big":[], "contain":set(), "bonks":{}}) +hm.big.append(ooe1) +hm.big.append(ooe2) +hm.big[0].a_bite = 0x22; +hm.big[1].a_bite = 0x22; + +hm.contain.add(("and a one", "and a two")) +hm.contain.add(("then a one, two", "three!", "FOUR!")) +hm.contain.add(()) + +hm.bonks["nothing"] = []; +hm.bonks["something"] = [ + Bonk({"type":1, "message":"Wait."}), + Bonk({"type":2, "message":"What?"}), +] +hm.bonks["poe"] = [ + Bonk({"type":3, "message":"quoth"}), + Bonk({"type":4, "message":"the raven"}), + Bonk({"type":5, "message":"nevermore"}), +] + +rs = RandomStuff() +rs.a = 1 +rs.b = 2 +rs.c = 3 +rs.myintlist = range(20) +rs.maps = {1:Wrapper({"foo":Empty()}),2:Wrapper({"foo":Empty()})} +rs.bigint = 124523452435L +rs.triple = 3.14 + +# make sure this splits two buffers in a buffered protocol +rshuge = RandomStuff() +rshuge.myintlist=range(10000) + +my_zero = Srv.Janky_result({"arg":5}) + +def checkWrite(o): + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + if ORIG != MINE: + print "mine: %s\norig: %s" % (repr(MINE), repr(ORIG)) + +def checkRead(o): + prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer()) + o.write(prot) + + slow_version_binary = prot.trans.getvalue() + + prot = TBinaryProtocol.TBinaryProtocolAccelerated( + TTransport.TMemoryBuffer(slow_version_binary)) + c = o.__class__() + c.read(prot) + if c != o: + print "copy: " + pprint(eval(repr(c))) + print "orig: " + pprint(eval(repr(o))) + + prot = TBinaryProtocol.TBinaryProtocolAccelerated( + TTransport.TBufferedTransport( + TTransport.TMemoryBuffer(slow_version_binary))) + c = o.__class__() + c.read(prot) + if c != o: + print "copy: " + pprint(eval(repr(c))) + print "orig: " + pprint(eval(repr(o))) + + +def doTest(): + checkWrite(hm) + no_set = deepcopy(hm) + no_set.contain = set() + checkRead(no_set) + checkWrite(rs) + checkRead(rs) + checkWrite(rshuge) + checkRead(rshuge) + checkWrite(my_zero) + checkRead(my_zero) + checkRead(Backwards({"first_tag2":4, "second_tag1":2})) + + # One case where the serialized form changes, but only superficially. + o = Backwards({"first_tag2":4, "second_tag1":2}) + trans_fast = TTransport.TMemoryBuffer() + trans_slow = TTransport.TMemoryBuffer() + prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast) + prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow) + + o.write(prot_fast) + o.write(prot_slow) + ORIG = trans_slow.getvalue() + MINE = trans_fast.getvalue() + if ORIG == MINE: + print "That shouldn't happen." + + + prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer()) + o.write(prot) + prot = TBinaryProtocol.TBinaryProtocol( + TTransport.TMemoryBuffer( + prot.trans.getvalue())) + c = o.__class__() + c.read(prot) + if c != o: + print "copy: " + pprint(eval(repr(c))) + print "orig: " + pprint(eval(repr(o))) + + + +def doBenchmark(): + + iters = 25000 + + setup = """ +from __main__ import hm, rs, TDevNullTransport +from thrift.protocol import TBinaryProtocol +trans = TDevNullTransport() +prot = TBinaryProtocol.TBinaryProtocol%s(trans) +""" + + setup_fast = setup % "Accelerated" + setup_slow = setup % "" + + print "Starting Benchmarks" + + print "HolyMoley Standard = %f" % \ + timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters) + print "HolyMoley Acceler. = %f" % \ + timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters) + + print "FastStruct Standard = %f" % \ + timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters) + print "FastStruct Acceler. = %f" % \ + timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters) + + + +doTest() +doBenchmark() + diff --git a/test/GenericHelpers.h b/test/GenericHelpers.h new file mode 100644 index 00000000..d661d8ba --- /dev/null +++ b/test/GenericHelpers.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef _THRIFT_TEST_GENERICHELPERS_H_ +#define _THRIFT_TEST_GENERICHELPERS_H_ 1 + +#include +#include +#include + +using boost::shared_ptr; +using namespace apache::thrift::protocol; + +/* ClassName Helper for cleaner exceptions */ +class ClassNames { + public: + template + static const char* getName() { return "Unknown type"; } +}; + +template <> const char* ClassNames::getName() { return "byte"; } +template <> const char* ClassNames::getName() { return "short"; } +template <> const char* ClassNames::getName() { return "int"; } +template <> const char* ClassNames::getName() { return "long"; } +template <> const char* ClassNames::getName() { return "double"; } +template <> const char* ClassNames::getName() { return "string"; } + +/* Generic Protocol I/O function for tests */ +class GenericIO { + public: + + /* Write functions */ + + static uint32_t write(shared_ptr proto, const int8_t& val) { + return proto->writeByte(val); + } + + static uint32_t write(shared_ptr proto, const int16_t& val) { + return proto->writeI16(val); + } + + static uint32_t write(shared_ptr proto, const int32_t& val) { + return proto->writeI32(val); + } + + static uint32_t write(shared_ptr proto, const double& val) { + return proto->writeDouble(val); + } + + static uint32_t write(shared_ptr proto, const int64_t& val) { + return proto->writeI64(val); + } + + static uint32_t write(shared_ptr proto, const std::string& val) { + return proto->writeString(val); + } + + /* Read functions */ + + static uint32_t read(shared_ptr proto, int8_t& val) { + return proto->readByte(val); + } + + static uint32_t read(shared_ptr proto, int16_t& val) { + return proto->readI16(val); + } + + static uint32_t read(shared_ptr proto, int32_t& val) { + return proto->readI32(val); + } + + static uint32_t read(shared_ptr proto, int64_t& val) { + return proto->readI64(val); + } + + static uint32_t read(shared_ptr proto, double& val) { + return proto->readDouble(val); + } + + static uint32_t read(shared_ptr proto, std::string& val) { + return proto->readString(val); + } + +}; + +#endif diff --git a/test/JSONProtoTest.cpp b/test/JSONProtoTest.cpp new file mode 100644 index 00000000..66813569 --- /dev/null +++ b/test/JSONProtoTest.cpp @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include "gen-cpp/DebugProtoTest_types.h" + +int main() { + using std::cout; + using std::endl; + using namespace thrift::test::debug; + using apache::thrift::transport::TMemoryBuffer; + using apache::thrift::protocol::TJSONProtocol; + + OneOfEach ooe; + ooe.im_true = true; + ooe.im_false = false; + ooe.a_bite = 0xd6; + ooe.integer16 = 27000; + ooe.integer32 = 1<<24; + ooe.integer64 = (uint64_t)6000 * 1000 * 1000; + ooe.double_precision = M_PI; + ooe.some_characters = "JSON THIS! \"\1"; + ooe.zomg_unicode = "\xd7\n\a\t"; + ooe.base64 = "\1\2\3\255"; + cout << apache::thrift::ThriftJSONString(ooe) << endl << endl; + + + Nesting n; + n.my_ooe = ooe; + n.my_ooe.integer16 = 16; + n.my_ooe.integer32 = 32; + n.my_ooe.integer64 = 64; + n.my_ooe.double_precision = (std::sqrt(5)+1)/2; + n.my_ooe.some_characters = ":R (me going \"rrrr\")"; + n.my_ooe.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20" + "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e" + "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba\xc7\x83\xe2\x80\xbc"; + n.my_bonk.type = 31337; + n.my_bonk.message = "I am a bonk... xor!"; + + cout << apache::thrift::ThriftJSONString(n) << endl << endl; + + + HolyMoley hm; + + hm.big.push_back(ooe); + hm.big.push_back(n.my_ooe); + hm.big[0].a_bite = 0x22; + hm.big[1].a_bite = 0x33; + + std::vector stage1; + stage1.push_back("and a one"); + stage1.push_back("and a two"); + hm.contain.insert(stage1); + stage1.clear(); + stage1.push_back("then a one, two"); + stage1.push_back("three!"); + stage1.push_back("FOUR!!"); + hm.contain.insert(stage1); + stage1.clear(); + hm.contain.insert(stage1); + + std::vector stage2; + hm.bonks["nothing"] = stage2; + stage2.resize(stage2.size()+1); + stage2.back().type = 1; + stage2.back().message = "Wait."; + stage2.resize(stage2.size()+1); + stage2.back().type = 2; + stage2.back().message = "What?"; + hm.bonks["something"] = stage2; + stage2.clear(); + stage2.resize(stage2.size()+1); + stage2.back().type = 3; + stage2.back().message = "quoth"; + stage2.resize(stage2.size()+1); + stage2.back().type = 4; + stage2.back().message = "the raven"; + stage2.resize(stage2.size()+1); + stage2.back().type = 5; + stage2.back().message = "nevermore"; + hm.bonks["poe"] = stage2; + + cout << apache::thrift::ThriftJSONString(hm) << endl << endl; + + boost::shared_ptr buffer(new TMemoryBuffer()); + boost::shared_ptr proto(new TJSONProtocol(buffer)); + + + cout << "Testing ooe" << endl; + + ooe.write(proto.get()); + OneOfEach ooe2; + ooe2.read(proto.get()); + + assert(ooe == ooe2); + + + cout << "Testing hm" << endl; + + hm.write(proto.get()); + HolyMoley hm2; + hm2.read(proto.get()); + + assert(hm == hm2); + + hm2.big[0].a_bite = 0xFF; + + assert(hm != hm2); + + Doubles dub; + dub.nan = HUGE_VAL/HUGE_VAL; + dub.inf = HUGE_VAL; + dub.neginf = -HUGE_VAL; + dub.repeating = 10.0/3.0; + dub.big = 1E+305; + dub.small = 1E-305; + dub.zero = 0.0; + dub.negzero = -0.0; + cout << apache::thrift::ThriftJSONString(dub) << endl << endl; + + cout << "Testing base" << endl; + + Base64 base; + base.a = 123; + base.b1 = "1"; + base.b2 = "12"; + base.b3 = "123"; + base.b4 = "1234"; + base.b5 = "12345"; + base.b6 = "123456"; + + base.write(proto.get()); + Base64 base2; + base2.read(proto.get()); + + assert(base == base2); + + return 0; +} diff --git a/test/JavaBeansTest.thrift b/test/JavaBeansTest.thrift new file mode 100644 index 00000000..02bf98d6 --- /dev/null +++ b/test/JavaBeansTest.thrift @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace java thrift.test + +struct OneOfEachBeans { + 1: bool boolean_field, + 2: byte a_bite, + 3: i16 integer16, + 4: i32 integer32, + 5: i64 integer64, + 6: double double_precision, + 7: string some_characters, + 8: binary base64, + 9: list byte_list, + 10: list i16_list, + 11: list i64_list +} diff --git a/test/Makefile.am b/test/Makefile.am new file mode 100644 index 00000000..1226935d --- /dev/null +++ b/test/Makefile.am @@ -0,0 +1,178 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SUBDIRS = + +if WITH_PYTHON +SUBDIRS += py +endif + +if WITH_RUBY +SUBDIRS += rb +endif + +noinst_LTLIBRARIES = libtestgencpp.la +libtestgencpp_la_SOURCES = \ + gen-cpp/DebugProtoTest_types.cpp \ + gen-cpp/OptionalRequiredTest_types.cpp \ + gen-cpp/DebugProtoTest_types.cpp \ + gen-cpp/ThriftTest_types.cpp \ + gen-cpp/DebugProtoTest_types.h \ + gen-cpp/OptionalRequiredTest_types.h \ + gen-cpp/ThriftTest_types.h \ + ThriftTest_extras.cpp \ + DebugProtoTest_extras.cpp + +ThriftTest_extras.o: gen-cpp/ThriftTest_types.h +DebugProtoTest_extras.o: gen-cpp/DebugProtoTest_types.h + +libtestgencpp_la_LIBADD = $(top_builddir)/lib/cpp/libthrift.la + +noinst_PROGRAMS = Benchmark + +Benchmark_SOURCES = \ + Benchmark.cpp + +Benchmark_LDADD = libtestgencpp.la + +check_PROGRAMS = \ + TFDTransportTest \ + TPipedTransportTest \ + DebugProtoTest \ + JSONProtoTest \ + OptionalRequiredTest \ + AllProtocolsTest \ + UnitTests + +TESTS = \ + $(check_PROGRAMS) + +UnitTests_SOURCES = \ + UnitTestMain.cpp \ + TMemoryBufferTest.cpp \ + TBufferBaseTest.cpp + +UnitTests_LDADD = libtestgencpp.la + +# +# TFDTransportTest +# +TFDTransportTest_SOURCES = \ + TFDTransportTest.cpp + +TFDTransportTest_LDADD = \ + $(top_builddir)/lib/cpp/libthrift.la + + +# +# TPipedTransportTest +# +TPipedTransportTest_SOURCES = \ + TPipedTransportTest.cpp + +TPipedTransportTest_LDADD = \ + $(top_builddir)/lib/cpp/libthrift.la + +# +# AllProtocolsTest +# +AllProtocolsTest_SOURCES = \ + AllProtocolTests.cpp \ + AllProtocolTests.tcc \ + GenericHelpers.h + +AllProtocolsTest_LDADD = libtestgencpp.la + +# +# DebugProtoTest +# +DebugProtoTest_SOURCES = \ + DebugProtoTest.cpp + +DebugProtoTest_LDADD = libtestgencpp.la + + +# +# JSONProtoTest +# +JSONProtoTest_SOURCES = \ + JSONProtoTest.cpp + +JSONProtoTest_LDADD = libtestgencpp.la + +# +# OptionalRequiredTest +# +OptionalRequiredTest_SOURCES = \ + OptionalRequiredTest.cpp + +OptionalRequiredTest_LDADD = libtestgencpp.la + + +# +# Common thrift code generation rules +# +THRIFT = $(top_builddir)/compiler/cpp/thrift + +gen-cpp/DebugProtoTest_types.cpp gen-cpp/DebugProtoTest_types.h: DebugProtoTest.thrift + $(THRIFT) --gen cpp:dense $< + +gen-cpp/OptionalRequiredTest_types.cpp gen-cpp/OptionalRequiredTest_types.h: OptionalRequiredTest.thrift + $(THRIFT) --gen cpp:dense $< + +gen-cpp/Service.cpp gen-cpp/StressTest_types.cpp: StressTest.thrift + $(THRIFT) --gen cpp:dense $< + +gen-cpp/SecondService.cpp gen-cpp/ThriftTest_constants.cpp gen-cpp/ThriftTest.cpp gen-cpp/ThriftTest_types.cpp gen-cpp/ThriftTest_types.h: ThriftTest.thrift + $(THRIFT) --gen cpp:dense $< + +INCLUDES = \ + -I$(top_srcdir)/lib/cpp/src + +AM_CPPFLAGS = $(BOOST_CPPFLAGS) + +clean-local: + $(RM) -r gen-cpp + +EXTRA_DIST = \ + cpp \ + threads \ + csharp \ + py \ + rb \ + perl \ + php \ + erl \ + hs \ + ocaml \ + AnnotationTest.thrift \ + BrokenConstants.thrift \ + ConstantsDemo.thrift \ + DebugProtoTest.thrift \ + DenseLinkingTest.thrift \ + DocTest.thrift \ + JavaBeansTest.thrift \ + ManyTypedefs.thrift \ + OptionalRequiredTest.thrift \ + SmallTest.thrift \ + StressTest.thrift \ + ThriftTest.thrift \ + ZlibTest.cpp \ + DenseProtoTest.cpp \ + FastbinaryTest.py diff --git a/test/ManyTypedefs.thrift b/test/ManyTypedefs.thrift new file mode 100644 index 00000000..d194b63c --- /dev/null +++ b/test/ManyTypedefs.thrift @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// This is to make sure you don't mess something up when you change typedef code. +// Generate it with the old and new thrift and make sure they are the same. +/* +rm -rf gen-* orig-* +mkdir old new +thrift --gen cpp --gen java --gen php --gen phpi --gen py --gen rb --gen xsd --gen perl --gen ocaml --gen erl --gen hs --strict ManyTypedefs.thrift +mv gen-* old +../compiler/cpp/thrift --gen cpp --gen java --gen php --gen phpi --gen py --gen rb --gen xsd --gen perl --gen ocaml --gen erl --gen hs --strict ManyTypedefs.thrift +mv gen-* new +diff -ur old new +rm -rf old new +# There should be no output. +*/ + +typedef i32 int32 +typedef list> biglist + +struct struct1 { + 1: int32 myint; + 2: biglist mylist; +} + +exception exception1 { + 1: biglist alist; + 2: struct1 mystruct; +} + +service AService { + struct1 method1(1: int32 myint) throws (1: exception1 exn); + biglist method2(); +} diff --git a/test/OptionalRequiredTest.cpp b/test/OptionalRequiredTest.cpp new file mode 100644 index 00000000..5743ce30 --- /dev/null +++ b/test/OptionalRequiredTest.cpp @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "gen-cpp/OptionalRequiredTest_types.h" + +using std::cout; +using std::endl; +using std::map; +using std::string; +using namespace thrift::test; +using namespace apache::thrift; +using namespace apache::thrift::transport; +using namespace apache::thrift::protocol; + + +/* +template +void trywrite(const Struct& s, bool should_work) { + bool worked; + try { + TBinaryProtocol protocol(boost::shared_ptr(new TMemoryBuffer)); + s.write(&protocol); + worked = true; + } catch (TProtocolException & ex) { + worked = false; + } + assert(worked == should_work); +} +*/ + +template +void write_to_read(const Struct1 & w, Struct2 & r) { + TBinaryProtocol protocol(boost::shared_ptr(new TMemoryBuffer)); + w.write(&protocol); + r.read(&protocol); +} + + +int main() { + + cout << "This old school struct should have three fields." << endl; + { + OldSchool o; + cout << ThriftDebugString(o) << endl; + } + cout << endl; + + cout << "Setting a value before setting isset." << endl; + { + Simple s; + cout << ThriftDebugString(s) << endl; + s.im_optional = 10; + cout << ThriftDebugString(s) << endl; + s.__isset.im_optional = true; + cout << ThriftDebugString(s) << endl; + } + cout << endl; + + cout << "Setting isset before setting a value." << endl; + { + Simple s; + cout << ThriftDebugString(s) << endl; + s.__isset.im_optional = true; + cout << ThriftDebugString(s) << endl; + s.im_optional = 10; + cout << ThriftDebugString(s) << endl; + } + cout << endl; + + // Write-to-read with optional fields. + { + Simple s1, s2, s3; + s1.im_optional = 10; + assert(!s1.__isset.im_default); + //assert(!s1.__isset.im_required); // Compile error. + assert(!s1.__isset.im_optional); + + write_to_read(s1, s2); + + assert( s2.__isset.im_default); + //assert( s2.__isset.im_required); // Compile error. + assert(!s2.__isset.im_optional); + assert(s3.im_optional == 0); + + s1.__isset.im_optional = true; + write_to_read(s1, s3); + + assert( s3.__isset.im_default); + //assert( s3.__isset.im_required); // Compile error. + assert( s3.__isset.im_optional); + assert(s3.im_optional == 10); + } + + // Writing between optional and default. + { + Tricky1 t1; + Tricky2 t2; + + t2.im_optional = 10; + write_to_read(t2, t1); + write_to_read(t1, t2); + assert(!t1.__isset.im_default); + assert( t2.__isset.im_optional); + assert(t1.im_default == t2.im_optional); + assert(t1.im_default == 0); + } + + // Writing between default and required. + { + Tricky1 t1; + Tricky3 t3; + write_to_read(t1, t3); + write_to_read(t3, t1); + assert(t1.__isset.im_default); + } + + // Writing between optional and required. + { + Tricky2 t2; + Tricky3 t3; + t2.__isset.im_optional = true; + write_to_read(t2, t3); + write_to_read(t3, t2); + } + + // Mu-hu-ha-ha-ha! + { + Tricky2 t2; + Tricky3 t3; + try { + write_to_read(t2, t3); + abort(); + } + catch (TProtocolException& ex) {} + + write_to_read(t3, t2); + assert(t2.__isset.im_optional); + } + + cout << "Complex struct, simple test." << endl; + { + Complex c; + cout << ThriftDebugString(c) << endl; + } + + + { + Tricky1 t1; + Tricky2 t2; + // Compile error. + //(void)(t1 == t2); + } + + { + OldSchool o1, o2, o3; + assert(o1 == o2); + o1.im_int = o2.im_int = 10; + assert(o1 == o2); + o1.__isset.im_int = true; + o2.__isset.im_int = false; + assert(o1 == o2); + o1.im_int = 20; + o1.__isset.im_int = false; + assert(o1 != o2); + o1.im_int = 10; + assert(o1 == o2); + o1.im_str = o2.im_str = "foo"; + assert(o1 == o2); + o1.__isset.im_str = o2.__isset.im_str = true; + assert(o1 == o2); + map mymap; + mymap[1] = "bar"; + mymap[2] = "baz"; + o1.im_big.push_back(map()); + assert(o1 != o2); + o2.im_big.push_back(map()); + assert(o1 == o2); + o2.im_big.push_back(mymap); + assert(o1 != o2); + o1.im_big.push_back(mymap); + assert(o1 == o2); + + TBinaryProtocol protocol(boost::shared_ptr(new TMemoryBuffer)); + o1.write(&protocol); + + o1.im_big.push_back(mymap); + mymap[3] = "qux"; + o2.im_big.push_back(mymap); + assert(o1 != o2); + o1.im_big.back()[3] = "qux"; + assert(o1 == o2); + + o3.read(&protocol); + o3.im_big.push_back(mymap); + assert(o1 == o3); + + //cout << ThriftDebugString(o3) << endl; + } + + { + Tricky2 t1, t2; + assert(t1.__isset.im_optional == false); + assert(t2.__isset.im_optional == false); + assert(t1 == t2); + t1.im_optional = 5; + assert(t1 == t2); + t2.im_optional = 5; + assert(t1 == t2); + t1.__isset.im_optional = true; + assert(t1 != t2); + t2.__isset.im_optional = true; + assert(t1 == t2); + t1.im_optional = 10; + assert(t1 != t2); + t2.__isset.im_optional = false; + assert(t1 != t2); + } + + return 0; +} diff --git a/test/OptionalRequiredTest.thrift b/test/OptionalRequiredTest.thrift new file mode 100644 index 00000000..f7d1fd6a --- /dev/null +++ b/test/OptionalRequiredTest.thrift @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace cpp thrift.test +namespace java thrift.test + +struct OldSchool { + 1: i16 im_int; + 2: string im_str; + 3: list> im_big; +} + +struct Simple { + 1: /* :) */ i16 im_default; + 2: required i16 im_required; + 3: optional i16 im_optional; +} + +struct Tricky1 { + 1: /* :) */ i16 im_default; +} + +struct Tricky2 { + 1: optional i16 im_optional; +} + +struct Tricky3 { + 1: required i16 im_required; +} + +struct Complex { + 1: i16 cp_default; + 2: required i16 cp_required; + 3: optional i16 cp_optional; + 4: map the_map; + 5: required Simple req_simp; + 6: optional Simple opt_simp; +} + +struct ManyOpt { + 1: optional i32 opt1; + 2: optional i32 opt2; + 3: optional i32 opt3; + 4: i32 def4; + 5: optional i32 opt5; + 6: optional i32 opt6; +} + +struct JavaTestHelper { + 1: required i32 req_int; + 2: optional i32 opt_int; + 3: required string req_obj; + 4: optional string opt_obj; + 5: required binary req_bin; + 6: optional binary opt_bin; +} diff --git a/test/SmallTest.thrift b/test/SmallTest.thrift new file mode 100644 index 00000000..d0821c7f --- /dev/null +++ b/test/SmallTest.thrift @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +namespace rb TestNamespace + +struct Goodbyez { + 1: i32 val = 325; +} + +senum Thinger { + "ASDFKJ", + "r32)*F#@", + "ASDFLJASDF" +} + +struct BoolPasser { + 1: bool value = 1 +} + +struct Hello { + 1: i32 simple = 53, + 2: map complex = {23:532, 6243:632, 2355:532}, + 3: map> complexer, + 4: string words = "words", + 5: Goodbyez thinz = {'val' : 36632} +} + +const map> CMAP = { 235: {235:235}, 53:{53:53} } +const i32 CINT = 325; +const Hello WHOA = {'simple' : 532} + +exception Goodbye { + 1: i32 simple, + 2: map complex, + 3: map> complexer, +} + +service SmallService { + Thinger testThinger(1:Thinger bootz), + Hello testMe(1:i32 hello=64, 2: Hello wonk) throws (1: Goodbye g), + void testVoid() throws (1: Goodbye g), + i32 testI32(1:i32 boo) +} diff --git a/test/StressTest.thrift b/test/StressTest.thrift new file mode 100644 index 00000000..87c6e476 --- /dev/null +++ b/test/StressTest.thrift @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace cpp test.stress + +service Service { + + void echoVoid(), + byte echoByte(1: byte arg), + i32 echoI32(1: i32 arg), + i64 echoI64(1: i64 arg), + string echoString(1: string arg), + list echoList(1: list arg), + set echoSet(1: set arg), + map echoMap(1: map arg), +} + diff --git a/test/TBufferBaseTest.cpp b/test/TBufferBaseTest.cpp new file mode 100644 index 00000000..da3ce856 --- /dev/null +++ b/test/TBufferBaseTest.cpp @@ -0,0 +1,639 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +using std::string; +using boost::shared_ptr; +using apache::thrift::transport::TMemoryBuffer; +using apache::thrift::transport::TBufferedTransport; +using apache::thrift::transport::TFramedTransport; +using apache::thrift::transport::test::TShortReadTransport; + +#define foreach BOOST_FOREACH + +// Shamelessly copied from ZlibTransport. TODO: refactor. +unsigned int dist[][5000] = { + { 1<<15 }, + + { + 5,13,9,1,8,9,11,13,18,48,24,13,21,13,5,11,35,2,4,20,17,72,27,14,15,4,7,26, + 12,1,14,9,2,16,29,41,7,24,4,27,14,4,1,4,25,3,6,34,10,8,50,2,14,13,55,29,3, + 43,53,49,14,4,10,32,27,48,1,3,1,11,5,17,16,51,17,30,15,11,9,2,2,11,52,12,2, + 13,94,1,19,1,38,2,8,43,8,33,7,30,8,17,22,2,15,14,12,34,2,12,6,37,29,74,3, + 165,16,11,17,5,14,3,10,7,37,11,24,7,1,3,12,37,8,9,34,17,12,8,21,13,37,1,4, + 30,14,78,4,15,2,40,37,17,12,36,82,14,4,1,4,7,17,11,16,88,77,2,3,15,3,34,11, + 5,79,22,34,8,4,4,40,22,24,28,9,13,3,34,27,9,16,39,16,39,13,2,4,3,41,26,10,4, + 33,4,7,12,5,6,3,10,30,8,21,16,58,19,9,0,47,7,13,11,19,15,7,53,57,2,13,28,22, + 3,16,9,25,33,12,40,7,12,64,7,14,24,44,9,2,14,11,2,58,1,26,30,11,9,5,24,7,9, + 94,2,10,21,5,5,4,5,6,179,9,18,2,7,13,31,41,17,4,36,3,21,6,26,8,15,18,44,27, + 11,9,25,7,0,14,2,12,20,23,13,2,163,9,5,15,65,2,14,6,8,98,11,15,14,34,2,3,10, + 22,9,92,7,10,32,67,13,3,4,35,8,2,1,5,0,26,381,7,27,8,2,16,93,4,19,5,8,25,9, + 31,14,4,21,5,3,9,22,56,4,18,3,11,18,6,4,3,40,12,16,110,8,35,14,1,18,40,9,12, + 14,3,11,7,57,13,18,116,53,19,22,7,16,11,5,8,21,16,1,75,21,20,1,28,2,6,1,7, + 19,38,5,6,9,9,4,1,7,55,36,62,5,4,4,24,15,1,12,35,48,20,5,17,1,5,26,15,4,54, + 13,5,5,15,5,19,32,29,31,7,6,40,7,80,11,18,8,128,48,6,12,84,13,4,7,2,13,9,16, + 17,3,254,1,4,181,8,44,7,6,24,27,9,23,14,34,16,22,25,10,3,3,4,4,12,2,12,6,7, + 13,58,13,6,11,19,53,11,66,18,19,10,4,13,2,5,49,58,1,67,7,21,64,14,11,14,8,3, + 26,33,91,31,20,7,9,42,39,4,3,55,11,10,0,7,4,75,8,12,0,27,3,8,9,0,12,12,23, + 28,23,20,4,13,30,2,22,20,19,30,6,22,2,6,4,24,7,19,55,86,5,33,2,161,6,7,1,62, + 13,3,72,12,12,9,7,12,10,5,10,29,1,5,22,13,13,5,2,12,3,7,14,18,2,3,46,21,17, + 15,19,3,27,5,16,45,31,10,8,17,18,18,3,7,24,6,55,9,3,6,12,10,12,8,91,9,4,4,4, + 27,29,16,5,7,22,43,28,11,14,8,11,28,109,55,71,40,3,8,22,26,15,44,3,25,29,5, + 3,32,17,12,3,29,27,25,15,11,8,40,39,38,17,3,9,11,2,32,11,6,20,48,75,27,3,7, + 54,12,95,12,7,24,23,2,13,8,15,16,5,12,4,17,7,19,88,2,6,13,115,45,12,21,2,86, + 74,9,7,5,16,32,16,2,21,18,6,34,5,18,260,7,12,16,44,19,92,31,7,8,2,9,0,0,15, + 8,38,4,8,20,18,2,83,3,3,4,9,5,3,10,3,5,29,15,7,11,8,48,17,23,2,17,4,11,22, + 21,64,8,8,4,19,95,0,17,28,9,11,20,71,5,11,18,12,13,45,49,4,1,33,32,23,13,5, + 52,2,2,16,3,4,7,12,2,1,12,6,24,1,22,155,21,3,45,4,12,44,26,5,40,36,9,9,8,20, + 35,31,3,2,32,50,10,8,37,2,75,35,22,15,192,8,11,23,1,4,29,6,8,8,5,12,18,32,4, + 7,12,2,0,0,9,5,48,11,35,3,1,123,6,29,8,11,8,23,51,16,6,63,12,2,5,4,14,2,15, + 7,14,3,2,7,17,32,8,8,10,1,23,62,2,49,6,49,47,23,3,20,7,11,39,10,24,6,15,5,5, + 11,8,16,36,8,13,20,3,10,44,7,52,7,10,36,6,15,10,5,11,4,14,19,17,10,12,3,6, + 23,4,13,94,70,7,36,7,38,7,28,8,4,15,3,19,4,33,39,21,109,4,80,6,40,4,432,4,4, + 7,8,3,31,8,28,37,34,10,2,21,5,22,0,7,36,14,12,6,24,1,21,5,9,2,29,20,54,113, + 13,31,39,27,6,0,27,4,5,2,43,7,8,57,8,62,7,9,12,22,90,30,6,19,7,10,20,6,5,58, + 32,30,41,4,10,25,13,3,8,7,10,2,9,6,151,44,16,12,16,20,8,3,18,11,17,4,10,45, + 15,8,56,38,52,25,40,14,4,17,15,8,2,19,7,8,26,30,2,3,180,8,26,17,38,35,5,16, + 28,5,15,56,13,14,18,9,15,83,27,3,9,4,11,8,27,27,44,10,12,8,3,48,14,7,9,4,4, + 8,4,5,9,122,8,14,12,19,17,21,4,29,63,21,17,10,12,18,47,10,10,53,4,18,16,4,8, + 118,9,5,12,9,11,9,3,12,32,3,23,2,15,3,3,30,3,17,235,15,22,9,299,14,17,1,5, + 16,8,3,7,3,13,2,7,6,4,8,66,2,13,6,15,16,47,3,36,5,7,10,24,1,9,9,8,13,16,26, + 12,7,24,21,18,49,23,39,10,41,4,13,4,27,11,12,12,19,4,147,8,10,9,40,21,2,83, + 10,5,6,11,25,9,50,57,40,12,12,21,1,3,24,23,9,3,9,13,2,3,12,57,8,11,13,15,26, + 15,10,47,36,4,25,1,5,8,5,4,0,12,49,5,19,4,6,16,14,6,10,69,10,33,29,7,8,61, + 12,4,0,3,7,6,3,16,29,27,38,4,21,0,24,3,2,1,19,16,22,2,8,138,11,7,7,3,12,22, + 3,16,5,7,3,53,9,10,32,14,5,7,3,6,22,9,59,26,8,7,58,5,16,11,55,7,4,11,146,91, + 8,13,18,14,6,8,8,31,26,22,6,11,30,11,30,15,18,31,3,48,17,7,6,4,9,2,25,3,35, + 13,13,7,8,4,31,10,8,10,4,3,45,10,23,2,7,259,17,21,13,14,3,26,3,8,27,4,18,9, + 66,7,12,5,8,17,4,23,55,41,51,2,32,26,66,4,21,14,12,65,16,22,17,5,14,2,29,24, + 7,3,36,2,43,53,86,5,28,4,58,13,49,121,6,2,73,2,1,47,4,2,27,10,35,28,27,10, + 17,10,56,7,10,14,28,20,24,40,7,4,7,3,10,11,32,6,6,3,15,11,54,573,2,3,6,2,3, + 14,64,4,16,12,16,42,10,26,4,6,11,69,18,27,2,2,17,22,9,13,22,11,6,1,15,49,3, + 14,1 + }, + + { + 11,11,11,15,47,1,3,1,23,5,8,18,3,23,15,21,1,7,19,10,26,1,17,11,31,21,41,18, + 34,4,9,58,19,3,3,36,5,18,13,3,14,4,9,10,4,19,56,15,3,5,3,11,27,9,4,10,13,4, + 11,6,9,2,18,3,10,19,11,4,53,4,2,2,3,4,58,16,3,0,5,30,2,11,93,10,2,14,10,6,2, + 115,2,25,16,22,38,101,4,18,13,2,145,51,45,15,14,15,13,20,7,24,5,13,14,30,40, + 10,4,107,12,24,14,39,12,6,13,20,7,7,11,5,18,18,45,22,6,39,3,2,1,51,9,11,4, + 13,9,38,44,8,11,9,15,19,9,23,17,17,17,13,9,9,1,10,4,18,6,2,9,5,27,32,72,8, + 37,9,4,10,30,17,20,15,17,66,10,4,73,35,37,6,4,16,117,45,13,4,75,5,24,65,10, + 4,9,4,13,46,5,26,29,10,4,4,52,3,13,18,63,6,14,9,24,277,9,88,2,48,27,123,14, + 61,7,5,10,8,7,90,3,10,3,3,48,17,13,10,18,33,2,19,36,6,21,1,16,12,5,6,2,16, + 15,29,88,28,2,15,6,11,4,6,11,3,3,4,18,9,53,5,4,3,33,8,9,8,6,7,36,9,62,14,2, + 1,10,1,16,7,32,7,23,20,11,10,23,2,1,0,9,16,40,2,81,5,22,8,5,4,37,51,37,10, + 19,57,11,2,92,31,6,39,10,13,16,8,20,6,9,3,10,18,25,23,12,30,6,2,26,7,64,18, + 6,30,12,13,27,7,10,5,3,33,24,99,4,23,4,1,27,7,27,49,8,20,16,3,4,13,9,22,67, + 28,3,10,16,3,2,10,4,8,1,8,19,3,85,6,21,1,9,16,2,30,10,33,12,4,9,3,1,60,38,6, + 24,32,3,14,3,40,8,34,115,5,9,27,5,96,3,40,6,15,5,8,22,112,5,5,25,17,58,2,7, + 36,21,52,1,3,95,12,21,4,11,8,59,24,5,21,4,9,15,8,7,21,3,26,5,11,6,7,17,65, + 14,11,10,2,17,5,12,22,4,4,2,21,8,112,3,34,63,35,2,25,1,2,15,65,23,0,3,5,15, + 26,27,9,5,48,11,15,4,9,5,33,20,15,1,18,19,11,24,40,10,21,74,6,6,32,30,40,5, + 4,7,44,10,25,46,16,12,5,40,7,18,5,18,9,12,8,4,25,5,6,36,4,43,8,9,12,35,17,4, + 8,9,11,27,5,10,17,40,8,12,4,18,9,18,12,20,25,39,42,1,24,13,22,15,7,112,35,3, + 7,17,33,2,5,5,19,8,4,12,24,14,13,2,1,13,6,5,19,11,7,57,0,19,6,117,48,14,8, + 10,51,17,12,14,2,5,8,9,15,4,48,53,13,22,4,25,12,11,19,45,5,2,6,54,22,9,15,9, + 13,2,7,11,29,82,16,46,4,26,14,26,40,22,4,26,6,18,13,4,4,20,3,3,7,12,17,8,9, + 23,6,20,7,25,23,19,5,15,6,23,15,11,19,11,3,17,59,8,18,41,4,54,23,44,75,13, + 20,6,11,2,3,1,13,10,3,7,12,3,4,7,8,30,6,6,7,3,32,9,5,28,6,114,42,13,36,27, + 59,6,93,13,74,8,69,140,3,1,17,48,105,6,11,5,15,1,10,10,14,8,53,0,8,24,60,2, + 6,35,2,12,32,47,16,17,75,2,5,4,37,28,10,5,9,57,4,59,5,12,13,7,90,5,11,5,24, + 22,13,30,1,2,10,9,6,19,3,18,47,2,5,7,9,35,15,3,6,1,21,14,14,18,14,9,12,8,73, + 6,19,3,32,9,14,17,17,5,55,23,6,16,28,3,11,48,4,6,6,6,12,16,30,10,30,27,51, + 18,29,2,3,15,1,76,0,16,33,4,27,3,62,4,10,2,4,8,15,9,41,26,22,2,4,20,4,49,0, + 8,1,57,13,12,39,3,63,10,19,34,35,2,7,8,29,72,4,10,0,77,8,6,7,9,15,21,9,4,1, + 20,23,1,9,18,9,15,36,4,7,6,15,5,7,7,40,2,9,22,2,3,20,4,12,34,13,6,18,15,1, + 38,20,12,7,16,3,19,85,12,16,18,16,2,17,1,13,8,6,12,15,97,17,12,9,3,21,15,12, + 23,44,81,26,30,2,5,17,6,6,0,22,42,19,6,19,41,14,36,7,3,56,7,9,3,2,6,9,69,3, + 15,4,30,28,29,7,9,15,17,17,6,1,6,153,9,33,5,12,14,16,28,3,8,7,14,12,4,6,36, + 9,24,13,13,4,2,9,15,19,9,53,7,13,4,150,17,9,2,6,12,7,3,5,58,19,58,28,8,14,3, + 20,3,0,32,56,7,5,4,27,1,68,4,29,13,5,58,2,9,65,41,27,16,15,12,14,2,10,9,24, + 3,2,9,2,2,3,14,32,10,22,3,13,11,4,6,39,17,0,10,5,5,10,35,16,19,14,1,8,63,19, + 14,8,56,10,2,12,6,12,6,7,16,2,9,9,12,20,73,25,13,21,17,24,5,32,8,12,25,8,14, + 16,5,23,3,7,6,3,11,24,6,30,4,21,13,28,4,6,29,15,5,17,6,26,8,15,8,3,7,7,50, + 11,30,6,2,28,56,16,24,25,23,24,89,31,31,12,7,22,4,10,17,3,3,8,11,13,5,3,27, + 1,12,1,14,8,10,29,2,5,2,2,20,10,0,31,10,21,1,48,3,5,43,4,5,18,13,5,18,25,34, + 18,3,5,22,16,3,4,20,3,9,3,25,6,6,44,21,3,12,7,5,42,3,2,14,4,36,5,3,45,51,15, + 9,11,28,9,7,6,6,12,26,5,14,10,11,42,55,13,21,4,28,6,7,23,27,11,1,41,36,0,32, + 15,26,2,3,23,32,11,2,15,7,29,26,144,33,20,12,7,21,10,7,11,65,46,10,13,20,32, + 4,4,5,19,2,19,15,49,41,1,75,10,11,25,1,2,45,11,8,27,18,10,60,28,29,12,30,19, + 16,4,24,11,19,27,17,49,18,7,40,13,19,22,8,55,12,11,3,6,5,11,8,10,22,5,9,9, + 25,7,17,7,64,1,24,2,12,17,44,4,12,27,21,11,10,7,47,5,9,13,12,38,27,21,7,29, + 7,1,17,3,3,5,48,62,10,3,11,17,15,15,6,3,8,10,8,18,19,13,3,9,7,6,44,9,10,4, + 43,8,6,6,14,20,38,24,2,4,5,5,7,5,9,39,8,44,40,9,19,7,3,15,25,2,37,18,15,9,5, + 8,32,10,5,18,4,7,46,20,17,23,4,11,16,18,31,11,3,11,1,14,1,25,4,27,13,13,39, + 14,6,6,35,6,16,13,11,122,21,15,20,24,10,5,152,15,39,5,20,16,9,14,7,53,6,3,8, + 19,63,32,6,2,3,20,1,19,5,13,42,15,4,6,68,31,46,11,38,10,24,5,5,8,9,12,3,35, + 46,26,16,2,8,4,74,16,44,4,5,1,16,4,14,23,16,69,15,42,31,14,7,7,6,97,14,40,1, + 8,7,34,9,39,19,13,15,10,21,18,10,5,15,38,7,5,12,7,20,15,4,11,6,14,5,17,7,39, + 35,36,18,20,26,22,4,2,36,21,64,0,5,9,10,6,4,1,7,3,1,3,3,4,10,20,90,2,22,48, + 16,23,2,33,40,1,21,21,17,20,8,8,12,4,83,14,48,4,21,3,9,27,5,11,40,15,9,3,16, + 17,9,11,4,24,31,17,3,4,2,11,1,8,4,8,6,41,17,4,13,3,7,17,8,27,5,13,6,10,7,13, + 12,18,13,60,18,3,8,1,12,125,2,7,16,2,11,2,4,7,26,5,9,14,14,16,8,14,7,14,6,9, + 13,9,6,4,26,35,49,36,55,3,9,6,40,26,23,31,19,41,2,10,31,6,54,5,69,16,7,8,16, + 1,5,7,4,22,7,7,5,4,48,11,13,3,98,4,11,19,4,2,14,7,34,7,10,3,2,12,7,6,2,5,118 + }, +}; + +uint8_t data[1<<15]; +string data_str; +void init_data() { + static bool initted = false; + if (initted) return; + initted = true; + + // Repeatability. Kind of. + std::srand(42); + for (int i = 0; i < (int)(sizeof(data)/sizeof(data[0])); ++i) { + data[i] = (uint8_t)rand(); + } + + data_str.assign((char*)data, sizeof(data)); +} + + +BOOST_AUTO_TEST_SUITE( TBufferBaseTest ) + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_GetBuffer ) { + init_data(); + + for (int d1 = 0; d1 < 3; d1++) { + TMemoryBuffer buffer(16); + int offset = 0; + int index = 0; + + while (offset < 1<<15) { + buffer.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + string output = buffer.getBufferAsString(); + BOOST_CHECK_EQUAL(data_str, output); + } +} + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_Read ) { + init_data(); + + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + TMemoryBuffer buffer(16); + uint8_t data_out[1<<15]; + int offset; + int index; + + offset = 0; + index = 0; + while (offset < 1<<15) { + buffer.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + offset = 0; + index = 0; + while (offset < 1<<15) { + unsigned int got = buffer.read(&data_out[offset], dist[d2][index]); + BOOST_CHECK_EQUAL(got, dist[d2][index]); + offset += dist[d2][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, sizeof(data))); + } + } +} + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_ReadString ) { + init_data(); + + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + TMemoryBuffer buffer(16); + string output; + int offset; + int index; + + offset = 0; + index = 0; + while (offset < 1<<15) { + buffer.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + offset = 0; + index = 0; + while (offset < 1<<15) { + unsigned int got = buffer.readAppendToString(output, dist[d2][index]); + BOOST_CHECK_EQUAL(got, dist[d2][index]); + offset += dist[d2][index]; + index++; + } + + BOOST_CHECK_EQUAL(output, data_str); + } + } +} + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_Read_Multi1 ) { + init_data(); + + // Do shorter writes and reads so we don't align to power-of-two boundaries. + + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + TMemoryBuffer buffer(16); + uint8_t data_out[1<<15]; + int offset; + int index; + + for (int iter = 0; iter < 6; iter++) { + offset = 0; + index = 0; + while (offset < (1<<15)-42) { + buffer.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + offset = 0; + index = 0; + while (offset < (1<<15)-42) { + buffer.read(&data_out[offset], dist[d2][index]); + offset += dist[d2][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, (1<<15)-42)); + + // Pull out the extra data. + buffer.read(data_out, 42); + } + } + } +} + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_Read_Multi2 ) { + init_data(); + + // Do shorter writes and reads so we don't align to power-of-two boundaries. + // Pull the buffer out of the loop so its state gets worked harder. + TMemoryBuffer buffer(16); + + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + uint8_t data_out[1<<15]; + int offset; + int index; + + for (int iter = 0; iter < 6; iter++) { + offset = 0; + index = 0; + while (offset < (1<<15)-42) { + buffer.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + offset = 0; + index = 0; + while (offset < (1<<15)-42) { + buffer.read(&data_out[offset], dist[d2][index]); + offset += dist[d2][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, (1<<15)-42)); + + // Pull out the extra data. + buffer.read(data_out, 42); + } + } + } +} + +BOOST_AUTO_TEST_CASE( test_MemoryBuffer_Write_Read_Incomplete ) { + init_data(); + + // Do shorter writes and reads so we don't align to power-of-two boundaries. + // Pull the buffer out of the loop so its state gets worked harder. + + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + TMemoryBuffer buffer(16); + uint8_t data_out[1<<13]; + + int write_offset = 0; + int write_index = 0; + unsigned int to_write = (1<<14)-42; + while (to_write > 0) { + int write_amt = std::min(dist[d1][write_index], to_write); + buffer.write(&data[write_offset], write_amt); + write_offset += write_amt; + write_index++; + to_write -= write_amt; + } + + int read_offset = 0; + int read_index = 0; + unsigned int to_read = (1<<13)-42; + while (to_read > 0) { + int read_amt = std::min(dist[d2][read_index], to_read); + int got = buffer.read(&data_out[read_offset], read_amt); + BOOST_CHECK_EQUAL(got, read_amt); + read_offset += read_amt; + read_index++; + to_read -= read_amt; + } + + BOOST_CHECK(!memcmp(data, data_out, (1<<13)-42)); + + int second_offset = write_offset; + int second_index = write_index-1; + unsigned int to_second = (1<<14)+42; + while (to_second > 0) { + int second_amt = std::min(dist[d1][second_index], to_second); + //printf("%d\n", second_amt); + buffer.write(&data[second_offset], second_amt); + second_offset += second_amt; + second_index++; + to_second -= second_amt; + } + + string output = buffer.getBufferAsString(); + BOOST_CHECK_EQUAL(data_str.substr((1<<13)-42), output); + } + } +} + +BOOST_AUTO_TEST_CASE( test_BufferedTransport_Write ) { + init_data(); + + int sizes[] = { + 12, 15, 16, 17, 20, + 501, 512, 523, + 2000, 2048, 2096, + 1<<14, 1<<17, + }; + + foreach (int size, sizes) { + for (int d1 = 0; d1 < 3; d1++) { + shared_ptr buffer(new TMemoryBuffer(16)); + TBufferedTransport trans(buffer, size); + + int offset = 0; + int index = 0; + while (offset < 1<<15) { + trans.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + trans.flush(); + + string output = buffer->getBufferAsString(); + BOOST_CHECK_EQUAL(data_str, output); + } + } +} + +BOOST_AUTO_TEST_CASE( test_BufferedTransport_Read_Full ) { + init_data(); + + int sizes[] = { + 12, 15, 16, 17, 20, + 501, 512, 523, + 2000, 2048, 2096, + 1<<14, 1<<17, + }; + + foreach (int size, sizes) { + for (int d1 = 0; d1 < 3; d1++) { + shared_ptr buffer(new TMemoryBuffer(data, sizeof(data))); + TBufferedTransport trans(buffer, size); + uint8_t data_out[1<<15]; + + int offset = 0; + int index = 0; + while (offset < 1<<15) { + // Note: this doesn't work with "read" because TBufferedTransport + // doesn't try loop over reads, so we get short reads. We don't + // check the return value, so that messes us up. + trans.readAll(&data_out[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, sizeof(data))); + } + } +} + +BOOST_AUTO_TEST_CASE( test_BufferedTransport_Read_Short ) { + init_data(); + + int sizes[] = { + 12, 15, 16, 17, 20, + 501, 512, 523, + 2000, 2048, 2096, + 1<<14, 1<<17, + }; + + foreach (int size, sizes) { + for (int d1 = 0; d1 < 3; d1++) { + shared_ptr buffer(new TMemoryBuffer(data, sizeof(data))); + shared_ptr tshort(new TShortReadTransport(buffer, 0.125)); + TBufferedTransport trans(buffer, size); + uint8_t data_out[1<<15]; + + int offset = 0; + int index = 0; + while (offset < 1<<15) { + // Note: this doesn't work with "read" because TBufferedTransport + // doesn't try loop over reads, so we get short reads. We don't + // check the return value, so that messes us up. + trans.readAll(&data_out[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, sizeof(data))); + } + } +} + +BOOST_AUTO_TEST_CASE( test_FramedTransport_Write ) { + init_data(); + + int sizes[] = { + 12, 15, 16, 17, 20, + 501, 512, 523, + 2000, 2048, 2096, + 1<<14, 1<<17, + }; + + foreach (int size, sizes) { + for (int d1 = 0; d1 < 3; d1++) { + shared_ptr buffer(new TMemoryBuffer(16)); + TFramedTransport trans(buffer, size); + + int offset = 0; + int index = 0; + while (offset < 1<<15) { + trans.write(&data[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + trans.flush(); + + int32_t frame_size = -1; + buffer->read(reinterpret_cast(&frame_size), sizeof(frame_size)); + frame_size = (int32_t)ntohl((uint32_t)frame_size); + BOOST_CHECK_EQUAL(frame_size, 1<<15); + BOOST_CHECK_EQUAL(data_str.size(), (unsigned int)frame_size); + string output = buffer->getBufferAsString(); + BOOST_CHECK_EQUAL(data_str, output); + } + } +} + +BOOST_AUTO_TEST_CASE( test_FramedTransport_Read ) { + init_data(); + + for (int d1 = 0; d1 < 3; d1++) { + uint8_t data_out[1<<15]; + shared_ptr buffer(new TMemoryBuffer()); + TFramedTransport trans(buffer); + int32_t length = sizeof(data); + length = (int32_t)htonl((uint32_t)length); + buffer->write(reinterpret_cast(&length), sizeof(length)); + buffer->write(data, sizeof(data)); + + int offset = 0; + int index = 0; + while (offset < 1<<15) { + // This should work with read because we have one huge frame. + trans.read(&data_out[offset], dist[d1][index]); + offset += dist[d1][index]; + index++; + } + + BOOST_CHECK(!memcmp(data, data_out, sizeof(data))); + } +} + +BOOST_AUTO_TEST_CASE( test_FramedTransport_Write_Read ) { + init_data(); + + int sizes[] = { + 12, 15, 16, 17, 20, + 501, 512, 523, + 2000, 2048, 2096, + 1<<14, 1<<17, + }; + + int probs[] = { 1, 2, 4, 8, 16, 32, }; + + foreach (int size, sizes) { + foreach (int prob, probs) { + for (int d1 = 0; d1 < 3; d1++) { + shared_ptr buffer(new TMemoryBuffer(16)); + TFramedTransport trans(buffer, size); + uint8_t data_out[1<<15]; + std::vector flush_sizes; + + int write_offset = 0; + int write_index = 0; + int flush_size = 0; + while (write_offset < 1<<15) { + trans.write(&data[write_offset], dist[d1][write_index]); + write_offset += dist[d1][write_index]; + flush_size += dist[d1][write_index]; + write_index++; + if (flush_size > 0 && rand()%prob == 0) { + flush_sizes.push_back(flush_size); + flush_size = 0; + trans.flush(); + } + } + if (flush_size != 0) { + flush_sizes.push_back(flush_size); + flush_size = 0; + trans.flush(); + } + + int read_offset = 0; + int read_index = 0; + foreach (int fsize, flush_sizes) { + // We are exploiting an implementation detail of TFramedTransport. + // The read buffer starts empty and it will never do more than one + // readFrame per read, so we should always get exactly one frame. + int got = trans.read(&data_out[read_offset], 1<<15); + BOOST_CHECK_EQUAL(got, fsize); + read_offset += got; + read_index++; + } + + BOOST_CHECK_EQUAL((unsigned int)read_offset, sizeof(data)); + BOOST_CHECK(!memcmp(data, data_out, sizeof(data))); + } + } + } +} + +BOOST_AUTO_TEST_CASE( test_FramedTransport_Empty_Flush ) { + init_data(); + + string output1("\x00\x00\x00\x01""a", 5); + string output2("\x00\x00\x00\x01""a\x00\x00\x00\x02""bc", 11); + + shared_ptr buffer(new TMemoryBuffer()); + TFramedTransport trans(buffer); + + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), ""); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), ""); + trans.flush(); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), ""); + trans.write((const uint8_t*)"a", 1); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), ""); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), output1); + trans.flush(); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), output1); + trans.write((const uint8_t*)"bc", 2); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), output1); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), output2); + trans.flush(); + trans.flush(); + BOOST_CHECK_EQUAL(buffer->getBufferAsString(), output2); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/test/TFDTransportTest.cpp b/test/TFDTransportTest.cpp new file mode 100644 index 00000000..1ec538e3 --- /dev/null +++ b/test/TFDTransportTest.cpp @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +using apache::thrift::transport::TTransportException; +using apache::thrift::transport::TFDTransport; + +class DummyException : std::exception { +}; + +int main() { + { + TFDTransport t(256, TFDTransport::NO_CLOSE_ON_DESTROY); + } + + try { + { + TFDTransport t(256, TFDTransport::CLOSE_ON_DESTROY); + } + std::abort(); + } catch (TTransportException) { + } + + try { + { + TFDTransport t(256, TFDTransport::CLOSE_ON_DESTROY); + throw DummyException(); + } + std::abort(); + } catch (TTransportException&) { + abort(); + } catch (DummyException&) { + } + + return 0; + +} diff --git a/test/TMemoryBufferTest.cpp b/test/TMemoryBufferTest.cpp new file mode 100644 index 00000000..49bd10b5 --- /dev/null +++ b/test/TMemoryBufferTest.cpp @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "gen-cpp/ThriftTest_types.h" + +BOOST_AUTO_TEST_SUITE( TMemoryBufferTest ) + +BOOST_AUTO_TEST_CASE( test_roundtrip ) { + using apache::thrift::transport::TMemoryBuffer; + using apache::thrift::protocol::TBinaryProtocol; + using boost::shared_ptr; + + shared_ptr strBuffer(new TMemoryBuffer()); + shared_ptr binaryProtcol(new TBinaryProtocol(strBuffer)); + + thrift::test::Xtruct a; + a.i32_thing = 10; + a.i64_thing = 30; + a.string_thing ="holla back a"; + + a.write(binaryProtcol.get()); + std::string serialized = strBuffer->getBufferAsString(); + + shared_ptr strBuffer2(new TMemoryBuffer()); + shared_ptr binaryProtcol2(new TBinaryProtocol(strBuffer2)); + + strBuffer2->resetBuffer((uint8_t*)serialized.data(), serialized.length()); + thrift::test::Xtruct a2; + a2.read(binaryProtcol2.get()); + + assert(a == a2); + } + +BOOST_AUTO_TEST_CASE( test_copy ) + { + using apache::thrift::transport::TMemoryBuffer; + using std::string; + using std::cout; + using std::endl; + + string* str1 = new string("abcd1234"); + const char* data1 = str1->data(); + TMemoryBuffer buf((uint8_t*)str1->data(), str1->length(), TMemoryBuffer::COPY); + delete str1; + string* str2 = new string("plsreuse"); + bool obj_reuse = (str1 == str2); + bool dat_reuse = (data1 == str2->data()); + cout << "Object reuse: " << obj_reuse << " Data reuse: " << dat_reuse + << ((obj_reuse && dat_reuse) ? " YAY!" : "") << endl; + delete str2; + + string str3 = "wxyz", str4 = "6789"; + buf.readAppendToString(str3, 4); + buf.readAppendToString(str4, INT_MAX); + + assert(str3 == "wxyzabcd"); + assert(str4 == "67891234"); + } + +BOOST_AUTO_TEST_CASE( test_exceptions ) + { + using apache::thrift::transport::TTransportException; + using apache::thrift::transport::TMemoryBuffer; + using std::string; + + char data[] = "foo\0bar"; + + TMemoryBuffer buf1((uint8_t*)data, 7, TMemoryBuffer::OBSERVE); + string str = buf1.getBufferAsString(); + assert(str.length() == 7); + buf1.resetBuffer(); + try { + buf1.write((const uint8_t*)"foo", 3); + assert(false); + } catch (TTransportException& ex) {} + + TMemoryBuffer buf2((uint8_t*)data, 7, TMemoryBuffer::COPY); + try { + buf2.write((const uint8_t*)"bar", 3); + } catch (TTransportException& ex) { + assert(false); + } + } + +BOOST_AUTO_TEST_SUITE_END() diff --git a/test/TPipedTransportTest.cpp b/test/TPipedTransportTest.cpp new file mode 100644 index 00000000..5708fd21 --- /dev/null +++ b/test/TPipedTransportTest.cpp @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +using namespace std; +using boost::shared_ptr; +using apache::thrift::transport::TTransportException; +using apache::thrift::transport::TPipedTransport; +using apache::thrift::transport::TMemoryBuffer; + +int main() { + shared_ptr underlying(new TMemoryBuffer); + shared_ptr pipe(new TMemoryBuffer); + shared_ptr trans(new TPipedTransport(underlying, pipe)); + + uint8_t buffer[4]; + + underlying->write((uint8_t*)"abcd", 4); + trans->readAll(buffer, 2); + assert( string((char*)buffer, 2) == "ab" ); + trans->readEnd(); + assert( pipe->getBufferAsString() == "ab" ); + pipe->resetBuffer(); + underlying->write((uint8_t*)"ef", 2); + trans->readAll(buffer, 2); + assert( string((char*)buffer, 2) == "cd" ); + trans->readAll(buffer, 2); + assert( string((char*)buffer, 2) == "ef" ); + trans->readEnd(); + assert( pipe->getBufferAsString() == "cdef" ); + + return 0; + +} diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift new file mode 100644 index 00000000..3517640a --- /dev/null +++ b/test/ThriftTest.thrift @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace java thrift.test +namespace cpp thrift.test +namespace rb Thrift.Test +namespace perl ThriftTest +namespace csharp Thrift.Test + +enum Numberz +{ + ONE = 1, + TWO, + THREE, + FIVE = 5, + SIX, + EIGHT = 8 +} + +typedef i64 UserId + +struct Bonk +{ + 1: string message, + 2: i32 type +} + +struct Bools { + 1: bool im_true, + 2: bool im_false, +} + +struct Xtruct +{ + 1: string string_thing, + 4: byte byte_thing, + 9: i32 i32_thing, + 11: i64 i64_thing +} + +struct Xtruct2 +{ + 1: byte byte_thing, + 2: Xtruct struct_thing, + 3: i32 i32_thing +} + +struct Xtruct3 +{ + 1: string string_thing, + 4: i32 changed, + 9: i32 i32_thing, + 11: i64 i64_thing +} + + +struct Insanity +{ + 1: map userMap, + 2: list xtructs +} + +struct CrazyNesting { + 1: string string_field, + 2: optional set set_field, + 3: required list< map,map>>>>> list_field +} + +exception Xception { + 1: i32 errorCode, + 2: string message +} + +exception Xception2 { + 1: i32 errorCode, + 2: Xtruct struct_thing +} + +struct EmptyStruct {} + +struct OneField { + 1: EmptyStruct field +} + +service ThriftTest +{ + void testVoid(), + string testString(1: string thing), + byte testByte(1: byte thing), + i32 testI32(1: i32 thing), + i64 testI64(1: i64 thing), + double testDouble(1: double thing), + Xtruct testStruct(1: Xtruct thing), + Xtruct2 testNest(1: Xtruct2 thing), + map testMap(1: map thing), + set testSet(1: set thing), + list testList(1: list thing), + Numberz testEnum(1: Numberz thing), + UserId testTypedef(1: UserId thing), + + map> testMapMap(1: i32 hello), + + /* So you think you've got this all worked, out eh? */ + map> testInsanity(1: Insanity argument), + + /* Multiple parameters */ + Xtruct testMulti(1: byte arg0, 2: i32 arg1, 3: i64 arg2, 4: map arg3, 5: Numberz arg4, 6: UserId arg5), + + /* Exception specifier */ + + void testException(1: string arg) throws(1: Xception err1), + + /* Multiple exceptions specifier */ + + Xtruct testMultiException(1: string arg0, 2: string arg1) throws(1: Xception err1, 2: Xception2 err2) + + /* Test oneway void */ + oneway void testOneway(1:i32 secondsToSleep) +} + +service SecondService +{ + void blahBlah() +} + +struct VersioningTestV1 { + 1: i32 begin_in_both, + 3: string old_string, + 12: i32 end_in_both +} + +struct VersioningTestV2 { + 1: i32 begin_in_both, + + 2: i32 newint, + 3: byte newbyte, + 4: i16 newshort, + 5: i64 newlong, + 6: double newdouble + 7: Bonk newstruct, + 8: list newlist, + 9: set newset, + 10: map newmap, + 11: string newstring, + 12: i32 end_in_both +} + +struct ListTypeVersioningV1 { + 1: list myints; + 2: string hello; +} + +struct ListTypeVersioningV2 { + 1: list strings; + 2: string hello; +} diff --git a/test/ThriftTest_extras.cpp b/test/ThriftTest_extras.cpp new file mode 100644 index 00000000..b78f2763 --- /dev/null +++ b/test/ThriftTest_extras.cpp @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Extra functions required for ThriftTest_types to work + +#include +#include "gen-cpp/ThriftTest_types.h" + + +namespace thrift { namespace test { + +bool Insanity::operator<(thrift::test::Insanity const& other) const { + using apache::thrift::ThriftDebugString; + return ThriftDebugString(*this) < ThriftDebugString(other); +} + +}} diff --git a/test/UnitTestMain.cpp b/test/UnitTestMain.cpp new file mode 100644 index 00000000..d90c54f4 --- /dev/null +++ b/test/UnitTestMain.cpp @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define BOOST_TEST_MODULE thrift +#include diff --git a/test/ZlibTest.cpp b/test/ZlibTest.cpp new file mode 100644 index 00000000..45d3ecc4 --- /dev/null +++ b/test/ZlibTest.cpp @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* +thrift --gen cpp DebugProtoTest.thrift +g++ -Wall -g -I../lib/cpp/src -I/usr/local/include/boost-1_33_1 \ + ZlibTest.cpp \ + ../lib/cpp/.libs/libthriftz.a ../lib/cpp/.libs/libthrift.a \ + -lz -o ZlibTest +./ZlibTest +*/ + +#include +#include +#include +#include +#include +#include + + +// Distributions of reads and writes meant to approximate a real load, +// mixing up small and large while also hitting various boundary conditions. +// Generated by Python: int(random.lognormvariate(2.5, 1)) +unsigned int dist[][5000] = { + { 1<<15 }, + + { + 5,13,9,1,8,9,11,13,18,48,24,13,21,13,5,11,35,2,4,20,17,72,27,14,15,4,7,26, + 12,1,14,9,2,16,29,41,7,24,4,27,14,4,1,4,25,3,6,34,10,8,50,2,14,13,55,29,3, + 43,53,49,14,4,10,32,27,48,1,3,1,11,5,17,16,51,17,30,15,11,9,2,2,11,52,12,2, + 13,94,1,19,1,38,2,8,43,8,33,7,30,8,17,22,2,15,14,12,34,2,12,6,37,29,74,3, + 165,16,11,17,5,14,3,10,7,37,11,24,7,1,3,12,37,8,9,34,17,12,8,21,13,37,1,4, + 30,14,78,4,15,2,40,37,17,12,36,82,14,4,1,4,7,17,11,16,88,77,2,3,15,3,34,11, + 5,79,22,34,8,4,4,40,22,24,28,9,13,3,34,27,9,16,39,16,39,13,2,4,3,41,26,10,4, + 33,4,7,12,5,6,3,10,30,8,21,16,58,19,9,0,47,7,13,11,19,15,7,53,57,2,13,28,22, + 3,16,9,25,33,12,40,7,12,64,7,14,24,44,9,2,14,11,2,58,1,26,30,11,9,5,24,7,9, + 94,2,10,21,5,5,4,5,6,179,9,18,2,7,13,31,41,17,4,36,3,21,6,26,8,15,18,44,27, + 11,9,25,7,0,14,2,12,20,23,13,2,163,9,5,15,65,2,14,6,8,98,11,15,14,34,2,3,10, + 22,9,92,7,10,32,67,13,3,4,35,8,2,1,5,0,26,381,7,27,8,2,16,93,4,19,5,8,25,9, + 31,14,4,21,5,3,9,22,56,4,18,3,11,18,6,4,3,40,12,16,110,8,35,14,1,18,40,9,12, + 14,3,11,7,57,13,18,116,53,19,22,7,16,11,5,8,21,16,1,75,21,20,1,28,2,6,1,7, + 19,38,5,6,9,9,4,1,7,55,36,62,5,4,4,24,15,1,12,35,48,20,5,17,1,5,26,15,4,54, + 13,5,5,15,5,19,32,29,31,7,6,40,7,80,11,18,8,128,48,6,12,84,13,4,7,2,13,9,16, + 17,3,254,1,4,181,8,44,7,6,24,27,9,23,14,34,16,22,25,10,3,3,4,4,12,2,12,6,7, + 13,58,13,6,11,19,53,11,66,18,19,10,4,13,2,5,49,58,1,67,7,21,64,14,11,14,8,3, + 26,33,91,31,20,7,9,42,39,4,3,55,11,10,0,7,4,75,8,12,0,27,3,8,9,0,12,12,23, + 28,23,20,4,13,30,2,22,20,19,30,6,22,2,6,4,24,7,19,55,86,5,33,2,161,6,7,1,62, + 13,3,72,12,12,9,7,12,10,5,10,29,1,5,22,13,13,5,2,12,3,7,14,18,2,3,46,21,17, + 15,19,3,27,5,16,45,31,10,8,17,18,18,3,7,24,6,55,9,3,6,12,10,12,8,91,9,4,4,4, + 27,29,16,5,7,22,43,28,11,14,8,11,28,109,55,71,40,3,8,22,26,15,44,3,25,29,5, + 3,32,17,12,3,29,27,25,15,11,8,40,39,38,17,3,9,11,2,32,11,6,20,48,75,27,3,7, + 54,12,95,12,7,24,23,2,13,8,15,16,5,12,4,17,7,19,88,2,6,13,115,45,12,21,2,86, + 74,9,7,5,16,32,16,2,21,18,6,34,5,18,260,7,12,16,44,19,92,31,7,8,2,9,0,0,15, + 8,38,4,8,20,18,2,83,3,3,4,9,5,3,10,3,5,29,15,7,11,8,48,17,23,2,17,4,11,22, + 21,64,8,8,4,19,95,0,17,28,9,11,20,71,5,11,18,12,13,45,49,4,1,33,32,23,13,5, + 52,2,2,16,3,4,7,12,2,1,12,6,24,1,22,155,21,3,45,4,12,44,26,5,40,36,9,9,8,20, + 35,31,3,2,32,50,10,8,37,2,75,35,22,15,192,8,11,23,1,4,29,6,8,8,5,12,18,32,4, + 7,12,2,0,0,9,5,48,11,35,3,1,123,6,29,8,11,8,23,51,16,6,63,12,2,5,4,14,2,15, + 7,14,3,2,7,17,32,8,8,10,1,23,62,2,49,6,49,47,23,3,20,7,11,39,10,24,6,15,5,5, + 11,8,16,36,8,13,20,3,10,44,7,52,7,10,36,6,15,10,5,11,4,14,19,17,10,12,3,6, + 23,4,13,94,70,7,36,7,38,7,28,8,4,15,3,19,4,33,39,21,109,4,80,6,40,4,432,4,4, + 7,8,3,31,8,28,37,34,10,2,21,5,22,0,7,36,14,12,6,24,1,21,5,9,2,29,20,54,113, + 13,31,39,27,6,0,27,4,5,2,43,7,8,57,8,62,7,9,12,22,90,30,6,19,7,10,20,6,5,58, + 32,30,41,4,10,25,13,3,8,7,10,2,9,6,151,44,16,12,16,20,8,3,18,11,17,4,10,45, + 15,8,56,38,52,25,40,14,4,17,15,8,2,19,7,8,26,30,2,3,180,8,26,17,38,35,5,16, + 28,5,15,56,13,14,18,9,15,83,27,3,9,4,11,8,27,27,44,10,12,8,3,48,14,7,9,4,4, + 8,4,5,9,122,8,14,12,19,17,21,4,29,63,21,17,10,12,18,47,10,10,53,4,18,16,4,8, + 118,9,5,12,9,11,9,3,12,32,3,23,2,15,3,3,30,3,17,235,15,22,9,299,14,17,1,5, + 16,8,3,7,3,13,2,7,6,4,8,66,2,13,6,15,16,47,3,36,5,7,10,24,1,9,9,8,13,16,26, + 12,7,24,21,18,49,23,39,10,41,4,13,4,27,11,12,12,19,4,147,8,10,9,40,21,2,83, + 10,5,6,11,25,9,50,57,40,12,12,21,1,3,24,23,9,3,9,13,2,3,12,57,8,11,13,15,26, + 15,10,47,36,4,25,1,5,8,5,4,0,12,49,5,19,4,6,16,14,6,10,69,10,33,29,7,8,61, + 12,4,0,3,7,6,3,16,29,27,38,4,21,0,24,3,2,1,19,16,22,2,8,138,11,7,7,3,12,22, + 3,16,5,7,3,53,9,10,32,14,5,7,3,6,22,9,59,26,8,7,58,5,16,11,55,7,4,11,146,91, + 8,13,18,14,6,8,8,31,26,22,6,11,30,11,30,15,18,31,3,48,17,7,6,4,9,2,25,3,35, + 13,13,7,8,4,31,10,8,10,4,3,45,10,23,2,7,259,17,21,13,14,3,26,3,8,27,4,18,9, + 66,7,12,5,8,17,4,23,55,41,51,2,32,26,66,4,21,14,12,65,16,22,17,5,14,2,29,24, + 7,3,36,2,43,53,86,5,28,4,58,13,49,121,6,2,73,2,1,47,4,2,27,10,35,28,27,10, + 17,10,56,7,10,14,28,20,24,40,7,4,7,3,10,11,32,6,6,3,15,11,54,573,2,3,6,2,3, + 14,64,4,16,12,16,42,10,26,4,6,11,69,18,27,2,2,17,22,9,13,22,11,6,1,15,49,3, + 14,1 + }, + + { + 11,11,11,15,47,1,3,1,23,5,8,18,3,23,15,21,1,7,19,10,26,1,17,11,31,21,41,18, + 34,4,9,58,19,3,3,36,5,18,13,3,14,4,9,10,4,19,56,15,3,5,3,11,27,9,4,10,13,4, + 11,6,9,2,18,3,10,19,11,4,53,4,2,2,3,4,58,16,3,0,5,30,2,11,93,10,2,14,10,6,2, + 115,2,25,16,22,38,101,4,18,13,2,145,51,45,15,14,15,13,20,7,24,5,13,14,30,40, + 10,4,107,12,24,14,39,12,6,13,20,7,7,11,5,18,18,45,22,6,39,3,2,1,51,9,11,4, + 13,9,38,44,8,11,9,15,19,9,23,17,17,17,13,9,9,1,10,4,18,6,2,9,5,27,32,72,8, + 37,9,4,10,30,17,20,15,17,66,10,4,73,35,37,6,4,16,117,45,13,4,75,5,24,65,10, + 4,9,4,13,46,5,26,29,10,4,4,52,3,13,18,63,6,14,9,24,277,9,88,2,48,27,123,14, + 61,7,5,10,8,7,90,3,10,3,3,48,17,13,10,18,33,2,19,36,6,21,1,16,12,5,6,2,16, + 15,29,88,28,2,15,6,11,4,6,11,3,3,4,18,9,53,5,4,3,33,8,9,8,6,7,36,9,62,14,2, + 1,10,1,16,7,32,7,23,20,11,10,23,2,1,0,9,16,40,2,81,5,22,8,5,4,37,51,37,10, + 19,57,11,2,92,31,6,39,10,13,16,8,20,6,9,3,10,18,25,23,12,30,6,2,26,7,64,18, + 6,30,12,13,27,7,10,5,3,33,24,99,4,23,4,1,27,7,27,49,8,20,16,3,4,13,9,22,67, + 28,3,10,16,3,2,10,4,8,1,8,19,3,85,6,21,1,9,16,2,30,10,33,12,4,9,3,1,60,38,6, + 24,32,3,14,3,40,8,34,115,5,9,27,5,96,3,40,6,15,5,8,22,112,5,5,25,17,58,2,7, + 36,21,52,1,3,95,12,21,4,11,8,59,24,5,21,4,9,15,8,7,21,3,26,5,11,6,7,17,65, + 14,11,10,2,17,5,12,22,4,4,2,21,8,112,3,34,63,35,2,25,1,2,15,65,23,0,3,5,15, + 26,27,9,5,48,11,15,4,9,5,33,20,15,1,18,19,11,24,40,10,21,74,6,6,32,30,40,5, + 4,7,44,10,25,46,16,12,5,40,7,18,5,18,9,12,8,4,25,5,6,36,4,43,8,9,12,35,17,4, + 8,9,11,27,5,10,17,40,8,12,4,18,9,18,12,20,25,39,42,1,24,13,22,15,7,112,35,3, + 7,17,33,2,5,5,19,8,4,12,24,14,13,2,1,13,6,5,19,11,7,57,0,19,6,117,48,14,8, + 10,51,17,12,14,2,5,8,9,15,4,48,53,13,22,4,25,12,11,19,45,5,2,6,54,22,9,15,9, + 13,2,7,11,29,82,16,46,4,26,14,26,40,22,4,26,6,18,13,4,4,20,3,3,7,12,17,8,9, + 23,6,20,7,25,23,19,5,15,6,23,15,11,19,11,3,17,59,8,18,41,4,54,23,44,75,13, + 20,6,11,2,3,1,13,10,3,7,12,3,4,7,8,30,6,6,7,3,32,9,5,28,6,114,42,13,36,27, + 59,6,93,13,74,8,69,140,3,1,17,48,105,6,11,5,15,1,10,10,14,8,53,0,8,24,60,2, + 6,35,2,12,32,47,16,17,75,2,5,4,37,28,10,5,9,57,4,59,5,12,13,7,90,5,11,5,24, + 22,13,30,1,2,10,9,6,19,3,18,47,2,5,7,9,35,15,3,6,1,21,14,14,18,14,9,12,8,73, + 6,19,3,32,9,14,17,17,5,55,23,6,16,28,3,11,48,4,6,6,6,12,16,30,10,30,27,51, + 18,29,2,3,15,1,76,0,16,33,4,27,3,62,4,10,2,4,8,15,9,41,26,22,2,4,20,4,49,0, + 8,1,57,13,12,39,3,63,10,19,34,35,2,7,8,29,72,4,10,0,77,8,6,7,9,15,21,9,4,1, + 20,23,1,9,18,9,15,36,4,7,6,15,5,7,7,40,2,9,22,2,3,20,4,12,34,13,6,18,15,1, + 38,20,12,7,16,3,19,85,12,16,18,16,2,17,1,13,8,6,12,15,97,17,12,9,3,21,15,12, + 23,44,81,26,30,2,5,17,6,6,0,22,42,19,6,19,41,14,36,7,3,56,7,9,3,2,6,9,69,3, + 15,4,30,28,29,7,9,15,17,17,6,1,6,153,9,33,5,12,14,16,28,3,8,7,14,12,4,6,36, + 9,24,13,13,4,2,9,15,19,9,53,7,13,4,150,17,9,2,6,12,7,3,5,58,19,58,28,8,14,3, + 20,3,0,32,56,7,5,4,27,1,68,4,29,13,5,58,2,9,65,41,27,16,15,12,14,2,10,9,24, + 3,2,9,2,2,3,14,32,10,22,3,13,11,4,6,39,17,0,10,5,5,10,35,16,19,14,1,8,63,19, + 14,8,56,10,2,12,6,12,6,7,16,2,9,9,12,20,73,25,13,21,17,24,5,32,8,12,25,8,14, + 16,5,23,3,7,6,3,11,24,6,30,4,21,13,28,4,6,29,15,5,17,6,26,8,15,8,3,7,7,50, + 11,30,6,2,28,56,16,24,25,23,24,89,31,31,12,7,22,4,10,17,3,3,8,11,13,5,3,27, + 1,12,1,14,8,10,29,2,5,2,2,20,10,0,31,10,21,1,48,3,5,43,4,5,18,13,5,18,25,34, + 18,3,5,22,16,3,4,20,3,9,3,25,6,6,44,21,3,12,7,5,42,3,2,14,4,36,5,3,45,51,15, + 9,11,28,9,7,6,6,12,26,5,14,10,11,42,55,13,21,4,28,6,7,23,27,11,1,41,36,0,32, + 15,26,2,3,23,32,11,2,15,7,29,26,144,33,20,12,7,21,10,7,11,65,46,10,13,20,32, + 4,4,5,19,2,19,15,49,41,1,75,10,11,25,1,2,45,11,8,27,18,10,60,28,29,12,30,19, + 16,4,24,11,19,27,17,49,18,7,40,13,19,22,8,55,12,11,3,6,5,11,8,10,22,5,9,9, + 25,7,17,7,64,1,24,2,12,17,44,4,12,27,21,11,10,7,47,5,9,13,12,38,27,21,7,29, + 7,1,17,3,3,5,48,62,10,3,11,17,15,15,6,3,8,10,8,18,19,13,3,9,7,6,44,9,10,4, + 43,8,6,6,14,20,38,24,2,4,5,5,7,5,9,39,8,44,40,9,19,7,3,15,25,2,37,18,15,9,5, + 8,32,10,5,18,4,7,46,20,17,23,4,11,16,18,31,11,3,11,1,14,1,25,4,27,13,13,39, + 14,6,6,35,6,16,13,11,122,21,15,20,24,10,5,152,15,39,5,20,16,9,14,7,53,6,3,8, + 19,63,32,6,2,3,20,1,19,5,13,42,15,4,6,68,31,46,11,38,10,24,5,5,8,9,12,3,35, + 46,26,16,2,8,4,74,16,44,4,5,1,16,4,14,23,16,69,15,42,31,14,7,7,6,97,14,40,1, + 8,7,34,9,39,19,13,15,10,21,18,10,5,15,38,7,5,12,7,20,15,4,11,6,14,5,17,7,39, + 35,36,18,20,26,22,4,2,36,21,64,0,5,9,10,6,4,1,7,3,1,3,3,4,10,20,90,2,22,48, + 16,23,2,33,40,1,21,21,17,20,8,8,12,4,83,14,48,4,21,3,9,27,5,11,40,15,9,3,16, + 17,9,11,4,24,31,17,3,4,2,11,1,8,4,8,6,41,17,4,13,3,7,17,8,27,5,13,6,10,7,13, + 12,18,13,60,18,3,8,1,12,125,2,7,16,2,11,2,4,7,26,5,9,14,14,16,8,14,7,14,6,9, + 13,9,6,4,26,35,49,36,55,3,9,6,40,26,23,31,19,41,2,10,31,6,54,5,69,16,7,8,16, + 1,5,7,4,22,7,7,5,4,48,11,13,3,98,4,11,19,4,2,14,7,34,7,10,3,2,12,7,6,2,5,118 + }, +}; + + +int main() { + using namespace std; + using namespace boost; + using namespace apache::thrift::transport; + + char *file_names[] = { + // Highly compressible. + "./gen-cpp/DebugProtoTest_types.cpp", + // Uncompressible. + "/dev/urandom", + // Null-terminated. + NULL, + }; + + + for (char** fnamep = &file_names[0]; *fnamep != NULL; fnamep++) { + ifstream file(*fnamep); + char buf[32*1024]; + file.read(buf, sizeof(buf)); + vector content(buf, buf+file.gcount()); + vector mirror; + file.close(); + + assert(content.size() == 32*1024); + + // Let's just start with the big dog! + { + mirror.clear(); + shared_ptr membuf(new TMemoryBuffer()); + shared_ptr zlib_trans(new TZlibTransport(membuf, false)); + zlib_trans->write(&content[0], content.size()); + zlib_trans->flush(); + mirror.resize(content.size()); + uint32_t got = zlib_trans->read(&mirror[0], mirror.size()); + assert(got == content.size()); + assert(mirror == content); + zlib_trans->verifyChecksum(); + } + + // This one is tricky. I separate the last byte of the stream out + // into a separate crbuf_. The last byte is part of the checksum, + // so the entire read goes fine, but when I go to verify the checksum + // it isn't there. The original implementation complained that + // the stream was not complete. I'm about to go fix that. + // It worked. Awesome. + { + mirror.clear(); + shared_ptr membuf(new TMemoryBuffer()); + shared_ptr zlib_trans(new TZlibTransport(membuf, false)); + zlib_trans->write(&content[0], content.size()); + zlib_trans->flush(); + string tmp_buf; + membuf->appendBufferToString(tmp_buf); + zlib_trans.reset(new TZlibTransport(membuf, false, + TZlibTransport::DEFAULT_URBUF_SIZE, + tmp_buf.length()-1)); + mirror.resize(content.size()); + uint32_t got = zlib_trans->read(&mirror[0], mirror.size()); + assert(got == content.size()); + assert(mirror == content); + zlib_trans->verifyChecksum(); + } + + // Make sure we still get that "not complete" error if + // it really isn't complete. + { + mirror.clear(); + shared_ptr membuf(new TMemoryBuffer()); + shared_ptr zlib_trans(new TZlibTransport(membuf, false)); + zlib_trans->write(&content[0], content.size()); + zlib_trans->flush(); + string tmp_buf; + membuf->appendBufferToString(tmp_buf); + tmp_buf.erase(tmp_buf.length() - 1); + membuf->resetFromString(tmp_buf); + mirror.resize(content.size()); + uint32_t got = zlib_trans->read(&mirror[0], mirror.size()); + assert(got == content.size()); + assert(mirror == content); + try { + zlib_trans->verifyChecksum(); + assert(false); + } catch (TTransportException& ex) { + assert(ex.getType() == TTransportException::CORRUPTED_DATA); + } + } + + // Try it with a mix of read/write sizes. + for (int d1 = 0; d1 < 3; d1++) { + for (int d2 = 0; d2 < 3; d2++) { + mirror.clear(); + shared_ptr membuf(new TMemoryBuffer()); + shared_ptr zlib_trans(new TZlibTransport(membuf, false)); + int idx; + unsigned int tot; + + idx = 0; + tot = 0; + while (tot < content.size()) { + zlib_trans->write(&content[tot], dist[d1][idx]); + tot += dist[d1][idx]; + idx++; + } + + zlib_trans->flush(); + mirror.resize(content.size()); + + idx = 0; + tot = 0; + while (tot < mirror.size()) { + uint32_t got = zlib_trans->read(&mirror[tot], dist[d2][idx]); + assert(got == dist[d2][idx]); + tot += dist[d2][idx]; + idx++; + } + + assert(mirror == content); + zlib_trans->verifyChecksum(); + } + } + + // Verify checksum checking. + { + mirror.clear(); + shared_ptr membuf(new TMemoryBuffer()); + shared_ptr zlib_trans(new TZlibTransport(membuf, false)); + zlib_trans->write(&content[0], content.size()); + zlib_trans->flush(); + string tmp_buf; + membuf->appendBufferToString(tmp_buf); + tmp_buf[57]++; + membuf->resetFromString(tmp_buf); + mirror.resize(content.size()); + try { + zlib_trans->read(&mirror[0], mirror.size()); + zlib_trans->verifyChecksum(); + assert(false); + } catch (TZlibTransportException& ex) { + assert(ex.getType() == TTransportException::INTERNAL_ERROR); + } + } + } + + return 0; +} diff --git a/test/cpp/Stress-test.mk b/test/cpp/Stress-test.mk new file mode 100644 index 00000000..1a753b3a --- /dev/null +++ b/test/cpp/Stress-test.mk @@ -0,0 +1,71 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +ifndef thrift_home +thrift_home=../.. +endif #thrift_home + +target: all + +ifndef boost_home +#boost_home=../../../../../thirdparty/boost_1_33_1 +boost_home=/usr/local/include/boost-1_33_1 +endif #boost_home +target: all + +include_paths = $(thrift_home)/lib/cpp/src \ + $(thrift_home)/lib/cpp \ + $(thrift_home)/ \ + $(boost_home) + +include_flags = $(patsubst %,-I%, $(include_paths)) + +# Tools +ifndef THRIFT +THRIFT = ../../compiler/cpp/thrift +endif # THRIFT + +CC = g++ +LD = g++ + +# Compiler flags +DCFL = -Wall -O3 -g -I./gen-cpp $(include_flags) -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent +CFL = -Wall -O3 -I./gen-cpp $(include_flags) -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent + +all: stress-test stress-test-nb + +debug: stress-test-debug stress-test-debug-nb + +stubs: ../StressTest.thrift + $(THRIFT) --gen cpp --gen php ../StressTest.thrift + +stress-test-debug-nb: stubs + g++ -o stress-test-nb $(DCFL) src/nb-main.cpp ./gen-cpp/Service.cpp gen-cpp/StressTest_types.cpp + +stress-test-nb: stubs + g++ -o stress-test-nb $(CFL) src/nb-main.cpp ./gen-cpp/Service.cpp gen-cpp/StressTest_types.cpp + +stress-test-debug: stubs + g++ -o stress-test $(DCFL) src/main.cpp ./gen-cpp/Service.cpp gen-cpp/StressTest_types.cpp + +stress-test: stubs + g++ -o stress-test $(CFL) src/main.cpp ./gen-cpp/Service.cpp gen-cpp/StressTest_types.cpp + +clean: + rm -fr stress-test stress-test-nb gen-cpp diff --git a/test/cpp/Thrift-test.mk b/test/cpp/Thrift-test.mk new file mode 100644 index 00000000..fb3e38bc --- /dev/null +++ b/test/cpp/Thrift-test.mk @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Makefile for Thrift test project. +# Default target is everything + +ifndef thrift_home +thrift_home=../.. +endif #thrift_home + +target: all + +ifndef boost_home +#boost_home=../../../../../thirdparty/boost_1_33_1 +boost_home=/usr/local/include/boost-1_33_1 +endif #boost_home +target: all + +include_paths = $(thrift_home)/lib/cpp/src \ + $(boost_home) + +include_flags = $(patsubst %,-I%, $(include_paths)) + +# Tools +ifndef THRIFT +THRIFT = ../../compiler/cpp/thrift +endif # THRIFT + +CC = g++ +LD = g++ + +# Compiler flags +DCFL = -Wall -O3 -g -I. -I./gen-cpp $(include_flags) -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent +LFL = -L$(thrift_home)/lib/cpp/.libs -lthrift -lthriftnb -levent +CCFL = -Wall -O3 -I. -I./gen-cpp $(include_flags) +CFL = $(CCFL) $(LFL) + +all: server client + +debug: server-debug client-debug + +stubs: ../ThriftTest.thrift + $(THRIFT) --gen cpp ../ThriftTest.thrift + +server-debug: stubs + g++ -o TestServer $(DCFL) src/TestServer.cpp ./gen-cpp/ThriftTest.cpp ./gen-cpp/ThriftTest_types.cpp ../ThriftTest_extras.cpp + +client-debug: stubs + g++ -o TestClient $(DCFL) src/TestClient.cpp ./gen-cpp/ThriftTest.cpp ./gen-cpp/ThriftTest_types.cpp ../ThriftTest_extras.cpp + +server: stubs + g++ -o TestServer $(CFL) src/TestServer.cpp ./gen-cpp/ThriftTest.cpp ./gen-cpp/ThriftTest_types.cpp ../ThriftTest_extras.cpp + +client: stubs + g++ -o TestClient $(CFL) src/TestClient.cpp ./gen-cpp/ThriftTest.cpp ./gen-cpp/ThriftTest_types.cpp ../ThriftTest_extras.cpp + +small: + $(THRIFT) --gen cpp ../SmallTest.thrift + g++ -c $(CCFL) ./gen-cpp/SmallService.cpp ./gen-cpp/SmallTest_types.cpp + +clean: + rm -fr *.o TestServer TestClient gen-cpp diff --git a/test/cpp/realloc/Makefile b/test/cpp/realloc/Makefile new file mode 100644 index 00000000..f89bbb3c --- /dev/null +++ b/test/cpp/realloc/Makefile @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# This probably should not go into "make check", because it is an experiment, +# not a test. Specifically, it is meant to determine how likely realloc is +# to avoid a copy. This is poorly documented. + +run: realloc_test + for it in 1 4 64 ; do \ + for nb in 1 8 64 512 ; do \ + for mins in 64 512 ; do \ + for maxs in 2048 262144 ; do \ + for db in 8 64 ; do \ + ./realloc_test $$nb $$mins $$maxs $$db $$it \ + ; done \ + ; done \ + ; done \ + ; done \ + ; done \ + > raw_stats + +CFLAGS = -Wall -g -std=c99 +LDLIBS = -ldl +realloc_test: realloc_test.c diff --git a/test/cpp/realloc/realloc_test.c b/test/cpp/realloc/realloc_test.c new file mode 100644 index 00000000..f9763adf --- /dev/null +++ b/test/cpp/realloc/realloc_test.c @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#define _GNU_SOURCE +#include +#include +#include +#include + +int copies; +int non_copies; + +void *realloc(void *ptr, size_t size) { + static void *(*real_realloc)(void*, size_t) = NULL; + if (real_realloc == NULL) { + real_realloc = (void* (*) (void*, size_t)) dlsym(RTLD_NEXT, "realloc"); + } + + void *ret_ptr = (*real_realloc)(ptr, size); + + if (ret_ptr == ptr) { + non_copies++; + } else { + copies++; + } + + return ret_ptr; +} + + +struct TMemoryBuffer { + void* ptr; + int size; +}; + +int main(int argc, char *argv[]) { + int num_buffers; + int init_size; + int max_size; + int doublings; + int iterations; + + if (argc < 6 || + argc > 7 || + (num_buffers = atoi(argv[1])) == 0 || + (init_size = atoi(argv[2])) == 0 || + (max_size = atoi(argv[3])) == 0 || + init_size > max_size || + (iterations = atoi(argv[4])) == 0 || + (doublings = atoi(argv[5])) == 0 || + (argc == 7 && atoi(argv[6]) == 0)) { + fprintf(stderr, "usage: realloc_test [seed]\n"); + exit(EXIT_FAILURE); + } + + for ( int i = 0 ; i < argc ; i++ ) { + printf("%s ", argv[i]); + } + printf("\n"); + + if (argc == 7) { + srand(atoi(argv[6])); + } else { + srand(time(NULL)); + } + + struct TMemoryBuffer* buffers = calloc(num_buffers, sizeof(*buffers)); + if (buffers == NULL) abort(); + + for ( int i = 0 ; i < num_buffers ; i++ ) { + buffers[i].size = max_size; + } + + while (iterations --> 0) { + for ( int i = 0 ; i < doublings * num_buffers ; i++ ) { + struct TMemoryBuffer* buf = &buffers[rand() % num_buffers]; + buf->size *= 2; + if (buf->size <= max_size) { + buf->ptr = realloc(buf->ptr, buf->size); + } else { + free(buf->ptr); + buf->size = init_size; + buf->ptr = malloc(buf->size); + } + if (buf->ptr == NULL) abort(); + } + } + + printf("Non-copied %d/%d (%.2f%%)\n", non_copies, copies + non_copies, 100.0 * non_copies / (copies + non_copies)); + return 0; +} diff --git a/test/cpp/src/TestClient.cpp b/test/cpp/src/TestClient.cpp new file mode 100644 index 00000000..5ddfa060 --- /dev/null +++ b/test/cpp/src/TestClient.cpp @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include "ThriftTest.h" + +#define __STDC_FORMAT_MACROS +#include + +using namespace boost; +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace thrift::test; + +//extern uint32_t g_socket_syscalls; + +// Current time, microseconds since the epoch +uint64_t now() +{ + long long ret; + struct timeval tv; + + gettimeofday(&tv, NULL); + ret = tv.tv_sec; + ret = ret*1000*1000 + tv.tv_usec; + return ret; +} + +int main(int argc, char** argv) { + string host = "localhost"; + int port = 9090; + int numTests = 1; + bool framed = false; + + for (int i = 0; i < argc; ++i) { + if (strcmp(argv[i], "-h") == 0) { + char* pch = strtok(argv[++i], ":"); + if (pch != NULL) { + host = string(pch); + } + pch = strtok(NULL, ":"); + if (pch != NULL) { + port = atoi(pch); + } + } else if (strcmp(argv[i], "-n") == 0) { + numTests = atoi(argv[++i]); + } else if (strcmp(argv[i], "-f") == 0) { + framed = true; + } + } + + + shared_ptr transport; + + shared_ptr socket(new TSocket(host, port)); + + if (framed) { + shared_ptr framedSocket(new TFramedTransport(socket)); + transport = framedSocket; + } else { + shared_ptr bufferedSocket(new TBufferedTransport(socket)); + transport = bufferedSocket; + } + + shared_ptr protocol(new TBinaryProtocol(transport)); + ThriftTestClient testClient(protocol); + + uint64_t time_min = 0; + uint64_t time_max = 0; + uint64_t time_tot = 0; + + int test = 0; + for (test = 0; test < numTests; ++test) { + + try { + transport->open(); + } catch (TTransportException& ttx) { + printf("Connect failed: %s\n", ttx.what()); + continue; + } + + /** + * CONNECT TEST + */ + printf("Test #%d, connect %s:%d\n", test+1, host.c_str(), port); + + uint64_t start = now(); + + /** + * VOID TEST + */ + try { + printf("testVoid()"); + testClient.testVoid(); + printf(" = void\n"); + } catch (TApplicationException tax) { + printf("%s\n", tax.what()); + } + + /** + * STRING TEST + */ + printf("testString(\"Test\")"); + string s; + testClient.testString(s, "Test"); + printf(" = \"%s\"\n", s.c_str()); + + /** + * BYTE TEST + */ + printf("testByte(1)"); + uint8_t u8 = testClient.testByte(1); + printf(" = %d\n", (int)u8); + + /** + * I32 TEST + */ + printf("testI32(-1)"); + int32_t i32 = testClient.testI32(-1); + printf(" = %d\n", i32); + + /** + * I64 TEST + */ + printf("testI64(-34359738368)"); + int64_t i64 = testClient.testI64(-34359738368LL); + printf(" = %"PRId64"\n", i64); + + /** + * DOUBLE TEST + */ + printf("testDouble(-5.2098523)"); + double dub = testClient.testDouble(-5.2098523); + printf(" = %lf\n", dub); + + /** + * STRUCT TEST + */ + printf("testStruct({\"Zero\", 1, -3, -5})"); + Xtruct out; + out.string_thing = "Zero"; + out.byte_thing = 1; + out.i32_thing = -3; + out.i64_thing = -5; + Xtruct in; + testClient.testStruct(in, out); + printf(" = {\"%s\", %d, %d, %"PRId64"}\n", + in.string_thing.c_str(), + (int)in.byte_thing, + in.i32_thing, + in.i64_thing); + + /** + * NESTED STRUCT TEST + */ + printf("testNest({1, {\"Zero\", 1, -3, -5}), 5}"); + Xtruct2 out2; + out2.byte_thing = 1; + out2.struct_thing = out; + out2.i32_thing = 5; + Xtruct2 in2; + testClient.testNest(in2, out2); + in = in2.struct_thing; + printf(" = {%d, {\"%s\", %d, %d, %"PRId64"}, %d}\n", + in2.byte_thing, + in.string_thing.c_str(), + (int)in.byte_thing, + in.i32_thing, + in.i64_thing, + in2.i32_thing); + + /** + * MAP TEST + */ + map mapout; + for (int32_t i = 0; i < 5; ++i) { + mapout.insert(make_pair(i, i-10)); + } + printf("testMap({"); + map::const_iterator m_iter; + bool first = true; + for (m_iter = mapout.begin(); m_iter != mapout.end(); ++m_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d => %d", m_iter->first, m_iter->second); + } + printf("})"); + map mapin; + testClient.testMap(mapin, mapout); + printf(" = {"); + first = true; + for (m_iter = mapin.begin(); m_iter != mapin.end(); ++m_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d => %d", m_iter->first, m_iter->second); + } + printf("}\n"); + + /** + * SET TEST + */ + set setout; + for (int32_t i = -2; i < 3; ++i) { + setout.insert(i); + } + printf("testSet({"); + set::const_iterator s_iter; + first = true; + for (s_iter = setout.begin(); s_iter != setout.end(); ++s_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *s_iter); + } + printf("})"); + set setin; + testClient.testSet(setin, setout); + printf(" = {"); + first = true; + for (s_iter = setin.begin(); s_iter != setin.end(); ++s_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *s_iter); + } + printf("}\n"); + + /** + * LIST TEST + */ + vector listout; + for (int32_t i = -2; i < 3; ++i) { + listout.push_back(i); + } + printf("testList({"); + vector::const_iterator l_iter; + first = true; + for (l_iter = listout.begin(); l_iter != listout.end(); ++l_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *l_iter); + } + printf("})"); + vector listin; + testClient.testList(listin, listout); + printf(" = {"); + first = true; + for (l_iter = listin.begin(); l_iter != listin.end(); ++l_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *l_iter); + } + printf("}\n"); + + /** + * ENUM TEST + */ + printf("testEnum(ONE)"); + Numberz ret = testClient.testEnum(ONE); + printf(" = %d\n", ret); + + printf("testEnum(TWO)"); + ret = testClient.testEnum(TWO); + printf(" = %d\n", ret); + + printf("testEnum(THREE)"); + ret = testClient.testEnum(THREE); + printf(" = %d\n", ret); + + printf("testEnum(FIVE)"); + ret = testClient.testEnum(FIVE); + printf(" = %d\n", ret); + + printf("testEnum(EIGHT)"); + ret = testClient.testEnum(EIGHT); + printf(" = %d\n", ret); + + /** + * TYPEDEF TEST + */ + printf("testTypedef(309858235082523)"); + UserId uid = testClient.testTypedef(309858235082523LL); + printf(" = %"PRId64"\n", uid); + + /** + * NESTED MAP TEST + */ + printf("testMapMap(1)"); + map > mm; + testClient.testMapMap(mm, 1); + printf(" = {"); + map >::const_iterator mi; + for (mi = mm.begin(); mi != mm.end(); ++mi) { + printf("%d => {", mi->first); + map::const_iterator mi2; + for (mi2 = mi->second.begin(); mi2 != mi->second.end(); ++mi2) { + printf("%d => %d, ", mi2->first, mi2->second); + } + printf("}, "); + } + printf("}\n"); + + /** + * INSANITY TEST + */ + Insanity insane; + insane.userMap.insert(make_pair(FIVE, 5000)); + Xtruct truck; + truck.string_thing = "Truck"; + truck.byte_thing = 8; + truck.i32_thing = 8; + truck.i64_thing = 8; + insane.xtructs.push_back(truck); + printf("testInsanity()"); + map > whoa; + testClient.testInsanity(whoa, insane); + printf(" = {"); + map >::const_iterator i_iter; + for (i_iter = whoa.begin(); i_iter != whoa.end(); ++i_iter) { + printf("%"PRId64" => {", i_iter->first); + map::const_iterator i2_iter; + for (i2_iter = i_iter->second.begin(); + i2_iter != i_iter->second.end(); + ++i2_iter) { + printf("%d => {", i2_iter->first); + map userMap = i2_iter->second.userMap; + map::const_iterator um; + printf("{"); + for (um = userMap.begin(); um != userMap.end(); ++um) { + printf("%d => %"PRId64", ", um->first, um->second); + } + printf("}, "); + + vector xtructs = i2_iter->second.xtructs; + vector::const_iterator x; + printf("{"); + for (x = xtructs.begin(); x != xtructs.end(); ++x) { + printf("{\"%s\", %d, %d, %"PRId64"}, ", + x->string_thing.c_str(), + (int)x->byte_thing, + x->i32_thing, + x->i64_thing); + } + printf("}"); + + printf("}, "); + } + printf("}, "); + } + printf("}\n"); + + /* test exception */ + + try { + printf("testClient.testException(\"Xception\") =>"); + testClient.testException("Xception"); + printf(" void\nFAILURE\n"); + + } catch(Xception& e) { + printf(" {%u, \"%s\"}\n", e.errorCode, e.message.c_str()); + } + + try { + printf("testClient.testException(\"success\") =>"); + testClient.testException("success"); + printf(" void\n"); + } catch(...) { + printf(" exception\nFAILURE\n"); + } + + /* test multi exception */ + + try { + printf("testClient.testMultiException(\"Xception\", \"test 1\") =>"); + Xtruct result; + testClient.testMultiException(result, "Xception", "test 1"); + printf(" result\nFAILURE\n"); + } catch(Xception& e) { + printf(" {%u, \"%s\"}\n", e.errorCode, e.message.c_str()); + } + + try { + printf("testClient.testMultiException(\"Xception2\", \"test 2\") =>"); + Xtruct result; + testClient.testMultiException(result, "Xception2", "test 2"); + printf(" result\nFAILURE\n"); + + } catch(Xception2& e) { + printf(" {%u, {\"%s\"}}\n", e.errorCode, e.struct_thing.string_thing.c_str()); + } + + try { + printf("testClient.testMultiException(\"success\", \"test 3\") =>"); + Xtruct result; + testClient.testMultiException(result, "success", "test 3"); + printf(" {{\"%s\"}}\n", result.string_thing.c_str()); + } catch(...) { + printf(" exception\nFAILURE\n"); + } + + /* test oneway void */ + { + printf("testClient.testOneway(3) =>"); + uint64_t startOneway = now(); + testClient.testOneway(3); + uint64_t elapsed = now() - startOneway; + if (elapsed > 200 * 1000) { // 0.2 seconds + printf(" FAILURE - took %.2f ms\n", (double)elapsed/1000.0); + } else { + printf(" success - took %.2f ms\n", (double)elapsed/1000.0); + } + } + + /** + * redo a simple test after the oneway to make sure we aren't "off by one" -- + * if the server treated oneway void like normal void, this next test will + * fail since it will get the void confirmation rather than the correct + * result. In this circumstance, the client will throw the exception: + * + * TApplicationException: Wrong method namea + */ + /** + * I32 TEST + */ + printf("re-test testI32(-1)"); + i32 = testClient.testI32(-1); + printf(" = %d\n", i32); + + + uint64_t stop = now(); + uint64_t tot = stop-start; + + printf("Total time: %"PRIu64" us\n", stop-start); + + time_tot += tot; + if (time_min == 0 || tot < time_min) { + time_min = tot; + } + if (tot > time_max) { + time_max = tot; + } + + transport->close(); + } + + // printf("\nSocket syscalls: %u", g_socket_syscalls); + printf("\nAll tests done.\n"); + + uint64_t time_avg = time_tot / numTests; + + printf("Min time: %"PRIu64" us\n", time_min); + printf("Max time: %"PRIu64" us\n", time_max); + printf("Avg time: %"PRIu64" us\n", time_avg); + + return 0; +} diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp new file mode 100644 index 00000000..83454dca --- /dev/null +++ b/test/cpp/src/TestServer.cpp @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ThriftTest.h" + +#include +#include +#include + +#define __STDC_FORMAT_MACROS +#include + +using namespace std; +using namespace boost; + +using namespace apache::thrift; +using namespace apache::thrift::concurrency; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; + +using namespace thrift::test; + +class TestHandler : public ThriftTestIf { + public: + TestHandler() {} + + void testVoid() { + printf("testVoid()\n"); + } + + void testString(string& out, const string &thing) { + printf("testString(\"%s\")\n", thing.c_str()); + out = thing; + } + + int8_t testByte(const int8_t thing) { + printf("testByte(%d)\n", (int)thing); + return thing; + } + + int32_t testI32(const int32_t thing) { + printf("testI32(%d)\n", thing); + return thing; + } + + int64_t testI64(const int64_t thing) { + printf("testI64(%"PRId64")\n", thing); + return thing; + } + + double testDouble(const double thing) { + printf("testDouble(%lf)\n", thing); + return thing; + } + + void testStruct(Xtruct& out, const Xtruct &thing) { + printf("testStruct({\"%s\", %d, %d, %"PRId64"})\n", thing.string_thing.c_str(), (int)thing.byte_thing, thing.i32_thing, thing.i64_thing); + out = thing; + } + + void testNest(Xtruct2& out, const Xtruct2& nest) { + const Xtruct &thing = nest.struct_thing; + printf("testNest({%d, {\"%s\", %d, %d, %"PRId64"}, %d})\n", (int)nest.byte_thing, thing.string_thing.c_str(), (int)thing.byte_thing, thing.i32_thing, thing.i64_thing, nest.i32_thing); + out = nest; + } + + void testMap(map &out, const map &thing) { + printf("testMap({"); + map::const_iterator m_iter; + bool first = true; + for (m_iter = thing.begin(); m_iter != thing.end(); ++m_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d => %d", m_iter->first, m_iter->second); + } + printf("})\n"); + out = thing; + } + + void testSet(set &out, const set &thing) { + printf("testSet({"); + set::const_iterator s_iter; + bool first = true; + for (s_iter = thing.begin(); s_iter != thing.end(); ++s_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *s_iter); + } + printf("})\n"); + out = thing; + } + + void testList(vector &out, const vector &thing) { + printf("testList({"); + vector::const_iterator l_iter; + bool first = true; + for (l_iter = thing.begin(); l_iter != thing.end(); ++l_iter) { + if (first) { + first = false; + } else { + printf(", "); + } + printf("%d", *l_iter); + } + printf("})\n"); + out = thing; + } + + Numberz testEnum(const Numberz thing) { + printf("testEnum(%d)\n", thing); + return thing; + } + + UserId testTypedef(const UserId thing) { + printf("testTypedef(%"PRId64")\n", thing); + return thing; + } + + void testMapMap(map > &mapmap, const int32_t hello) { + printf("testMapMap(%d)\n", hello); + + map pos; + map neg; + for (int i = 1; i < 5; i++) { + pos.insert(make_pair(i,i)); + neg.insert(make_pair(-i,-i)); + } + + mapmap.insert(make_pair(4, pos)); + mapmap.insert(make_pair(-4, neg)); + + } + + void testInsanity(map > &insane, const Insanity &argument) { + printf("testInsanity()\n"); + + Xtruct hello; + hello.string_thing = "Hello2"; + hello.byte_thing = 2; + hello.i32_thing = 2; + hello.i64_thing = 2; + + Xtruct goodbye; + goodbye.string_thing = "Goodbye4"; + goodbye.byte_thing = 4; + goodbye.i32_thing = 4; + goodbye.i64_thing = 4; + + Insanity crazy; + crazy.userMap.insert(make_pair(EIGHT, 8)); + crazy.xtructs.push_back(goodbye); + + Insanity looney; + crazy.userMap.insert(make_pair(FIVE, 5)); + crazy.xtructs.push_back(hello); + + map first_map; + map second_map; + + first_map.insert(make_pair(TWO, crazy)); + first_map.insert(make_pair(THREE, crazy)); + + second_map.insert(make_pair(SIX, looney)); + + insane.insert(make_pair(1, first_map)); + insane.insert(make_pair(2, second_map)); + + printf("return"); + printf(" = {"); + map >::const_iterator i_iter; + for (i_iter = insane.begin(); i_iter != insane.end(); ++i_iter) { + printf("%"PRId64" => {", i_iter->first); + map::const_iterator i2_iter; + for (i2_iter = i_iter->second.begin(); + i2_iter != i_iter->second.end(); + ++i2_iter) { + printf("%d => {", i2_iter->first); + map userMap = i2_iter->second.userMap; + map::const_iterator um; + printf("{"); + for (um = userMap.begin(); um != userMap.end(); ++um) { + printf("%d => %"PRId64", ", um->first, um->second); + } + printf("}, "); + + vector xtructs = i2_iter->second.xtructs; + vector::const_iterator x; + printf("{"); + for (x = xtructs.begin(); x != xtructs.end(); ++x) { + printf("{\"%s\", %d, %d, %"PRId64"}, ", x->string_thing.c_str(), (int)x->byte_thing, x->i32_thing, x->i64_thing); + } + printf("}"); + + printf("}, "); + } + printf("}, "); + } + printf("}\n"); + + + } + + void testMulti(Xtruct &hello, const int8_t arg0, const int32_t arg1, const int64_t arg2, const std::map &arg3, const Numberz arg4, const UserId arg5) { + printf("testMulti()\n"); + + hello.string_thing = "Hello2"; + hello.byte_thing = arg0; + hello.i32_thing = arg1; + hello.i64_thing = (int64_t)arg2; + } + + void testException(const std::string &arg) + throw(Xception, apache::thrift::TException) + { + printf("testException(%s)\n", arg.c_str()); + if (arg.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = arg; + throw e; + } else if (arg.compare("ApplicationException") == 0) { + apache::thrift::TException e; + throw e; + } else { + Xtruct result; + result.string_thing = arg; + return; + } + } + + void testMultiException(Xtruct &result, const std::string &arg0, const std::string &arg1) throw(Xception, Xception2) { + + printf("testMultiException(%s, %s)\n", arg0.c_str(), arg1.c_str()); + + if (arg0.compare("Xception") == 0) { + Xception e; + e.errorCode = 1001; + e.message = "This is an Xception"; + throw e; + } else if (arg0.compare("Xception2") == 0) { + Xception2 e; + e.errorCode = 2002; + e.struct_thing.string_thing = "This is an Xception2"; + throw e; + } else { + result.string_thing = arg1; + return; + } + } + + void testOneway(int sleepFor) { + printf("testOneway(%d): Sleeping...\n", sleepFor); + sleep(sleepFor); + printf("testOneway(%d): done sleeping!\n", sleepFor); + } +}; + +int main(int argc, char **argv) { + + int port = 9090; + string serverType = "simple"; + string protocolType = "binary"; + size_t workerCount = 4; + + ostringstream usage; + + usage << + argv[0] << " [--port=] [--server-type=] [--protocol-type=] [--workers=]" << endl << + + "\t\tserver-type\t\ttype of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\". Default is " << serverType << endl << + + "\t\tprotocol-type\t\ttype of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << + + "\t\tworkers\t\tNumber of thread pools workers. Only valid for thread-pool server type. Default is " << workerCount << endl; + + map args; + + for (int ix = 1; ix < argc; ix++) { + string arg(argv[ix]); + if (arg.compare(0,2, "--") == 0) { + size_t end = arg.find_first_of("=", 2); + if (end != string::npos) { + args[string(arg, 2, end - 2)] = string(arg, end + 1); + } else { + args[string(arg, 2)] = "true"; + } + } else { + throw invalid_argument("Unexcepted command line token: "+arg); + } + } + + try { + + if (!args["port"].empty()) { + port = atoi(args["port"].c_str()); + } + + if (!args["server-type"].empty()) { + serverType = args["server-type"]; + if (serverType == "simple") { + } else if (serverType == "thread-pool") { + } else if (serverType == "threaded") { + } else if (serverType == "nonblocking") { + } else { + throw invalid_argument("Unknown server type "+serverType); + } + } + + if (!args["protocol-type"].empty()) { + protocolType = args["protocol-type"]; + if (protocolType == "binary") { + } else if (protocolType == "ascii") { + throw invalid_argument("ASCII protocol not supported"); + } else if (protocolType == "xml") { + throw invalid_argument("XML protocol not supported"); + } else { + throw invalid_argument("Unknown protocol type "+protocolType); + } + } + + if (!args["workers"].empty()) { + workerCount = atoi(args["workers"].c_str()); + } + } catch (exception& e) { + cerr << e.what() << endl; + cerr << usage; + } + + // Dispatcher + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + shared_ptr testHandler(new TestHandler()); + + shared_ptr testProcessor(new ThriftTestProcessor(testHandler)); + + // Transport + shared_ptr serverSocket(new TServerSocket(port)); + + // Factory + shared_ptr transportFactory(new TBufferedTransportFactory()); + + if (serverType == "simple") { + + // Server + TSimpleServer simpleServer(testProcessor, + serverSocket, + transportFactory, + protocolFactory); + + printf("Starting the server on port %d...\n", port); + simpleServer.serve(); + + } else if (serverType == "thread-pool") { + + shared_ptr threadManager = + ThreadManager::newSimpleThreadManager(workerCount); + + shared_ptr threadFactory = + shared_ptr(new PosixThreadFactory()); + + threadManager->threadFactory(threadFactory); + + threadManager->start(); + + TThreadPoolServer threadPoolServer(testProcessor, + serverSocket, + transportFactory, + protocolFactory, + threadManager); + + printf("Starting the server on port %d...\n", port); + threadPoolServer.serve(); + + } else if (serverType == "threaded") { + + TThreadedServer threadedServer(testProcessor, + serverSocket, + transportFactory, + protocolFactory); + + printf("Starting the server on port %d...\n", port); + threadedServer.serve(); + + } else if (serverType == "nonblocking") { + TNonblockingServer nonblockingServer(testProcessor, port); + printf("Starting the nonblocking server on port %d...\n", port); + nonblockingServer.serve(); + } + + printf("done.\n"); + return 0; +} diff --git a/test/cpp/src/main.cpp b/test/cpp/src/main.cpp new file mode 100644 index 00000000..46ee950d --- /dev/null +++ b/test/cpp/src/main.cpp @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Service.h" + +#include +#include +#include +#include + +#include +#include +using __gnu_cxx::hash_map; +using __gnu_cxx::hash; + +using namespace std; +using namespace boost; + +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; +using namespace apache::thrift::concurrency; + +using namespace test::stress; + +struct eqstr { + bool operator()(const char* s1, const char* s2) const { + return strcmp(s1, s2) == 0; + } +}; + +struct ltstr { + bool operator()(const char* s1, const char* s2) const { + return strcmp(s1, s2) < 0; + } +}; + + +// typedef hash_map, eqstr> count_map; +typedef map count_map; + +class Server : public ServiceIf { + public: + Server() {} + + void count(const char* method) { + Guard m(lock_); + int ct = counts_[method]; + counts_[method] = ++ct; + } + + void echoVoid() { + count("echoVoid"); + return; + } + + count_map getCount() { + Guard m(lock_); + return counts_; + } + + int8_t echoByte(const int8_t arg) {return arg;} + int32_t echoI32(const int32_t arg) {return arg;} + int64_t echoI64(const int64_t arg) {return arg;} + void echoString(string& out, const string &arg) { + if (arg != "hello") { + T_ERROR_ABORT("WRONG STRING!!!!"); + } + out = arg; + } + void echoList(vector &out, const vector &arg) { out = arg; } + void echoSet(set &out, const set &arg) { out = arg; } + void echoMap(map &out, const map &arg) { out = arg; } + +private: + count_map counts_; + Mutex lock_; + +}; + +class ClientThread: public Runnable { +public: + + ClientThread(shared_ptrtransport, shared_ptr client, Monitor& monitor, size_t& workerCount, size_t loopCount, TType loopType) : + _transport(transport), + _client(client), + _monitor(monitor), + _workerCount(workerCount), + _loopCount(loopCount), + _loopType(loopType) + {} + + void run() { + + // Wait for all worker threads to start + + {Synchronized s(_monitor); + while(_workerCount == 0) { + _monitor.wait(); + } + } + + _startTime = Util::currentTime(); + + _transport->open(); + + switch(_loopType) { + case T_VOID: loopEchoVoid(); break; + case T_BYTE: loopEchoByte(); break; + case T_I32: loopEchoI32(); break; + case T_I64: loopEchoI64(); break; + case T_STRING: loopEchoString(); break; + default: cerr << "Unexpected loop type" << _loopType << endl; break; + } + + _endTime = Util::currentTime(); + + _transport->close(); + + _done = true; + + {Synchronized s(_monitor); + + _workerCount--; + + if (_workerCount == 0) { + + _monitor.notify(); + } + } + } + + void loopEchoVoid() { + for (size_t ix = 0; ix < _loopCount; ix++) { + _client->echoVoid(); + } + } + + void loopEchoByte() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int8_t arg = 1; + int8_t result; + result =_client->echoByte(arg); + assert(result == arg); + } + } + + void loopEchoI32() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int32_t arg = 1; + int32_t result; + result =_client->echoI32(arg); + assert(result == arg); + } + } + + void loopEchoI64() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int64_t arg = 1; + int64_t result; + result =_client->echoI64(arg); + assert(result == arg); + } + } + + void loopEchoString() { + for (size_t ix = 0; ix < _loopCount; ix++) { + string arg = "hello"; + string result; + _client->echoString(result, arg); + assert(result == arg); + } + } + + shared_ptr _transport; + shared_ptr _client; + Monitor& _monitor; + size_t& _workerCount; + size_t _loopCount; + TType _loopType; + long long _startTime; + long long _endTime; + bool _done; + Monitor _sleep; +}; + + +int main(int argc, char **argv) { + + int port = 9091; + string serverType = "thread-pool"; + string protocolType = "binary"; + size_t workerCount = 4; + size_t clientCount = 20; + size_t loopCount = 50000; + TType loopType = T_VOID; + string callName = "echoVoid"; + bool runServer = true; + bool logRequests = false; + string requestLogPath = "./requestlog.tlog"; + bool replayRequests = false; + + ostringstream usage; + + usage << + argv[0] << " [--port=] [--server] [--server-type=] [--protocol-type=] [--workers=] [--clients=] [--loop=]" << endl << + "\tclients Number of client threads to create - 0 implies no clients, i.e. server only. Default is " << clientCount << endl << + "\thelp Prints this help text." << endl << + "\tcall Service method to call. Default is " << callName << endl << + "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl << + "\tport The port the server and clients should bind to for thrift network connections. Default is " << port << endl << + "\tserver Run the Thrift server in this process. Default is " << runServer << endl << + "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl << + "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << + "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl << + "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl << + "\tworkers Number of thread pools workers. Only valid for thread-pool server type. Default is " << workerCount << endl; + + + map args; + + for (int ix = 1; ix < argc; ix++) { + + string arg(argv[ix]); + + if (arg.compare(0,2, "--") == 0) { + + size_t end = arg.find_first_of("=", 2); + + string key = string(arg, 2, end - 2); + + if (end != string::npos) { + args[key] = string(arg, end + 1); + } else { + args[key] = "true"; + } + } else { + throw invalid_argument("Unexcepted command line token: "+arg); + } + } + + try { + + if (!args["clients"].empty()) { + clientCount = atoi(args["clients"].c_str()); + } + + if (!args["help"].empty()) { + cerr << usage.str(); + return 0; + } + + if (!args["loop"].empty()) { + loopCount = atoi(args["loop"].c_str()); + } + + if (!args["call"].empty()) { + callName = args["call"]; + } + + if (!args["port"].empty()) { + port = atoi(args["port"].c_str()); + } + + if (!args["server"].empty()) { + runServer = args["server"] == "true"; + } + + if (!args["log-request"].empty()) { + logRequests = args["log-request"] == "true"; + } + + if (!args["replay-request"].empty()) { + replayRequests = args["replay-request"] == "true"; + } + + if (!args["server-type"].empty()) { + serverType = args["server-type"]; + + if (serverType == "simple") { + + } else if (serverType == "thread-pool") { + + } else if (serverType == "threaded") { + + } else { + + throw invalid_argument("Unknown server type "+serverType); + } + } + + if (!args["workers"].empty()) { + workerCount = atoi(args["workers"].c_str()); + } + + } catch(exception& e) { + cerr << e.what() << endl; + cerr << usage; + } + + shared_ptr threadFactory = shared_ptr(new PosixThreadFactory()); + + // Dispatcher + shared_ptr serviceHandler(new Server()); + + if (replayRequests) { + shared_ptr serviceHandler(new Server()); + shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); + + // Transports + shared_ptr fileTransport(new TFileTransport(requestLogPath)); + fileTransport->setChunkSize(2 * 1024 * 1024); + fileTransport->setMaxEventSize(1024 * 16); + fileTransport->seekToEnd(); + + // Protocol Factory + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + TFileProcessor fileProcessor(serviceProcessor, + protocolFactory, + fileTransport); + + fileProcessor.process(0, true); + exit(0); + } + + + if (runServer) { + + shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); + + // Transport + shared_ptr serverSocket(new TServerSocket(port)); + + // Transport Factory + shared_ptr transportFactory(new TBufferedTransportFactory()); + + // Protocol Factory + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + if (logRequests) { + // initialize the log file + shared_ptr fileTransport(new TFileTransport(requestLogPath)); + fileTransport->setChunkSize(2 * 1024 * 1024); + fileTransport->setMaxEventSize(1024 * 16); + + transportFactory = + shared_ptr(new TPipedTransportFactory(fileTransport)); + } + + shared_ptr serverThread; + + if (serverType == "simple") { + + serverThread = threadFactory->newThread(shared_ptr(new TSimpleServer(serviceProcessor, serverSocket, transportFactory, protocolFactory))); + + } else if (serverType == "threaded") { + + serverThread = threadFactory->newThread(shared_ptr(new TThreadedServer(serviceProcessor, serverSocket, transportFactory, protocolFactory))); + + } else if (serverType == "thread-pool") { + + shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount); + + threadManager->threadFactory(threadFactory); + threadManager->start(); + serverThread = threadFactory->newThread(shared_ptr(new TThreadPoolServer(serviceProcessor, serverSocket, transportFactory, protocolFactory, threadManager))); + } + + cerr << "Starting the server on port " << port << endl; + + serverThread->start(); + + // If we aren't running clients, just wait forever for external clients + + if (clientCount == 0) { + serverThread->join(); + } + } + + if (clientCount > 0) { + + Monitor monitor; + + size_t threadCount = 0; + + set > clientThreads; + + if (callName == "echoVoid") { loopType = T_VOID;} + else if (callName == "echoByte") { loopType = T_BYTE;} + else if (callName == "echoI32") { loopType = T_I32;} + else if (callName == "echoI64") { loopType = T_I64;} + else if (callName == "echoString") { loopType = T_STRING;} + else {throw invalid_argument("Unknown service call "+callName);} + + for (size_t ix = 0; ix < clientCount; ix++) { + + shared_ptr socket(new TSocket("127.0.01", port)); + shared_ptr bufferedSocket(new TBufferedTransport(socket, 2048)); + shared_ptr protocol(new TBinaryProtocol(bufferedSocket)); + shared_ptr serviceClient(new ServiceClient(protocol)); + + clientThreads.insert(threadFactory->newThread(shared_ptr(new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType)))); + } + + for (std::set >::const_iterator thread = clientThreads.begin(); thread != clientThreads.end(); thread++) { + (*thread)->start(); + } + + long long time00; + long long time01; + + {Synchronized s(monitor); + threadCount = clientCount; + + cerr << "Launch "<< clientCount << " client threads" << endl; + + time00 = Util::currentTime(); + + monitor.notifyAll(); + + while(threadCount > 0) { + monitor.wait(); + } + + time01 = Util::currentTime(); + } + + long long firstTime = 9223372036854775807LL; + long long lastTime = 0; + + double averageTime = 0; + long long minTime = 9223372036854775807LL; + long long maxTime = 0; + + for (set >::iterator ix = clientThreads.begin(); ix != clientThreads.end(); ix++) { + + shared_ptr client = dynamic_pointer_cast((*ix)->runnable()); + + long long delta = client->_endTime - client->_startTime; + + assert(delta > 0); + + if (client->_startTime < firstTime) { + firstTime = client->_startTime; + } + + if (client->_endTime > lastTime) { + lastTime = client->_endTime; + } + + if (delta < minTime) { + minTime = delta; + } + + if (delta > maxTime) { + maxTime = delta; + } + + averageTime+= delta; + } + + averageTime /= clientCount; + + + cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl; + + count_map count = serviceHandler->getCount(); + count_map::iterator iter; + for (iter = count.begin(); iter != count.end(); ++iter) { + printf("%s => %d\n", iter->first, iter->second); + } + cerr << "done." << endl; + } + + return 0; +} diff --git a/test/cpp/src/nb-main.cpp b/test/cpp/src/nb-main.cpp new file mode 100644 index 00000000..8c74a815 --- /dev/null +++ b/test/cpp/src/nb-main.cpp @@ -0,0 +1,502 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Service.h" + +#include +#include + +#include +#include +#include +#include + +#include +#include +using __gnu_cxx::hash_map; +using __gnu_cxx::hash; + +using namespace std; +using namespace boost; + +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; +using namespace apache::thrift::concurrency; + +using namespace test::stress; + +struct eqstr { + bool operator()(const char* s1, const char* s2) const { + return strcmp(s1, s2) == 0; + } +}; + +struct ltstr { + bool operator()(const char* s1, const char* s2) const { + return strcmp(s1, s2) < 0; + } +}; + + +// typedef hash_map, eqstr> count_map; +typedef map count_map; + +class Server : public ServiceIf { + public: + Server() {} + + void count(const char* method) { + Guard m(lock_); + int ct = counts_[method]; + counts_[method] = ++ct; + } + + void echoVoid() { + count("echoVoid"); + // Sleep to simulate work + usleep(5000); + return; + } + + count_map getCount() { + Guard m(lock_); + return counts_; + } + + int8_t echoByte(const int8_t arg) {return arg;} + int32_t echoI32(const int32_t arg) {return arg;} + int64_t echoI64(const int64_t arg) {return arg;} + void echoString(string& out, const string &arg) { + if (arg != "hello") { + T_ERROR_ABORT("WRONG STRING!!!!"); + } + out = arg; + } + void echoList(vector &out, const vector &arg) { out = arg; } + void echoSet(set &out, const set &arg) { out = arg; } + void echoMap(map &out, const map &arg) { out = arg; } + +private: + count_map counts_; + Mutex lock_; + +}; + +class ClientThread: public Runnable { +public: + + ClientThread(shared_ptrtransport, shared_ptr client, Monitor& monitor, size_t& workerCount, size_t loopCount, TType loopType) : + _transport(transport), + _client(client), + _monitor(monitor), + _workerCount(workerCount), + _loopCount(loopCount), + _loopType(loopType) + {} + + void run() { + + // Wait for all worker threads to start + + {Synchronized s(_monitor); + while(_workerCount == 0) { + _monitor.wait(); + } + } + + _startTime = Util::currentTime(); + + _transport->open(); + + switch(_loopType) { + case T_VOID: loopEchoVoid(); break; + case T_BYTE: loopEchoByte(); break; + case T_I32: loopEchoI32(); break; + case T_I64: loopEchoI64(); break; + case T_STRING: loopEchoString(); break; + default: cerr << "Unexpected loop type" << _loopType << endl; break; + } + + _endTime = Util::currentTime(); + + _transport->close(); + + _done = true; + + {Synchronized s(_monitor); + + _workerCount--; + + if (_workerCount == 0) { + + _monitor.notify(); + } + } + } + + void loopEchoVoid() { + for (size_t ix = 0; ix < _loopCount; ix++) { + _client->echoVoid(); + } + } + + void loopEchoByte() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int8_t arg = 1; + int8_t result; + result =_client->echoByte(arg); + assert(result == arg); + } + } + + void loopEchoI32() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int32_t arg = 1; + int32_t result; + result =_client->echoI32(arg); + assert(result == arg); + } + } + + void loopEchoI64() { + for (size_t ix = 0; ix < _loopCount; ix++) { + int64_t arg = 1; + int64_t result; + result =_client->echoI64(arg); + assert(result == arg); + } + } + + void loopEchoString() { + for (size_t ix = 0; ix < _loopCount; ix++) { + string arg = "hello"; + string result; + _client->echoString(result, arg); + assert(result == arg); + } + } + + shared_ptr _transport; + shared_ptr _client; + Monitor& _monitor; + size_t& _workerCount; + size_t _loopCount; + TType _loopType; + long long _startTime; + long long _endTime; + bool _done; + Monitor _sleep; +}; + + +int main(int argc, char **argv) { + + int port = 9091; + string serverType = "simple"; + string protocolType = "binary"; + size_t workerCount = 4; + size_t clientCount = 20; + size_t loopCount = 50000; + TType loopType = T_VOID; + string callName = "echoVoid"; + bool runServer = true; + bool logRequests = false; + string requestLogPath = "./requestlog.tlog"; + bool replayRequests = false; + + ostringstream usage; + + usage << + argv[0] << " [--port=] [--server] [--server-type=] [--protocol-type=] [--workers=] [--clients=] [--loop=]" << endl << + "\tclients Number of client threads to create - 0 implies no clients, i.e. server only. Default is " << clientCount << endl << + "\thelp Prints this help text." << endl << + "\tcall Service method to call. Default is " << callName << endl << + "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl << + "\tport The port the server and clients should bind to for thrift network connections. Default is " << port << endl << + "\tserver Run the Thrift server in this process. Default is " << runServer << endl << + "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl << + "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << + "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl << + "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl << + "\tworkers Number of thread pools workers. Only valid for thread-pool server type. Default is " << workerCount << endl; + + + map args; + + for (int ix = 1; ix < argc; ix++) { + + string arg(argv[ix]); + + if (arg.compare(0,2, "--") == 0) { + + size_t end = arg.find_first_of("=", 2); + + string key = string(arg, 2, end - 2); + + if (end != string::npos) { + args[key] = string(arg, end + 1); + } else { + args[key] = "true"; + } + } else { + throw invalid_argument("Unexcepted command line token: "+arg); + } + } + + try { + + if (!args["clients"].empty()) { + clientCount = atoi(args["clients"].c_str()); + } + + if (!args["help"].empty()) { + cerr << usage.str(); + return 0; + } + + if (!args["loop"].empty()) { + loopCount = atoi(args["loop"].c_str()); + } + + if (!args["call"].empty()) { + callName = args["call"]; + } + + if (!args["port"].empty()) { + port = atoi(args["port"].c_str()); + } + + if (!args["server"].empty()) { + runServer = args["server"] == "true"; + } + + if (!args["log-request"].empty()) { + logRequests = args["log-request"] == "true"; + } + + if (!args["replay-request"].empty()) { + replayRequests = args["replay-request"] == "true"; + } + + if (!args["server-type"].empty()) { + serverType = args["server-type"]; + } + + if (!args["workers"].empty()) { + workerCount = atoi(args["workers"].c_str()); + } + + } catch(exception& e) { + cerr << e.what() << endl; + cerr << usage; + } + + shared_ptr threadFactory = shared_ptr(new PosixThreadFactory()); + + // Dispatcher + shared_ptr serviceHandler(new Server()); + + if (replayRequests) { + shared_ptr serviceHandler(new Server()); + shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); + + // Transports + shared_ptr fileTransport(new TFileTransport(requestLogPath)); + fileTransport->setChunkSize(2 * 1024 * 1024); + fileTransport->setMaxEventSize(1024 * 16); + fileTransport->seekToEnd(); + + // Protocol Factory + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + TFileProcessor fileProcessor(serviceProcessor, + protocolFactory, + fileTransport); + + fileProcessor.process(0, true); + exit(0); + } + + + if (runServer) { + + shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); + + // Protocol Factory + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + // Transport Factory + shared_ptr transportFactory; + + if (logRequests) { + // initialize the log file + shared_ptr fileTransport(new TFileTransport(requestLogPath)); + fileTransport->setChunkSize(2 * 1024 * 1024); + fileTransport->setMaxEventSize(1024 * 16); + + transportFactory = + shared_ptr(new TPipedTransportFactory(fileTransport)); + } + + shared_ptr serverThread; + shared_ptr serverThread2; + + if (serverType == "simple") { + + serverThread = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port))); + serverThread2 = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port+1))); + + } else if (serverType == "thread-pool") { + + shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount); + + threadManager->threadFactory(threadFactory); + threadManager->start(); + serverThread = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port, threadManager))); + serverThread2 = threadFactory->newThread(shared_ptr(new TNonblockingServer(serviceProcessor, protocolFactory, port+1, threadManager))); + } + + cerr << "Starting the server on port " << port << " and " << (port + 1) << endl; + serverThread->start(); + serverThread2->start(); + + // If we aren't running clients, just wait forever for external clients + + if (clientCount == 0) { + serverThread->join(); + serverThread2->join(); + } + } + sleep(1); + + if (clientCount > 0) { + + Monitor monitor; + + size_t threadCount = 0; + + set > clientThreads; + + if (callName == "echoVoid") { loopType = T_VOID;} + else if (callName == "echoByte") { loopType = T_BYTE;} + else if (callName == "echoI32") { loopType = T_I32;} + else if (callName == "echoI64") { loopType = T_I64;} + else if (callName == "echoString") { loopType = T_STRING;} + else {throw invalid_argument("Unknown service call "+callName);} + + for (size_t ix = 0; ix < clientCount; ix++) { + + shared_ptr socket(new TSocket("127.0.0.1", port + (ix % 2))); + shared_ptr framedSocket(new TFramedTransport(socket)); + shared_ptr protocol(new TBinaryProtocol(framedSocket)); + shared_ptr serviceClient(new ServiceClient(protocol)); + + clientThreads.insert(threadFactory->newThread(shared_ptr(new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType)))); + } + + for (std::set >::const_iterator thread = clientThreads.begin(); thread != clientThreads.end(); thread++) { + (*thread)->start(); + } + + long long time00; + long long time01; + + {Synchronized s(monitor); + threadCount = clientCount; + + cerr << "Launch "<< clientCount << " client threads" << endl; + + time00 = Util::currentTime(); + + monitor.notifyAll(); + + while(threadCount > 0) { + monitor.wait(); + } + + time01 = Util::currentTime(); + } + + long long firstTime = 9223372036854775807LL; + long long lastTime = 0; + + double averageTime = 0; + long long minTime = 9223372036854775807LL; + long long maxTime = 0; + + for (set >::iterator ix = clientThreads.begin(); ix != clientThreads.end(); ix++) { + + shared_ptr client = dynamic_pointer_cast((*ix)->runnable()); + + long long delta = client->_endTime - client->_startTime; + + assert(delta > 0); + + if (client->_startTime < firstTime) { + firstTime = client->_startTime; + } + + if (client->_endTime > lastTime) { + lastTime = client->_endTime; + } + + if (delta < minTime) { + minTime = delta; + } + + if (delta > maxTime) { + maxTime = delta; + } + + averageTime+= delta; + } + + averageTime /= clientCount; + + + cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl; + + count_map count = serviceHandler->getCount(); + count_map::iterator iter; + for (iter = count.begin(); iter != count.end(); ++iter) { + printf("%s => %d\n", iter->first, iter->second); + } + cerr << "done." << endl; + } + + return 0; +} diff --git a/test/csharp/CSharpClient.cs b/test/csharp/CSharpClient.cs new file mode 100644 index 00000000..641d5c9b --- /dev/null +++ b/test/csharp/CSharpClient.cs @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using Thrift; +using Thrift.Protocol; +using Thrift.Server; +using Thrift.Transport; + + +namespace CSharpTutorial +{ + public class CSharpClient + { + public static void Main() + { + try + { + TTransport transport = new TSocket("localhost", 9090); + TProtocol protocol = new TBinaryProtocol(transport); + Calculator.Client client = new Calculator.Client(protocol); + + transport.Open(); + + client.ping(); + Console.WriteLine("ping()"); + + int sum = client.add(1, 1); + Console.WriteLine("1+1={0}", sum); + + Work work = new Work(); + + work.op = Operation.DIVIDE; + work.num1 = 1; + work.num2 = 0; + try + { + int quotient = client.calculate(1, work); + Console.WriteLine("Whoa we can divide by 0"); + } + catch (InvalidOperation io) + { + Console.WriteLine("Invalid operation: " + io.why); + } + + work.op = Operation.SUBTRACT; + work.num1 = 15; + work.num2 = 10; + try + { + int diff = client.calculate(1, work); + Console.WriteLine("15-10={0}", diff); + } + catch (InvalidOperation io) + { + Console.WriteLine("Invalid operation: " + io.why); + } + + SharedStruct log = client.getStruct(1); + Console.WriteLine("Check log: {0}", log.value); + + transport.Close(); + } + catch (TApplicationException x) + { + Console.WriteLine(x.StackTrace); + } + + } + } +} diff --git a/test/csharp/CSharpServer.cs b/test/csharp/CSharpServer.cs new file mode 100644 index 00000000..f9ab8fd2 --- /dev/null +++ b/test/csharp/CSharpServer.cs @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using Thrift.Server; +using Thrift.Transport; + +namespace CSharpTutorial +{ + public class CalculatorHandler : Calculator.Iface + { + Dictionary log; + + public CalculatorHandler() + { + log = new Dictionary(); + } + + public void ping() + { + Console.WriteLine("ping()"); + } + + public int add(int n1, int n2) + { + Console.WriteLine("add({0},{1})", n1, n2); + return n1 + n2; + } + + public int calculate(int logid, Work work) + { + Console.WriteLine("calculate({0}, [{1},{2},{3}])", logid, work.op, work.num1, work.num2); + int val = 0; + switch (work.op) + { + case Operation.ADD: + val = work.num1 + work.num2; + break; + + case Operation.SUBTRACT: + val = work.num1 - work.num2; + break; + + case Operation.MULTIPLY: + val = work.num1 * work.num2; + break; + + case Operation.DIVIDE: + if (work.num2 == 0) + { + InvalidOperation io = new InvalidOperation(); + io.what = (int)work.op; + io.why = "Cannot divide by 0"; + throw io; + } + val = work.num1 / work.num2; + break; + + default: + { + InvalidOperation io = new InvalidOperation(); + io.what = (int)work.op; + io.why = "Unknown operation"; + throw io; + } + } + + SharedStruct entry = new SharedStruct(); + entry.key = logid; + entry.value = val.ToString(); + log[logid] = entry; + + return val; + } + + public SharedStruct getStruct(int key) + { + Console.WriteLine("getStruct({0})", key); + return log[key]; + } + + public void zip() + { + Console.WriteLine("zip()"); + } + } + + public class CSharpServer + { + public static void Main() + { + try + { + CalculatorHandler handler = new CalculatorHandler(); + Calculator.Processor processor = new Calculator.Processor(handler); + TServerTransport serverTransport = new TServerSocket(9090); + TServer server = new TSimpleServer(processor, serverTransport); + + // Use this for a multithreaded server + // server = new TThreadPoolServer(processor, serverTransport); + + Console.WriteLine("Starting the server..."); + server.Serve(); + } + catch (Exception x) + { + Console.WriteLine(x.StackTrace); + } + Console.WriteLine("done."); + } + } +} diff --git a/test/csharp/ThriftTest/Program.cs b/test/csharp/ThriftTest/Program.cs new file mode 100644 index 00000000..4c63ca47 --- /dev/null +++ b/test/csharp/ThriftTest/Program.cs @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ + +using System; +using Thrift.Transport; +using Thrift.Protocol; +using Thrift.Test; //generated code + +namespace Test +{ + class Program + { + static void Main(string[] args) + { + if (args.Length == 0) + { + Console.WriteLine("must provide 'server' or 'client' arg"); + return; + } + + string[] subArgs = new string[args.Length - 1]; + for(int i = 1; i < args.Length; i++) + { + subArgs[i-1] = args[i]; + } + if (args[0] == "client") + { + TestClient.Execute(subArgs); + } + else if (args[0] == "server") + { + TestServer.Execute(subArgs); + } + else + { + Console.WriteLine("first argument must be 'server' or 'client'"); + } + } + } +} diff --git a/test/csharp/ThriftTest/Properties/AssemblyInfo.cs b/test/csharp/ThriftTest/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..504ca8de --- /dev/null +++ b/test/csharp/ThriftTest/Properties/AssemblyInfo.cs @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +// General Information about an assembly is controlled through the following +// set of attributes. Change these attribute values to modify the information +// associated with an assembly. +[assembly: AssemblyTitle("ThriftTest")] +[assembly: AssemblyDescription("")] +[assembly: AssemblyConfiguration("")] +[assembly: AssemblyCompany("")] +[assembly: AssemblyProduct("ThriftTest")] +[assembly: AssemblyCopyright("Copyright © 2009 The Apache Software Foundation")] +[assembly: AssemblyTrademark("")] +[assembly: AssemblyCulture("")] + +// Setting ComVisible to false makes the types in this assembly not visible +// to COM components. If you need to access a type in this assembly from +// COM, set the ComVisible attribute to true on that type. +[assembly: ComVisible(false)] + +// The following GUID is for the ID of the typelib if this project is exposed to COM +[assembly: Guid("f41b193b-f1ab-48ee-8843-f88e43084e26")] + +// Version information for an assembly consists of the following four values: +// +// Major Version +// Minor Version +// Build Number +// Revision +// +// You can specify all the values or you can default the Build and Revision Numbers +// by using the '*' as shown below: +// [assembly: AssemblyVersion("1.0.*")] +[assembly: AssemblyVersion("1.0.0.0")] +[assembly: AssemblyFileVersion("1.0.0.0")] diff --git a/test/csharp/ThriftTest/TestClient.cs b/test/csharp/ThriftTest/TestClient.cs new file mode 100644 index 00000000..2d278b7f --- /dev/null +++ b/test/csharp/ThriftTest/TestClient.cs @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Threading; +using Thrift.Collections; +using Thrift.Protocol; +using Thrift.Transport; +using Thrift.Test; + +namespace Test +{ + public class TestClient + { + private static int numIterations = 1; + + public static void Execute(string[] args) + { + try + { + string host = "localhost"; + int port = 9090; + string url = null; + int numThreads = 1; + bool buffered = false; + + try + { + for (int i = 0; i < args.Length; i++) + { + if (args[i] == "-h") + { + string[] hostport = args[++i].Split(':'); + host = hostport[0]; + if (hostport.Length > 1) + { + port = Convert.ToInt32(hostport[1]); + } + } + else if (args[i] == "-u") + { + url = args[++i]; + } + else if (args[i] == "-n") + { + numIterations = Convert.ToInt32(args[++i]); + } + else if (args[i] == "-b" || args[i] == "-buffered") + { + buffered = true; + Console.WriteLine("Using buffered sockets"); + } + else if (args[i] == "-t") + { + numThreads = Convert.ToInt32(args[++i]); + } + } + } + catch (Exception e) + { + Console.WriteLine(e.StackTrace); + } + + + + //issue tests on separate threads simultaneously + Thread[] threads = new Thread[numThreads]; + DateTime start = DateTime.Now; + for (int test = 0; test < numThreads; test++) + { + Thread t = new Thread(new ParameterizedThreadStart(ClientThread)); + threads[test] = t; + TSocket socket = new TSocket(host, port); + if (buffered) + { + TBufferedTransport buffer = new TBufferedTransport(socket); + t.Start(buffer); + } + else + { + t.Start(socket); + } + } + + for (int test = 0; test < numThreads; test++) + { + threads[test].Join(); + } + Console.Write("Total time: " + (DateTime.Now - start)); + } + catch (Exception outerEx) + { + Console.WriteLine(outerEx.Message + " ST: " + outerEx.StackTrace); + } + + Console.WriteLine(); + Console.WriteLine(); + } + + public static void ClientThread(object obj) + { + TTransport transport = (TTransport)obj; + for (int i = 0; i < numIterations; i++) + { + ClientTest(transport); + } + transport.Close(); + } + + public static void ClientTest(TTransport transport) + { + TBinaryProtocol binaryProtocol = new TBinaryProtocol(transport); + + ThriftTest.Client client = new ThriftTest.Client(binaryProtocol); + try + { + if (!transport.IsOpen) + { + transport.Open(); + } + } + catch (TTransportException ttx) + { + Console.WriteLine("Connect failed: " + ttx.Message); + return; + } + + long start = DateTime.Now.ToFileTime(); + + Console.Write("testVoid()"); + client.testVoid(); + Console.WriteLine(" = void"); + + Console.Write("testString(\"Test\")"); + string s = client.testString("Test"); + Console.WriteLine(" = \"" + s + "\""); + + Console.Write("testByte(1)"); + byte i8 = client.testByte((byte)1); + Console.WriteLine(" = " + i8); + + Console.Write("testI32(-1)"); + int i32 = client.testI32(-1); + Console.WriteLine(" = " + i32); + + Console.Write("testI64(-34359738368)"); + long i64 = client.testI64(-34359738368); + Console.WriteLine(" = " + i64); + + Console.Write("testDouble(5.325098235)"); + double dub = client.testDouble(5.325098235); + Console.WriteLine(" = " + dub); + + Console.Write("testStruct({\"Zero\", 1, -3, -5})"); + Xtruct o = new Xtruct(); + o.String_thing = "Zero"; + o.Byte_thing = (byte)1; + o.I32_thing = -3; + o.I64_thing = -5; + Xtruct i = client.testStruct(o); + Console.WriteLine(" = {\"" + i.String_thing + "\", " + i.Byte_thing + ", " + i.I32_thing + ", " + i.I64_thing + "}"); + + Console.Write("testNest({1, {\"Zero\", 1, -3, -5}, 5})"); + Xtruct2 o2 = new Xtruct2(); + o2.Byte_thing = (byte)1; + o2.Struct_thing = o; + o2.I32_thing = 5; + Xtruct2 i2 = client.testNest(o2); + i = i2.Struct_thing; + Console.WriteLine(" = {" + i2.Byte_thing + ", {\"" + i.String_thing + "\", " + i.Byte_thing + ", " + i.I32_thing + ", " + i.I64_thing + "}, " + i2.I32_thing + "}"); + + Dictionary mapout = new Dictionary(); + for (int j = 0; j < 5; j++) + { + mapout[j] = j - 10; + } + Console.Write("testMap({"); + bool first = true; + foreach (int key in mapout.Keys) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(key + " => " + mapout[key]); + } + Console.Write("})"); + + Dictionary mapin = client.testMap(mapout); + + Console.Write(" = {"); + first = true; + foreach (int key in mapin.Keys) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(key + " => " + mapin[key]); + } + Console.WriteLine("}"); + + List listout = new List(); + for (int j = -2; j < 3; j++) + { + listout.Add(j); + } + Console.Write("testList({"); + first = true; + foreach (int j in listout) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(j); + } + Console.Write("})"); + + List listin = client.testList(listout); + + Console.Write(" = {"); + first = true; + foreach (int j in listin) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(j); + } + Console.WriteLine("}"); + + //set + THashSet setout = new THashSet(); + for (int j = -2; j < 3; j++) + { + setout.Add(j); + } + Console.Write("testSet({"); + first = true; + foreach (int j in setout) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(j); + } + Console.Write("})"); + + THashSet setin = client.testSet(setout); + + Console.Write(" = {"); + first = true; + foreach (int j in setin) + { + if (first) + { + first = false; + } + else + { + Console.Write(", "); + } + Console.Write(j); + } + Console.WriteLine("}"); + + + Console.Write("testEnum(ONE)"); + Numberz ret = client.testEnum(Numberz.ONE); + Console.WriteLine(" = " + ret); + + Console.Write("testEnum(TWO)"); + ret = client.testEnum(Numberz.TWO); + Console.WriteLine(" = " + ret); + + Console.Write("testEnum(THREE)"); + ret = client.testEnum(Numberz.THREE); + Console.WriteLine(" = " + ret); + + Console.Write("testEnum(FIVE)"); + ret = client.testEnum(Numberz.FIVE); + Console.WriteLine(" = " + ret); + + Console.Write("testEnum(EIGHT)"); + ret = client.testEnum(Numberz.EIGHT); + Console.WriteLine(" = " + ret); + + Console.Write("testTypedef(309858235082523)"); + long uid = client.testTypedef(309858235082523L); + Console.WriteLine(" = " + uid); + + Console.Write("testMapMap(1)"); + Dictionary> mm = client.testMapMap(1); + Console.Write(" = {"); + foreach (int key in mm.Keys) + { + Console.Write(key + " => {"); + Dictionary m2 = mm[key]; + foreach (int k2 in m2.Keys) + { + Console.Write(k2 + " => " + m2[k2] + ", "); + } + Console.Write("}, "); + } + Console.WriteLine("}"); + + Insanity insane = new Insanity(); + insane.UserMap = new Dictionary(); + insane.UserMap[Numberz.FIVE] = 5000L; + Xtruct truck = new Xtruct(); + truck.String_thing = "Truck"; + truck.Byte_thing = (byte)8; + truck.I32_thing = 8; + truck.I64_thing = 8; + insane.Xtructs = new List(); + insane.Xtructs.Add(truck); + Console.Write("testInsanity()"); + Dictionary> whoa = client.testInsanity(insane); + Console.Write(" = {"); + foreach (long key in whoa.Keys) + { + Dictionary val = whoa[key]; + Console.Write(key + " => {"); + + foreach (Numberz k2 in val.Keys) + { + Insanity v2 = val[k2]; + + Console.Write(k2 + " => {"); + Dictionary userMap = v2.UserMap; + + Console.Write("{"); + if (userMap != null) + { + foreach (Numberz k3 in userMap.Keys) + { + Console.Write(k3 + " => " + userMap[k3] + ", "); + } + } + else + { + Console.Write("null"); + } + Console.Write("}, "); + + List xtructs = v2.Xtructs; + + Console.Write("{"); + if (xtructs != null) + { + foreach (Xtruct x in xtructs) + { + Console.Write("{\"" + x.String_thing + "\", " + x.Byte_thing + ", " + x.I32_thing + ", " + x.I32_thing + "}, "); + } + } + else + { + Console.Write("null"); + } + Console.Write("}"); + + Console.Write("}, "); + } + Console.Write("}, "); + } + Console.WriteLine("}"); + + + byte arg0 = 1; + int arg1 = 2; + long arg2 = long.MaxValue; + Dictionary multiDict = new Dictionary(); + multiDict[1] = "one"; + Numberz arg4 = Numberz.FIVE; + long arg5 = 5000000; + Console.Write("Test Multi(" + arg0 + "," + arg1 + "," + arg2 + "," + multiDict + "," + arg4 + "," + arg5 + ")"); + Xtruct multiResponse = client.testMulti(arg0, arg1, arg2, multiDict, arg4, arg5); + Console.Write(" = Xtruct(byte_thing:" + multiResponse.Byte_thing + ",String_thing:" + multiResponse.String_thing + + ",i32_thing:" + multiResponse.I32_thing + ",i64_thing:" + multiResponse.I64_thing + ")\n"); + + Console.WriteLine("Test Oneway(1)"); + client.testOneway(1); + } + } +} diff --git a/test/csharp/ThriftTest/TestServer.cs b/test/csharp/ThriftTest/TestServer.cs new file mode 100644 index 00000000..e3706404 --- /dev/null +++ b/test/csharp/ThriftTest/TestServer.cs @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ +using System; +using System.Collections.Generic; +using Thrift.Collections; +using Thrift.Test; //generated code +using Thrift.Transport; +using Thrift.Protocol; +using Thrift.Server; + +namespace Test +{ + public class TestServer + { + public class TestHandler : ThriftTest.Iface + { + public TServer server; + + public TestHandler() { } + + public void testVoid() + { + Console.WriteLine("testVoid()"); + } + + public string testString(string thing) + { + Console.WriteLine("teststring(\"" + thing + "\")"); + return thing; + } + + public byte testByte(byte thing) + { + Console.WriteLine("testByte(" + thing + ")"); + return thing; + } + + public int testI32(int thing) + { + Console.WriteLine("testI32(" + thing + ")"); + return thing; + } + + public long testI64(long thing) + { + Console.WriteLine("testI64(" + thing + ")"); + return thing; + } + + public double testDouble(double thing) + { + Console.WriteLine("testDouble(" + thing + ")"); + return thing; + } + + public Xtruct testStruct(Xtruct thing) + { + Console.WriteLine("testStruct({" + + "\"" + thing.String_thing + "\", " + + thing.Byte_thing + ", " + + thing.I32_thing + ", " + + thing.I64_thing + "})"); + return thing; + } + + public Xtruct2 testNest(Xtruct2 nest) + { + Xtruct thing = nest.Struct_thing; + Console.WriteLine("testNest({" + + nest.Byte_thing + ", {" + + "\"" + thing.String_thing + "\", " + + thing.Byte_thing + ", " + + thing.I32_thing + ", " + + thing.I64_thing + "}, " + + nest.I32_thing + "})"); + return nest; + } + + public Dictionary testMap(Dictionary thing) + { + Console.WriteLine("testMap({"); + bool first = true; + foreach (int key in thing.Keys) + { + if (first) + { + first = false; + } + else + { + Console.WriteLine(", "); + } + Console.WriteLine(key + " => " + thing[key]); + } + Console.WriteLine("})"); + return thing; + } + + public THashSet testSet(THashSet thing) + { + Console.WriteLine("testSet({"); + bool first = true; + foreach (int elem in thing) + { + if (first) + { + first = false; + } + else + { + Console.WriteLine(", "); + } + Console.WriteLine(elem); + } + Console.WriteLine("})"); + return thing; + } + + public List testList(List thing) + { + Console.WriteLine("testList({"); + bool first = true; + foreach (int elem in thing) + { + if (first) + { + first = false; + } + else + { + Console.WriteLine(", "); + } + Console.WriteLine(elem); + } + Console.WriteLine("})"); + return thing; + } + + public Numberz testEnum(Numberz thing) + { + Console.WriteLine("testEnum(" + thing + ")"); + return thing; + } + + public long testTypedef(long thing) + { + Console.WriteLine("testTypedef(" + thing + ")"); + return thing; + } + + public Dictionary> testMapMap(int hello) + { + Console.WriteLine("testMapMap(" + hello + ")"); + Dictionary> mapmap = + new Dictionary>(); + + Dictionary pos = new Dictionary(); + Dictionary neg = new Dictionary(); + for (int i = 1; i < 5; i++) + { + pos[i] = i; + neg[-i] = -i; + } + + mapmap[4] = pos; + mapmap[-4] = neg; + + return mapmap; + } + + public Dictionary> testInsanity(Insanity argument) + { + Console.WriteLine("testInsanity()"); + + Xtruct hello = new Xtruct(); + hello.String_thing = "Hello2"; + hello.Byte_thing = 2; + hello.I32_thing = 2; + hello.I64_thing = 2; + + Xtruct goodbye = new Xtruct(); + goodbye.String_thing = "Goodbye4"; + goodbye.Byte_thing = (byte)4; + goodbye.I32_thing = 4; + goodbye.I64_thing = (long)4; + + Insanity crazy = new Insanity(); + crazy.UserMap = new Dictionary(); + crazy.UserMap[Numberz.EIGHT] = (long)8; + crazy.Xtructs = new List(); + crazy.Xtructs.Add(goodbye); + + Insanity looney = new Insanity(); + crazy.UserMap[Numberz.FIVE] = (long)5; + crazy.Xtructs.Add(hello); + + Dictionary first_map = new Dictionary(); + Dictionary second_map = new Dictionary(); ; + + first_map[Numberz.TWO] = crazy; + first_map[Numberz.THREE] = crazy; + + second_map[Numberz.SIX] = looney; + + Dictionary> insane = + new Dictionary>(); + insane[(long)1] = first_map; + insane[(long)2] = second_map; + + return insane; + } + + public Xtruct testMulti(byte arg0, int arg1, long arg2, Dictionary arg3, Numberz arg4, long arg5) + { + Console.WriteLine("testMulti()"); + + Xtruct hello = new Xtruct(); ; + hello.String_thing = "Hello2"; + hello.Byte_thing = arg0; + hello.I32_thing = arg1; + hello.I64_thing = arg2; + return hello; + } + + public void testException(string arg) + { + Console.WriteLine("testException(" + arg + ")"); + if (arg == "Xception") + { + Xception x = new Xception(); + x.ErrorCode = 1001; + x.Message = "This is an Xception"; + throw x; + } + return; + } + + public Xtruct testMultiException(string arg0, string arg1) + { + Console.WriteLine("testMultiException(" + arg0 + ", " + arg1 + ")"); + if (arg0 == "Xception") + { + Xception x = new Xception(); + x.ErrorCode = 1001; + x.Message = "This is an Xception"; + throw x; + } + else if (arg0 == "Xception2") + { + Xception2 x = new Xception2(); + x.ErrorCode = 2002; + x.Struct_thing = new Xtruct(); + x.Struct_thing.String_thing = "This is an Xception2"; + throw x; + } + + Xtruct result = new Xtruct(); + result.String_thing = arg1; + return result; + } + + public void testStop() + { + if (server != null) + { + server.Stop(); + } + } + + public void testOneway(int arg) + { + Console.WriteLine("testOneway(" + arg + "), sleeping..."); + System.Threading.Thread.Sleep(arg * 1000); + Console.WriteLine("testOneway finished"); + } + + } // class TestHandler + + public static void Execute(string[] args) + { + try + { + bool useBufferedSockets = false; + int port = 9090; + if (args.Length > 0) + { + port = int.Parse(args[0]); + + if (args.Length > 1) + { + bool.TryParse(args[1], out useBufferedSockets); + } + } + + // Processor + TestHandler testHandler = new TestHandler(); + ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler); + + // Transport + TServerSocket tServerSocket = new TServerSocket(port, 0, useBufferedSockets); + + TServer serverEngine; + + // Simple Server + serverEngine = new TSimpleServer(testProcessor, tServerSocket); + + // ThreadPool Server + // serverEngine = new TThreadPoolServer(testProcessor, tServerSocket); + + // Threaded Server + // serverEngine = new TThreadedServer(testProcessor, tServerSocket); + + testHandler.server = serverEngine; + + // Run it + Console.WriteLine("Starting the server on port " + port + (useBufferedSockets ? " with buffered socket" : "") + "..."); + serverEngine.Serve(); + + } + catch (Exception x) + { + Console.Error.Write(x); + } + Console.WriteLine("done."); + } + } +} diff --git a/test/csharp/ThriftTest/ThriftTest.csproj b/test/csharp/ThriftTest/ThriftTest.csproj new file mode 100644 index 00000000..3f427fd7 --- /dev/null +++ b/test/csharp/ThriftTest/ThriftTest.csproj @@ -0,0 +1,111 @@ + + + + Debug + AnyCPU + 9.0.21022 + 2.0 + {48DD757F-CA95-4DD7-BDA4-58DB6F108C2C} + Exe + Properties + ThriftTest + ThriftTest + v3.5 + 512 + publish\ + true + Disk + false + Foreground + 7 + Days + false + false + true + 0 + 1.0.0.%2a + false + false + true + SAK + SAK + SAK + SAK + + + true + full + false + bin\Debug\ + DEBUG;TRACE + prompt + 4 + + + pdbonly + true + bin\Release\ + TRACE + prompt + 4 + + + + + False + .\ThriftImpl.dll + + + + + + + + + + + False + .NET Framework 2.0 %28x86%29 + false + + + False + .NET Framework 3.0 %28x86%29 + false + + + False + .NET Framework 3.5 + true + + + False + Windows Installer 3.1 + true + + + + + {499EB63C-D74C-47E8-AE48-A2FC94538E9D} + Thrift + + + + + + rmdir /s /q $(ProjectDir)gen-csharp +del /f /q $(ProjectDir)ThriftImpl.dll + +$(ProjectDir)\..\..\..\compiler\cpp\thrift.exe -csharp -o $(ProjectDir) $(ProjectDir)\..\..\ThriftTest.thrift + +cd $(ProjectDir) + +$(MSBuildToolsPath)\Csc.exe /t:library /out:.\ThriftImpl.dll /recurse:.\gen-csharp\* /reference:$(ProjectDir)..\..\..\lib\csharp\src\bin\Debug\Thrift.dll + + diff --git a/test/csharp/ThriftTest/maketest.sh b/test/csharp/ThriftTest/maketest.sh new file mode 100755 index 00000000..5580de80 --- /dev/null +++ b/test/csharp/ThriftTest/maketest.sh @@ -0,0 +1,23 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +../../../compiler/cpp/thrift --gen csharp -o . ../../ThriftTest.thrift +gmcs /t:library /out:./ThriftImpl.dll /recurse:./gen-csharp/* /reference:../../../lib/csharp/Thrift.dll diff --git a/test/erl/Makefile b/test/erl/Makefile new file mode 100644 index 00000000..17e30da6 --- /dev/null +++ b/test/erl/Makefile @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +GENDIR=gen +GEN_INCLUDEDIR=$(GENDIR)/include +GEN_SRCDIR=$(GENDIR)/src +GEN_TARGETDIR=$(GENDIR)/ebin + +INCLUDEDIR=include +TARGETDIR=ebin +SRCDIR=src + +ALL_INCLUDEDIR=$(GEN_INCLUDEDIR) $(INCLUDEDIR) ../../lib/erl/include +INCLUDEFLAGS=$(patsubst %,-I%, ${ALL_INCLUDEDIR}) + +MODULES = stress_server test_server test_disklog test_membuffer + +INCLUDES = +TARGETS = $(patsubst %,${TARGETDIR}/%.beam,${MODULES}) +HEADERS = $(patsubst %,${INCLUDEDIR}/%.hrl,${INCLUDES}) + +all: ${GEN_TARGETDIR}/ ${TARGETS} + +TEST_RPCFILE = ../ThriftTest.thrift +STRESS_RPCFILE = ../StressTest.thrift +THRIFT = ../../compiler/cpp/thrift + +${GENDIR}/: ${RPCFILE} + rm -rf ${GENDIR} + ${THRIFT} --gen erl ${TEST_RPCFILE} + ${THRIFT} --gen erl ${STRESS_RPCFILE} + mkdir -p ${GEN_INCLUDEDIR} + mkdir -p ${GEN_SRCDIR} + mkdir -p ${GEN_TARGETDIR} + mv -t ${GEN_INCLUDEDIR} gen-erl/*.hrl + mv -t ${GEN_SRCDIR} gen-erl/*.erl + rm -rf gen-erl + +${GEN_TARGETDIR}/: ${GENDIR}/ + rm -rf ${GEN_TARGETDIR} + mkdir -p ${GEN_TARGETDIR} + erlc ${INCLUDEFLAGS} -o ${GEN_TARGETDIR} ${GEN_SRCDIR}/*.erl + +$(TARGETS): ${TARGETDIR}/%.beam: ${SRCDIR}/%.erl ${GEN_INCLUDEDIR}/ ${HEADERS} + mkdir -p ${TARGETDIR} + erlc ${INCLUDEFLAGS} -o ${TARGETDIR} $< + +clean: + rm -f ${TARGETDIR}/*.beam + rm -rf ${GENDIR} diff --git a/test/erl/src/stress_server.erl b/test/erl/src/stress_server.erl new file mode 100644 index 00000000..35fff069 --- /dev/null +++ b/test/erl/src/stress_server.erl @@ -0,0 +1,64 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(stress_server). + + +-export([start_link/1, + + handle_function/2, + + echoVoid/0, + echoByte/1, + echoI32/1, + echoI64/1, + echoString/1, + echoList/1, + echoSet/1, + echoMap/1 + ]). + +start_link(Port) -> + thrift_server:start_link(Port, service_thrift, ?MODULE). + + +handle_function(Function, Args) -> + case apply(?MODULE, Function, tuple_to_list(Args)) of + ok -> + ok; + Else -> {reply, Else} + end. + + +echoVoid() -> + ok. +echoByte(X) -> + X. +echoI32(X) -> + X. +echoI64(X) -> + X. +echoString(X) -> + X. +echoList(X) -> + X. +echoSet(X) -> + X. +echoMap(X) -> + X. diff --git a/test/erl/src/test_disklog.erl b/test/erl/src/test_disklog.erl new file mode 100644 index 00000000..7b0be72d --- /dev/null +++ b/test/erl/src/test_disklog.erl @@ -0,0 +1,81 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_disklog). + +-compile(export_all). + +t() -> + {ok, TransportFactory} = + thrift_disk_log_transport:new_transport_factory( + test_disklog, + [{file, "/tmp/test_log"}, + {size, {1024*1024, 10}}]), + {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory( + TransportFactory, []), + {ok, Client} = thrift_client:start_link(ProtocolFactory, thriftTest_thrift), + + io:format("Client started~n"), + + % We have to make oneway calls into this client only since otherwise it will try + % to read from the disklog and go boom. + {ok, ok} = thrift_client:call(Client, testOneway, [16#deadbeef]), + io:format("Call written~n"), + + % Use the send_call method to write a non-oneway call into the log + ok = thrift_client:send_call(Client, testString, [<<"hello world">>]), + io:format("Non-oneway call sent~n"), + + ok = thrift_client:close(Client), + io:format("Client closed~n"), + + ok. + + + +t_base64() -> + {ok, TransportFactory} = + thrift_disk_log_transport:new_transport_factory( + test_disklog, + [{file, "/tmp/test_b64_log"}, + {size, {1024*1024, 10}}]), + {ok, B64Factory} = + thrift_base64_transport:new_transport_factory(TransportFactory), + {ok, BufFactory} = + thrift_buffered_transport:new_transport_factory(B64Factory), + {ok, ProtocolFactory} = thrift_binary_protocol:new_protocol_factory( + BufFactory, []), + {ok, Client} = thrift_client:start_link(ProtocolFactory, thriftTest_thrift), + + io:format("Client started~n"), + + % We have to make oneway calls into this client only since otherwise it will try + % to read from the disklog and go boom. + {ok, ok} = thrift_client:call(Client, testOneway, [16#deadbeef]), + io:format("Call written~n"), + + % Use the send_call method to write a non-oneway call into the log + ok = thrift_client:send_call(Client, testString, [<<"hello world">>]), + io:format("Non-oneway call sent~n"), + + ok = thrift_client:close(Client), + io:format("Client closed~n"), + + ok. + diff --git a/test/erl/src/test_membuffer.erl b/test/erl/src/test_membuffer.erl new file mode 100644 index 00000000..7bd23a0f --- /dev/null +++ b/test/erl/src/test_membuffer.erl @@ -0,0 +1,81 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_membuffer). +-export([t/0]). + +-include("thriftTest_types.hrl"). + +test_data() -> + #xtruct{string_thing = <<"foobar">>, + byte_thing = 123, + i32_thing = 1234567, + i64_thing = 12345678900}. + +t1() -> + {ok, Transport} = thrift_memory_buffer:new(), + {ok, Protocol} = thrift_binary_protocol:new(Transport), + TestData = test_data(), + ok = thrift_protocol:write(Protocol, + {{struct, element(2, thriftTest_types:struct_info('xtruct'))}, + TestData}), + {ok, Result} = thrift_protocol:read(Protocol, + {struct, element(2, thriftTest_types:struct_info('xtruct'))}, + 'xtruct'), + + Result = TestData. + + +t2() -> + {ok, Transport} = thrift_memory_buffer:new(), + {ok, Protocol} = thrift_binary_protocol:new(Transport), + TestData = test_data(), + ok = thrift_protocol:write(Protocol, + {{struct, element(2, thriftTest_types:struct_info('xtruct'))}, + TestData}), + {ok, Result} = thrift_protocol:read(Protocol, + {struct, element(2, thriftTest_types:struct_info('xtruct3'))}, + 'xtruct3'), + + Result = #xtruct3{string_thing = TestData#xtruct.string_thing, + changed = undefined, + i32_thing = TestData#xtruct.i32_thing, + i64_thing = TestData#xtruct.i64_thing}. + + +t3() -> + {ok, Transport} = thrift_memory_buffer:new(), + {ok, Protocol} = thrift_binary_protocol:new(Transport), + TestData = #bools{im_true = true, im_false = false}, + ok = thrift_protocol:write(Protocol, + {{struct, element(2, thriftTest_types:struct_info('bools'))}, + TestData}), + {ok, Result} = thrift_protocol:read(Protocol, + {struct, element(2, thriftTest_types:struct_info('bools'))}, + 'bools'), + + true = TestData#bools.im_true =:= Result#bools.im_true, + true = TestData#bools.im_false =:= Result#bools.im_false. + + +t() -> + t1(), + t2(), + t3(). + diff --git a/test/erl/src/test_server.erl b/test/erl/src/test_server.erl new file mode 100644 index 00000000..cd439ccd --- /dev/null +++ b/test/erl/src/test_server.erl @@ -0,0 +1,174 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(test_server). + +-export([start_link/1, handle_function/2]). + +-include("thriftTest_types.hrl"). + +start_link(Port) -> + thrift_server:start_link(Port, thriftTest_thrift, ?MODULE). + + +handle_function(testVoid, {}) -> + io:format("testVoid~n"), + ok; + +handle_function(testString, {S}) when is_binary(S) -> + io:format("testString: ~p~n", [S]), + {reply, S}; + +handle_function(testByte, {I8}) when is_integer(I8) -> + io:format("testByte: ~p~n", [I8]), + {reply, I8}; + +handle_function(testI32, {I32}) when is_integer(I32) -> + io:format("testI32: ~p~n", [I32]), + {reply, I32}; + +handle_function(testI64, {I64}) when is_integer(I64) -> + io:format("testI64: ~p~n", [I64]), + {reply, I64}; + +handle_function(testDouble, {Double}) when is_float(Double) -> + io:format("testDouble: ~p~n", [Double]), + {reply, Double}; + +handle_function(testStruct, + {Struct = #xtruct{string_thing = String, + byte_thing = Byte, + i32_thing = I32, + i64_thing = I64}}) +when is_binary(String), + is_integer(Byte), + is_integer(I32), + is_integer(I64) -> + io:format("testStruct: ~p~n", [Struct]), + {reply, Struct}; + +handle_function(testNest, + {Nest}) when is_record(Nest, xtruct2), + is_record(Nest#xtruct2.struct_thing, xtruct) -> + io:format("testNest: ~p~n", [Nest]), + {reply, Nest}; + +handle_function(testMap, {Map}) -> + io:format("testMap: ~p~n", [dict:to_list(Map)]), + {reply, Map}; + +handle_function(testSet, {Set}) -> + true = sets:is_set(Set), + io:format("testSet: ~p~n", [sets:to_list(Set)]), + {reply, Set}; + +handle_function(testList, {List}) when is_list(List) -> + io:format("testList: ~p~n", [List]), + {reply, List}; + +handle_function(testEnum, {Enum}) when is_integer(Enum) -> + io:format("testEnum: ~p~n", [Enum]), + {reply, Enum}; + +handle_function(testTypedef, {UserID}) when is_integer(UserID) -> + io:format("testTypedef: ~p~n", [UserID]), + {reply, UserID}; + +handle_function(testMapMap, {Hello}) -> + io:format("testMapMap: ~p~n", [Hello]), + + PosList = [{I, I} || I <- lists:seq(1, 5)], + NegList = [{-I, -I} || I <- lists:seq(1, 5)], + + MapMap = dict:from_list([{4, dict:from_list(PosList)}, + {-4, dict:from_list(NegList)}]), + {reply, MapMap}; + +handle_function(testInsanity, {Insanity}) when is_record(Insanity, insanity) -> + Hello = #xtruct{string_thing = <<"Hello2">>, + byte_thing = 2, + i32_thing = 2, + i64_thing = 2}, + + Goodbye = #xtruct{string_thing = <<"Goodbye4">>, + byte_thing = 4, + i32_thing = 4, + i64_thing = 4}, + Crazy = #insanity{ + userMap = dict:from_list([{?thriftTest_EIGHT, 8}]), + xtructs = [Goodbye] + }, + + Looney = #insanity{ + userMap = dict:from_list([{?thriftTest_FIVE, 5}]), + xtructs = [Hello] + }, + + FirstMap = dict:from_list([{?thriftTest_TWO, Crazy}, + {?thriftTest_THREE, Crazy}]), + + SecondMap = dict:from_list([{?thriftTest_SIX, Looney}]), + + Insane = dict:from_list([{1, FirstMap}, + {2, SecondMap}]), + + io:format("Return = ~p~n", [Insane]), + + {reply, Insane}; + +handle_function(testMulti, Args = {Arg0, Arg1, Arg2, _Arg3, Arg4, Arg5}) + when is_integer(Arg0), + is_integer(Arg1), + is_integer(Arg2), + is_integer(Arg4), + is_integer(Arg5) -> + + io:format("testMulti(~p)~n", [Args]), + {reply, #xtruct{string_thing = <<"Hello2">>, + byte_thing = Arg0, + i32_thing = Arg1, + i64_thing = Arg2}}; + +handle_function(testException, {String}) when is_binary(String) -> + io:format("testException(~p)~n", [String]), + case String of + <<"Xception">> -> + throw(#xception{errorCode = 1001, + message = <<"This is an Xception">>}); + _ -> + ok + end; + +handle_function(testMultiException, {Arg0, Arg1}) -> + io:format("testMultiException(~p, ~p)~n", [Arg0, Arg1]), + case Arg0 of + <<"Xception">> -> + throw(#xception{errorCode = 1001, + message = <<"This is an Xception">>}); + <<"Xception2">> -> + throw(#xception2{errorCode = 2002, + struct_thing = + #xtruct{string_thing = <<"This is an Xception2">>}}); + _ -> + {reply, #xtruct{string_thing = Arg1}} + end; + +handle_function(testOneway, {Seconds}) -> + timer:sleep(1000 * Seconds), + ok. diff --git a/test/hs/Client.hs b/test/hs/Client.hs new file mode 100644 index 00000000..c5e4d907 --- /dev/null +++ b/test/hs/Client.hs @@ -0,0 +1,58 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Client where + +import ThriftTest_Client +import ThriftTest_Types +import qualified Data.Map as Map +import qualified Data.Set as Set +import Control.Monad +import Control.Exception as CE + +import Network + +import Thrift +import Thrift.Transport.Handle +import Thrift.Protocol.Binary + + +serverAddress = ("127.0.0.1", PortNumber 9090) + +main = do to <- hOpen serverAddress + let p = BinaryProtocol to + let ps = (p,p) + print =<< testString ps "bya" + print =<< testByte ps 8 + print =<< testByte ps (-8) + print =<< testI32 ps 32 + print =<< testI32 ps (-32) + print =<< testI64 ps 64 + print =<< testI64 ps (-64) + print =<< testDouble ps 3.14 + print =<< testDouble ps (-3.14) + print =<< testMap ps (Map.fromList [(1,1),(2,2),(3,3)]) + print =<< testList ps [1,2,3,4,5] + print =<< testSet ps (Set.fromList [1,2,3,4,5]) + print =<< testStruct ps (Xtruct (Just "hi") (Just 4) (Just 5) Nothing) + CE.catch (testException ps "e" >> print "bad") (\e -> print (e :: Xception)) + CE.catch (testMultiException ps "e" "e2" >> print "ok") (\e -> print (e :: Xception)) + CE.catch (CE.catch (testMultiException ps "e" "e2">> print "bad") (\e -> print (e :: Xception2))) (\(e :: SomeException) -> print "ok") + tClose to + diff --git a/test/hs/Server.hs b/test/hs/Server.hs new file mode 100644 index 00000000..0ca9d9fe --- /dev/null +++ b/test/hs/Server.hs @@ -0,0 +1,57 @@ +-- +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Server where + +import ThriftTest +import ThriftTest_Iface +import Data.Map as Map +import Control.Exception +import ThriftTest_Types + +import Thrift +import Thrift.Server + + +data TestHandler = TestHandler +instance ThriftTest_Iface TestHandler where + testVoid a = return () + testString a (Just s) = do print s; return s + testByte a (Just x) = do print x; return x + testI32 a (Just x) = do print x; return x + testI64 a (Just x) = do print x; return x + testDouble a (Just x) = do print x; return x + testStruct a (Just x) = do print x; return x + testNest a (Just x) = do print x; return x + testMap a (Just x) = do print x; return x + testSet a (Just x) = do print x; return x + testList a (Just x) = do print x; return x + testEnum a (Just x) = do print x; return x + testTypedef a (Just x) = do print x; return x + testMapMap a (Just x) = return (Map.fromList [(1,Map.fromList [(2,2)])]) + testInsanity a (Just x) = return (Map.fromList [(1,Map.fromList [(ONE,x)])]) + testMulti a a1 a2 a3 a4 a5 a6 = return (Xtruct Nothing Nothing Nothing Nothing) + testException a c = throw (Xception (Just 1) (Just "bya")) + testMultiException a c1 c2 = throw (Xception (Just 1) (Just "xyz")) + testOneway a (Just i) = do print i + + +main = do (runBasicServer TestHandler process 9090) + `Control.Exception.catch` + (\(TransportExn s t) -> print s) diff --git a/test/hs/runclient.sh b/test/hs/runclient.sh new file mode 100644 index 00000000..b93bbb14 --- /dev/null +++ b/test/hs/runclient.sh @@ -0,0 +1,26 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +if [ -z $BASE ]; then + BASE=../.. +fi + +ghci -fglasgow-exts -i$BASE/lib/hs/src -i$BASE/test/hs/gen-hs Client.hs diff --git a/test/hs/runserver.sh b/test/hs/runserver.sh new file mode 100644 index 00000000..b23301b4 --- /dev/null +++ b/test/hs/runserver.sh @@ -0,0 +1,27 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +if [ -z $BASE ]; then + BASE=../.. +fi + +printf "Starting server... " +ghc -fglasgow-exts -i$BASE/lib/hs/src -i$BASE/test/hs/gen-hs Server.hs -e "putStrLn \"ready.\" >> Server.main" diff --git a/test/ocaml/Makefile b/test/ocaml/Makefile new file mode 100644 index 00000000..a543ce58 --- /dev/null +++ b/test/ocaml/Makefile @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +all: + cd client; make; cd ..; cd server; make +clean: + cd client; make clean; cd ..; cd server; make clean + diff --git a/test/ocaml/client/Makefile b/test/ocaml/client/Makefile new file mode 100644 index 00000000..806ed20a --- /dev/null +++ b/test/ocaml/client/Makefile @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SOURCES = ../gen-ocaml/ThriftTest_types.ml ../gen-ocaml/ThriftTest_consts.ml ../gen-ocaml/SecondService.ml ../gen-ocaml/ThriftTest.ml TestClient.ml +RESULT = tc +INCDIRS = "../../../lib/ocaml/src/" "../gen-ocaml/" +LIBS = unix thrift +all: nc +OCAMLMAKEFILE = ../../../lib/ocaml/OCamlMakefile +include $(OCAMLMAKEFILE) diff --git a/test/ocaml/client/TestClient.ml b/test/ocaml/client/TestClient.ml new file mode 100644 index 00000000..91783ae4 --- /dev/null +++ b/test/ocaml/client/TestClient.ml @@ -0,0 +1,82 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift;; +open ThriftTest_types;; + +let s = new TSocket.t "127.0.0.1" 9090;; +let p = new TBinaryProtocol.t s;; +let c = new ThriftTest.client p p;; +let sod = function + Some v -> v + | None -> raise Thrift_error;; + +s#opn; +print_string (c#testString "bya"); +print_char '\n'; +print_int (c#testByte 8); +print_char '\n'; +print_int (c#testByte (-8)); +print_char '\n'; +print_int (c#testI32 32); +print_char '\n'; +print_string (Int64.to_string (c#testI64 64L)); +print_char '\n'; +print_float (c#testDouble 3.14); +print_char '\n'; + +let l = [1;2;3;4] in + if l = (c#testList l) then print_string "list ok\n" else print_string "list fail\n";; +let h = Hashtbl.create 5 in +let a = Hashtbl.add h in + for i=1 to 10 do + a i (10*i) + done; + let r = c#testMap h in + for i=1 to 10 do + try + let g = Hashtbl.find r i in + print_int i; + print_char ' '; + print_int g; + print_char '\n' + with Not_found -> print_string ("Can't find "^(string_of_int i)^"\n") + done;; + +let s = Hashtbl.create 5 in +let a = Hashtbl.add s in + for i = 1 to 10 do + a i true + done; + let r = c#testSet s in + for i = 1 to 10 do + try + let g = Hashtbl.find r i in + print_int i; + print_char '\n' + with Not_found -> print_string ("Can't find "^(string_of_int i)^"\n") + done;; +try + c#testException "Xception" +with Xception _ -> print_string "testException ok\n";; +try + ignore(c#testMultiException "Xception" "bya") +with Xception e -> Printf.printf "%d %s\n" (sod e#get_errorCode) (sod e#get_message);; + + diff --git a/test/ocaml/server/Makefile b/test/ocaml/server/Makefile new file mode 100644 index 00000000..44dcac76 --- /dev/null +++ b/test/ocaml/server/Makefile @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +SOURCES = ../gen-ocaml/ThriftTest_types.ml ../gen-ocaml/ThriftTest_consts.ml ../gen-ocaml/SecondService.ml ../gen-ocaml/ThriftTest.ml TestServer.ml +RESULT = ts +INCDIRS = "../../../lib/ocaml/src/" "../gen-ocaml/" +LIBS = thrift +THREADS = yes +all: nc +OCAMLMAKEFILE = ../../../lib/ocaml/OCamlMakefile +include $(OCAMLMAKEFILE) diff --git a/test/ocaml/server/TestServer.ml b/test/ocaml/server/TestServer.ml new file mode 100644 index 00000000..3f5c9ee1 --- /dev/null +++ b/test/ocaml/server/TestServer.ml @@ -0,0 +1,136 @@ +(* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*) + +open Thrift +open ThriftTest_types + +let p = Printf.printf;; +exception Die;; +let sod = function + Some v -> v + | None -> raise Die;; + + +class test_handler = +object (self) + inherit ThriftTest.iface + method testVoid = p "testVoid()\n" + method testString x = p "testString(%s)\n" (sod x); (sod x) + method testByte x = p "testByte(%d)\n" (sod x); (sod x) + method testI32 x = p "testI32(%d)\n" (sod x); (sod x) + method testI64 x = p "testI64(%s)\n" (Int64.to_string (sod x)); (sod x) + method testDouble x = p "testDouble(%f)\n" (sod x); (sod x) + method testStruct x = p "testStruct(---)\n"; (sod x) + method testNest x = p "testNest(---)\n"; (sod x) + method testMap x = p "testMap(---)\n"; (sod x) + method testSet x = p "testSet(---)\n"; (sod x) + method testList x = p "testList(---)\n"; (sod x) + method testEnum x = p "testEnum(---)\n"; (sod x) + method testTypedef x = p "testTypedef(---)\n"; (sod x) + method testMapMap x = p "testMapMap(%d)\n" (sod x); + let mm = Hashtbl.create 3 in + let pos = Hashtbl.create 7 in + let neg = Hashtbl.create 7 in + for i=1 to 4 do + Hashtbl.add pos i i; + Hashtbl.add neg (-i) (-i); + done; + Hashtbl.add mm 4 pos; + Hashtbl.add mm (-4) neg; + mm + method testInsanity x = p "testInsanity()\n"; + p "testinsanity()\n"; + let hello = new xtruct in + let goodbye = new xtruct in + let crazy = new insanity in + let looney = new insanity in + let cumap = Hashtbl.create 7 in + let insane = Hashtbl.create 7 in + let firstmap = Hashtbl.create 7 in + let secondmap = Hashtbl.create 7 in + hello#set_string_thing "Hello2"; + hello#set_byte_thing 2; + hello#set_i32_thing 2; + hello#set_i64_thing 2L; + goodbye#set_string_thing "Goodbye4"; + goodbye#set_byte_thing 4; + goodbye#set_i32_thing 4; + goodbye#set_i64_thing 4L; + Hashtbl.add cumap Numberz.EIGHT 8L; + Hashtbl.add cumap Numberz.FIVE 5L; + crazy#set_userMap cumap; + crazy#set_xtructs [goodbye; hello]; + Hashtbl.add firstmap Numberz.TWO crazy; + Hashtbl.add firstmap Numberz.THREE crazy; + Hashtbl.add secondmap Numberz.SIX looney; + Hashtbl.add insane 1L firstmap; + Hashtbl.add insane 2L secondmap; + insane + method testMulti a0 a1 a2 a3 a4 a5 = + p "testMulti()\n"; + let hello = new xtruct in + hello#set_string_thing "Hello2"; + hello#set_byte_thing (sod a0); + hello#set_i32_thing (sod a1); + hello#set_i64_thing (sod a2); + hello + method testException s = + p "testException(%S)\n" (sod s); + if (sod s) = "Xception" then + let x = new xception in + x#set_errorCode 1001; + x#set_message "This is an Xception"; + raise (Xception x) + else () + method testMultiException a0 a1 = + p "testMultiException(%S, %S)\n" (sod a0) (sod a1); + if (sod a0) = "Xception" then + let x = new xception in + x#set_errorCode 1001; + x#set_message "This is an Xception"; + raise (Xception x) + else (if (sod a0) = "Xception2" then + let x = new xception2 in + let s = new xtruct in + x#set_errorCode 2002; + s#set_string_thing "This as an Xception2"; + x#set_struct_thing s; + raise (Xception2 x) + else ()); + let res = new xtruct in + res#set_string_thing (sod a1); + res + method testOneway i = + Unix.sleep (sod i) +end;; + +let h = new test_handler in +let proc = new ThriftTest.processor h in +let port = 9090 in +let pf = new TBinaryProtocol.factory in +let server = new TThreadedServer.t + proc + (new TServerSocket.t port) + (new Transport.factory) + pf + pf +in + server#serve + + diff --git a/test/perl/Makefile b/test/perl/Makefile new file mode 100644 index 00000000..e2d81d45 --- /dev/null +++ b/test/perl/Makefile @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Default target is everything +target: all + +# Tools +THRIFT = ../../compiler/cpp/thrift + +all: ../ThriftTest.thrift + $(THRIFT) --gen perl ../ThriftTest.thrift + +clean: + $(RM) -r gen-perl diff --git a/test/perl/TestClient.pl b/test/perl/TestClient.pl new file mode 100644 index 00000000..af80d469 --- /dev/null +++ b/test/perl/TestClient.pl @@ -0,0 +1,338 @@ +#!/usr/bin/env perl + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 5.6.0; +use strict; +use warnings; +use Data::Dumper; +use Time::HiRes qw(gettimeofday); + +use lib '../../lib/perl/lib'; +use lib 'gen-perl'; + +use Thrift; +use Thrift::BinaryProtocol; +use Thrift::Socket; +use Thrift::BufferedTransport; + +use ThriftTest::ThriftTest; +use ThriftTest::Types; + +$|++; + +my $host = 'localhost'; +my $port = 9090; + + +my $socket = new Thrift::Socket($host, $port); + +my $bufferedSocket = new Thrift::BufferedTransport($socket, 1024, 1024); +my $transport = $bufferedSocket; +my $protocol = new Thrift::BinaryProtocol($transport); +my $testClient = new ThriftTest::ThriftTestClient($protocol); + +eval{ +$transport->open(); +}; if($@){ + die(Dumper($@)); +} +my $start = gettimeofday(); + +# +# VOID TEST +# +print("testVoid()"); +$testClient->testVoid(); +print(" = void\n"); + +# +# STRING TEST +# +print("testString(\"Test\")"); +my $s = $testClient->testString("Test"); +print(" = \"$s\"\n"); + +# +# BYTE TEST +# +print("testByte(1)"); +my $u8 = $testClient->testByte(1); +print(" = $u8\n"); + +# +# I32 TEST +# +print("testI32(-1)"); +my $i32 = $testClient->testI32(-1); +print(" = $i32\n"); + +# +#I64 TEST +# +print("testI64(-34359738368)"); +my $i64 = $testClient->testI64(-34359738368); +print(" = $i64\n"); + +# +# DOUBLE TEST +# +print("testDouble(-852.234234234)"); +my $dub = $testClient->testDouble(-852.234234234); +print(" = $dub\n"); + +# +# STRUCT TEST +# +print("testStruct({\"Zero\", 1, -3, -5})"); +my $out = new ThriftTest::Xtruct(); +$out->string_thing("Zero"); +$out->byte_thing(1); +$out->i32_thing(-3); +$out->i64_thing(-5); +my $in = $testClient->testStruct($out); +print(" = {\"".$in->string_thing."\", ". + $in->byte_thing.", ". + $in->i32_thing.", ". + $in->i64_thing."}\n"); + +# +# NESTED STRUCT TEST +# +print("testNest({1, {\"Zero\", 1, -3, -5}, 5}"); +my $out2 = new ThriftTest::Xtruct2(); +$out2->byte_thing(1); +$out2->struct_thing($out); +$out2->i32_thing(5); +my $in2 = $testClient->testNest($out2); +$in = $in2->struct_thing; +print(" = {".$in2->byte_thing.", {\"". + $in->string_thing."\", ". + $in->byte_thing.", ". + $in->i32_thing.", ". + $in->i64_thing."}, ". + $in2->i32_thing."}\n"); + +# +# MAP TEST +# +my $mapout = {}; +for (my $i = 0; $i < 5; ++$i) { + $mapout->{$i} = $i-10; +} +print("testMap({"); +my $first = 1; +while( my($key,$val) = each %$mapout) { + if ($first) { + $first = 0; + } else { + print(", "); + } + print("$key => $val"); +} +print("})"); + + +my $mapin = $testClient->testMap($mapout); +print(" = {"); + +$first = 1; +while( my($key,$val) = each %$mapin){ + if ($first) { + $first = 0; + } else { + print(", "); + } + print("$key => $val"); +} +print("}\n"); + +# +# SET TEST +# +my $setout = []; +for (my $i = -2; $i < 3; ++$i) { + push(@$setout, $i); +} + +print("testSet({".join(",",@$setout)."})"); + +my $setin = $testClient->testSet($setout); + +print(" = {".join(",",@$setout)."}\n"); + +# +# LIST TEST +# +my $listout = []; +for (my $i = -2; $i < 3; ++$i) { + push(@$listout, $i); +} + +print("testList({".join(",",@$listout)."})"); + +my $listin = $testClient->testList($listout); + +print(" = {".join(",",@$listin)."}\n"); + +# +# ENUM TEST +# +print("testEnum(ONE)"); +my $ret = $testClient->testEnum(ThriftTest::Numberz::ONE); +print(" = $ret\n"); + +print("testEnum(TWO)"); +$ret = $testClient->testEnum(ThriftTest::Numberz::TWO); +print(" = $ret\n"); + +print("testEnum(THREE)"); +$ret = $testClient->testEnum(ThriftTest::Numberz::THREE); +print(" = $ret\n"); + +print("testEnum(FIVE)"); +$ret = $testClient->testEnum(ThriftTest::Numberz::FIVE); +print(" = $ret\n"); + +print("testEnum(EIGHT)"); +$ret = $testClient->testEnum(ThriftTest::Numberz::EIGHT); +print(" = $ret\n"); + +# +# TYPEDEF TEST +# +print("testTypedef(309858235082523)"); +my $uid = $testClient->testTypedef(309858235082523); +print(" = $uid\n"); + +# +# NESTED MAP TEST +# +print("testMapMap(1)"); +my $mm = $testClient->testMapMap(1); +print(" = {"); +while( my ($key,$val) = each %$mm) { + print("$key => {"); + while( my($k2,$v2) = each %$val) { + print("$k2 => $v2, "); + } + print("}, "); +} +print("}\n"); + +# +# INSANITY TEST +# +my $insane = new ThriftTest::Insanity(); +$insane->{userMap}->{ThriftTest::Numberz::FIVE} = 5000; +my $truck = new ThriftTest::Xtruct(); +$truck->string_thing("Truck"); +$truck->byte_thing(8); +$truck->i32_thing(8); +$truck->i64_thing(8); +push(@{$insane->{xtructs}}, $truck); + +print("testInsanity()"); +my $whoa = $testClient->testInsanity($insane); +print(" = {"); +while( my ($key,$val) = each %$whoa) { + print("$key => {"); + while( my($k2,$v2) = each %$val) { + print("$k2 => {"); + my $userMap = $v2->{userMap}; + print("{"); + if (ref($userMap) eq "HASH") { + while( my($k3,$v3) = each %$userMap) { + print("$k3 => $v3, "); + } + } + print("}, "); + + my $xtructs = $v2->{xtructs}; + print("{"); + if (ref($xtructs) eq "ARRAY") { + foreach my $x (@$xtructs) { + print("{\"".$x->{string_thing}."\", ". + $x->{byte_thing}.", ".$x->{i32_thing}.", ".$x->{i64_thing}."}, "); + } + } + print("}"); + + print("}, "); + } + print("}, "); +} +print("}\n"); + +# +# EXCEPTION TEST +# +print("testException('Xception')"); +eval { + $testClient->testException('Xception'); + print(" void\nFAILURE\n"); +}; if($@ && $@->UNIVERSAL::isa('ThriftTest::Xception')) { + print(' caught xception '.$@->{errorCode}.': '.$@->{message}."\n"); +} + + +# +# Normal tests done. +# +my $stop = gettimeofday(); +my $elp = sprintf("%d",1000*($stop - $start), 0); +print("Total time: $elp ms\n"); + +# +# Extraneous "I don't trust PHP to pack/unpack integer" tests +# + +# Max I32 +my $num = 2**30 + 2**30 - 1; +my $num2 = $testClient->testI32($num); +if ($num != $num2) { + print "Missed max32 $num = $num2\n"; +} + +# Min I32 +$num = 0 - 2**31; +$num2 = $testClient->testI32($num); +if ($num != $num2) { + print "Missed min32 $num = $num2\n"; +} + +# Max Number I can get out of my perl +$num = 2**40; +$num2 = $testClient->testI64($num); +if ($num != $num2) { + print "Missed max64 $num = $num2\n"; +} + +# Max Number I can get out of my perl +$num = 0 - 2**40; +$num2 = $testClient->testI64($num); +if ($num != $num2) { + print "Missed min64 $num = $num2\n"; +} + +$transport->close(); + + + diff --git a/test/php/Makefile b/test/php/Makefile new file mode 100644 index 00000000..aa35c6e9 --- /dev/null +++ b/test/php/Makefile @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Default target is everything +target: all + +# Tools +THRIFT = ../../compiler/cpp/thrift + +all: normal inline + +normal: stubs + +inline: stubs-inline + +stubs: ../ThriftTest.thrift + $(THRIFT) --gen php ../ThriftTest.thrift + +stubs-inline: ../ThriftTest.thrift + $(THRIFT) --gen php:inlined ../ThriftTest.thrift + +clean: + $(RM) -r gen-php gen-phpi diff --git a/test/php/TestClient.php b/test/php/TestClient.php new file mode 100644 index 00000000..6d640dac --- /dev/null +++ b/test/php/TestClient.php @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + 1) { + $host = $argv[0]; +} + +if ($argc > 2) { + $host = $argv[1]; +} + +$hosts = array('localhost'); + +$socket = new TSocket($host, $port); +$socket = new TSocketPool($hosts, $port); +$socket->setDebug(TRUE); + +if ($MODE == 'inline') { + $transport = $socket; + $testClient = new ThriftTestClient($transport); +} else { + $bufferedSocket = new TBufferedTransport($socket, 1024, 1024); + $transport = $bufferedSocket; + $protocol = new TBinaryProtocol($transport); + $testClient = new ThriftTestClient($protocol); +} + +$transport->open(); + +$start = microtime(true); + +/** + * VOID TEST + */ +print_r("testVoid()"); +$testClient->testVoid(); +print_r(" = void\n"); + +/** + * STRING TEST + */ +print_r("testString(\"Test\")"); +$s = $testClient->testString("Test"); +print_r(" = \"$s\"\n"); + +/** + * BYTE TEST + */ +print_r("testByte(1)"); +$u8 = $testClient->testByte(1); +print_r(" = $u8\n"); + +/** + * I32 TEST + */ +print_r("testI32(-1)"); +$i32 = $testClient->testI32(-1); +print_r(" = $i32\n"); + +/** + * I64 TEST + */ +print_r("testI64(-34359738368)"); +$i64 = $testClient->testI64(-34359738368); +print_r(" = $i64\n"); + +/** + * DOUBLE TEST + */ +print_r("testDouble(-852.234234234)"); +$dub = $testClient->testDouble(-852.234234234); +print_r(" = $dub\n"); + +/** + * STRUCT TEST + */ +print_r("testStruct({\"Zero\", 1, -3, -5})"); +$out = new Xtruct(); +$out->string_thing = "Zero"; +$out->byte_thing = 1; +$out->i32_thing = -3; +$out->i64_thing = -5; +$in = $testClient->testStruct($out); +print_r(" = {\"".$in->string_thing."\", ". + $in->byte_thing.", ". + $in->i32_thing.", ". + $in->i64_thing."}\n"); + +/** + * NESTED STRUCT TEST + */ +print_r("testNest({1, {\"Zero\", 1, -3, -5}), 5}"); +$out2 = new Xtruct2(); +$out2->byte_thing = 1; +$out2->struct_thing = $out; +$out2->i32_thing = 5; +$in2 = $testClient->testNest($out2); +$in = $in2->struct_thing; +print_r(" = {".$in2->byte_thing.", {\"". + $in->string_thing."\", ". + $in->byte_thing.", ". + $in->i32_thing.", ". + $in->i64_thing."}, ". + $in2->i32_thing."}\n"); + +/** + * MAP TEST + */ +$mapout = array(); +for ($i = 0; $i < 5; ++$i) { + $mapout[$i] = $i-10; +} +print_r("testMap({"); +$first = true; +foreach ($mapout as $key => $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r("$key => $val"); +} +print_r("})"); + +$mapin = $testClient->testMap($mapout); +print_r(" = {"); +$first = true; +foreach ($mapin as $key => $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r("$key => $val"); +} +print_r("}\n"); + +/** + * SET TEST + */ +$setout = array();; +for ($i = -2; $i < 3; ++$i) { + $setout []= $i; +} +print_r("testSet({"); +$first = true; +foreach ($setout as $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r($val); +} +print_r("})"); +$setin = $testClient->testSet($setout); +print_r(" = {"); +$first = true; +foreach ($setin as $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r($val); +} +print_r("}\n"); + +/** + * LIST TEST + */ +$listout = array(); +for ($i = -2; $i < 3; ++$i) { + $listout []= $i; +} +print_r("testList({"); +$first = true; +foreach ($listout as $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r($val); +} +print_r("})"); +$listin = $testClient->testList($listout); +print_r(" = {"); +$first = true; +foreach ($listin as $val) { + if ($first) { + $first = false; + } else { + print_r(", "); + } + print_r($val); +} +print_r("}\n"); + +/** + * ENUM TEST + */ +print_r("testEnum(ONE)"); +$ret = $testClient->testEnum(Numberz::ONE); +print_r(" = $ret\n"); + +print_r("testEnum(TWO)"); +$ret = $testClient->testEnum(Numberz::TWO); +print_r(" = $ret\n"); + +print_r("testEnum(THREE)"); +$ret = $testClient->testEnum(Numberz::THREE); +print_r(" = $ret\n"); + +print_r("testEnum(FIVE)"); +$ret = $testClient->testEnum(Numberz::FIVE); +print_r(" = $ret\n"); + +print_r("testEnum(EIGHT)"); +$ret = $testClient->testEnum(Numberz::EIGHT); +print_r(" = $ret\n"); + +/** + * TYPEDEF TEST + */ +print_r("testTypedef(309858235082523)"); +$uid = $testClient->testTypedef(309858235082523); +print_r(" = $uid\n"); + +/** + * NESTED MAP TEST + */ +print_r("testMapMap(1)"); +$mm = $testClient->testMapMap(1); +print_r(" = {"); +foreach ($mm as $key => $val) { + print_r("$key => {"); + foreach ($val as $k2 => $v2) { + print_r("$k2 => $v2, "); + } + print_r("}, "); +} +print_r("}\n"); + +/** + * INSANITY TEST + */ +$insane = new Insanity(); +$insane->userMap[Numberz::FIVE] = 5000; +$truck = new Xtruct(); +$truck->string_thing = "Truck"; +$truck->byte_thing = 8; +$truck->i32_thing = 8; +$truck->i64_thing = 8; +$insane->xtructs []= $truck; +print_r("testInsanity()"); +$whoa = $testClient->testInsanity($insane); +print_r(" = {"); +foreach ($whoa as $key => $val) { + print_r("$key => {"); + foreach ($val as $k2 => $v2) { + print_r("$k2 => {"); + $userMap = $v2->userMap; + print_r("{"); + if (is_array($usermap)) { + foreach ($userMap as $k3 => $v3) { + print_r("$k3 => $v3, "); + } + } + print_r("}, "); + + $xtructs = $v2->xtructs; + print_r("{"); + if (is_array($xtructs)) { + foreach ($xtructs as $x) { + print_r("{\"".$x->string_thing."\", ". + $x->byte_thing.", ".$x->i32_thing.", ".$x->i64_thing."}, "); + } + } + print_r("}"); + + print_r("}, "); + } + print_r("}, "); +} +print_r("}\n"); + +/** + * EXCEPTION TEST + */ +print_r("testException('Xception')"); +try { + $testClient->testException('Xception'); + print_r(" void\nFAILURE\n"); +} catch (Xception $x) { + print_r(' caught xception '.$x->errorCode.': '.$x->message."\n"); +} + + +/** + * Normal tests done. + */ + +$stop = microtime(true); +$elp = round(1000*($stop - $start), 0); +print_r("Total time: $elp ms\n"); + +/** + * Extraneous "I don't trust PHP to pack/unpack integer" tests + */ + +// Max I32 +$num = pow(2, 30) + (pow(2, 30) - 1); +$num2 = $testClient->testI32($num); +if ($num != $num2) { + print "Missed $num = $num2\n"; +} + +// Min I32 +$num = 0 - pow(2, 31); +$num2 = $testClient->testI32($num); +if ($num != $num2) { + print "Missed $num = $num2\n"; +} + +// Max I64 +$num = pow(2, 62) + (pow(2, 62) - 1); +$num2 = $testClient->testI64($num); +if ($num != $num2) { + print "Missed $num = $num2\n"; +} + +// Min I64 +$num = 0 - pow(2, 63); +$num2 = $testClient->testI64($num); +if ($num != $num2) { + print "Missed $num = $num2\n"; +} + +$transport->close(); +return; + +?> diff --git a/test/php/TestInline.php b/test/php/TestInline.php new file mode 100644 index 00000000..7066c461 --- /dev/null +++ b/test/php/TestInline.php @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + diff --git a/test/py/Makefile.am b/test/py/Makefile.am new file mode 100644 index 00000000..63b7a890 --- /dev/null +++ b/test/py/Makefile.am @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFT = $(top_srcdir)/compiler/cpp/thrift + +py_unit_tests = \ + SerializationTest.py \ + TestEof.py \ + TestSyntax.py \ + RunClientServer.py + +thrift_gen = \ + gen-py/ThriftTest/__init__.py \ + gen-py/DebugProtoTest/__init__.py + +helper_scripts= \ + TestClient.py \ + TestServer.py + +check_SCRIPTS= \ + $(thrift_gen) \ + $(py_unit_tests) \ + $(helper_scripts) + +TESTS= $(py_unit_tests) + + +gen-py/%/__init__.py: ../%.thrift + $(THRIFT) --gen py $< + +clean-local: + $(RM) -r gen-py diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py new file mode 100755 index 00000000..2bd6094d --- /dev/null +++ b/test/py/RunClientServer.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import time +import subprocess +import sys +import os +import signal + +def relfile(fname): + return os.path.join(os.path.dirname(__file__), fname) + +FRAMED = ["TNonblockingServer"] + +def runTest(server_class): + print "Testing ", server_class + serverproc = subprocess.Popen([sys.executable, relfile("TestServer.py"), server_class]) + time.sleep(0.25) + try: + argv = [sys.executable, relfile("TestClient.py")] + if server_class in FRAMED: + argv.append('--framed') + if server_class == 'THttpServer': + argv.append('--http=/') + ret = subprocess.call(argv) + if ret != 0: + raise Exception("subprocess failed") + finally: + # fixme: should check that server didn't die + os.kill(serverproc.pid, signal.SIGKILL) + + # wait for shutdown + time.sleep(1) + +map(runTest, [ + "TSimpleServer", + "TThreadedServer", + "TThreadPoolServer", + "TForkingServer", + "TNonblockingServer", + "THttpServer", + ]) diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py new file mode 100755 index 00000000..52bedd5e --- /dev/null +++ b/test/py/SerializationTest.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.protocol import TBinaryProtocol +import unittest +import time + +class AbstractTest(unittest.TestCase): + + def setUp(self): + self.v1obj = VersioningTestV1( + begin_in_both=12345, + old_string='aaa', + end_in_both=54321, + ) + + self.v2obj = VersioningTestV2( + begin_in_both=12345, + newint=1, + newbyte=2, + newshort=3, + newlong=4, + newdouble=5.0, + newstruct=Bonk(message="Hello!", type=123), + newlist=[7,8,9], + newset=[42,1,8], + newmap={1:2,2:3}, + newstring="Hola!", + end_in_both=54321, + ) + + def _serialize(self, obj): + trans = TTransport.TMemoryBuffer() + prot = self.protocol_factory.getProtocol(trans) + obj.write(prot) + return trans.getvalue() + + def _deserialize(self, objtype, data): + prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) + ret = objtype() + ret.read(prot) + return ret + + def testForwards(self): + obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj)) + self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v1obj.end_in_both) + + def testBackwards(self): + obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj)) + self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both) + self.assertEquals(obj.end_in_both, self.v2obj.end_in_both) + + +class NormalBinaryTest(AbstractTest): + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + +class AcceleratedBinaryTest(AbstractTest): + protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() + + +class AcceleratedFramedTest(unittest.TestCase): + def testSplit(self): + """Test FramedTransport and BinaryProtocolAccelerated + + Tests that TBinaryProtocolAccelerated and TFramedTransport + play nicely together when a read spans a frame""" + + protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() + bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z")+1)) + + databuf = TTransport.TMemoryBuffer() + prot = protocol_factory.getProtocol(databuf) + prot.writeI32(42) + prot.writeString(bigstring) + prot.writeI16(24) + data = databuf.getvalue() + cutpoint = len(data)/2 + parts = [ data[:cutpoint], data[cutpoint:] ] + + framed_buffer = TTransport.TMemoryBuffer() + framed_writer = TTransport.TFramedTransport(framed_buffer) + for part in parts: + framed_writer.write(part) + framed_writer.flush() + self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8) + + # Recreate framed_buffer so we can read from it. + framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue()) + framed_reader = TTransport.TFramedTransport(framed_buffer) + prot = protocol_factory.getProtocol(framed_reader) + self.assertEqual(prot.readI32(), 42) + self.assertEqual(prot.readString(), bigstring) + self.assertEqual(prot.readI16(), 24) + + + +def suite(): + suite = unittest.TestSuite() + loader = unittest.TestLoader() + + suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) + suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) + suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest)) + return suite + +if __name__ == "__main__": + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestClient.py b/test/py/TestClient.py new file mode 100755 index 00000000..64e5e872 --- /dev/null +++ b/test/py/TestClient.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.transport import THttpClient +from thrift.protocol import TBinaryProtocol +import unittest +import time +from optparse import OptionParser + + +parser = OptionParser() +parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090) +parser.add_option("--port", type="int", dest="port", + help="connect to server at port") +parser.add_option("--host", type="string", dest="host", + help="connect to server") +parser.add_option("--framed", action="store_true", dest="framed", + help="use framed transport") +parser.add_option("--http", dest="http_path", + help="Use the HTTP transport with the specified path") +parser.add_option('-v', '--verbose', action="store_const", + dest="verbose", const=2, + help="verbose output") +parser.add_option('-q', '--quiet', action="store_const", + dest="verbose", const=0, + help="minimal output") + +options, args = parser.parse_args() + +class AbstractTest(unittest.TestCase): + def setUp(self): + if options.http_path: + self.transport = THttpClient.THttpClient( + options.host, options.port, options.http_path) + else: + socket = TSocket.TSocket(options.host, options.port) + + # frame or buffer depending upon args + if options.framed: + self.transport = TTransport.TFramedTransport(socket) + else: + self.transport = TTransport.TBufferedTransport(socket) + + self.transport.open() + + protocol = self.protocol_factory.getProtocol(self.transport) + self.client = ThriftTest.Client(protocol) + + def tearDown(self): + # Close! + self.transport.close() + + def testVoid(self): + self.client.testVoid() + + def testString(self): + self.assertEqual(self.client.testString('Python'), 'Python') + + def testByte(self): + self.assertEqual(self.client.testByte(63), 63) + + def testI32(self): + self.assertEqual(self.client.testI32(-1), -1) + self.assertEqual(self.client.testI32(0), 0) + + def testI64(self): + self.assertEqual(self.client.testI64(-34359738368), -34359738368) + + def testDouble(self): + self.assertEqual(self.client.testDouble(-5.235098235), -5.235098235) + + def testStruct(self): + x = Xtruct() + x.string_thing = "Zero" + x.byte_thing = 1 + x.i32_thing = -3 + x.i64_thing = -5 + y = self.client.testStruct(x) + + self.assertEqual(y.string_thing, "Zero") + self.assertEqual(y.byte_thing, 1) + self.assertEqual(y.i32_thing, -3) + self.assertEqual(y.i64_thing, -5) + + def testException(self): + self.client.testException('Safe') + try: + self.client.testException('Xception') + self.fail("should have gotten exception") + except Xception, x: + self.assertEqual(x.errorCode, 1001) + self.assertEqual(x.message, 'Xception') + + try: + self.client.testException("throw_undeclared") + self.fail("should have thrown exception") + except Exception: # type is undefined + pass + + def testOneway(self): + start = time.time() + self.client.testOneway(0.5) + end = time.time() + self.assertTrue(end - start < 0.2, + "oneway sleep took %f sec" % (end - start)) + +class NormalBinaryTest(AbstractTest): + protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() + +class AcceleratedBinaryTest(AbstractTest): + protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory() + +def suite(): + suite = unittest.TestSuite() + loader = unittest.TestLoader() + + suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest)) + suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest)) + return suite + +class OwnArgsTestProgram(unittest.TestProgram): + def parseArgs(self, argv): + if args: + self.testNames = args + else: + self.testNames = (self.defaultTest,) + self.createTests() + +if __name__ == "__main__": + OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestEof.py b/test/py/TestEof.py new file mode 100755 index 00000000..7d64289d --- /dev/null +++ b/test/py/TestEof.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.protocol import TBinaryProtocol +import unittest +import time + +class TestEof(unittest.TestCase): + + def setUp(self): + trans = TTransport.TMemoryBuffer() + prot = TBinaryProtocol.TBinaryProtocol(trans) + + x = Xtruct() + x.string_thing = "Zero" + x.byte_thing = 0 + + x.write(prot) + + x = Xtruct() + x.string_thing = "One" + x.byte_thing = 1 + + x.write(prot) + + self.data = trans.getvalue() + + def testTransportReadAll(self): + """Test that readAll on any type of transport throws an EOFError""" + trans = TTransport.TMemoryBuffer(self.data) + trans.readAll(1) + + try: + trans.readAll(10000) + except EOFError: + return + + self.fail("Should have gotten EOFError") + + def eofTestHelper(self, pfactory): + trans = TTransport.TMemoryBuffer(self.data) + prot = pfactory.getProtocol(trans) + + x = Xtruct() + x.read(prot) + self.assertEqual(x.string_thing, "Zero") + self.assertEqual(x.byte_thing, 0) + + x = Xtruct() + x.read(prot) + self.assertEqual(x.string_thing, "One") + self.assertEqual(x.byte_thing, 1) + + try: + x = Xtruct() + x.read(prot) + except EOFError: + return + + self.fail("Should have gotten EOFError") + + def eofTestHelperStress(self, pfactory): + """Teest the ability of TBinaryProtocol to deal with the removal of every byte in the file""" + # TODO: we should make sure this covers more of the code paths + + for i in xrange(0, len(self.data) + 1): + trans = TTransport.TMemoryBuffer(self.data[0:i]) + prot = pfactory.getProtocol(trans) + try: + x = Xtruct() + x.read(prot) + x.read(prot) + x.read(prot) + except EOFError: + continue + self.fail("Should have gotten an EOFError") + + def testBinaryProtocolEof(self): + """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream""" + self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory()) + self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory()) + + def testBinaryProtocolAcceleratedEof(self): + """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream""" + self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) + self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory()) + +def suite(): + suite = unittest.TestSuite() + loader = unittest.TestLoader() + suite.addTest(loader.loadTestsFromTestCase(TestEof)) + return suite + +if __name__ == "__main__": + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/test/py/TestServer.py b/test/py/TestServer.py new file mode 100755 index 00000000..3d379eae --- /dev/null +++ b/test/py/TestServer.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob, time +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.protocol import TBinaryProtocol +from thrift.server import TServer, TNonblockingServer, THttpServer + +class TestHandler: + + def testVoid(self): + print 'testVoid()' + + def testString(self, str): + print 'testString(%s)' % str + return str + + def testByte(self, byte): + print 'testByte(%d)' % byte + return byte + + def testI16(self, i16): + print 'testI16(%d)' % i16 + return i16 + + def testI32(self, i32): + print 'testI32(%d)' % i32 + return i32 + + def testI64(self, i64): + print 'testI64(%d)' % i64 + return i64 + + def testDouble(self, dub): + print 'testDouble(%f)' % dub + return dub + + def testStruct(self, thing): + print 'testStruct({%s, %d, %d, %d})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing) + return thing + + def testException(self, str): + print 'testException(%s)' % str + if str == 'Xception': + x = Xception() + x.errorCode = 1001 + x.message = str + raise x + elif str == "throw_undeclared": + raise ValueError("foo") + + def testOneway(self, seconds): + print 'testOneway(%d) => sleeping...' % seconds + time.sleep(seconds) + print 'done sleeping' + + def testNest(self, thing): + return thing + + def testMap(self, thing): + return thing + + def testSet(self, thing): + return thing + + def testList(self, thing): + return thing + + def testEnum(self, thing): + return thing + + def testTypedef(self, thing): + return thing + +pfactory = TBinaryProtocol.TBinaryProtocolFactory() +handler = TestHandler() +processor = ThriftTest.Processor(handler) + +if sys.argv[1] == "THttpServer": + server = THttpServer.THttpServer(processor, ('', 9090), pfactory) +else: + transport = TSocket.TServerSocket(9090) + tfactory = TTransport.TBufferedTransportFactory() + + if sys.argv[1] == "TNonblockingServer": + server = TNonblockingServer.TNonblockingServer(processor, transport) + else: + ServerClass = getattr(TServer, sys.argv[1]) + server = ServerClass(processor, transport, tfactory, pfactory) + +server.serve() diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py new file mode 100755 index 00000000..2f7353fb --- /dev/null +++ b/test/py/TestSocket.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest import ThriftTest +from ThriftTest.ttypes import * +from thrift.transport import TTransport +from thrift.transport import TSocket +from thrift.protocol import TBinaryProtocol +import unittest +import time +import socket +import random +from optparse import OptionParser + +class TimeoutTest(unittest.TestCase): + def setUp(self): + for i in xrange(50): + try: + # find a port we can use + self.listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.port = random.randint(10000, 30000) + self.listen_sock.bind(('localhost', self.port)) + self.listen_sock.listen(5) + break + except: + if i == 49: + raise + + def testConnectTimeout(self): + starttime = time.time() + + try: + leaky = [] + for i in xrange(100): + socket = TSocket.TSocket('localhost', self.port) + socket.setTimeout(10) + socket.open() + leaky.append(socket) + except: + self.assert_(time.time() - starttime < 5.0) + + def testWriteTimeout(self): + starttime = time.time() + + try: + socket = TSocket.TSocket('localhost', self.port) + socket.setTimeout(10) + socket.open() + lsock = self.listen_sock.accept() + while True: + socket.write("hi" * 100) + + except: + self.assert_(time.time() - starttime < 5.0) + +suite = unittest.TestSuite() +loader = unittest.TestLoader() + +suite.addTest(loader.loadTestsFromTestCase(TimeoutTest)) + +testRunner = unittest.TextTestRunner(verbosity=2) +testRunner.run(suite) diff --git a/test/py/TestSyntax.py b/test/py/TestSyntax.py new file mode 100755 index 00000000..df67d485 --- /dev/null +++ b/test/py/TestSyntax.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys, glob +sys.path.insert(0, './gen-py') +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +# Just import these generated files to make sure they are syntactically valid +from DebugProtoTest import EmptyService +from DebugProtoTest import Inherited diff --git a/test/py/explicit_module/runtest.sh b/test/py/explicit_module/runtest.sh new file mode 100755 index 00000000..2e5a4f1b --- /dev/null +++ b/test/py/explicit_module/runtest.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +rm -rf gen-py +../../../compiler/cpp/thrift --gen py test1.thrift || exit 1 +../../../compiler/cpp/thrift --gen py test2.thrift || exit 1 +PYTHONPATH=./gen-py python -c 'import foo.bar.baz' || exit 1 +PYTHONPATH=./gen-py python -c 'import test2' || exit 1 +PYTHONPATH=./gen-py python -c 'import test1' &>/dev/null && exit 1 # Should fail. +cp -r gen-py simple +../../../compiler/cpp/thrift -r --gen py test2.thrift || exit 1 +PYTHONPATH=./gen-py python -c 'import test2' || exit 1 +diff -ur simple gen-py > thediffs +file thediffs | grep -s -q empty || exit 1 +rm -rf simple thediffs +echo 'All tests pass!' diff --git a/test/py/explicit_module/test1.thrift b/test/py/explicit_module/test1.thrift new file mode 100644 index 00000000..ec600d7d --- /dev/null +++ b/test/py/explicit_module/test1.thrift @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace py foo.bar.baz + +struct astruct { + 1: i32 how_unoriginal; +} diff --git a/test/py/explicit_module/test2.thrift b/test/py/explicit_module/test2.thrift new file mode 100644 index 00000000..68f9da4d --- /dev/null +++ b/test/py/explicit_module/test2.thrift @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +include "test1.thrift" + +struct another { + 1: test1.astruct something; +} diff --git a/test/rb/Makefile.am b/test/rb/Makefile.am new file mode 100644 index 00000000..a6f431c8 --- /dev/null +++ b/test/rb/Makefile.am @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +THRIFT = $(top_srcdir)/compiler/cpp/thrift + +stubs: ../ThriftTest.thrift ../SmallTest.thrift + $(THRIFT) --gen rb ../ThriftTest.thrift + $(THRIFT) --gen rb ../SmallTest.thrift + +check: stubs + $(RUBY) test_suite.rb + diff --git a/test/rb/benchmarks/protocol_benchmark.rb b/test/rb/benchmarks/protocol_benchmark.rb new file mode 100644 index 00000000..05a8ee53 --- /dev/null +++ b/test/rb/benchmarks/protocol_benchmark.rb @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$LOAD_PATH.unshift File.join(File.dirname(__FILE__), *%w[.. .. .. lib rb lib]) +$LOAD_PATH.unshift File.join(File.dirname(__FILE__), *%w[.. .. .. lib rb ext]) + +require 'thrift' + +require 'benchmark' +require 'rubygems' +require 'set' +require 'pp' + +# require 'ruby-debug' +# require 'ruby-prof' + +require File.join(File.dirname(__FILE__), '../fixtures/structs') + +transport1 = Thrift::MemoryBuffer.new +ruby_binary_protocol = Thrift::BinaryProtocol.new(transport1) + +transport2 = Thrift::MemoryBuffer.new +c_fast_binary_protocol = Thrift::BinaryProtocolAccelerated.new(transport2) + + +ooe = Fixtures::Structs::OneOfEach.new +ooe.im_true = true +ooe.im_false = false +ooe.a_bite = -42 +ooe.integer16 = 27000 +ooe.integer32 = 1<<24 +ooe.integer64 = 6000 * 1000 * 1000 +ooe.double_precision = Math::PI +ooe.some_characters = "Debug THIS!" +ooe.zomg_unicode = "\xd7\n\a\t" + +n1 = Fixtures::Structs::Nested1.new +n1.a_list = [] +n1.a_list << ooe << ooe << ooe << ooe +n1.i32_map = {} +n1.i32_map[1234] = ooe +n1.i32_map[46345] = ooe +n1.i32_map[-34264] = ooe +n1.i64_map = {} +n1.i64_map[43534986783945] = ooe +n1.i64_map[-32434639875122] = ooe +n1.dbl_map = {} +n1.dbl_map[324.65469834] = ooe +n1.dbl_map[-9458672340.4986798345112] = ooe +n1.str_map = {} +n1.str_map['sdoperuix'] = ooe +n1.str_map['pwoerxclmn'] = ooe + +n2 = Fixtures::Structs::Nested2.new +n2.a_list = [] +n2.a_list << n1 << n1 << n1 << n1 << n1 +n2.i32_map = {} +n2.i32_map[398345] = n1 +n2.i32_map[-2345] = n1 +n2.i32_map[12312] = n1 +n2.i64_map = {} +n2.i64_map[2349843765934] = n1 +n2.i64_map[-123234985495] = n1 +n2.i64_map[0] = n1 +n2.dbl_map = {} +n2.dbl_map[23345345.38927834] = n1 +n2.dbl_map[-1232349.5489345] = n1 +n2.dbl_map[-234984574.23498725] = n1 +n2.str_map = {} +n2.str_map[''] = n1 +n2.str_map['sdflkertpioux'] = n1 +n2.str_map['sdfwepwdcjpoi'] = n1 + +n3 = Fixtures::Structs::Nested3.new +n3.a_list = [] +n3.a_list << n2 << n2 << n2 << n2 << n2 +n3.i32_map = {} +n3.i32_map[398345] = n2 +n3.i32_map[-2345] = n2 +n3.i32_map[12312] = n2 +n3.i64_map = {} +n3.i64_map[2349843765934] = n2 +n3.i64_map[-123234985495] = n2 +n3.i64_map[0] = n2 +n3.dbl_map = {} +n3.dbl_map[23345345.38927834] = n2 +n3.dbl_map[-1232349.5489345] = n2 +n3.dbl_map[-234984574.23498725] = n2 +n3.str_map = {} +n3.str_map[''] = n2 +n3.str_map['sdflkertpioux'] = n2 +n3.str_map['sdfwepwdcjpoi'] = n2 + +n4 = Fixtures::Structs::Nested4.new +n4.a_list = [] +n4.a_list << n3 +n4.i32_map = {} +n4.i32_map[-2345] = n3 +n4.i64_map = {} +n4.i64_map[2349843765934] = n3 +n4.dbl_map = {} +n4.dbl_map[-1232349.5489345] = n3 +n4.str_map = {} +n4.str_map[''] = n3 + + +# prof = RubyProf.profile do +# n4.write(c_fast_binary_protocol) +# Fixtures::Structs::Nested4.new.read(c_fast_binary_protocol) +# end +# +# printer = RubyProf::GraphHtmlPrinter.new(prof) +# printer.print(STDOUT, :min_percent=>0) + +Benchmark.bmbm do |x| + x.report("ruby write large (1MB) structure once") do + n4.write(ruby_binary_protocol) + end + + x.report("ruby read large (1MB) structure once") do + Fixtures::Structs::Nested4.new.read(ruby_binary_protocol) + end + + x.report("c write large (1MB) structure once") do + n4.write(c_fast_binary_protocol) + end + + x.report("c read large (1MB) structure once") do + Fixtures::Structs::Nested4.new.read(c_fast_binary_protocol) + end + + + + x.report("ruby write 10_000 small structures") do + 10_000.times do + ooe.write(ruby_binary_protocol) + end + end + + x.report("ruby read 10_000 small structures") do + 10_000.times do + Fixtures::Structs::OneOfEach.new.read(ruby_binary_protocol) + end + end + + x.report("c write 10_000 small structures") do + 10_000.times do + ooe.write(c_fast_binary_protocol) + end + end + + x.report("c read 10_000 small structures") do + 10_000.times do + Fixtures::Structs::OneOfEach.new.read(c_fast_binary_protocol) + end + end + +end diff --git a/test/rb/core/test_backwards_compatability.rb b/test/rb/core/test_backwards_compatability.rb new file mode 100644 index 00000000..0577515d --- /dev/null +++ b/test/rb/core/test_backwards_compatability.rb @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' + +class TestThriftException < Test::Unit::TestCase + def test_has_accessible_message + msg = "hi there thrift" + assert_equal msg, Thrift::Exception.new(msg).message + end +end + diff --git a/test/rb/core/test_exceptions.rb b/test/rb/core/test_exceptions.rb new file mode 100644 index 00000000..f41587a7 --- /dev/null +++ b/test/rb/core/test_exceptions.rb @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' + +class TestException < Test::Unit::TestCase + def test_has_accessible_message + msg = "hi there thrift" + assert_equal msg, Thrift::Exception.new(msg).message + end +end + diff --git a/test/rb/core/transport/test_transport.rb b/test/rb/core/transport/test_transport.rb new file mode 100644 index 00000000..52755c1d --- /dev/null +++ b/test/rb/core/transport/test_transport.rb @@ -0,0 +1,70 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../../test_helper') + +require 'thrift' + +class DummyTransport < Thrift::BaseTransport + def initialize(data) + @data = data + end + + def read(size) + @data.slice!(0, size) + end +end + +# TTransport is basically an abstract class, but isn't raising NotImplementedError +class TestThriftTransport < Test::Unit::TestCase + def setup + @trans = Thrift::BaseTransport.new + end + + def test_open? + assert_nil @trans.open? + end + + def test_open + assert_nil @trans.open + end + + def test_close + assert_nil @trans.close + end + + # TODO: + # This doesn't necessarily test he right thing. + # It _looks_ like read isn't guarenteed to return the length + # you ask for and read_all is. This means our test needs to check + # for blocking. -- Kevin Clark 3/27/08 + def test_read_all + # Implements read + t = DummyTransport.new("hello") + assert_equal "hello", t.read_all(5) + end + + def test_write + assert_nil @trans.write(5) # arbitrary value + end + + def test_flush + assert_nil @trans.flush + end +end diff --git a/test/rb/fixtures/structs.rb b/test/rb/fixtures/structs.rb new file mode 100644 index 00000000..ebbeb0a7 --- /dev/null +++ b/test/rb/fixtures/structs.rb @@ -0,0 +1,298 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require 'thrift' + +module Fixtures + module Structs + class OneBool + include Thrift::Struct + attr_accessor :bool + FIELDS = { + 1 => {:type => Thrift::Types::BOOL, :name => 'bool'} + } + + def validate + end + end + + class OneByte + include Thrift::Struct + attr_accessor :byte + FIELDS = { + 1 => {:type => Thrift::Types::BYTE, :name => 'byte'} + } + + def validate + end + end + + class OneI16 + include Thrift::Struct + attr_accessor :i16 + FIELDS = { + 1 => {:type => Thrift::Types::I16, :name => 'i16'} + } + + def validate + end + end + + class OneI32 + include Thrift::Struct + attr_accessor :i32 + FIELDS = { + 1 => {:type => Thrift::Types::I32, :name => 'i32'} + } + + def validate + end + end + + class OneI64 + include Thrift::Struct + attr_accessor :i64 + FIELDS = { + 1 => {:type => Thrift::Types::I64, :name => 'i64'} + } + + def validate + end + end + + class OneDouble + include Thrift::Struct + attr_accessor :double + FIELDS = { + 1 => {:type => Thrift::Types::DOUBLE, :name => 'double'} + } + + def validate + end + end + + class OneString + include Thrift::Struct + attr_accessor :string + FIELDS = { + 1 => {:type => Thrift::Types::STRING, :name => 'string'} + } + + def validate + end + end + + class OneMap + include Thrift::Struct + attr_accessor :map + FIELDS = { + 1 => {:type => Thrift::Types::MAP, :name => 'map', :key => {:type => Thrift::Types::STRING}, :value => {:type => Thrift::Types::STRING}} + } + + def validate + end + end + + class NestedMap + include Thrift::Struct + attr_accessor :map + FIELDS = { + 0 => {:type => Thrift::Types::MAP, :name => 'map', :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::MAP, :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::I32}}} + } + + def validate + end + end + + class OneList + include Thrift::Struct + attr_accessor :list + FIELDS = { + 1 => {:type => Thrift::Types::LIST, :name => 'list', :element => {:type => Thrift::Types::STRING}} + } + + def validate + end + end + + class NestedList + include Thrift::Struct + attr_accessor :list + FIELDS = { + 0 => {:type => Thrift::Types::LIST, :name => 'list', :element => {:type => Thrift::Types::LIST, :element => { :type => Thrift::Types::I32 } } } + } + + def validate + end + end + + class OneSet + include Thrift::Struct + attr_accessor :set + FIELDS = { + 1 => {:type => Thrift::Types::SET, :name => 'set', :element => {:type => Thrift::Types::STRING}} + } + + def validate + end + end + + class NestedSet + include Thrift::Struct + attr_accessor :set + FIELDS = { + 1 => {:type => Thrift::Types::SET, :name => 'set', :element => {:type => Thrift::Types::SET, :element => { :type => Thrift::Types::STRING } }} + } + + def validate + end + end + + # struct OneOfEach { + # 1: bool im_true, + # 2: bool im_false, + # 3: byte a_bite, + # 4: i16 integer16, + # 5: i32 integer32, + # 6: i64 integer64, + # 7: double double_precision, + # 8: string some_characters, + # 9: string zomg_unicode, + # 10: bool what_who, + # 11: binary base64, + # } + class OneOfEach + include Thrift::Struct + attr_accessor :im_true, :im_false, :a_bite, :integer16, :integer32, :integer64, :double_precision, :some_characters, :zomg_unicode, :what_who, :base64 + FIELDS = { + 1 => {:type => Thrift::Types::BOOL, :name => 'im_true'}, + 2 => {:type => Thrift::Types::BOOL, :name => 'im_false'}, + 3 => {:type => Thrift::Types::BYTE, :name => 'a_bite'}, + 4 => {:type => Thrift::Types::I16, :name => 'integer16'}, + 5 => {:type => Thrift::Types::I32, :name => 'integer32'}, + 6 => {:type => Thrift::Types::I64, :name => 'integer64'}, + 7 => {:type => Thrift::Types::DOUBLE, :name => 'double_precision'}, + 8 => {:type => Thrift::Types::STRING, :name => 'some_characters'}, + 9 => {:type => Thrift::Types::STRING, :name => 'zomg_unicode'}, + 10 => {:type => Thrift::Types::BOOL, :name => 'what_who'}, + 11 => {:type => Thrift::Types::STRING, :name => 'base64'} + } + + # Added for assert_equal + def ==(other) + [:im_true, :im_false, :a_bite, :integer16, :integer32, :integer64, :double_precision, :some_characters, :zomg_unicode, :what_who, :base64].each do |f| + var = "@#{f}" + return false if instance_variable_get(var) != other.instance_variable_get(var) + end + true + end + + def validate + end + end + + # struct Nested1 { + # 1: list a_list + # 2: map i32_map + # 3: map i64_map + # 4: map dbl_map + # 5: map str_map + # } + class Nested1 + include Thrift::Struct + attr_accessor :a_list, :i32_map, :i64_map, :dbl_map, :str_map + FIELDS = { + 1 => {:type => Thrift::Types::LIST, :name => 'a_list', :element => {:type => Thrift::Types::STRUCT, :class => OneOfEach}}, + 2 => {:type => Thrift::Types::MAP, :name => 'i32_map', :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::STRUCT, :class => OneOfEach}}, + 3 => {:type => Thrift::Types::MAP, :name => 'i64_map', :key => {:type => Thrift::Types::I64}, :value => {:type => Thrift::Types::STRUCT, :class => OneOfEach}}, + 4 => {:type => Thrift::Types::MAP, :name => 'dbl_map', :key => {:type => Thrift::Types::DOUBLE}, :value => {:type => Thrift::Types::STRUCT, :class => OneOfEach}}, + 5 => {:type => Thrift::Types::MAP, :name => 'str_map', :key => {:type => Thrift::Types::STRING}, :value => {:type => Thrift::Types::STRUCT, :class => OneOfEach}} + } + + def validate + end + end + + # struct Nested2 { + # 1: list a_list + # 2: map i32_map + # 3: map i64_map + # 4: map dbl_map + # 5: map str_map + # } + class Nested2 + include Thrift::Struct + attr_accessor :a_list, :i32_map, :i64_map, :dbl_map, :str_map + FIELDS = { + 1 => {:type => Thrift::Types::LIST, :name => 'a_list', :element => {:type => Thrift::Types::STRUCT, :class => Nested1}}, + 2 => {:type => Thrift::Types::MAP, :name => 'i32_map', :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::STRUCT, :class => Nested1}}, + 3 => {:type => Thrift::Types::MAP, :name => 'i64_map', :key => {:type => Thrift::Types::I64}, :value => {:type => Thrift::Types::STRUCT, :class => Nested1}}, + 4 => {:type => Thrift::Types::MAP, :name => 'dbl_map', :key => {:type => Thrift::Types::DOUBLE}, :value => {:type => Thrift::Types::STRUCT, :class => Nested1}}, + 5 => {:type => Thrift::Types::MAP, :name => 'str_map', :key => {:type => Thrift::Types::STRING}, :value => {:type => Thrift::Types::STRUCT, :class => Nested1}} + } + + def validate + end + end + + # struct Nested3 { + # 1: list a_list + # 2: map i32_map + # 3: map i64_map + # 4: map dbl_map + # 5: map str_map + # } + class Nested3 + include Thrift::Struct + attr_accessor :a_list, :i32_map, :i64_map, :dbl_map, :str_map + FIELDS = { + 1 => {:type => Thrift::Types::LIST, :name => 'a_list', :element => {:type => Thrift::Types::STRUCT, :class => Nested2}}, + 2 => {:type => Thrift::Types::MAP, :name => 'i32_map', :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::STRUCT, :class => Nested2}}, + 3 => {:type => Thrift::Types::MAP, :name => 'i64_map', :key => {:type => Thrift::Types::I64}, :value => {:type => Thrift::Types::STRUCT, :class => Nested2}}, + 4 => {:type => Thrift::Types::MAP, :name => 'dbl_map', :key => {:type => Thrift::Types::DOUBLE}, :value => {:type => Thrift::Types::STRUCT, :class => Nested2}}, + 5 => {:type => Thrift::Types::MAP, :name => 'str_map', :key => {:type => Thrift::Types::STRING}, :value => {:type => Thrift::Types::STRUCT, :class => Nested2}} + } + + def validate + end + end + + # struct Nested4 { + # 1: list a_list + # 2: map i32_map + # 3: map i64_map + # 4: map dbl_map + # 5: map str_map + # } + class Nested4 + include Thrift::Struct + attr_accessor :a_list, :i32_map, :i64_map, :dbl_map, :str_map + FIELDS = { + 1 => {:type => Thrift::Types::LIST, :name => 'a_list', :element => {:type => Thrift::Types::STRUCT, :class => Nested3}}, + 2 => {:type => Thrift::Types::MAP, :name => 'i32_map', :key => {:type => Thrift::Types::I32}, :value => {:type => Thrift::Types::STRUCT, :class => Nested3}}, + 3 => {:type => Thrift::Types::MAP, :name => 'i64_map', :key => {:type => Thrift::Types::I64}, :value => {:type => Thrift::Types::STRUCT, :class => Nested3}}, + 4 => {:type => Thrift::Types::MAP, :name => 'dbl_map', :key => {:type => Thrift::Types::DOUBLE}, :value => {:type => Thrift::Types::STRUCT, :class => Nested3}}, + 5 => {:type => Thrift::Types::MAP, :name => 'str_map', :key => {:type => Thrift::Types::STRING}, :value => {:type => Thrift::Types::STRUCT, :class => Nested3}} + } + + def validate + end + end + end +end diff --git a/test/rb/generation/test_enum.rb b/test/rb/generation/test_enum.rb new file mode 100644 index 00000000..7d3f08ba --- /dev/null +++ b/test/rb/generation/test_enum.rb @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') +require 'thrift_test' + +class TestEnumGeneration < Test::Unit::TestCase + include Thrift::Test + def test_enum_valid_values + assert_equal(Numberz::VALID_VALUES, Set.new([Numberz::ONE, Numberz::TWO, Numberz::THREE, Numberz::FIVE, Numberz::SIX, Numberz::EIGHT])) + end +end \ No newline at end of file diff --git a/test/rb/generation/test_struct.rb b/test/rb/generation/test_struct.rb new file mode 100644 index 00000000..3bd4fc9b --- /dev/null +++ b/test/rb/generation/test_struct.rb @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') +require 'small_service' + +class TestStructGeneration < Test::Unit::TestCase + + def test_default_values + hello = TestNamespace::Hello.new + + assert_kind_of(TestNamespace::Hello, hello) + assert_nil(hello.complexer) + + assert_equal(hello.simple, 53) + assert_equal(hello.words, 'words') + + assert_kind_of(TestNamespace::Goodbyez, hello.thinz) + assert_equal(hello.thinz.val, 36632) + + assert_kind_of(Hash, hello.complex) + assert_equal(hello.complex, { 6243 => 632, 2355 => 532, 23 => 532}) + + bool_passer = TestNamespace::BoolPasser.new(:value => false) + assert_equal false, bool_passer.value + end + + def test_goodbyez + assert_equal(TestNamespace::Goodbyez.new.val, 325) + end + +end diff --git a/test/rb/integration/accelerated_buffered_client.rb b/test/rb/integration/accelerated_buffered_client.rb new file mode 100644 index 00000000..7cec1df5 --- /dev/null +++ b/test/rb/integration/accelerated_buffered_client.rb @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' +require 'ThriftTest' + +class AcceleratedBufferedClientTest < Test::Unit::TestCase + def setup + unless @socket + @socket = Thrift::Socket.new('localhost', 9090) + @protocol = Thrift::BinaryProtocolAccelerated.new(Thrift::BufferedTransport.new(@socket)) + @client = Thrift::Test::ThriftTest::Client.new(@protocol) + @socket.open + end + end + + def test_string + assert_equal(@client.testString('string'), 'string') + end + + def test_byte + val = 8 + assert_equal(@client.testByte(val), val) + assert_equal(@client.testByte(-val), -val) + end + + def test_i32 + val = 32 + assert_equal(@client.testI32(val), val) + assert_equal(@client.testI32(-val), -val) + end + + def test_i64 + val = 64 + assert_equal(@client.testI64(val), val) + assert_equal(@client.testI64(-val), -val) + end + + def test_double + val = 3.14 + assert_equal(@client.testDouble(val), val) + assert_equal(@client.testDouble(-val), -val) + assert_kind_of(Float, @client.testDouble(val)) + end + + def test_map + val = {1 => 1, 2 => 2, 3 => 3} + assert_equal(@client.testMap(val), val) + assert_kind_of(Hash, @client.testMap(val)) + end + + def test_list + val = [1,2,3,4,5] + assert_equal(@client.testList(val), val) + assert_kind_of(Array, @client.testList(val)) + end + + def test_enum + val = Thrift::Test::Numberz::SIX + ret = @client.testEnum(val) + + assert_equal(ret, 6) + assert_kind_of(Fixnum, ret) + end + + def test_typedef + #UserId testTypedef(1: UserId thing), + true + end + + def test_set + val = Set.new([1,2,3]) + assert_equal(@client.testSet(val), val) + assert_kind_of(Set, @client.testSet(val)) + end + + def get_struct + Thrift::Test::Xtruct.new({'string_thing' => 'hi!', 'i32_thing' => 4 }) + end + + def test_struct + ret = @client.testStruct(get_struct) + + assert_nil(ret.byte_thing, nil) + assert_nil(ret.i64_thing, nil) + assert_equal(ret.string_thing, 'hi!') + assert_equal(ret.i32_thing, 4) + assert_kind_of(Thrift::Test::Xtruct, ret) + end + + def test_nest + struct2 = Thrift::Test::Xtruct2.new({'struct_thing' => get_struct, 'i32_thing' => 10}) + + ret = @client.testNest(struct2) + + assert_nil(ret.struct_thing.byte_thing, nil) + assert_nil(ret.struct_thing.i64_thing, nil) + assert_equal(ret.struct_thing.string_thing, 'hi!') + assert_equal(ret.struct_thing.i32_thing, 4) + assert_equal(ret.i32_thing, 10) + + assert_kind_of(Thrift::Test::Xtruct, ret.struct_thing) + assert_kind_of(Thrift::Test::Xtruct2, ret) + end + + def test_insane + insane = Thrift::Test::Insanity.new({ + 'userMap' => { Thrift::Test::Numberz::ONE => 44 }, + 'xtructs' => [get_struct, + Thrift::Test::Xtruct.new({ + 'string_thing' => 'hi again', + 'i32_thing' => 12 + }) + ] + }) + + ret = @client.testInsanity(insane) + + assert_not_nil(ret[44]) + assert_not_nil(ret[44][1]) + + struct = ret[44][1] + + assert_equal(struct.userMap[Thrift::Test::Numberz::ONE], 44) + assert_equal(struct.xtructs[1].string_thing, 'hi again') + assert_equal(struct.xtructs[1].i32_thing, 12) + + assert_kind_of(Hash, struct.userMap) + assert_kind_of(Array, struct.xtructs) + assert_kind_of(Thrift::Test::Insanity, struct) + end + + def test_map_map + ret = @client.testMapMap(4) + assert_kind_of(Hash, ret) + assert_equal(ret, { 4 => { 4 => 4}}) + end + + def test_exception + assert_raise Thrift::Test::Xception do + @client.testException('foo') + end + end +end + diff --git a/test/rb/integration/accelerated_buffered_server.rb b/test/rb/integration/accelerated_buffered_server.rb new file mode 100644 index 00000000..1ca66e54 --- /dev/null +++ b/test/rb/integration/accelerated_buffered_server.rb @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.push File.dirname(__FILE__) + '/../gen-rb' +$:.push File.join(File.dirname(__FILE__), '../../../lib/rb/lib') +$:.push File.join(File.dirname(__FILE__), '../../../lib/rb/ext') + +require 'thrift' +require 'ThriftTest' + +class SimpleHandler + [:testString, :testByte, :testI32, :testI64, :testDouble, + :testStruct, :testMap, :testSet, :testList, :testNest, + :testEnum, :testTypedef].each do |meth| + + define_method(meth) do |thing| + thing + end + + end + + def testInsanity(thing) + num, uid = thing.userMap.find { true } + return {uid => {num => thing}} + end + + def testMapMap(thing) + return {thing => {thing => thing}} + end + + def testEnum(thing) + return thing + end + + def testTypedef(thing) + return thing + end + + def testException(thing) + raise Thrift::Test::Xception, :message => 'error' + end +end + +@handler = SimpleHandler.new +@processor = Thrift::Test::ThriftTest::Processor.new(@handler) +@transport = Thrift::ServerSocket.new(9090) +@server = Thrift::ThreadedServer.new(@processor, @transport, Thrift::BufferedTransportFactory.new, Thrift::BinaryProtocolAcceleratedFactory.new) + +@server.serve diff --git a/test/rb/integration/buffered_client.rb b/test/rb/integration/buffered_client.rb new file mode 100644 index 00000000..1a925ccf --- /dev/null +++ b/test/rb/integration/buffered_client.rb @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' +require 'ThriftTest' + +class BufferedClientTest < Test::Unit::TestCase + def setup + unless @socket + @socket = Thrift::Socket.new('localhost', 9090) + @protocol = Thrift::BinaryProtocol.new(Thrift::BufferedTransport.new(@socket)) + @client = Thrift::Test::ThriftTest::Client.new(@protocol) + @socket.open + end + end + + def test_string + assert_equal(@client.testString('string'), 'string') + end + + def test_byte + val = 8 + assert_equal(@client.testByte(val), val) + assert_equal(@client.testByte(-val), -val) + end + + def test_i32 + val = 32 + assert_equal(@client.testI32(val), val) + assert_equal(@client.testI32(-val), -val) + end + + def test_i64 + val = 64 + assert_equal(@client.testI64(val), val) + assert_equal(@client.testI64(-val), -val) + end + + def test_double + val = 3.14 + assert_equal(@client.testDouble(val), val) + assert_equal(@client.testDouble(-val), -val) + assert_kind_of(Float, @client.testDouble(val)) + end + + def test_map + val = {1 => 1, 2 => 2, 3 => 3} + assert_equal(@client.testMap(val), val) + assert_kind_of(Hash, @client.testMap(val)) + end + + def test_list + val = [1,2,3,4,5] + assert_equal(@client.testList(val), val) + assert_kind_of(Array, @client.testList(val)) + end + + def test_enum + val = Thrift::Test::Numberz::SIX + ret = @client.testEnum(val) + + assert_equal(ret, 6) + assert_kind_of(Fixnum, ret) + end + + def test_typedef + #UserId testTypedef(1: UserId thing), + true + end + + def test_set + val = Set.new([1,2,3]) + assert_equal(@client.testSet(val), val) + assert_kind_of(Set, @client.testSet(val)) + end + + def get_struct + Thrift::Test::Xtruct.new({'string_thing' => 'hi!', 'i32_thing' => 4 }) + end + + def test_struct + ret = @client.testStruct(get_struct) + + assert_nil(ret.byte_thing, nil) + assert_nil(ret.i64_thing, nil) + assert_equal(ret.string_thing, 'hi!') + assert_equal(ret.i32_thing, 4) + assert_kind_of(Thrift::Test::Xtruct, ret) + end + + def test_nest + struct2 = Thrift::Test::Xtruct2.new({'struct_thing' => get_struct, 'i32_thing' => 10}) + + ret = @client.testNest(struct2) + + assert_nil(ret.struct_thing.byte_thing, nil) + assert_nil(ret.struct_thing.i64_thing, nil) + assert_equal(ret.struct_thing.string_thing, 'hi!') + assert_equal(ret.struct_thing.i32_thing, 4) + assert_equal(ret.i32_thing, 10) + + assert_kind_of(Thrift::Test::Xtruct, ret.struct_thing) + assert_kind_of(Thrift::Test::Xtruct2, ret) + end + + def test_insane + insane = Thrift::Test::Insanity.new({ + 'userMap' => { Thrift::Test::Numberz::ONE => 44 }, + 'xtructs' => [get_struct, + Thrift::Test::Xtruct.new({ + 'string_thing' => 'hi again', + 'i32_thing' => 12 + }) + ] + }) + + ret = @client.testInsanity(insane) + + assert_not_nil(ret[44]) + assert_not_nil(ret[44][1]) + + struct = ret[44][1] + + assert_equal(struct.userMap[Thrift::Test::Numberz::ONE], 44) + assert_equal(struct.xtructs[1].string_thing, 'hi again') + assert_equal(struct.xtructs[1].i32_thing, 12) + + assert_kind_of(Hash, struct.userMap) + assert_kind_of(Array, struct.xtructs) + assert_kind_of(Thrift::Test::Insanity, struct) + end + + def test_map_map + ret = @client.testMapMap(4) + assert_kind_of(Hash, ret) + assert_equal(ret, { 4 => { 4 => 4}}) + end + + def test_exception + assert_raise Thrift::Test::Xception do + @client.testException('foo') + end + end +end + diff --git a/test/rb/integration/simple_client.rb b/test/rb/integration/simple_client.rb new file mode 100644 index 00000000..1064822a --- /dev/null +++ b/test/rb/integration/simple_client.rb @@ -0,0 +1,163 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' +require 'ThriftTest' + +class SimpleClientTest < Test::Unit::TestCase + def setup + unless @socket + @socket = Thrift::Socket.new('localhost', 9090) + @protocol = Thrift::BinaryProtocol.new(@socket) + @client = Thrift::Test::ThriftTest::Client.new(@protocol) + @socket.open + end + end + + def test_string + assert_equal(@client.testString('string'), 'string') + end + + def test_byte + val = 8 + assert_equal(@client.testByte(val), val) + assert_equal(@client.testByte(-val), -val) + end + + def test_i32 + val = 32 + assert_equal(@client.testI32(val), val) + assert_equal(@client.testI32(-val), -val) + end + + def test_i64 + val = 64 + assert_equal(@client.testI64(val), val) + assert_equal(@client.testI64(-val), -val) + end + + def test_double + val = 3.14 + assert_equal(@client.testDouble(val), val) + assert_equal(@client.testDouble(-val), -val) + assert_kind_of(Float, @client.testDouble(val)) + end + + def test_map + val = {1 => 1, 2 => 2, 3 => 3} + assert_equal(@client.testMap(val), val) + assert_kind_of(Hash, @client.testMap(val)) + end + + def test_list + val = [1,2,3,4,5] + assert_equal(@client.testList(val), val) + assert_kind_of(Array, @client.testList(val)) + end + + def test_enum + val = Thrift::Test::Numberz::SIX + ret = @client.testEnum(val) + + assert_equal(ret, 6) + assert_kind_of(Fixnum, ret) + end + + def test_typedef + #UserId testTypedef(1: UserId thing), + true + end + + def test_set + val = Set.new([1,2,3]) + assert_equal(@client.testSet(val), val) + assert_kind_of(Set, @client.testSet(val)) + end + + def get_struct + Thrift::Test::Xtruct.new({'string_thing' => 'hi!', 'i32_thing' => 4 }) + end + + def test_struct + ret = @client.testStruct(get_struct) + + assert_nil(ret.byte_thing, nil) + assert_nil(ret.i64_thing, nil) + assert_equal(ret.string_thing, 'hi!') + assert_equal(ret.i32_thing, 4) + assert_kind_of(Thrift::Test::Xtruct, ret) + end + + def test_nest + struct2 = Thrift::Test::Xtruct2.new({'struct_thing' => get_struct, 'i32_thing' => 10}) + + ret = @client.testNest(struct2) + + assert_nil(ret.struct_thing.byte_thing, nil) + assert_nil(ret.struct_thing.i64_thing, nil) + assert_equal(ret.struct_thing.string_thing, 'hi!') + assert_equal(ret.struct_thing.i32_thing, 4) + assert_equal(ret.i32_thing, 10) + + assert_kind_of(Thrift::Test::Xtruct, ret.struct_thing) + assert_kind_of(Thrift::Test::Xtruct2, ret) + end + + def test_insane + insane = Thrift::Test::Insanity.new({ + 'userMap' => { Thrift::Test::Numberz::ONE => 44 }, + 'xtructs' => [get_struct, + Thrift::Test::Xtruct.new({ + 'string_thing' => 'hi again', + 'i32_thing' => 12 + }) + ] + }) + + ret = @client.testInsanity(insane) + + assert_not_nil(ret[44]) + assert_not_nil(ret[44][1]) + + struct = ret[44][1] + + assert_equal(struct.userMap[Thrift::Test::Numberz::ONE], 44) + assert_equal(struct.xtructs[1].string_thing, 'hi again') + assert_equal(struct.xtructs[1].i32_thing, 12) + + assert_kind_of(Hash, struct.userMap) + assert_kind_of(Array, struct.xtructs) + assert_kind_of(Thrift::Test::Insanity, struct) + end + + def test_map_map + ret = @client.testMapMap(4) + assert_kind_of(Hash, ret) + assert_equal(ret, { 4 => { 4 => 4}}) + end + + def test_exception + assert_raise Thrift::Test::Xception do + @client.testException('foo') + end + end +end + diff --git a/test/rb/integration/simple_server.rb b/test/rb/integration/simple_server.rb new file mode 100644 index 00000000..3518d2e1 --- /dev/null +++ b/test/rb/integration/simple_server.rb @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.push File.dirname(__FILE__) + '/../gen-rb' +$:.push File.join(File.dirname(__FILE__), '../../../lib/rb/lib') + +require 'thrift' +require 'ThriftTest' + +class SimpleHandler + [:testString, :testByte, :testI32, :testI64, :testDouble, + :testStruct, :testMap, :testSet, :testList, :testNest, + :testEnum, :testTypedef].each do |meth| + + define_method(meth) do |thing| + thing + end + + end + + def testInsanity(thing) + num, uid = thing.userMap.find { true } + return {uid => {num => thing}} + end + + def testMapMap(thing) + return {thing => {thing => thing}} + end + + def testEnum(thing) + return thing + end + + def testTypedef(thing) + return thing + end + + def testException(thing) + raise Thrift::Test::Xception, :message => 'error' + end +end + +@handler = SimpleHandler.new +@processor = Thrift::Test::ThriftTest::Processor.new(@handler) +@transport = Thrift::ServerSocket.new(9090) +@server = Thrift::ThreadedServer.new(@processor, @transport) + +@server.serve diff --git a/test/rb/integration/test_simple_handler.rb b/test/rb/integration/test_simple_handler.rb new file mode 100644 index 00000000..c34aa7e5 --- /dev/null +++ b/test/rb/integration/test_simple_handler.rb @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +require File.join(File.dirname(__FILE__), '../test_helper') + +require 'thrift' +require 'ThriftTest' + +class TestHandler + [:testString, :testByte, :testI32, :testI64, :testDouble, + :testStruct, :testMap, :testSet, :testList, :testNest, + :testEnum, :testTypedef].each do |meth| + + define_method(meth) do |thing| + thing + end + + end + + def testInsanity(thing) + num, uid = thing.userMap.find { true } + return {uid => {num => thing}} + end + + def testMapMap(thing) + return {thing => {thing => thing}} + end + + def testEnum(thing) + return thing + end + + def testTypedef(thing) + return thing + end + + def testException(thing) + raise Thrift::Test::Xception, :message => 'error' + end + +end +class TestThrift < Test::Unit::TestCase + + @@INIT = nil + + def setup + if @@INIT.nil? + # Initialize the server + @handler = TestHandler.new() + @processor = Thrift::Test::ThriftTest::Processor.new(@handler) + @transport = Thrift::ServerSocket.new(9090) + @server = Thrift::ThreadedServer.new(@processor, @transport) + + @thread = Thread.new { @server.serve } + + # And the Client + @socket = Thrift::Socket.new('localhost', 9090) + @protocol = Thrift::BinaryProtocol.new(@socket) + @client = Thrift::Test::ThriftTest::Client.new(@protocol) + @socket.open + end + end + + def test_string + assert_equal(@client.testString('string'), 'string') + end + + def test_byte + val = 8 + assert_equal(@client.testByte(val), val) + assert_equal(@client.testByte(-val), -val) + end + + def test_i32 + val = 32 + assert_equal(@client.testI32(val), val) + assert_equal(@client.testI32(-val), -val) + end + + def test_i64 + val = 64 + assert_equal(@client.testI64(val), val) + assert_equal(@client.testI64(-val), -val) + end + + def test_double + val = 3.14 + assert_equal(@client.testDouble(val), val) + assert_equal(@client.testDouble(-val), -val) + assert_kind_of(Float, @client.testDouble(val)) + end + + def test_map + val = {1 => 1, 2 => 2, 3 => 3} + assert_equal(@client.testMap(val), val) + assert_kind_of(Hash, @client.testMap(val)) + end + + def test_list + val = [1,2,3,4,5] + assert_equal(@client.testList(val), val) + assert_kind_of(Array, @client.testList(val)) + end + + def test_enum + val = Thrift::Test::Numberz::SIX + ret = @client.testEnum(val) + + assert_equal(ret, 6) + assert_kind_of(Fixnum, ret) + end + + def test_typedef + #UserId testTypedef(1: UserId thing), + true + end + + def test_set + val = Set.new([1, 2, 3]) + assert_equal(val, @client.testSet(val)) + assert_kind_of(Set, @client.testSet(val)) + end + + def get_struct + Thrift::Test::Xtruct.new({'string_thing' => 'hi!', 'i32_thing' => 4 }) + end + + def test_struct + ret = @client.testStruct(get_struct) + + assert_nil(ret.byte_thing, nil) + assert_nil(ret.i64_thing, nil) + assert_equal(ret.string_thing, 'hi!') + assert_equal(ret.i32_thing, 4) + assert_kind_of(Thrift::Test::Xtruct, ret) + end + + def test_nest + struct2 = Thrift::Test::Xtruct2.new({'struct_thing' => get_struct, 'i32_thing' => 10}) + + ret = @client.testNest(struct2) + + assert_nil(ret.struct_thing.byte_thing, nil) + assert_nil(ret.struct_thing.i64_thing, nil) + assert_equal(ret.struct_thing.string_thing, 'hi!') + assert_equal(ret.struct_thing.i32_thing, 4) + assert_equal(ret.i32_thing, 10) + + assert_kind_of(Thrift::Test::Xtruct, ret.struct_thing) + assert_kind_of(Thrift::Test::Xtruct2, ret) + end + + def test_insane + insane = Thrift::Test::Insanity.new({ + 'userMap' => { Thrift::Test::Numberz::ONE => 44 }, + 'xtructs' => [get_struct, + Thrift::Test::Xtruct.new({ + 'string_thing' => 'hi again', + 'i32_thing' => 12 + }) + ] + }) + + ret = @client.testInsanity(insane) + + assert_not_nil(ret[44]) + assert_not_nil(ret[44][1]) + + struct = ret[44][1] + + assert_equal(struct.userMap[Thrift::Test::Numberz::ONE], 44) + assert_equal(struct.xtructs[1].string_thing, 'hi again') + assert_equal(struct.xtructs[1].i32_thing, 12) + + assert_kind_of(Hash, struct.userMap) + assert_kind_of(Array, struct.xtructs) + assert_kind_of(Thrift::Test::Insanity, struct) + end + + def test_map_map + ret = @client.testMapMap(4) + assert_kind_of(Hash, ret) + assert_equal(ret, { 4 => { 4 => 4}}) + end + + def test_exception + assert_raise Thrift::Test::Xception do + @client.testException('foo') + end + end + + def teardown + end + +end diff --git a/test/rb/test_helper.rb b/test/rb/test_helper.rb new file mode 100644 index 00000000..c1ed779e --- /dev/null +++ b/test/rb/test_helper.rb @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.unshift File.dirname(__FILE__) + '/gen-rb' +$:.unshift File.join(File.dirname(__FILE__), '../../lib/rb/lib') +$:.unshift File.join(File.dirname(__FILE__), '../../lib/rb/ext') + +require 'test/unit' + +module Thrift + module Struct + def ==(other) + return false unless other.is_a? self.class + self.class.const_get(:FIELDS).collect {|fid, data| data[:name] }.all? do |field| + send(field) == other.send(field) + end + end + end +end diff --git a/test/rb/test_suite.rb b/test/rb/test_suite.rb new file mode 100644 index 00000000..b157c2c5 --- /dev/null +++ b/test/rb/test_suite.rb @@ -0,0 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +Dir["{core,generation}/**/*.rb"].each {|f| require f } \ No newline at end of file diff --git a/test/threads/Makefile b/test/threads/Makefile new file mode 100644 index 00000000..14f1a589 --- /dev/null +++ b/test/threads/Makefile @@ -0,0 +1,63 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Default target is everything + +ifndef thrift_home +thrift_home=../../ +endif #thrift_home + +target: all + +ifndef boost_home +boost_home=/usr/local/include/boost-1_33_1 +endif #boost_home +target: all + +include_paths = $(thrift_home)/lib/cpp/src \ + $(boost_home) + +include_flags = $(patsubst %,-I%, $(include_paths)) + +# Tools +ifndef THRIFT +THRIFT = ../../compiler/cpp/thrift +endif # THRIFT + +CC = g++ +LD = g++ + +# Compiler flags +LFL = -L$(thrift_home)/lib/cpp/.libs -lthrift +CCFL = -Wall -O3 -g -I./gen-cpp $(include_flags) +CFL = $(CCFL) $(LFL) + +all: server client + +stubs: ThreadsTest.thrift + $(THRIFT) --gen cpp --gen py ThreadsTest.thrift + +server: stubs + g++ -o ThreadsServer $(CFL) ThreadsServer.cpp ./gen-cpp/ThreadsTest.cpp ./gen-cpp/ThreadsTest_types.cpp + +client: stubs + g++ -o ThreadsClient $(CFL) ThreadsClient.cpp ./gen-cpp/ThreadsTest.cpp ./gen-cpp/ThreadsTest_types.cpp + +clean: + $(RM) -r *.o ThreadsServer ThreadsClient gen-cpp gen-py diff --git a/test/threads/ThreadsClient.cpp b/test/threads/ThreadsClient.cpp new file mode 100644 index 00000000..85274a63 --- /dev/null +++ b/test/threads/ThreadsClient.cpp @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// This autogenerated skeleton file illustrates how to build a server. +// You should copy it to another filename to avoid overwriting it. + +#include "ThreadsTest.h" +#include +#include +#include +#include +#include +#include +#include + +using boost::shared_ptr; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; +using namespace apache::thrift::concurrency; + +int main(int argc, char **argv) { + int port = 9090; + std::string host = "localhost"; + + shared_ptr transport(new TSocket(host, port)); + shared_ptr protocol(new TBinaryProtocol(transport)); + + transport->open(); + + ThreadsTestClient client(protocol); + int val; + val = client.threadOne(5); + fprintf(stderr, "%d\n", val); + val = client.stop(); + fprintf(stderr, "%d\n", val); + val = client.threadTwo(5); + fprintf(stderr, "%d\n", val); + + transport->close(); + + fprintf(stderr, "done.\n"); + + return 0; +} + diff --git a/test/threads/ThreadsServer.cpp b/test/threads/ThreadsServer.cpp new file mode 100644 index 00000000..8734ee89 --- /dev/null +++ b/test/threads/ThreadsServer.cpp @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// This autogenerated skeleton file illustrates how to build a server. +// You should copy it to another filename to avoid overwriting it. + +#include "ThreadsTest.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using boost::shared_ptr; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; +using namespace apache::thrift::concurrency; + + +class ThreadsTestHandler : virtual public ThreadsTestIf { + public: + ThreadsTestHandler() { + // Your initialization goes here + } + + int32_t threadOne(const int32_t sleep) { + // Your implementation goes here + printf("threadOne\n"); + go2sleep(1, sleep); + return 1; + } + + int32_t threadTwo(const int32_t sleep) { + // Your implementation goes here + printf("threadTwo\n"); + go2sleep(2, sleep); + return 1; + } + + int32_t threadThree(const int32_t sleep) { + // Your implementation goes here + printf("threadThree\n"); + go2sleep(3, sleep); + return 1; + } + + int32_t threadFour(const int32_t sleep) { + // Your implementation goes here + printf("threadFour\n"); + go2sleep(4, sleep); + return 1; + } + + int32_t stop() { + printf("stop\n"); + server_->stop(); + return 1; + } + + void setServer(boost::shared_ptr server) { + server_ = server; + } + +protected: + void go2sleep(int thread, int seconds) { + Monitor m; + for (int i = 0; i < seconds; ++i) { + fprintf(stderr, "Thread %d: sleep %d\n", thread, i); + try { + m.wait(1000); + } catch(TimedOutException& e) { + } + } + fprintf(stderr, "THREAD %d DONE\n", thread); + } + +private: + boost::shared_ptr server_; + +}; + +int main(int argc, char **argv) { + int port = 9090; + shared_ptr handler(new ThreadsTestHandler()); + shared_ptr processor(new ThreadsTestProcessor(handler)); + shared_ptr serverTransport(new TServerSocket(port)); + shared_ptr transportFactory(new TBufferedTransportFactory()); + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + + /* + shared_ptr threadManager = + ThreadManager::newSimpleThreadManager(10); + shared_ptr threadFactory = + shared_ptr(new PosixThreadFactory()); + threadManager->threadFactory(threadFactory); + threadManager->start(); + + shared_ptr server = + shared_ptr(new TThreadPoolServer(processor, + serverTransport, + transportFactory, + protocolFactory, + threadManager)); + */ + + shared_ptr server = + shared_ptr(new TThreadedServer(processor, + serverTransport, + transportFactory, + protocolFactory)); + + handler->setServer(server); + + server->serve(); + + fprintf(stderr, "done.\n"); + + return 0; +} + diff --git a/test/threads/ThreadsTest.thrift b/test/threads/ThreadsTest.thrift new file mode 100644 index 00000000..caa93460 --- /dev/null +++ b/test/threads/ThreadsTest.thrift @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +service ThreadsTest { + i32 threadOne(1: i32 sleep=15), + i32 threadTwo(2: i32 sleep=15), + i32 threadThree(3: i32 sleep=15), + i32 threadFour(4: i32 sleep=15) + + i32 stop(); + +} diff --git a/tutorial/README b/tutorial/README new file mode 100644 index 00000000..a29f977b --- /dev/null +++ b/tutorial/README @@ -0,0 +1,42 @@ +Thrift Tutorial + +License +======= + +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +Tutorial +======== + +1) First things first, you'll need to install the Thrift compiler and the + language libraries. Do that using the instructions in the top level + README file. + +2) Read tutorial.thrift to learn about the syntax of a Thrift file + +3) Compile the code for the language of your choice: + + $ thrift + $ thrift -r --gen cpp tutorial.thrift + +4) Take a look at the generated code. + +5) Look in the language directories for sample client/server code. + +6) That's about it for now. This tutorial is intentionally brief. It should be + just enough to get you started and ready to build your own project. diff --git a/tutorial/cpp/CppClient.cpp b/tutorial/cpp/CppClient.cpp new file mode 100644 index 00000000..a3f17fee --- /dev/null +++ b/tutorial/cpp/CppClient.cpp @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#include "../gen-cpp/Calculator.h" + +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; + +using namespace tutorial; +using namespace shared; + +using namespace boost; + +int main(int argc, char** argv) { + shared_ptr socket(new TSocket("localhost", 9090)); + shared_ptr transport(new TBufferedTransport(socket)); + shared_ptr protocol(new TBinaryProtocol(transport)); + CalculatorClient client(protocol); + + try { + transport->open(); + + client.ping(); + printf("ping()\n"); + + int32_t sum = client.add(1,1); + printf("1+1=%d\n", sum); + + Work work; + work.op = DIVIDE; + work.num1 = 1; + work.num2 = 0; + + try { + int32_t quotient = client.calculate(1, work); + printf("Whoa? We can divide by zero!\n"); + } catch (InvalidOperation &io) { + printf("InvalidOperation: %s\n", io.why.c_str()); + } + + work.op = SUBTRACT; + work.num1 = 15; + work.num2 = 10; + int32_t diff = client.calculate(1, work); + printf("15-10=%d\n", diff); + + // Note that C++ uses return by reference for complex types to avoid + // costly copy construction + SharedStruct ss; + client.getStruct(ss, 1); + printf("Check log: %s\n", ss.value.c_str()); + + transport->close(); + } catch (TException &tx) { + printf("ERROR: %s\n", tx.what()); + } + +} diff --git a/tutorial/cpp/CppServer.cpp b/tutorial/cpp/CppServer.cpp new file mode 100644 index 00000000..23c2b833 --- /dev/null +++ b/tutorial/cpp/CppServer.cpp @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../gen-cpp/Calculator.h" + +using namespace std; +using namespace apache::thrift; +using namespace apache::thrift::protocol; +using namespace apache::thrift::transport; +using namespace apache::thrift::server; + +using namespace tutorial; +using namespace shared; + +using namespace boost; + +class CalculatorHandler : public CalculatorIf { + public: + CalculatorHandler() {} + + void ping() { + printf("ping()\n"); + } + + int32_t add(const int32_t n1, const int32_t n2) { + printf("add(%d,%d)\n", n1, n2); + return n1 + n2; + } + + int32_t calculate(const int32_t logid, const Work &work) { + printf("calculate(%d,{%d,%d,%d})\n", logid, work.op, work.num1, work.num2); + int32_t val; + + switch (work.op) { + case ADD: + val = work.num1 + work.num2; + break; + case SUBTRACT: + val = work.num1 - work.num2; + break; + case MULTIPLY: + val = work.num1 * work.num2; + break; + case DIVIDE: + if (work.num2 == 0) { + InvalidOperation io; + io.what = work.op; + io.why = "Cannot divide by 0"; + throw io; + } + val = work.num1 / work.num2; + break; + default: + InvalidOperation io; + io.what = work.op; + io.why = "Invalid Operation"; + throw io; + } + + SharedStruct ss; + ss.key = logid; + char buffer[12]; + snprintf(buffer, sizeof(buffer), "%d", val); + ss.value = buffer; + + log[logid] = ss; + + return val; + } + + void getStruct(SharedStruct &ret, const int32_t logid) { + printf("getStruct(%d)\n", logid); + ret = log[logid]; + } + + void zip() { + printf("zip()\n"); + } + +protected: + map log; + +}; + +int main(int argc, char **argv) { + + shared_ptr protocolFactory(new TBinaryProtocolFactory()); + shared_ptr handler(new CalculatorHandler()); + shared_ptr processor(new CalculatorProcessor(handler)); + shared_ptr serverTransport(new TServerSocket(9090)); + shared_ptr transportFactory(new TBufferedTransportFactory()); + + TSimpleServer server(processor, + serverTransport, + transportFactory, + protocolFactory); + + + /** + * Or you could do one of these + + shared_ptr threadManager = + ThreadManager::newSimpleThreadManager(workerCount); + shared_ptr threadFactory = + shared_ptr(new PosixThreadFactory()); + threadManager->threadFactory(threadFactory); + threadManager->start(); + TThreadPoolServer server(processor, + serverTransport, + transportFactory, + protocolFactory, + threadManager); + + TThreadedServer server(processor, + serverTransport, + transportFactory, + protocolFactory); + + */ + + printf("Starting the server...\n"); + server.serve(); + printf("done.\n"); + return 0; +} diff --git a/tutorial/cpp/Makefile b/tutorial/cpp/Makefile new file mode 100644 index 00000000..e834dee2 --- /dev/null +++ b/tutorial/cpp/Makefile @@ -0,0 +1,35 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +BOOST_DIR = /usr/local/boost/include/boost-1_33_1/ +THRIFT_DIR = /usr/local/include/thrift +LIB_DIR = /usr/local/lib + +GEN_SRC = ../gen-cpp/SharedService.cpp ../gen-cpp/shared_types.cpp ../gen-cpp/tutorial_types.cpp ../gen-cpp/Calculator.cpp + +default: server client + +server: CppServer.cpp + g++ -o CppServer -I${THRIFT_DIR} -I${BOOST_DIR} -I../gen-cpp -L${LIB_DIR} -lthrift CppServer.cpp ${GEN_SRC} + +client: CppClient.cpp + g++ -o CppClient -I${THRIFT_DIR} -I${BOOST_DIR} -I../gen-cpp -L${LIB_DIR} -lthrift CppClient.cpp ${GEN_SRC} + +clean: + $(RM) -r CppClient CppServer diff --git a/tutorial/erl/client.erl b/tutorial/erl/client.erl new file mode 100644 index 00000000..97803349 --- /dev/null +++ b/tutorial/erl/client.erl @@ -0,0 +1,74 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(client). + +-include("calculator_thrift.hrl"). + +-export([t/0]). + +p(X) -> + io:format("~p~n", [X]), + ok. + +t() -> + Port = 9999, + + {ok, Client} = thrift_client:start_link("127.0.0.1", + Port, + calculator_thrift), + + thrift_client:call(Client, ping, []), + io:format("ping~n", []), + + {ok, Sum} = thrift_client:call(Client, add, [1, 1]), + io:format("1+1=~p~n", [Sum]), + + {ok, Sum1} = thrift_client:call(Client, add, [1, 4]), + io:format("1+4=~p~n", [Sum1]), + + Work = #work{op=?tutorial_SUBTRACT, + num1=15, + num2=10}, + {ok, Diff} = thrift_client:call(Client, calculate, [1, Work]), + io:format("15-10=~p~n", [Diff]), + + {ok, Log} = thrift_client:call(Client, getStruct, [1]), + io:format("Log: ~p~n", [Log]), + + try + Work1 = #work{op=?tutorial_DIVIDE, + num1=1, + num2=0}, + {ok, _Quot} = thrift_client:call(Client, calculate, [2, Work1]), + + io:format("LAME: exception handling is broken~n", []) + catch + Z -> + io:format("Got exception where expecting - the " ++ + "following is NOT a problem!!!~n"), + p(Z) + end, + + + {ok, ok} = thrift_client:call(Client, zip, []), + io:format("zip~n", []), + + ok = thrift_client:close(Client), + ok. diff --git a/tutorial/erl/client.sh b/tutorial/erl/client.sh new file mode 120000 index 00000000..a417e0da --- /dev/null +++ b/tutorial/erl/client.sh @@ -0,0 +1 @@ +server.sh \ No newline at end of file diff --git a/tutorial/erl/server.erl b/tutorial/erl/server.erl new file mode 100644 index 00000000..5a994ce7 --- /dev/null +++ b/tutorial/erl/server.erl @@ -0,0 +1,82 @@ +%% +%% Licensed to the Apache Software Foundation (ASF) under one +%% or more contributor license agreements. See the NOTICE file +%% distributed with this work for additional information +%% regarding copyright ownership. The ASF licenses this file +%% to you under the Apache License, Version 2.0 (the +%% "License"); you may not use this file except in compliance +%% with the License. You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, +%% software distributed under the License is distributed on an +%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +%% KIND, either express or implied. See the License for the +%% specific language governing permissions and limitations +%% under the License. +%% + +-module(server). + +-include("calculator_thrift.hrl"). + +-export([start/0, start/1, handle_function/2, + stop/1, ping/0, add/2, calculate/2, getStruct/1, zip/0]). + +debug(Format, Data) -> + error_logger:info_msg(Format, Data). + +ping() -> + debug("ping()",[]), + ok. + +add(N1, N2) -> + debug("add(~p,~p)",[N1,N2]), + N1+N2. + +calculate(Logid, Work) -> + { Op, Num1, Num2 } = { Work#work.op, Work#work.num1, Work#work.num2 }, + debug("calculate(~p, {~p,~p,~p})", [Logid, Op, Num1, Num2]), + case Op of + ?tutorial_ADD -> Num1 + Num2; + ?tutorial_SUBTRACT -> Num1 - Num2; + ?tutorial_MULTIPLY -> Num1 * Num2; + + ?tutorial_DIVIDE when Num2 == 0 -> + throw(#invalidOperation{what=Op, why="Cannot divide by 0"}); + ?tutorial_DIVIDE -> + Num1 div Num2; + + _Else -> + throw(#invalidOperation{what=Op, why="Invalid operation"}) + end. + +getStruct(Key) -> + debug("getStruct(~p)", [Key]), + #sharedStruct{key=Key, value="RARG"}. + +zip() -> + debug("zip", []), + ok. + +%% + +start() -> + start(9999). + +start(Port) -> + Handler = ?MODULE, + thrift_socket_server:start([{handler, Handler}, + {service, calculator_thrift}, + {port, Port}, + {name, tutorial_server}]). + +stop(Server) -> + thrift_socket_server:stop(Server). + +handle_function(Function, Args) when is_atom(Function), is_tuple(Args) -> + case apply(?MODULE, Function, tuple_to_list(Args)) of + ok -> ok; + Reply -> {reply, Reply} + end. diff --git a/tutorial/erl/server.sh b/tutorial/erl/server.sh new file mode 100755 index 00000000..106c89e9 --- /dev/null +++ b/tutorial/erl/server.sh @@ -0,0 +1,37 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +ERL_THRIFT=../../lib/erl + +if ! [ -d ${ERL_THRIFT}/ebin ]; then + echo "Please build the Thrift library by running \`make' in ${ERL_THRIFT}" + exit 1 +fi + +if ! [ -d ../gen-erl ]; then + echo "Please run thrift first to generate ../gen-erl/" + exit 1 +fi + + +erlc -I ${ERL_THRIFT}/include -I ../gen-erl -o ../gen-erl ../gen-erl/*.erl && + erlc -I ${ERL_THRIFT}/include -I ../gen-erl *.erl && + erl +K true -pa ${ERL_THRIFT}/ebin -pa ../gen-erl diff --git a/tutorial/java/JavaClient b/tutorial/java/JavaClient new file mode 100755 index 00000000..68d87b81 --- /dev/null +++ b/tutorial/java/JavaClient @@ -0,0 +1,22 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -cp tutorial.jar:/usr/local/lib/libthrift.jar JavaClient diff --git a/tutorial/java/JavaServer b/tutorial/java/JavaServer new file mode 100755 index 00000000..89616001 --- /dev/null +++ b/tutorial/java/JavaServer @@ -0,0 +1,22 @@ +#!/bin/sh + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +java -cp tutorial.jar:/usr/local/lib/libthrift.jar JavaServer diff --git a/tutorial/java/build.xml b/tutorial/java/build.xml new file mode 100644 index 00000000..0ec1ea40 --- /dev/null +++ b/tutorial/java/build.xml @@ -0,0 +1,47 @@ + + + + Thrift Tutorial + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tutorial/java/src/JavaClient.java b/tutorial/java/src/JavaClient.java new file mode 100644 index 00000000..5dc70ed5 --- /dev/null +++ b/tutorial/java/src/JavaClient.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Generated code +import tutorial.*; +import shared.*; + +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; +import org.apache.thrift.transport.TSocket; +import org.apache.thrift.transport.TTransportException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; + +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.HashSet; +import java.util.ArrayList; + +public class JavaClient { + public static void main(String [] args) { + try { + + TTransport transport = new TSocket("localhost", 9090); + TProtocol protocol = new TBinaryProtocol(transport); + Calculator.Client client = new Calculator.Client(protocol); + + transport.open(); + + client.ping(); + System.out.println("ping()"); + + int sum = client.add(1,1); + System.out.println("1+1=" + sum); + + Work work = new Work(); + + work.op = Operation.DIVIDE; + work.num1 = 1; + work.num2 = 0; + try { + int quotient = client.calculate(1, work); + System.out.println("Whoa we can divide by 0"); + } catch (InvalidOperation io) { + System.out.println("Invalid operation: " + io.why); + } + + work.op = Operation.SUBTRACT; + work.num1 = 15; + work.num2 = 10; + try { + int diff = client.calculate(1, work); + System.out.println("15-10=" + diff); + } catch (InvalidOperation io) { + System.out.println("Invalid operation: " + io.why); + } + + SharedStruct log = client.getStruct(1); + System.out.println("Check log: " + log.value); + + transport.close(); + + } catch (TException x) { + x.printStackTrace(); + } + + } + +} diff --git a/tutorial/java/src/JavaServer.java b/tutorial/java/src/JavaServer.java new file mode 100644 index 00000000..14440eb7 --- /dev/null +++ b/tutorial/java/src/JavaServer.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TProtocol; +import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TSimpleServer; +import org.apache.thrift.transport.TServerSocket; +import org.apache.thrift.transport.TServerTransport; + +// Generated code +import tutorial.*; +import shared.*; + +import java.util.HashMap; + +public class JavaServer { + + public static class CalculatorHandler implements Calculator.Iface { + + private HashMap log; + + public CalculatorHandler() { + log = new HashMap(); + } + + public void ping() { + System.out.println("ping()"); + } + + public int add(int n1, int n2) { + System.out.println("add(" + n1 + "," + n2 + ")"); + return n1 + n2; + } + + public int calculate(int logid, Work work) throws InvalidOperation { + System.out.println("calculate(" + logid + ", {" + work.op + "," + work.num1 + "," + work.num2 + "})"); + int val = 0; + switch (work.op) { + case Operation.ADD: + val = work.num1 + work.num2; + break; + case Operation.SUBTRACT: + val = work.num1 - work.num2; + break; + case Operation.MULTIPLY: + val = work.num1 * work.num2; + break; + case Operation.DIVIDE: + if (work.num2 == 0) { + InvalidOperation io = new InvalidOperation(); + io.what = work.op; + io.why = "Cannot divide by 0"; + throw io; + } + val = work.num1 / work.num2; + break; + default: + InvalidOperation io = new InvalidOperation(); + io.what = work.op; + io.why = "Unknown operation"; + throw io; + } + + SharedStruct entry = new SharedStruct(); + entry.key = logid; + entry.value = Integer.toString(val); + log.put(logid, entry); + + return val; + } + + public SharedStruct getStruct(int key) { + System.out.println("getStruct(" + key + ")"); + return log.get(key); + } + + public void zip() { + System.out.println("zip()"); + } + + } + + public static void main(String [] args) { + try { + CalculatorHandler handler = new CalculatorHandler(); + Calculator.Processor processor = new Calculator.Processor(handler); + TServerTransport serverTransport = new TServerSocket(9090); + TServer server = new TSimpleServer(processor, serverTransport); + + // Use this for a multithreaded server + // server = new TThreadPoolServer(processor, serverTransport); + + System.out.println("Starting the server..."); + server.serve(); + + } catch (Exception x) { + x.printStackTrace(); + } + System.out.println("done."); + } +} diff --git a/tutorial/perl/PerlClient.pl b/tutorial/perl/PerlClient.pl new file mode 100644 index 00000000..1d596568 --- /dev/null +++ b/tutorial/perl/PerlClient.pl @@ -0,0 +1,82 @@ +#!/usr/bin/env perl + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +use strict; +use warnings; + +use lib '../../lib/perl/lib'; +use lib '../gen-perl'; + +use Thrift; +use Thrift::BinaryProtocol; +use Thrift::Socket; +use Thrift::BufferedTransport; + +use shared::SharedService; +use tutorial::Calculator; +use shared::Types; +use tutorial::Types; + +use Data::Dumper; + +my $socket = new Thrift::Socket('localhost',9090); +my $transport = new Thrift::BufferedTransport($socket,1024,1024); +my $protocol = new Thrift::BinaryProtocol($transport); +my $client = new tutorial::CalculatorClient($protocol); + + +eval{ + $transport->open(); + + $client->ping(); + print "ping()\n"; + + + my $sum = $client->add(1,1); + print "1+1=$sum\n"; + + my $work = new tutorial::Work(); + + $work->op(tutorial::Operation::DIVIDE); + $work->num1(1); + $work->num2(0); + + eval { + $client->calculate(1, $work); + print "Whoa! We can divide by zero?\n"; + }; if($@) { + warn "InvalidOperation: ".Dumper($@); + } + + $work->op(tutorial::Operation::SUBTRACT); + $work->num1(15); + $work->num2(10); + my $diff = $client->calculate(1, $work); + print "15-10=$diff\n"; + + my $log = $client->getStruct(1); + print "Log: $log->{value}\n"; + + $transport->close(); + +}; if($@){ + warn(Dumper($@)); +} diff --git a/tutorial/php/PhpClient.php b/tutorial/php/PhpClient.php new file mode 100755 index 00000000..c5c08101 --- /dev/null +++ b/tutorial/php/PhpClient.php @@ -0,0 +1,92 @@ +#!/usr/bin/env php +open(); + + $client->ping(); + print "ping()\n"; + + $sum = $client->add(1,1); + print "1+1=$sum\n"; + + $work = new tutorial_Work(); + + $work->op = tutorial_Operation::DIVIDE; + $work->num1 = 1; + $work->num2 = 0; + + try { + $client->calculate(1, $work); + print "Whoa! We can divide by zero?\n"; + } catch (tutorial_InvalidOperation $io) { + print "InvalidOperation: $io->why\n"; + } + + $work->op = tutorial_Operation::SUBTRACT; + $work->num1 = 15; + $work->num2 = 10; + $diff = $client->calculate(1, $work); + print "15-10=$diff\n"; + + $log = $client->getStruct(1); + print "Log: $log->value\n"; + + $transport->close(); + +} catch (TException $tx) { + print 'TException: '.$tx->getMessage()."\n"; +} + +?> diff --git a/tutorial/php/PhpServer.php b/tutorial/php/PhpServer.php new file mode 100755 index 00000000..9482c649 --- /dev/null +++ b/tutorial/php/PhpServer.php @@ -0,0 +1,132 @@ +#!/usr/bin/env php +op}, {$w->num1}, {$w->num2}})"); + switch ($w->op) { + case tutorial_Operation::ADD: + $val = $w->num1 + $w->num2; + break; + case tutorial_Operation::SUBTRACT: + $val = $w->num1 - $w->num2; + break; + case tutorial_Operation::MULTIPLY: + $val = $w->num1 * $w->num2; + break; + case tutorial_Operation::DIVIDE: + if ($w->num2 == 0) { + $io = new tutorial_InvalidOperation(); + $io->what = $w->op; + $io->why = "Cannot divide by 0"; + throw $io; + } + $val = $w->num1 / $w->num2; + break; + default: + $io = new tutorial_InvalidOperation(); + $io->what = $w->op; + $io->why = "Invalid Operation"; + throw $io; + } + + $log = new SharedStruct(); + $log->key = $logid; + $log->value = (string)$val; + $this->log[$logid] = $log; + + return $val; + } + + public function getStruct($key) { + error_log("getStruct({$key})"); + // This actually doesn't work because the PHP interpreter is + // restarted for every request. + //return $this->log[$key]; + return new SharedStruct(array("key" => $key, "value" => "PHP is stateless!")); + } + + public function zip() { + error_log("zip()"); + } + +}; + +header('Content-Type', 'application/x-thrift'); +if (php_sapi_name() == 'cli') { + echo "\r\n"; +} + +$handler = new CalculatorHandler(); +$processor = new CalculatorProcessor($handler); + +$transport = new TBufferedTransport(new TPhpStream(TPhpStream::MODE_R | TPhpStream::MODE_W)); +$protocol = new TBinaryProtocol($transport, true, true); + +$transport->open(); +$processor->process($protocol, $protocol); +$transport->close(); diff --git a/tutorial/php/runserver.py b/tutorial/php/runserver.py new file mode 100755 index 00000000..ae29fed9 --- /dev/null +++ b/tutorial/php/runserver.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +import BaseHTTPServer +import CGIHTTPServer + +# chdir(2) into the tutorial directory. +os.chdir(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +class Handler(CGIHTTPServer.CGIHTTPRequestHandler): + cgi_directories = ['/php'] + +BaseHTTPServer.HTTPServer(('', 8080), Handler).serve_forever() diff --git a/tutorial/py/PythonClient.py b/tutorial/py/PythonClient.py new file mode 100755 index 00000000..916e9157 --- /dev/null +++ b/tutorial/py/PythonClient.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +sys.path.append('../gen-py') + +from tutorial import Calculator +from tutorial.ttypes import * + +from thrift import Thrift +from thrift.transport import TSocket +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol + +try: + + # Make socket + transport = TSocket.TSocket('localhost', 9090) + + # Buffering is critical. Raw sockets are very slow + transport = TTransport.TBufferedTransport(transport) + + # Wrap in a protocol + protocol = TBinaryProtocol.TBinaryProtocol(transport) + + # Create a client to use the protocol encoder + client = Calculator.Client(protocol) + + # Connect! + transport.open() + + client.ping() + print 'ping()' + + sum = client.add(1,1) + print '1+1=%d' % (sum) + + work = Work() + + work.op = Operation.DIVIDE + work.num1 = 1 + work.num2 = 0 + + try: + quotient = client.calculate(1, work) + print 'Whoa? You know how to divide by zero?' + except InvalidOperation, io: + print 'InvalidOperation: %r' % io + + work.op = Operation.SUBTRACT + work.num1 = 15 + work.num2 = 10 + + diff = client.calculate(1, work) + print '15-10=%d' % (diff) + + log = client.getStruct(1) + print 'Check log: %s' % (log.value) + + # Close! + transport.close() + +except Thrift.TException, tx: + print '%s' % (tx.message) diff --git a/tutorial/py/PythonServer.py b/tutorial/py/PythonServer.py new file mode 100755 index 00000000..63f993bc --- /dev/null +++ b/tutorial/py/PythonServer.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +sys.path.append('../gen-py') + +from tutorial import Calculator +from tutorial.ttypes import * + +from shared.ttypes import SharedStruct + +from thrift.transport import TSocket +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol +from thrift.server import TServer + +class CalculatorHandler: + def __init__(self): + self.log = {} + + def ping(self): + print 'ping()' + + def add(self, n1, n2): + print 'add(%d,%d)' % (n1, n2) + return n1+n2 + + def calculate(self, logid, work): + print 'calculate(%d, %r)' % (logid, work) + + if work.op == Operation.ADD: + val = work.num1 + work.num2 + elif work.op == Operation.SUBTRACT: + val = work.num1 - work.num2 + elif work.op == Operation.MULTIPLY: + val = work.num1 * work.num2 + elif work.op == Operation.DIVIDE: + if work.num2 == 0: + x = InvalidOperation() + x.what = work.op + x.why = 'Cannot divide by 0' + raise x + val = work.num1 / work.num2 + else: + x = InvalidOperation() + x.what = work.op + x.why = 'Invalid operation' + raise x + + log = SharedStruct() + log.key = logid + log.value = '%d' % (val) + self.log[logid] = log + + return val + + def getStruct(self, key): + print 'getStruct(%d)' % (key) + return self.log[key] + + def zip(self): + print 'zip()' + +handler = CalculatorHandler() +processor = Calculator.Processor(handler) +transport = TSocket.TServerSocket(9090) +tfactory = TTransport.TBufferedTransportFactory() +pfactory = TBinaryProtocol.TBinaryProtocolFactory() + +server = TServer.TSimpleServer(processor, transport, tfactory, pfactory) + +# You could do one of these for a multithreaded server +#server = TServer.TThreadedServer(processor, transport, tfactory, pfactory) +#server = TServer.TThreadPoolServer(processor, transport, tfactory, pfactory) + +print 'Starting the server...' +server.serve() +print 'done.' diff --git a/tutorial/rb/RubyClient.rb b/tutorial/rb/RubyClient.rb new file mode 100755 index 00000000..8971fed9 --- /dev/null +++ b/tutorial/rb/RubyClient.rb @@ -0,0 +1,75 @@ +#!/usr/bin/env ruby + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.push('../gen-rb') +$:.unshift '../../lib/rb/lib' + +require 'thrift' + +require 'Calculator' + +begin + port = ARGV[0] || 9090 + + transport = Thrift::BufferedTransport.new(Thrift::Socket.new('localhost', port)) + protocol = Thrift::BinaryProtocol.new(transport) + client = Calculator::Client.new(protocol) + + transport.open() + + client.ping() + print "ping()\n" + + sum = client.add(1,1) + print "1+1=", sum, "\n" + + sum = client.add(1,4) + print "1+4=", sum, "\n" + + work = Work.new() + + work.op = Operation::SUBTRACT + work.num1 = 15 + work.num2 = 10 + diff = client.calculate(1, work) + print "15-10=", diff, "\n" + + log = client.getStruct(1) + print "Log: ", log.value, "\n" + + begin + work.op = Operation::DIVIDE + work.num1 = 1 + work.num2 = 0 + quot = client.calculate(1, work) + puts "Whoa, we can divide by 0 now?" + rescue InvalidOperation => io + print "InvalidOperation: ", io.why, "\n" + end + + client.zip() + print "zip\n" + + transport.close() + +rescue Thrift::Exception => tx + print 'Thrift::Exception: ', tx.message, "\n" +end diff --git a/tutorial/rb/RubyServer.rb b/tutorial/rb/RubyServer.rb new file mode 100755 index 00000000..89eb3738 --- /dev/null +++ b/tutorial/rb/RubyServer.rb @@ -0,0 +1,95 @@ +#!/usr/bin/env ruby + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +$:.push('../gen-rb') +$:.unshift '../../lib/rb/lib' + +require 'thrift' + +require 'Calculator' +require 'shared_types' + +class CalculatorHandler + def initialize() + @log = {} + end + + def ping() + puts "ping()" + end + + def add(n1, n2) + print "add(", n1, ",", n2, ")\n" + return n1 + n2 + end + + def calculate(logid, work) + print "calculate(", logid, ", {", work.op, ",", work.num1, ",", work.num2,"})\n" + if work.op == Operation::ADD + val = work.num1 + work.num2 + elsif work.op == Operation::SUBTRACT + val = work.num1 - work.num2 + elsif work.op == Operation::MULTIPLY + val = work.num1 * work.num2 + elsif work.op == Operation::DIVIDE + if work.num2 == 0 + x = InvalidOperation.new() + x.what = work.op + x.why = "Cannot divide by 0" + raise x + end + val = work.num1 / work.num2 + else + x = InvalidOperation.new() + x.what = work.op + x.why = "Invalid operation" + raise x + end + + entry = SharedStruct.new() + entry.key = logid + entry.value = "#{val}" + @log[logid] = entry + + return val + + end + + def getStruct(key) + print "getStruct(", key, ")\n" + return @log[key] + end + + def zip() + print "zip\n" + end + +end + +handler = CalculatorHandler.new() +processor = Calculator::Processor.new(handler) +transport = Thrift::ServerSocket.new(9090) +transportFactory = Thrift::BufferedTransportFactory.new() +server = Thrift::SimpleServer.new(processor, transport, transportFactory) + +puts "Starting the server..." +server.serve() +puts "done." diff --git a/tutorial/shared.thrift b/tutorial/shared.thrift new file mode 100644 index 00000000..475e7f80 --- /dev/null +++ b/tutorial/shared.thrift @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * This Thrift file can be included by other Thrift files that want to share + * these definitions. + */ + +namespace cpp shared +namespace java shared +namespace perl shared + +struct SharedStruct { + 1: i32 key + 2: string value +} + +service SharedService { + SharedStruct getStruct(1: i32 key) +} diff --git a/tutorial/tutorial.thrift b/tutorial/tutorial.thrift new file mode 100644 index 00000000..86e433dd --- /dev/null +++ b/tutorial/tutorial.thrift @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +# Thrift Tutorial +# Mark Slee (mcslee@facebook.com) +# +# This file aims to teach you how to use Thrift, in a .thrift file. Neato. The +# first thing to notice is that .thrift files support standard shell comments. +# This lets you make your thrift file executable and include your Thrift build +# step on the top line. And you can place comments like this anywhere you like. +# +# Before running this file, you will need to have installed the thrift compiler +# into /usr/local/bin. + +/** + * The first thing to know about are types. The available types in Thrift are: + * + * bool Boolean, one byte + * byte Signed byte + * i16 Signed 16-bit integer + * i32 Signed 32-bit integer + * i64 Signed 64-bit integer + * double 64-bit floating point value + * string String + * binary Blob (byte array) + * map Map from one type to another + * list Ordered list of one type + * set Set of unique elements of one type + * + * Did you also notice that Thrift supports C style comments? + */ + +// Just in case you were wondering... yes. We support simple C comments too. + +/** + * Thrift files can reference other Thrift files to include common struct + * and service definitions. These are found using the current path, or by + * searching relative to any paths specified with the -I compiler flag. + * + * Included objects are accessed using the name of the .thrift file as a + * prefix. i.e. shared.SharedObject + */ +include "shared.thrift" + +/** + * Thrift files can namespace, package, or prefix their output in various + * target languages. + */ +namespace cpp tutorial +namespace java tutorial +namespace php tutorial +namespace perl tutorial +namespace smalltalk.category Thrift.Tutorial + +/** + * Thrift lets you do typedefs to get pretty names for your types. Standard + * C style here. + */ +typedef i32 MyInteger + +/** + * Thrift also lets you define constants for use across languages. Complex + * types and structs are specified using JSON notation. + */ +const i32 INT32CONSTANT = 9853 +const map MAPCONSTANT = {'hello':'world', 'goodnight':'moon'} + +/** + * You can define enums, which are just 32 bit integers. Values are optional + * and start at 1 if not supplied, C style again. + */ +enum Operation { + ADD = 1, + SUBTRACT = 2, + MULTIPLY = 3, + DIVIDE = 4 +} + +/** + * Structs are the basic complex data structures. They are comprised of fields + * which each have an integer identifier, a type, a symbolic name, and an + * optional default value. + * + * Fields can be declared "optional", which ensures they will not be included + * in the serialized output if they aren't set. Note that this requires some + * manual management in some languages. + */ +struct Work { + 1: i32 num1 = 0, + 2: i32 num2, + 3: Operation op, + 4: optional string comment, +} + +/** + * Structs can also be exceptions, if they are nasty. + */ +exception InvalidOperation { + 1: i32 what, + 2: string why +} + +/** + * Ahh, now onto the cool part, defining a service. Services just need a name + * and can optionally inherit from another service using the extends keyword. + */ +service Calculator extends shared.SharedService { + + /** + * A method definition looks like C code. It has a return type, arguments, + * and optionally a list of exceptions that it may throw. Note that argument + * lists and exception lists are specified using the exact same syntax as + * field lists in struct or exception definitions. + */ + + void ping(), + + i32 add(1:i32 num1, 2:i32 num2), + + i32 calculate(1:i32 logid, 2:Work w) throws (1:InvalidOperation ouch), + + /** + * This method has a oneway modifier. That means the client only makes + * a request and does not listen for any response at all. Oneway methods + * must be void. + */ + oneway void zip() + +} + +/** + * That just about covers the basics. Take a look in the test/ folder for more + * detailed examples. After you run this file, your generated code shows up + * in folders with names gen-. The generated code isn't too scary + * to look at. It even has pretty indentation. + */