| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 1 | from TProtocol import * | 
|  | 2 | from struct import pack, unpack | 
|  | 3 |  | 
|  | 4 | __all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] | 
|  | 5 |  | 
|  | 6 | CLEAR = 0 | 
|  | 7 | FIELD_WRITE = 1 | 
|  | 8 | VALUE_WRITE = 2 | 
|  | 9 | CONTAINER_WRITE = 3 | 
|  | 10 | BOOL_WRITE = 4 | 
|  | 11 | FIELD_READ = 5 | 
|  | 12 | CONTAINER_READ = 6 | 
|  | 13 | VALUE_READ = 7 | 
|  | 14 | BOOL_READ = 8 | 
|  | 15 |  | 
|  | 16 | def make_helper(v_from, container): | 
|  | 17 | def helper(func): | 
|  | 18 | def nested(self, *args, **kwargs): | 
|  | 19 | assert self.state in (v_from, container), (self.state, v_from, container) | 
|  | 20 | return func(self, *args, **kwargs) | 
|  | 21 | return nested | 
|  | 22 | return helper | 
|  | 23 | writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) | 
|  | 24 | reader = make_helper(VALUE_READ, CONTAINER_READ) | 
|  | 25 |  | 
|  | 26 | def makeZigZag(n, bits): | 
|  | 27 | return (n << 1) ^ (n >> (bits - 1)) | 
|  | 28 |  | 
|  | 29 | def fromZigZag(n): | 
|  | 30 | return (n >> 1) ^ -(n & 1) | 
|  | 31 |  | 
|  | 32 | def writeVarint(trans, n): | 
|  | 33 | out = [] | 
|  | 34 | while True: | 
|  | 35 | if n & ~0x7f == 0: | 
|  | 36 | out.append(n) | 
|  | 37 | break | 
|  | 38 | else: | 
|  | 39 | out.append((n & 0xff) | 0x80) | 
|  | 40 | n = n >> 7 | 
|  | 41 | trans.write(''.join(map(chr, out))) | 
|  | 42 |  | 
|  | 43 | def readVarint(trans): | 
|  | 44 | result = 0 | 
|  | 45 | shift = 0 | 
|  | 46 | while True: | 
|  | 47 | x = trans.readAll(1) | 
|  | 48 | byte = ord(x) | 
|  | 49 | result |= (byte & 0x7f) << shift | 
|  | 50 | if byte >> 7 == 0: | 
|  | 51 | return result | 
|  | 52 | shift += 7 | 
|  | 53 |  | 
|  | 54 | class CompactType: | 
| Bryan Duxbury | df4cffd | 2011-03-15 17:16:09 +0000 | [diff] [blame] | 55 | STOP = 0x00 | 
|  | 56 | TRUE = 0x01 | 
|  | 57 | FALSE = 0x02 | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 58 | BYTE = 0x03 | 
|  | 59 | I16 = 0x04 | 
|  | 60 | I32 = 0x05 | 
|  | 61 | I64 = 0x06 | 
|  | 62 | DOUBLE = 0x07 | 
|  | 63 | BINARY = 0x08 | 
|  | 64 | LIST = 0x09 | 
|  | 65 | SET = 0x0A | 
|  | 66 | MAP = 0x0B | 
|  | 67 | STRUCT = 0x0C | 
|  | 68 |  | 
| Bryan Duxbury | df4cffd | 2011-03-15 17:16:09 +0000 | [diff] [blame] | 69 | CTYPES = {TType.STOP: CompactType.STOP, | 
|  | 70 | TType.BOOL: CompactType.TRUE, # used for collection | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 71 | TType.BYTE: CompactType.BYTE, | 
|  | 72 | TType.I16: CompactType.I16, | 
|  | 73 | TType.I32: CompactType.I32, | 
|  | 74 | TType.I64: CompactType.I64, | 
|  | 75 | TType.DOUBLE: CompactType.DOUBLE, | 
|  | 76 | TType.STRING: CompactType.BINARY, | 
|  | 77 | TType.STRUCT: CompactType.STRUCT, | 
|  | 78 | TType.LIST: CompactType.LIST, | 
|  | 79 | TType.SET: CompactType.SET, | 
| Bryan Duxbury | df4cffd | 2011-03-15 17:16:09 +0000 | [diff] [blame] | 80 | TType.MAP: CompactType.MAP | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 81 | } | 
|  | 82 |  | 
|  | 83 | TTYPES = {} | 
|  | 84 | for k, v in CTYPES.items(): | 
|  | 85 | TTYPES[v] = k | 
|  | 86 | TTYPES[CompactType.FALSE] = TType.BOOL | 
|  | 87 | del k | 
|  | 88 | del v | 
|  | 89 |  | 
|  | 90 | class TCompactProtocol(TProtocolBase): | 
|  | 91 | "Compact implementation of the Thrift protocol driver." | 
|  | 92 |  | 
|  | 93 | PROTOCOL_ID = 0x82 | 
|  | 94 | VERSION = 1 | 
|  | 95 | VERSION_MASK = 0x1f | 
|  | 96 | TYPE_MASK = 0xe0 | 
|  | 97 | TYPE_SHIFT_AMOUNT = 5 | 
|  | 98 |  | 
|  | 99 | def __init__(self, trans): | 
|  | 100 | TProtocolBase.__init__(self, trans) | 
|  | 101 | self.state = CLEAR | 
|  | 102 | self.__last_fid = 0 | 
|  | 103 | self.__bool_fid = None | 
|  | 104 | self.__bool_value = None | 
|  | 105 | self.__structs = [] | 
|  | 106 | self.__containers = [] | 
|  | 107 |  | 
|  | 108 | def __writeVarint(self, n): | 
|  | 109 | writeVarint(self.trans, n) | 
|  | 110 |  | 
|  | 111 | def writeMessageBegin(self, name, type, seqid): | 
|  | 112 | assert self.state == CLEAR | 
|  | 113 | self.__writeUByte(self.PROTOCOL_ID) | 
|  | 114 | self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) | 
|  | 115 | self.__writeVarint(seqid) | 
|  | 116 | self.__writeString(name) | 
|  | 117 | self.state = VALUE_WRITE | 
|  | 118 |  | 
|  | 119 | def writeMessageEnd(self): | 
|  | 120 | assert self.state == VALUE_WRITE | 
|  | 121 | self.state = CLEAR | 
|  | 122 |  | 
|  | 123 | def writeStructBegin(self, name): | 
|  | 124 | assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state | 
|  | 125 | self.__structs.append((self.state, self.__last_fid)) | 
|  | 126 | self.state = FIELD_WRITE | 
|  | 127 | self.__last_fid = 0 | 
|  | 128 |  | 
|  | 129 | def writeStructEnd(self): | 
|  | 130 | assert self.state == FIELD_WRITE | 
|  | 131 | self.state, self.__last_fid = self.__structs.pop() | 
|  | 132 |  | 
|  | 133 | def writeFieldStop(self): | 
|  | 134 | self.__writeByte(0) | 
|  | 135 |  | 
|  | 136 | def __writeFieldHeader(self, type, fid): | 
|  | 137 | delta = fid - self.__last_fid | 
|  | 138 | if 0 < delta <= 15: | 
|  | 139 | self.__writeUByte(delta << 4 | type) | 
|  | 140 | else: | 
|  | 141 | self.__writeByte(type) | 
|  | 142 | self.__writeI16(fid) | 
|  | 143 | self.__last_fid = fid | 
|  | 144 |  | 
|  | 145 | def writeFieldBegin(self, name, type, fid): | 
|  | 146 | assert self.state == FIELD_WRITE, self.state | 
|  | 147 | if type == TType.BOOL: | 
|  | 148 | self.state = BOOL_WRITE | 
|  | 149 | self.__bool_fid = fid | 
|  | 150 | else: | 
|  | 151 | self.state = VALUE_WRITE | 
|  | 152 | self.__writeFieldHeader(CTYPES[type], fid) | 
|  | 153 |  | 
|  | 154 | def writeFieldEnd(self): | 
|  | 155 | assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state | 
|  | 156 | self.state = FIELD_WRITE | 
|  | 157 |  | 
|  | 158 | def __writeUByte(self, byte): | 
|  | 159 | self.trans.write(pack('!B', byte)) | 
|  | 160 |  | 
|  | 161 | def __writeByte(self, byte): | 
|  | 162 | self.trans.write(pack('!b', byte)) | 
|  | 163 |  | 
|  | 164 | def __writeI16(self, i16): | 
|  | 165 | self.__writeVarint(makeZigZag(i16, 16)) | 
|  | 166 |  | 
|  | 167 | def __writeSize(self, i32): | 
|  | 168 | self.__writeVarint(i32) | 
|  | 169 |  | 
|  | 170 | def writeCollectionBegin(self, etype, size): | 
|  | 171 | assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state | 
|  | 172 | if size <= 14: | 
|  | 173 | self.__writeUByte(size << 4 | CTYPES[etype]) | 
|  | 174 | else: | 
|  | 175 | self.__writeUByte(0xf0 | CTYPES[etype]) | 
|  | 176 | self.__writeSize(size) | 
|  | 177 | self.__containers.append(self.state) | 
|  | 178 | self.state = CONTAINER_WRITE | 
|  | 179 | writeSetBegin = writeCollectionBegin | 
|  | 180 | writeListBegin = writeCollectionBegin | 
|  | 181 |  | 
|  | 182 | def writeMapBegin(self, ktype, vtype, size): | 
|  | 183 | assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state | 
|  | 184 | if size == 0: | 
|  | 185 | self.__writeByte(0) | 
|  | 186 | else: | 
|  | 187 | self.__writeSize(size) | 
|  | 188 | self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) | 
|  | 189 | self.__containers.append(self.state) | 
|  | 190 | self.state = CONTAINER_WRITE | 
|  | 191 |  | 
|  | 192 | def writeCollectionEnd(self): | 
|  | 193 | assert self.state == CONTAINER_WRITE, self.state | 
|  | 194 | self.state = self.__containers.pop() | 
|  | 195 | writeMapEnd = writeCollectionEnd | 
|  | 196 | writeSetEnd = writeCollectionEnd | 
|  | 197 | writeListEnd = writeCollectionEnd | 
|  | 198 |  | 
|  | 199 | def writeBool(self, bool): | 
|  | 200 | if self.state == BOOL_WRITE: | 
| Bryan Duxbury | df4cffd | 2011-03-15 17:16:09 +0000 | [diff] [blame] | 201 | if bool: | 
|  | 202 | ctype = CompactType.TRUE | 
|  | 203 | else: | 
|  | 204 | ctype = CompactType.FALSE | 
|  | 205 | self.__writeFieldHeader(ctype, self.__bool_fid) | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 206 | elif self.state == CONTAINER_WRITE: | 
|  | 207 | self.__writeByte(int(bool)) | 
|  | 208 | else: | 
| Bryan Duxbury | df4cffd | 2011-03-15 17:16:09 +0000 | [diff] [blame] | 209 | raise AssertionError, "Invalid state in compact protocol" | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 210 |  | 
|  | 211 | writeByte = writer(__writeByte) | 
|  | 212 | writeI16 = writer(__writeI16) | 
|  | 213 |  | 
|  | 214 | @writer | 
|  | 215 | def writeI32(self, i32): | 
|  | 216 | self.__writeVarint(makeZigZag(i32, 32)) | 
|  | 217 |  | 
|  | 218 | @writer | 
|  | 219 | def writeI64(self, i64): | 
|  | 220 | self.__writeVarint(makeZigZag(i64, 64)) | 
|  | 221 |  | 
|  | 222 | @writer | 
|  | 223 | def writeDouble(self, dub): | 
|  | 224 | self.trans.write(pack('!d', dub)) | 
|  | 225 |  | 
|  | 226 | def __writeString(self, s): | 
|  | 227 | self.__writeSize(len(s)) | 
|  | 228 | self.trans.write(s) | 
|  | 229 | writeString = writer(__writeString) | 
|  | 230 |  | 
|  | 231 | def readFieldBegin(self): | 
|  | 232 | assert self.state == FIELD_READ, self.state | 
|  | 233 | type = self.__readUByte() | 
|  | 234 | if type & 0x0f == TType.STOP: | 
|  | 235 | return (None, 0, 0) | 
|  | 236 | delta = type >> 4 | 
|  | 237 | if delta == 0: | 
|  | 238 | fid = self.__readI16() | 
|  | 239 | else: | 
|  | 240 | fid = self.__last_fid + delta | 
|  | 241 | self.__last_fid = fid | 
|  | 242 | type = type & 0x0f | 
|  | 243 | if type == CompactType.TRUE: | 
|  | 244 | self.state = BOOL_READ | 
|  | 245 | self.__bool_value = True | 
|  | 246 | elif type == CompactType.FALSE: | 
|  | 247 | self.state = BOOL_READ | 
|  | 248 | self.__bool_value = False | 
|  | 249 | else: | 
|  | 250 | self.state = VALUE_READ | 
|  | 251 | return (None, self.__getTType(type), fid) | 
|  | 252 |  | 
|  | 253 | def readFieldEnd(self): | 
|  | 254 | assert self.state in (VALUE_READ, BOOL_READ), self.state | 
|  | 255 | self.state = FIELD_READ | 
|  | 256 |  | 
|  | 257 | def __readUByte(self): | 
|  | 258 | result, = unpack('!B', self.trans.readAll(1)) | 
|  | 259 | return result | 
|  | 260 |  | 
|  | 261 | def __readByte(self): | 
|  | 262 | result, = unpack('!b', self.trans.readAll(1)) | 
|  | 263 | return result | 
|  | 264 |  | 
|  | 265 | def __readVarint(self): | 
|  | 266 | return readVarint(self.trans) | 
|  | 267 |  | 
|  | 268 | def __readZigZag(self): | 
|  | 269 | return fromZigZag(self.__readVarint()) | 
|  | 270 |  | 
|  | 271 | def __readSize(self): | 
|  | 272 | result = self.__readVarint() | 
|  | 273 | if result < 0: | 
|  | 274 | raise TException("Length < 0") | 
|  | 275 | return result | 
|  | 276 |  | 
|  | 277 | def readMessageBegin(self): | 
|  | 278 | assert self.state == CLEAR | 
|  | 279 | proto_id = self.__readUByte() | 
|  | 280 | if proto_id != self.PROTOCOL_ID: | 
|  | 281 | raise TProtocolException(TProtocolException.BAD_VERSION, | 
|  | 282 | 'Bad protocol id in the message: %d' % proto_id) | 
|  | 283 | ver_type = self.__readUByte() | 
|  | 284 | type = (ver_type & self.TYPE_MASK) >> self.TYPE_SHIFT_AMOUNT | 
|  | 285 | version = ver_type & self.VERSION_MASK | 
|  | 286 | if version != self.VERSION: | 
|  | 287 | raise TProtocolException(TProtocolException.BAD_VERSION, | 
|  | 288 | 'Bad version: %d (expect %d)' % (version, self.VERSION)) | 
|  | 289 | seqid = self.__readVarint() | 
|  | 290 | name = self.__readString() | 
|  | 291 | return (name, type, seqid) | 
|  | 292 |  | 
|  | 293 | def readMessageEnd(self): | 
| Bryan Duxbury | 59d4efd | 2011-03-21 17:38:22 +0000 | [diff] [blame] | 294 | assert self.state == CLEAR | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 295 | assert len(self.__structs) == 0 | 
| David Reiss | abafd79 | 2010-09-27 17:28:15 +0000 | [diff] [blame] | 296 |  | 
|  | 297 | def readStructBegin(self): | 
|  | 298 | assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state | 
|  | 299 | self.__structs.append((self.state, self.__last_fid)) | 
|  | 300 | self.state = FIELD_READ | 
|  | 301 | self.__last_fid = 0 | 
|  | 302 |  | 
|  | 303 | def readStructEnd(self): | 
|  | 304 | assert self.state == FIELD_READ | 
|  | 305 | self.state, self.__last_fid = self.__structs.pop() | 
|  | 306 |  | 
|  | 307 | def readCollectionBegin(self): | 
|  | 308 | assert self.state in (VALUE_READ, CONTAINER_READ), self.state | 
|  | 309 | size_type = self.__readUByte() | 
|  | 310 | size = size_type >> 4 | 
|  | 311 | type = self.__getTType(size_type) | 
|  | 312 | if size == 15: | 
|  | 313 | size = self.__readSize() | 
|  | 314 | self.__containers.append(self.state) | 
|  | 315 | self.state = CONTAINER_READ | 
|  | 316 | return type, size | 
|  | 317 | readSetBegin = readCollectionBegin | 
|  | 318 | readListBegin = readCollectionBegin | 
|  | 319 |  | 
|  | 320 | def readMapBegin(self): | 
|  | 321 | assert self.state in (VALUE_READ, CONTAINER_READ), self.state | 
|  | 322 | size = self.__readSize() | 
|  | 323 | types = 0 | 
|  | 324 | if size > 0: | 
|  | 325 | types = self.__readUByte() | 
|  | 326 | vtype = self.__getTType(types) | 
|  | 327 | ktype = self.__getTType(types >> 4) | 
|  | 328 | self.__containers.append(self.state) | 
|  | 329 | self.state = CONTAINER_READ | 
|  | 330 | return (ktype, vtype, size) | 
|  | 331 |  | 
|  | 332 | def readCollectionEnd(self): | 
|  | 333 | assert self.state == CONTAINER_READ, self.state | 
|  | 334 | self.state = self.__containers.pop() | 
|  | 335 | readSetEnd = readCollectionEnd | 
|  | 336 | readListEnd = readCollectionEnd | 
|  | 337 | readMapEnd = readCollectionEnd | 
|  | 338 |  | 
|  | 339 | def readBool(self): | 
|  | 340 | if self.state == BOOL_READ: | 
|  | 341 | return self.__bool_value | 
|  | 342 | elif self.state == CONTAINER_READ: | 
|  | 343 | return bool(self.__readByte()) | 
|  | 344 | else: | 
|  | 345 | raise AssertionError, "Invalid state in compact protocol: %d" % self.state | 
|  | 346 |  | 
|  | 347 | readByte = reader(__readByte) | 
|  | 348 | __readI16 = __readZigZag | 
|  | 349 | readI16 = reader(__readZigZag) | 
|  | 350 | readI32 = reader(__readZigZag) | 
|  | 351 | readI64 = reader(__readZigZag) | 
|  | 352 |  | 
|  | 353 | @reader | 
|  | 354 | def readDouble(self): | 
|  | 355 | buff = self.trans.readAll(8) | 
|  | 356 | val, = unpack('!d', buff) | 
|  | 357 | return val | 
|  | 358 |  | 
|  | 359 | def __readString(self): | 
|  | 360 | len = self.__readSize() | 
|  | 361 | return self.trans.readAll(len) | 
|  | 362 | readString = reader(__readString) | 
|  | 363 |  | 
|  | 364 | def __getTType(self, byte): | 
|  | 365 | return TTYPES[byte & 0x0f] | 
|  | 366 |  | 
|  | 367 |  | 
|  | 368 | class TCompactProtocolFactory: | 
|  | 369 | def __init__(self): | 
|  | 370 | pass | 
|  | 371 |  | 
|  | 372 | def getProtocol(self, trans): | 
|  | 373 | return TCompactProtocol(trans) |