From ead3382915e3e50845a2c6d0acdd75dc00dd3be3 Mon Sep 17 00:00:00 2001 From: Kevin Clark Date: Wed, 4 Feb 2009 22:43:59 +0000 Subject: [PATCH] THRIFT-254. rb: Add optional strict version support to binary protocols Author: Michael Stockton git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@740930 13f79535-47bb-0310-9956-ffa450edef68 --- lib/rb/ext/binary_protocol_accelerated.c | 43 +++++++++++++++---- lib/rb/ext/constants.h | 2 + lib/rb/ext/thrift_native.c | 4 ++ lib/rb/lib/thrift/protocol/binaryprotocol.rb | 45 +++++++++++++++----- lib/rb/spec/binaryprotocol_spec.rb | 15 ++++++- lib/rb/spec/binaryprotocol_spec_shared.rb | 24 ++++++++++- 6 files changed, 111 insertions(+), 22 deletions(-) diff --git a/lib/rb/ext/binary_protocol_accelerated.c b/lib/rb/ext/binary_protocol_accelerated.c index 8a6757f8..fc4b675b 100644 --- a/lib/rb/ext/binary_protocol_accelerated.c +++ b/lib/rb/ext/binary_protocol_accelerated.c @@ -5,6 +5,8 @@ #include #define GET_TRANSPORT(obj) rb_ivar_get(obj, transport_ivar_id) +#define GET_STRICT_READ(obj) rb_ivar_get(obj, strict_read_ivar_id) +#define GET_STRICT_WRITE(obj) rb_ivar_get(obj, strict_write_ivar_id) #define WRITE(obj, data, length) rb_funcall(obj, write_method_id, 1, rb_str_new(data, length)) #define CHECK_NIL(obj) if (NIL_P(obj)) { rb_raise(rb_eStandardError, "nil argument not allowed!");} @@ -16,6 +18,7 @@ VALUE rb_thrift_binary_proto_native_qmark(VALUE self) { static int VERSION_1; static int VERSION_MASK; +static int TYPE_MASK; static int BAD_VERSION; static void write_byte_direct(VALUE trans, int8_t b) { @@ -97,9 +100,17 @@ VALUE rb_thrift_binary_proto_write_set_end(VALUE self) { VALUE rb_thrift_binary_proto_write_message_begin(VALUE self, VALUE name, VALUE type, VALUE seqid) { VALUE trans = GET_TRANSPORT(self); - write_i32_direct(trans, VERSION_1 | FIX2INT(type)); - write_string_direct(trans, name); - write_i32_direct(trans, FIX2INT(seqid)); + VALUE strict_write = GET_STRICT_WRITE(self); + + if (strict_write == Qtrue) { + write_i32_direct(trans, VERSION_1 | FIX2INT(type)); + write_string_direct(trans, name); + write_i32_direct(trans, FIX2INT(seqid)); + } else { + write_string_direct(trans, name); + write_byte_direct(trans, type); + write_i32_direct(trans, FIX2INT(seqid)); + } return Qnil; } @@ -260,14 +271,27 @@ VALUE rb_thift_binary_proto_read_set_end(VALUE self) { } VALUE rb_thrift_binary_proto_read_message_begin(VALUE self) { + VALUE strict_read = GET_STRICT_READ(self); + VALUE name, seqid; + int type; + int version = read_i32_direct(self); - if ((version & VERSION_MASK) != VERSION_1) { - rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier"))); - } - int type = version & 0x000000ff; - VALUE name = rb_thrift_binary_proto_read_string(self); - VALUE seqid = rb_thrift_binary_proto_read_i32(self); + if (version < 0) { + if ((version & VERSION_MASK) != VERSION_1) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("Missing version identifier"))); + } + type = version & TYPE_MASK; + name = rb_thrift_binary_proto_read_string(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } else { + if (strict_read == Qtrue) { + rb_exc_raise(get_protocol_exception(INT2FIX(BAD_VERSION), rb_str_new2("No version identifier, old protocol client?"))); + } + name = READ(self, version); + type = rb_thrift_binary_proto_read_byte(self); + seqid = rb_thrift_binary_proto_read_i32(self); + } return rb_ary_new3(3, name, INT2FIX(type), seqid); } @@ -339,6 +363,7 @@ void Init_binary_protocol_accelerated() { VERSION_1 = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_1"))); VERSION_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("VERSION_MASK"))); + TYPE_MASK = rb_num2ll(rb_const_get(thrift_binary_protocol_class, rb_intern("TYPE_MASK"))); VALUE bpa_class = rb_define_class_under(thrift_module, "BinaryProtocolAccelerated", thrift_binary_protocol_class); diff --git a/lib/rb/ext/constants.h b/lib/rb/ext/constants.h index e540234d..1922fb16 100644 --- a/lib/rb/ext/constants.h +++ b/lib/rb/ext/constants.h @@ -60,6 +60,8 @@ extern ID native_qmark_method_id; extern ID fields_const_id; extern ID transport_ivar_id; +extern ID strict_read_ivar_id; +extern ID strict_write_ivar_id; extern VALUE type_sym; extern VALUE name_sym; diff --git a/lib/rb/ext/thrift_native.c b/lib/rb/ext/thrift_native.c index 89d32c52..4d5623d4 100644 --- a/lib/rb/ext/thrift_native.c +++ b/lib/rb/ext/thrift_native.c @@ -73,6 +73,8 @@ ID native_qmark_method_id; // constant ids ID fields_const_id; ID transport_ivar_id; +ID strict_read_ivar_id; +ID strict_write_ivar_id; // cached symbols VALUE type_sym; @@ -153,6 +155,8 @@ void Init_thrift_native() { // constant ids fields_const_id = rb_intern("FIELDS"); transport_ivar_id = rb_intern("@trans"); + strict_read_ivar_id = rb_intern("@strict_read"); + strict_write_ivar_id = rb_intern("@strict_write"); // cached symbols type_sym = ID2SYM(rb_intern("type")); diff --git a/lib/rb/lib/thrift/protocol/binaryprotocol.rb b/lib/rb/lib/thrift/protocol/binaryprotocol.rb index f0fa3ad5..c3d927ac 100644 --- a/lib/rb/lib/thrift/protocol/binaryprotocol.rb +++ b/lib/rb/lib/thrift/protocol/binaryprotocol.rb @@ -13,14 +13,29 @@ module Thrift class BinaryProtocol < Protocol VERSION_MASK = 0xffff0000 VERSION_1 = 0x80010000 + TYPE_MASK = 0x000000ff + + attr_reader :strict_read, :strict_write + + def initialize(trans, strict_read=true, strict_write=true) + super(trans) + @strict_read = strict_read + @strict_write = strict_write + end def write_message_begin(name, type, seqid) # this is necessary because we added (needed) bounds checking to # write_i32, and 0x80010000 is too big for that. - write_i16(VERSION_1 >> 16) - write_i16(type) - write_string(name) - write_i32(seqid) + if strict_write + write_i16(VERSION_1 >> 16) + write_i16(type) + write_string(name) + write_i32(seqid) + else + write_string(name) + write_byte(type) + write_i32(seqid) + end end def write_field_begin(name, type, id) @@ -82,13 +97,23 @@ module Thrift def read_message_begin version = read_i32 - if (version & VERSION_MASK != VERSION_1) - raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier') + if version < 0 + if (version & VERSION_MASK != VERSION_1) + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier') + end + type = version & TYPE_MASK + name = read_string + seqid = read_i32 + [name, type, seqid] + else + if strict_read + raise ProtocolException.new(ProtocolException::BAD_VERSION, 'No version identifier, old protocol client?') + end + name = trans.read_all(version) + type = read_byte + seqid = read_i32 + [name, type, seqid] end - type = version & 0x000000ff - name = read_string - seqid = read_i32 - [name, type, seqid] end def read_field_begin diff --git a/lib/rb/spec/binaryprotocol_spec.rb b/lib/rb/spec/binaryprotocol_spec.rb index b85f096d..2d5b3751 100644 --- a/lib/rb/spec/binaryprotocol_spec.rb +++ b/lib/rb/spec/binaryprotocol_spec.rb @@ -13,17 +13,28 @@ class ThriftBinaryProtocolSpec < Spec::ExampleGroup end it "should read a message header" do - @prot.should_receive(:read_i32).and_return(protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY, 42) + @trans.should_receive(:read_all).exactly(2).times.and_return( + [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::REPLY].pack('N'), + [42].pack('N') + ) @prot.should_receive(:read_string).and_return('testMessage') @prot.read_message_begin.should == ['testMessage', Thrift::MessageTypes::REPLY, 42] end it "should raise an exception if the message header has the wrong version" do - @prot.should_receive(:read_i32).and_return(42) + @prot.should_receive(:read_i32).and_return(-1) lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'Missing version identifier') do |e| e.type == Thrift::ProtocolException::BAD_VERSION end end + + it "should raise an exception if the message header does not exist and strict_read is enabled" do + @prot.should_receive(:read_i32).and_return(42) + @prot.should_receive(:strict_read).and_return(true) + lambda { @prot.read_message_begin }.should raise_error(Thrift::ProtocolException, 'No version identifier, old protocol client?') do |e| + e.type == Thrift::ProtocolException::BAD_VERSION + end + end end describe BinaryProtocolFactory do diff --git a/lib/rb/spec/binaryprotocol_spec_shared.rb b/lib/rb/spec/binaryprotocol_spec_shared.rb index 78e2ccba..1d685b65 100644 --- a/lib/rb/spec/binaryprotocol_spec_shared.rb +++ b/lib/rb/spec/binaryprotocol_spec_shared.rb @@ -6,15 +6,37 @@ shared_examples_for 'a binary protocol' do @prot = protocol_class.new(@trans) end - it "should define the proper VERSION_1 and VERSION_MASK" do + it "should define the proper VERSION_1, VERSION_MASK AND TYPE_MASK" do protocol_class.const_get(:VERSION_MASK).should == 0xffff0000 protocol_class.const_get(:VERSION_1).should == 0x80010000 + protocol_class.const_get(:TYPE_MASK).should == 0x000000ff end + it "should make strict_read readable" do + @prot.strict_read.should eql(true) + end + + it "should make strict_write readable" do + @prot.strict_write.should eql(true) + end + it "should write the message header" do @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) @trans.read(1000).should == [protocol_class.const_get(:VERSION_1) | Thrift::MessageTypes::CALL, "testMessage".size, "testMessage", 17].pack("NNa11N") end + + it "should write the message header without version when writes are not strict" do + @prot = protocol_class.new(@trans, true, false) # no strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\000\000\000\vtestMessage\001\000\000\000\021" + end + + it "should write the message header with a version when writes are strict" do + @prot = protocol_class.new(@trans) # strict write + @prot.write_message_begin('testMessage', Thrift::MessageTypes::CALL, 17) + @trans.read(1000).should == "\200\001\000\001\000\000\000\vtestMessage\000\000\000\021" + end + # message footer is a noop -- 2.17.1