THRIFT-2109 Secure connections should be supported in Go
authorJens Geyer <jensg@apache.org>
Tue, 13 Aug 2013 19:34:17 +0000 (21:34 +0200)
committerJens Geyer <jensg@apache.org>
Tue, 13 Aug 2013 19:34:17 +0000 (21:34 +0200)
Patch: Justin Judd

lib/go/thrift/ssl_server_socket.go [new file with mode: 0644]
lib/go/thrift/ssl_socket.go [new file with mode: 0644]
tutorial/go/Makefile.am
tutorial/go/server.crt [new file with mode: 0644]
tutorial/go/server.key [new file with mode: 0644]
tutorial/go/src/client.go
tutorial/go/src/main.go
tutorial/go/src/server.go

diff --git a/lib/go/thrift/ssl_server_socket.go b/lib/go/thrift/ssl_server_socket.go
new file mode 100644 (file)
index 0000000..58f859b
--- /dev/null
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "net"
+       "time"
+       "crypto/tls"
+)
+
+type TSSLServerSocket struct {
+       listener      net.Listener
+       addr          net.Addr
+       clientTimeout time.Duration
+       interrupted   bool
+       cfg           *tls.Config
+}
+
+func NewTSSLServerSocket(listenAddr string, cfg *tls.Config) (*TSSLServerSocket, error) {
+       return NewTSSLServerSocketTimeout(listenAddr, cfg, 0)
+}
+
+func NewTSSLServerSocketTimeout(listenAddr string, cfg *tls.Config, clientTimeout time.Duration) (*TSSLServerSocket, error) {
+       addr, err := net.ResolveTCPAddr("tcp", listenAddr)
+       if err != nil {
+               return nil, err
+       }
+       return &TSSLServerSocket{addr: addr, clientTimeout: clientTimeout, cfg: cfg}, nil
+}
+
+func (p *TSSLServerSocket) Listen() error {
+       if p.IsListening() {
+               return nil
+       }
+       l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg)
+       if err != nil {
+               return err
+       }
+       p.listener = l
+       return nil
+}
+
+func (p *TSSLServerSocket) Accept() (TTransport, error) {
+       if p.interrupted {
+               return nil, errTransportInterrupted
+       }
+       if p.listener == nil {
+               return nil, NewTTransportException(NOT_OPEN, "No underlying server socket")
+       }
+       conn, err := p.listener.Accept()
+       if err != nil {
+               return nil, NewTTransportExceptionFromError(err)
+       }
+       return NewTSSLSocketFromConnTimeout(conn, p.cfg, p.clientTimeout), nil
+}
+
+// Checks whether the socket is listening.
+func (p *TSSLServerSocket) IsListening() bool {
+       return p.listener != nil
+}
+
+// Connects the socket, creating a new socket object if necessary.
+func (p *TSSLServerSocket) Open() error {
+       if p.IsListening() {
+               return NewTTransportException(ALREADY_OPEN, "Server socket already open")
+       }
+       if l, err := tls.Listen(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
+               return err
+       } else {
+               p.listener = l
+       }
+       return nil
+}
+
+func (p *TSSLServerSocket) Addr() net.Addr {
+       return p.addr
+}
+
+func (p *TSSLServerSocket) Close() error {
+       defer func() {
+               p.listener = nil
+       }()
+       if p.IsListening() {
+               return p.listener.Close()
+       }
+       return nil
+}
+
+func (p *TSSLServerSocket) Interrupt() error {
+       p.interrupted = true
+       return nil
+}
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
new file mode 100644 (file)
index 0000000..943bd90
--- /dev/null
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "net"
+       "time"
+       "crypto/tls"
+)
+
+type TSSLSocket struct {
+       conn    net.Conn
+       addr    net.Addr
+       timeout time.Duration
+       cfg     *tls.Config
+}
+
+// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
+//
+// Example:
+//     trans, err := thrift.NewTSocket("localhost:9090")
+func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
+       return NewTSSLSocketTimeout(hostPort, cfg, 0)
+}
+
+// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
+// it also accepts a tls Configuration and a timeout as a time.Duration
+func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
+       //conn, err := net.DialTimeout(network, address, timeout)
+       addr, err := net.ResolveTCPAddr("tcp", hostPort)
+       if err != nil {
+               return nil, err
+       }
+       return NewTSSLSocketFromAddrTimeout(addr, cfg, timeout), nil
+}
+
+// Creates a TSSLSocket from a net.Addr
+func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
+       return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg}
+}
+
+// Creates a TSSLSocket from an existing net.Conn
+func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
+       return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
+}
+
+// Sets the socket timeout
+func (p *TSSLSocket) SetTimeout(timeout time.Duration) error {
+       p.timeout = timeout
+       return nil
+}
+
+func (p *TSSLSocket) pushDeadline(read, write bool) {
+       var t time.Time
+       if p.timeout > 0 {
+               t = time.Now().Add(time.Duration(p.timeout))
+       }
+       if read && write {
+               p.conn.SetDeadline(t)
+       } else if read {
+               p.conn.SetReadDeadline(t)
+       } else if write {
+               p.conn.SetWriteDeadline(t)
+       }
+}
+
+// Connects the socket, creating a new socket object if necessary.
+func (p *TSSLSocket) Open() error {
+       if p.IsOpen() {
+               return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
+       }
+       if p.addr == nil {
+               return NewTTransportException(NOT_OPEN, "Cannot open nil address.")
+       }
+       if len(p.addr.Network()) == 0 {
+               return NewTTransportException(NOT_OPEN, "Cannot open bad network name.")
+       }
+       if len(p.addr.String()) == 0 {
+               return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
+       }
+       var err error
+       if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
+               return NewTTransportException(NOT_OPEN, err.Error())
+       }
+       return nil
+}
+
+// Retreive the underlying net.Conn
+func (p *TSSLSocket) Conn() net.Conn {
+       return p.conn
+}
+
+// Returns true if the connection is open
+func (p *TSSLSocket) IsOpen() bool {
+       if p.conn == nil {
+               return false
+       }
+       return true
+}
+
+// Closes the socket.
+func (p *TSSLSocket) Close() error {
+       // Close the socket
+       if p.conn != nil {
+               err := p.conn.Close()
+               if err != nil {
+                       return err
+               }
+               p.conn = nil
+       }
+       return nil
+}
+
+func (p *TSSLSocket) Read(buf []byte) (int, error) {
+       if !p.IsOpen() {
+               return 0, NewTTransportException(NOT_OPEN, "Connection not open")
+       }
+       p.pushDeadline(true, false)
+       n, err := p.conn.Read(buf)
+       return n, NewTTransportExceptionFromError(err)
+}
+
+func (p *TSSLSocket) Write(buf []byte) (int, error) {
+       if !p.IsOpen() {
+               return 0, NewTTransportException(NOT_OPEN, "Connection not open")
+       }
+       p.pushDeadline(false, true)
+       return p.conn.Write(buf)
+}
+
+func (p *TSSLSocket) Peek() bool {
+       return p.IsOpen()
+}
+
+func (p *TSSLSocket) Flush() error {
+       return nil
+}
+
+func (p *TSSLSocket) Interrupt() error {
+       if !p.IsOpen() {
+               return nil
+       }
+       return p.conn.Close()
+}
index 5b1f710..5df065a 100644 (file)
@@ -40,7 +40,13 @@ tutorialserver: all
        GOPATH=`pwd` $(GO) run src/*.go -server=true
 
 tutorialclient: all
-       GOPATH=`pwd` $(GO) run src/*.go -client=true
+       GOPATH=`pwd` $(GO) run src/*.go 
+
+tutorialsecureserver: all
+       GOPATH=`pwd` $(GO) run src/*.go -server=true -secure=true
+
+tutorialsecureclient: all
+       GOPATH=`pwd` $(GO) run src/*.go -secure=true
 
 clean-local:
        $(RM) -r gen-*
diff --git a/tutorial/go/server.crt b/tutorial/go/server.crt
new file mode 100644 (file)
index 0000000..b345bf0
--- /dev/null
@@ -0,0 +1,20 @@
+-----BEGIN CERTIFICATE-----
+MIIDRjCCAi4CCQC7ZHL2gkCrNDANBgkqhkiG9w0BAQUFADBlMQswCQYDVQQGEwJV
+UzEPMA0GA1UECAwGVGhyaWZ0MQ8wDQYDVQQHDAZUaHJpZnQxDzANBgNVBAoMBkFw
+YWNoZTEPMA0GA1UECwwGVGhyaWZ0MRIwEAYDVQQDDAlUaHJpZnQtR28wHhcNMTMw
+ODAyMTM0MzM0WhcNMTQwODAyMTM0MzM0WjBlMQswCQYDVQQGEwJVUzEPMA0GA1UE
+CAwGVGhyaWZ0MQ8wDQYDVQQHDAZUaHJpZnQxDzANBgNVBAoMBkFwYWNoZTEPMA0G
+A1UECwwGVGhyaWZ0MRIwEAYDVQQDDAlUaHJpZnQtR28wggEiMA0GCSqGSIb3DQEB
+AQUAA4IBDwAwggEKAoIBAQDBIkwFBgvHNS4yR8VYqYgd6bZtE1bJSdUHAOmASYwF
+Q5dDGltiwwWJiyLQ6njgcugqJ+5Icn2i7zd3kWXuTi6uNHlDzAy253uRj0skhXIA
+CYcMpNB5KI/bZd94VYg8ZG5x/or9mhpQNaZBKMpQ6bb1MvAlHfJO08y6YH2mZjfW
+SlpjEem51R8OK/3AM6mWZfWHeuSX+nzbChgRDZH4m9leWutgKyUgQtU7b5tEndsP
+qzGNeedaWGcyLT2dtD5PmsFbJ/RQXE3NEWACelJh7B1JjwB42HvZtl7m83GuY7ew
+eKlJP2HQAmmUkNTLdSa0yTzLuNitIKoh7RW5q4bl4zyNAgMBAAEwDQYJKoZIhvcN
+AQEFBQADggEBAB7r7lXn2M3SdyuXH+U6wbiNKPq8SX3sgncpaOluC36Phfdu38XJ
+NLovB05BIlkkExkv/IvjUZxGByd9WvNZBgajqll/FaK3Vv8cTo53yjxbQexFVK/m
+J4G9q/dGIV+B+8soVedoMTZOSmKhowM//9Oshs70foJLBJoHA5UdTdBxMvYcZBXV
+S9vUaVNEFd2cN0tyvguY8JNIPU8yEOUspR/uBeRRk3pRTbgcACC8+zYGUSuBiEf3
+SEKO2BHYBCkoodHWBWeMiiksYd3I5xOE9yS+Wn4eZlJBMTIjgxSddR1HH2cDSwET
+FzuW8WjtE1A28JL/hR6YcxaEDTulaFhaKs8=
+-----END CERTIFICATE-----
diff --git a/tutorial/go/server.key b/tutorial/go/server.key
new file mode 100644 (file)
index 0000000..27d52f4
--- /dev/null
@@ -0,0 +1,27 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIEogIBAAKCAQEAwSJMBQYLxzUuMkfFWKmIHem2bRNWyUnVBwDpgEmMBUOXQxpb
+YsMFiYsi0Op44HLoKifuSHJ9ou83d5Fl7k4urjR5Q8wMtud7kY9LJIVyAAmHDKTQ
+eSiP22XfeFWIPGRucf6K/ZoaUDWmQSjKUOm29TLwJR3yTtPMumB9pmY31kpaYxHp
+udUfDiv9wDOplmX1h3rkl/p82woYEQ2R+JvZXlrrYCslIELVO2+bRJ3bD6sxjXnn
+WlhnMi09nbQ+T5rBWyf0UFxNzRFgAnpSYewdSY8AeNh72bZe5vNxrmO3sHipST9h
+0AJplJDUy3UmtMk8y7jYrSCqIe0VuauG5eM8jQIDAQABAoIBAEU9zpNef4qD/nP4
+V0BaR3qx971TWaIA3mcMZKqhs5mPigN8x5a45JtTTsAnz/5oM+QpPLysj26C5Rfx
+AOJXFVVPasprtYM9qoedIAuP7DcnM0vNKxDFAg5ej6fMwnMkbpRf9eTGAvkOwvRJ
+c39ey0FNadtkySKJvLR1M5ccvpgMnybCMDYsjDH0tAqWJcWsCX+/htk4rpg4V7yG
+JDg23yx4An+WWmPuR1zSQNx5mZBSg4RXYykr1MEKsHo+TDQ6IK6Wq2rtLUM0/0M6
+CJ80EswX6uY0Uh6eJH1o1BLJeAfNGk/a9MUUqPaWj7ospa5XJ0adG3Qq5MmF21Ft
+VbhRhQECgYEA/+2CcKlRlxoBzCRg8DdFf5OHE45EiUFAEX8/J9RWdLQSI/EwA6K2
+Q9CGy6WWKEFMHBHsxyV8Hx6dS+M+2UpM4h28Atiu/HLs3UZXrRv7FLJuXP7j+iBv
+oNo5+8sxcxL1GgJ3zcmSSHhMcZmbBowsYmx+lMDSSxGgXo2jEY3mZmkCgYEAwTBA
+KkO22D9263MYA/Exjcto+t5O7Q26gs/j2UscyBUn7fbcKofTvrWBBjMYMZvFokO8
+HM0PNTIpr0F9FPoUB2oky7VLmuoGf9smyZtl5fwIl6R/4MmStyAEUWnkI3qt/EOr
+5ZwrdzFdru6Z4zLYaW6bms+8A1G7GWnTNen+yIUCgYB57Lb14VRjfhpZHQOprUtI
+ygnSATcZhKJ3M33tBbXih18VDHRpZv0aNZ/iKRLuPp15yfhZr7wAP1+EpdBtSH50
+QuItIPnMfxvlFvvyFqB5bcAyQaRup0FHCnARSu5V+jQWnhJhUaSFLfqNLDa02dbT
+VQjA6VPGO7GBGk0TsdyP8QKBgE0OzvlMyzkUj325CeJAqdByS2yNkhPSPwwAmlTJ
+NjDE54lux0EbrqVKRq3PYZ4gEUP5GqauUJuaZ7AlQhxE6ApRF1498WtYX8FOC/ms
+x4dl8ZNzJSLnpGLxHWfQAhT40T9nSsCqe1fu0/x75dwPIu1jFiQ5Kjh0uFmZsYq2
+zE71AoGAIgX3hfqFTcQ0vQJ2bSk4Q2IBMRjW9maZDK/O009cWoVA1pq7qUFXX/Rl
+ADA5FD/HOZZ1QYEfRaIEItPZP6cbnza4mPVql6YwSqE5IV1DIefUkPzQVTWxSyPi
+FlH3RwTKS7V/qQ+7tPL7lGAsI6W/mtPM0TneDRcrpr6C9fghSDs=
+-----END RSA PRIVATE KEY-----
index 7f8d28f..543d7fb 100644 (file)
@@ -23,6 +23,7 @@ import (
        "fmt"
        "git.apache.org/thrift.git/lib/go/thrift"
        "tutorial"
+       "crypto/tls"
 )
 
 func handleClient(client *tutorial.CalculatorClient) (err error) {
@@ -69,9 +70,16 @@ func handleClient(client *tutorial.CalculatorClient) (err error) {
        return err
 }
 
-func runClient(transportFactory thrift.TTransportFactory, protocolFactory thrift.TProtocolFactory, addr string) error {
+func runClient(transportFactory thrift.TTransportFactory, protocolFactory thrift.TProtocolFactory, addr string, secure bool) error {
        var transport thrift.TTransport
-       transport, err := thrift.NewTSocket(addr)
+       var err error
+       if secure {
+               cfg := new(tls.Config)
+               cfg.InsecureSkipVerify = true
+               transport, err = thrift.NewTSSLSocket(addr, cfg)
+       } else {
+               transport, err = thrift.NewTSocket(addr)
+       }
        if err != nil {
                fmt.Println("Error opening socket:", err)
                return err
index d371394..96e5ec9 100644 (file)
@@ -39,6 +39,7 @@ func main() {
        framed := flag.Bool("framed", false, "Use framed transport")
        buffered := flag.Bool("buffered", false, "Use buffered transport")
        addr := flag.String("addr", "localhost:9090", "Address to listen to")
+       secure := flag.Bool("secure", false, "Use tls secure transport")
 
        flag.Parse()
 
@@ -70,11 +71,11 @@ func main() {
        }
 
        if *server {
-               if err := runServer(transportFactory, protocolFactory, *addr); err != nil {
+               if err := runServer(transportFactory, protocolFactory, *addr, *secure); err != nil {
                        fmt.Println("error running server:", err)
                }
        } else {
-               if err := runClient(transportFactory, protocolFactory, *addr); err != nil {
+               if err := runClient(transportFactory, protocolFactory, *addr, *secure); err != nil {
                        fmt.Println("error running client:", err)
                }
        }
index aea749e..0374cde 100644 (file)
@@ -23,17 +23,34 @@ import (
        "fmt"
        "git.apache.org/thrift.git/lib/go/thrift"
        "tutorial"
+       "crypto/tls"
 )
 
-func runServer(transportFactory thrift.TTransportFactory, protocolFactory thrift.TProtocolFactory, addr string) error {
-       transport, err := thrift.NewTServerSocket(addr)
+func runServer(transportFactory thrift.TTransportFactory, protocolFactory thrift.TProtocolFactory, addr string, secure bool) error {
+       var transport thrift.TServerTransport
+       var err error
+       if secure {
+               cfg := new(tls.Config)
+               if cert, err := tls.LoadX509KeyPair("server.crt", "server.key"); err == nil {
+                       cfg.Certificates = append(cfg.Certificates, cert)
+               }
+               if err != nil {
+                       fmt.Println("Unable to load server certificate and key")
+                       return err
+               }
+               transport, err = thrift.NewTSSLServerSocket(addr, cfg)
+       } else {
+               transport, err = thrift.NewTServerSocket(addr)
+       }
+       
        if err != nil {
                return err
        }
+       fmt.Printf("%T\n", transport)
        handler := NewCalculatorHandler()
        processor := tutorial.NewCalculatorProcessor(handler)
        server := thrift.NewTSimpleServer4(processor, transport, transportFactory, protocolFactory)
 
-       fmt.Println("Starting the simple server... on ", transport.Addr())
+       fmt.Println("Starting the simple server... on ", addr)
        return server.Serve()
 }