THRIFT-2377 Allow addition of custom HTTP Headers to an HTTP Transport
authorJens Geyer <jensg@apache.org>
Tue, 18 Mar 2014 22:37:10 +0000 (00:37 +0200)
committerJens Geyer <jensg@apache.org>
Tue, 18 Mar 2014 22:37:10 +0000 (00:37 +0200)
Patch: Sheran Gunasekera

lib/go/thrift/http_client.go
lib/go/thrift/http_client_test.go
lib/go/thrift/protocol_test.go
lib/go/thrift/transport_test.go

index 18b1671..9f60992 100644 (file)
@@ -30,6 +30,7 @@ type THttpClient struct {
        response           *http.Response
        url                *url.URL
        requestBuffer      *bytes.Buffer
+       header             http.Header
        nsecConnectTimeout int64
        nsecReadTimeout    int64
 }
@@ -85,7 +86,37 @@ func NewTHttpPostClient(urlstr string) (TTransport, error) {
                return nil, err
        }
        buf := make([]byte, 0, 1024)
-       return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf)}, nil
+       return &THttpClient{url: parsedURL, requestBuffer: bytes.NewBuffer(buf), header: http.Header{}}, nil
+}
+
+// Set the HTTP Header for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.SetHeader("User-Agent","Thrift Client 1.0")
+func (p *THttpClient) SetHeader(key string, value string) {
+       p.header.Add(key, value)
+}
+
+// Get the HTTP Header represented by the supplied Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// hdrValue := httpTrans.GetHeader("User-Agent")
+func (p *THttpClient) GetHeader(key string) string {
+       return p.header.Get(key)
+}
+
+// Deletes the HTTP Header given a Header Key for this specific Thrift Transport
+// It is important that you first assert the TTransport as a THttpClient type
+// like so:
+//
+// httpTrans := trans.(THttpClient)
+// httpTrans.DelHeader("User-Agent")
+func (p *THttpClient) DelHeader(key string) {
+       p.header.Del(key)
 }
 
 func (p *THttpClient) Open() error {
@@ -128,7 +159,14 @@ func (p *THttpClient) Write(buf []byte) (int, error) {
 }
 
 func (p *THttpClient) Flush() error {
-       response, err := http.Post(p.url.String(), "application/x-thrift", p.requestBuffer)
+       client := &http.Client{}
+       req, err := http.NewRequest("POST", p.url.String(), p.requestBuffer)
+       if err != nil {
+               return NewTTransportExceptionFromError(err)
+       }
+       p.header.Add("Content-Type", "application/x-thrift")
+       req.Header = p.header
+       response, err := client.Do(req)
        if err != nil {
                return NewTTransportExceptionFromError(err)
        }
index 041faec..0c2cb28 100644 (file)
@@ -35,3 +35,16 @@ func TestHttpClient(t *testing.T) {
        }
        TransportTest(t, trans, trans)
 }
+
+func TestHttpClientHeaders(t *testing.T) {
+       l, addr := HttpClientSetupForTest(t)
+       if l != nil {
+               defer l.Close()
+       }
+       trans, err := NewTHttpPostClient("http://" + addr.String())
+       if err != nil {
+               l.Close()
+               t.Fatalf("Unable to connect to %s: %s", addr.String(), err)
+       }
+       TransportHeaderTest(t, trans, trans)
+}
index 632098c..d88afed 100644 (file)
@@ -58,6 +58,7 @@ func init() {
 }
 
 type HTTPEchoServer struct{}
+type HTTPHeaderEchoServer struct{}
 
 func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        buf, err := ioutil.ReadAll(req.Body)
@@ -70,6 +71,17 @@ func (p *HTTPEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
        }
 }
 
+func (p *HTTPHeaderEchoServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
+       buf, err := ioutil.ReadAll(req.Body)
+       if err != nil {
+               w.WriteHeader(http.StatusBadRequest)
+               w.Write(buf)
+       } else {
+               w.WriteHeader(http.StatusOK)
+               w.Write(buf)
+       }
+}
+
 func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
        addr, err := FindAvailableTCPServerPort(40000)
        if err != nil {
@@ -85,6 +97,21 @@ func HttpClientSetupForTest(t *testing.T) (net.Listener, net.Addr) {
        return l, addr
 }
 
+func HttpClientSetupForHeaderTest(t *testing.T) (net.Listener, net.Addr) {
+       addr, err := FindAvailableTCPServerPort(40000)
+       if err != nil {
+               t.Fatalf("Unable to find available tcp port addr: %s", err)
+               return nil, addr
+       }
+       l, err := net.Listen(addr.Network(), addr.String())
+       if err != nil {
+               t.Fatalf("Unable to setup tcp listener on %s: %s", addr.String(), err)
+               return l, addr
+       }
+       go http.Serve(l, &HTTPHeaderEchoServer{})
+       return l, addr
+}
+
 func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
        buf := bytes.NewBuffer(make([]byte, 0, 1024))
        l, addr := HttpClientSetupForTest(t)
@@ -145,13 +172,13 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
        }
 
        for _, tf := range transports {
-         trans := tf.GetTransport(nil)
-         p := protocolFactory.GetProtocol(trans);
-         ReadWriteI64(t, p, trans);
-         ReadWriteDouble(t, p, trans);
-         ReadWriteBinary(t, p, trans);
-         ReadWriteByte(t, p, trans);
-         trans.Close()
+               trans := tf.GetTransport(nil)
+               p := protocolFactory.GetProtocol(trans)
+               ReadWriteI64(t, p, trans)
+               ReadWriteDouble(t, p, trans)
+               ReadWriteBinary(t, p, trans)
+               ReadWriteByte(t, p, trans)
+               trans.Close()
        }
 
 }
index c9f1d56..864958a 100644 (file)
@@ -29,7 +29,8 @@ import (
 const TRANSPORT_BINARY_DATA_SIZE = 4096
 
 var (
-       transport_bdata []byte // test data for writing; same as data
+       transport_bdata  []byte // test data for writing; same as data
+       transport_header map[string]string
 )
 
 func init() {
@@ -37,6 +38,8 @@ func init() {
        for i := 0; i < TRANSPORT_BINARY_DATA_SIZE; i++ {
                transport_bdata[i] = byte((i + 'a') % 255)
        }
+       transport_header = map[string]string{"key": "User-Agent",
+               "value": "Mozilla/5.0 (Windows NT 6.2; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/32.0.1667.0 Safari/537.36"}
 }
 
 func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
@@ -94,6 +97,50 @@ func TransportTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
        }
 }
 
+func TransportHeaderTest(t *testing.T, writeTrans TTransport, readTrans TTransport) {
+       buf := make([]byte, TRANSPORT_BINARY_DATA_SIZE)
+       if !writeTrans.IsOpen() {
+               t.Fatalf("Transport %T not open: %s", writeTrans, writeTrans)
+       }
+       if !readTrans.IsOpen() {
+               t.Fatalf("Transport %T not open: %s", readTrans, readTrans)
+       }
+       // Need to assert type of TTransport to THttpClient to expose the Setter
+       httpWPostTrans := writeTrans.(*THttpClient)
+       httpWPostTrans.SetHeader(transport_header["key"], transport_header["value"])
+
+       _, err := writeTrans.Write(transport_bdata)
+       if err != nil {
+               t.Fatalf("Transport %T cannot write binary data of length %d: %s", writeTrans, len(transport_bdata), err)
+       }
+       err = writeTrans.Flush()
+       if err != nil {
+               t.Fatalf("Transport %T cannot flush write of binary data: %s", writeTrans, err)
+       }
+       // Need to assert type of TTransport to THttpClient to expose the Getter
+       httpRPostTrans := readTrans.(*THttpClient)
+       readHeader := httpRPostTrans.GetHeader(transport_header["key"])
+       if err != nil {
+               t.Errorf("Transport %T cannot read HTTP Header Value", httpRPostTrans)
+       }
+
+       if transport_header["value"] != readHeader {
+               t.Errorf("Expected HTTP Header Value %s, got %s", transport_header["value"], readHeader)
+       }
+       n, err := io.ReadFull(readTrans, buf)
+       if err != nil {
+               t.Errorf("Transport %T cannot read binary data of length %d: %s", readTrans, TRANSPORT_BINARY_DATA_SIZE, err)
+       }
+       if n != TRANSPORT_BINARY_DATA_SIZE {
+               t.Errorf("Transport %T read only %d instead of %d bytes of binary data", readTrans, n, TRANSPORT_BINARY_DATA_SIZE)
+       }
+       for k, v := range buf {
+               if v != transport_bdata[k] {
+                       t.Fatalf("Transport %T read %d instead of %d for index %d of binary data 2", readTrans, v, transport_bdata[k], k)
+               }
+       }
+}
+
 func CloseTransports(t *testing.T, readTrans TTransport, writeTrans TTransport) {
        err := readTrans.Close()
        if err != nil {
@@ -118,3 +165,12 @@ func FindAvailableTCPServerPort(startPort int) (net.Addr, error) {
        }
        return nil, NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, "Could not find available server port")
 }
+
+func valueInSlice(value string, slice []string) bool {
+       for _, v := range slice {
+               if value == v {
+                       return true
+               }
+       }
+       return false
+}