From: Roger Meier Date: Wed, 26 Dec 2012 21:09:55 +0000 (+0100) Subject: THRIFT-1797 Python implementation of TSimpleJSONProtocol X-Git-Tag: 0.9.1~216 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=0895dfe6c0f74f77cea1ed5c4e28ac0c0b27d525;p=common%2Fthrift.git THRIFT-1797 Python implementation of TSimpleJSONProtocol Patch: Avi Flamholz --- diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 index 9aa31aa4..12a6b0b5 --- a/.gitignore +++ b/.gitignore @@ -245,3 +245,5 @@ gen-* /tutorial/py/Makefile /tutorial/py/Makefile.in /ylwrap +.project +.pydevproject diff --git a/lib/py/src/protocol/TJSONProtocol.py b/lib/py/src/protocol/TJSONProtocol.py index 5fb3ec77..3048197d 100644 --- a/lib/py/src/protocol/TJSONProtocol.py +++ b/lib/py/src/protocol/TJSONProtocol.py @@ -17,10 +17,15 @@ # under the License. # -from TProtocol import * -import json, base64, sys +from TProtocol import TType, TProtocolBase, TProtocolException +import base64 +import json +import math -__all__ = ['TJSONProtocol', 'TJSONProtocolFactory'] +__all__ = ['TJSONProtocol', + 'TJSONProtocolFactory', + 'TSimpleJSONProtocol', + 'TSimpleJSONProtocolFactory'] VERSION = 1 @@ -74,6 +79,9 @@ class JSONBaseContext(object): def escapeNum(self): return False + def __str__(self): + return self.__class__.__name__ + class JSONListContext(JSONBaseContext): @@ -91,14 +99,17 @@ class JSONListContext(JSONBaseContext): class JSONPairContext(JSONBaseContext): - colon = True + + def __init__(self, protocol): + super(JSONPairContext, self).__init__(protocol) + self.colon = True def doIO(self, function): - if self.first is True: + if self.first: self.first = False self.colon = True else: - function(COLON if self.colon == True else COMMA) + function(COLON if self.colon else COMMA) self.colon = not self.colon def write(self): @@ -110,6 +121,9 @@ class JSONPairContext(JSONBaseContext): def escapeNum(self): return self.colon + def __str__(self): + return '%s, colon=%s' % (self.__class__.__name__, self.colon) + class LookaheadReader(): hasData = False @@ -139,8 +153,8 @@ class TJSONProtocolBase(TProtocolBase): self.resetReadContext() def resetWriteContext(self): - self.contextStack = [] - self.context = JSONBaseContext(self) + self.context = JSONBaseContext(self) + self.contextStack = [self.context] def resetReadContext(self): self.resetWriteContext() @@ -152,6 +166,10 @@ class TJSONProtocolBase(TProtocolBase): def popContext(self): self.contextStack.pop() + if self.contextStack: + self.context = self.contextStack[-1] + else: + self.context = JSONBaseContext(self) def writeJSONString(self, string): self.context.write() @@ -210,7 +228,7 @@ class TJSONProtocolBase(TProtocolBase): self.readJSONSyntaxChar(ZERO) character = json.JSONDecoder().decode('"\u00%s"' % self.trans.read(2)) else: - off = ESCAPE_CHAR.find(char) + off = ESCAPE_CHAR.find(character) if off == -1: raise TProtocolException(TProtocolException.INVALID_DATA, "Expected control char") @@ -251,7 +269,9 @@ class TJSONProtocolBase(TProtocolBase): string = self.readJSONString(True) try: double = float(string) - if self.context.escapeNum is False and double != inf and double != nan: + if (self.context.escapeNum is False and + not math.isinf(double) and + not math.isnan(double)): raise TProtocolException(TProtocolException.INVALID_DATA, "Numeric data unexpectedly quoted") return double @@ -445,9 +465,86 @@ class TJSONProtocol(TJSONProtocolBase): def writeBinary(self, binary): self.writeJSONBase64(binary) + class TJSONProtocolFactory: - def __init__(self): - pass def getProtocol(self, trans): return TJSONProtocol(trans) + + +class TSimpleJSONProtocol(TJSONProtocolBase): + """Simple, readable, write-only JSON protocol. + + Useful for interacting with scripting languages. + """ + + def readMessageBegin(self): + raise NotImplementedError() + + def readMessageEnd(self): + raise NotImplementedError() + + def readStructBegin(self): + raise NotImplementedError() + + def readStructEnd(self): + raise NotImplementedError() + + def writeMessageBegin(self, name, request_type, seqid): + self.resetWriteContext() + + def writeMessageEnd(self): + pass + + def writeStructBegin(self, name): + self.writeJSONObjectStart() + + def writeStructEnd(self): + self.writeJSONObjectEnd() + + def writeFieldBegin(self, name, ttype, fid): + self.writeJSONString(name) + + def writeFieldEnd(self): + pass + + def writeMapBegin(self, ktype, vtype, size): + self.writeJSONObjectStart() + + def writeMapEnd(self): + self.writeJSONObjectEnd() + + def _writeCollectionBegin(self, etype, size): + self.writeJSONArrayStart() + + def _writeCollectionEnd(self): + self.writeJSONArrayEnd() + writeListBegin = _writeCollectionBegin + writeListEnd = _writeCollectionEnd + writeSetBegin = _writeCollectionBegin + writeSetEnd = _writeCollectionEnd + + def writeInteger(self, integer): + self.writeJSONNumber(integer) + writeByte = writeInteger + writeI16 = writeInteger + writeI32 = writeInteger + writeI64 = writeInteger + + def writeBool(self, boolean): + self.writeJSONNumber(1 if boolean is True else 0) + + def writeDouble(self, dbl): + self.writeJSONNumber(dbl) + + def writeString(self, string): + self.writeJSONString(string) + + def writeBinary(self, binary): + self.writeJSONBase64(binary) + + +class TSimpleJSONProtocolFactory(object): + + def getProtocol(self, trans): + return TSimpleJSONProtocol(trans) diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py index f9121c8b..db0bfa46 100755 --- a/test/py/RunClientServer.py +++ b/test/py/RunClientServer.py @@ -46,7 +46,11 @@ generated_dirs = [] for gp_dir in options.genpydirs.split(','): generated_dirs.append('gen-py-%s' % (gp_dir)) -SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py'] +SCRIPTS = ['TSimpleJSONProtocolTest.py', + 'SerializationTest.py', + 'TestEof.py', + 'TestSyntax.py', + 'TestSocket.py'] FRAMED = ["TNonblockingServer"] SKIP_ZLIB = ['TNonblockingServer', 'THttpServer'] SKIP_SSL = ['TNonblockingServer', 'THttpServer'] diff --git a/test/py/TSimpleJSONProtocolTest.py b/test/py/TSimpleJSONProtocolTest.py new file mode 100644 index 00000000..080293a0 --- /dev/null +++ b/test/py/TSimpleJSONProtocolTest.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import sys +import glob +from optparse import OptionParser +parser = OptionParser() +parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py') +options, args = parser.parse_args() +del sys.argv[1:] # clean up hack so unittest doesn't complain +sys.path.insert(0, options.genpydir) +sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0]) + +from ThriftTest.ttypes import * +from thrift.protocol import TJSONProtocol +from thrift.transport import TTransport + +import json +import unittest + + +class SimpleJSONProtocolTest(unittest.TestCase): + protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory() + + def _assertDictEqual(self, a ,b, msg=None): + if hasattr(self, 'assertDictEqual'): + # assertDictEqual only in Python 2.7. Depends on your machine. + self.assertDictEqual(a, b, msg) + return + + # Substitute implementation not as good as unittest library's + self.assertEquals(len(a), len(b), msg) + for k, v in a.iteritems(): + self.assertTrue(k in b, msg) + self.assertEquals(b.get(k), v, msg) + + def _serialize(self, obj): + trans = TTransport.TMemoryBuffer() + prot = self.protocol_factory.getProtocol(trans) + obj.write(prot) + return trans.getvalue() + + def _deserialize(self, objtype, data): + prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) + ret = objtype() + ret.read(prot) + return ret + + def testWriteOnly(self): + self.assertRaises(NotImplementedError, + self._deserialize, VersioningTestV1, '{}') + + def testSimpleMessage(self): + v1obj = VersioningTestV1( + begin_in_both=12345, + old_string='aaa', + end_in_both=54321) + expected = dict(begin_in_both=v1obj.begin_in_both, + old_string=v1obj.old_string, + end_in_both=v1obj.end_in_both) + actual = json.loads(self._serialize(v1obj)) + + self._assertDictEqual(expected, actual) + + def testComplicated(self): + v2obj = VersioningTestV2( + begin_in_both=12345, + newint=1, + newbyte=2, + newshort=3, + newlong=4, + newdouble=5.0, + newstruct=Bonk(message="Hello!", type=123), + newlist=[7,8,9], + newset=set([42,1,8]), + newmap={1:2,2:3}, + newstring="Hola!", + end_in_both=54321) + expected = dict(begin_in_both=v2obj.begin_in_both, + newint=v2obj.newint, + newbyte=v2obj.newbyte, + newshort=v2obj.newshort, + newlong=v2obj.newlong, + newdouble=v2obj.newdouble, + newstruct=dict(message=v2obj.newstruct.message, + type=v2obj.newstruct.type), + newlist=v2obj.newlist, + newset=list(v2obj.newset), + newmap=v2obj.newmap, + newstring=v2obj.newstring, + end_in_both=v2obj.end_in_both) + + # Need to load/dump because map keys get escaped. + expected = json.loads(json.dumps(expected)) + actual = json.loads(self._serialize(v2obj)) + self._assertDictEqual(expected, actual) + + +if __name__ == '__main__': + unittest.main() +