From: Jens Geyer Date: Tue, 18 Mar 2014 22:37:10 +0000 (+0200) Subject: THRIFT-2377 Allow addition of custom HTTP Headers to an HTTP Transport X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=706cb4e4cb7426d25722b1166e0c8b102b20051e;p=common%2Fthrift.git THRIFT-2377 Allow addition of custom HTTP Headers to an HTTP Transport Patch: Sheran Gunasekera --- diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go index 18b1671d..9f609928 100644 --- a/lib/go/thrift/http_client.go +++ b/lib/go/thrift/http_client.go @@ -30,6 +30,7 @@ type THttpClient struct { response *http.Response url *url.URL requestBuffer *bytes.Buffer + header http.Header nsecConnectTimeout int64 nsecReadTimeout int64 } @@ -85,7 +86,37 @@ func NewTHttpPostClient(urlstr string) (TTransport, error) { return nil, err } buf := make([]byte, 0, 1024) - return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf)}, nil + return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil +} + +// Set the HTTP Header for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.SetHeader("User-Agent","Thrift Client 1.0") +func (p *THttpClient) SetHeader(key string, value string) { + p.header.Add(key, value) +} + +// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// hdrValue := httpTrans.GetHeader("User-Agent") +func (p *THttpClient) GetHeader(key string) string { + return p.header.Get(key) +} + +// Deletes the HTTP Header given a Header Key for this specific Thrift Transport +// It is important that you first assert the TTransport as a THttpClient type +// like so: +// +// httpTrans := trans.(THttpClient) +// httpTrans.DelHeader("User-Agent") +func (p *THttpClient) DelHeader(key string) { + p.header.Del(key) } func (p *THttpClient) Open() error { @@ -128,7 +159,14 @@ func (p *THttpClient) Write(buf []byte) (int, error) { } func (p *THttpClient) Flush() error { - response, err := http.Post(p.url.String(), "application/x-thrift", p.requestBuffer) + client := &http.Client{} + req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer) + if err != nil { + return NewTTransportExceptionFromError(err) + } + p.header.Add("Content-Type", "application/x-thrift") + req.Header = p.header + response, err := client.Do(req) if err != nil { return NewTTransportExceptionFromError(err) } diff --git a/lib/go/thrift/http_client_test.go b/lib/go/thrift/http_client_test.go index 041faecc..0c2cb28a 100644 --- a/lib/go/thrift/http_client_test.go +++ b/lib/go/thrift/http_client_test.go @@ -35,3 +35,16 @@ func TestHttpClient(t *testing.T) { } TransportTest(t, trans, trans) } + +func TestHttpClientHeaders(t *testing.T) { + l, addr := HttpClientSetupForTest(t) + if l != nil { + defer l.Close() + } + trans, err := NewTHttpPostClient("http://" + addr.String()) + if err != nil { + l.Close() + t.Fatalf("Unable to connect to %s: %s", addr.String(), err) + } + TransportHeaderTest(t, trans, trans) +} diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go index 632098cc..d88afedc 100644 --- a/lib/go/thrift/protocol_test.go +++ b/lib/go/thrift/protocol_test.go @@ -58,6 +58,7 @@ func init() { } type HTTPEchoServer struct{} +type HTTPHeaderEchoServer struct{} func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { buf, err := ioutil.ReadAll(req.Body) @@ -70,6 +71,17 @@ func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { } } +func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) { + buf, err := ioutil.ReadAll(req.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write(buf) + } else { + w.WriteHeader(http.StatusOK) + w.Write(buf) + } +} + func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) { addr, err := FindAvailableTCPServerPort(40000) if err != nil { @@ -85,6 +97,21 @@ func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) { return l, addr } +func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) { + addr, err := FindAvailableTCPServerPort(40000) + if err != nil { + t.Fatalf("Unable to find available tcp port addr: %s", err) + return nil, addr + } + l, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err) + return l, addr + } + go http.Serve(l, &HTTPHeaderEchoServer{}) + return l, addr +} + func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { buf := bytes.NewBuffer(make([]byte, 0, 1024)) l, addr := HttpClientSetupForTest(t) @@ -145,13 +172,13 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) { } for _, tf := range transports { - trans := tf.GetTransport(nil) - p := protocolFactory.GetProtocol(trans); - ReadWriteI64(t, p, trans); - ReadWriteDouble(t, p, trans); - ReadWriteBinary(t, p, trans); - ReadWriteByte(t, p, trans); - trans.Close() + trans := tf.GetTransport(nil) + p := protocolFactory.GetProtocol(trans) + ReadWriteI64(t, p, trans) + ReadWriteDouble(t, p, trans) + ReadWriteBinary(t, p, trans) + ReadWriteByte(t, p, trans) + trans.Close() } } diff --git a/lib/go/thrift/transport_test.go b/lib/go/thrift/transport_test.go index c9f1d56c..864958a9 100644 --- a/lib/go/thrift/transport_test.go +++ b/lib/go/thrift/transport_test.go @@ -29,7 +29,8 @@ import ( const TRANSPORT_BINARY_DATA_SIZE = 4096 var ( - transport_bdata []byte // test data for writing; same as data + transport_bdata []byte // test data for writing; same as data + transport_header map[string]string ) func init() { @@ -37,6 +38,8 @@ func init() { for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ { transport_bdata[i] = byte((i + 'a') % 255) } + transport_header = map[string]string{"key": "User-Agent", + "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"} } func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { @@ -94,6 +97,50 @@ func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { } } +func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) { + buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE) + if !writeTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans) + } + if !readTrans.IsOpen() { + t.Fatalf("Transport %T not open: %s", readTrans, readTrans) + } + // Need to assert type of TTransport to THttpClient to expose the Setter + httpWPostTrans := writeTrans.(*THttpClient) + httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"]) + + _, err := writeTrans.Write(transport_bdata) + if err != nil { + t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err) + } + err = writeTrans.Flush() + if err != nil { + t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err) + } + // Need to assert type of TTransport to THttpClient to expose the Getter + httpRPostTrans := readTrans.(*THttpClient) + readHeader := httpRPostTrans.GetHeader(transport_header["key"]) + if err != nil { + t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans) + } + + if transport_header["value"] != readHeader { + t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader) + } + n, err := io.ReadFull(readTrans, buf) + if err != nil { + t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err) + } + if n != TRANSPORT_BINARY_DATA_SIZE { + t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE) + } + for k, v := range buf { + if v != transport_bdata[k] { + t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k) + } + } +} + func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) { err := readTrans.Close() if err != nil { @@ -118,3 +165,12 @@ func FindAvailableTCPServerPort(startPort int) (net.Addr, error) { } return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port") } + +func valueInSlice(value string, slice []string) bool { + for _, v := range slice { + if value == v { + return true + } + } + return false +}