blob: 943bd90d855e5d17d50445f493684f2d5879cc79 [file] [log] [blame]
Jens Geyer4c835952013-08-13 21:34:17 +02001/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23 "net"
24 "time"
25 "crypto/tls"
26)
27
28type TSSLSocket struct {
29 conn net.Conn
30 addr net.Addr
31 timeout time.Duration
32 cfg *tls.Config
33}
34
35// NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration
36//
37// Example:
38// trans, err := thrift.NewTSocket("localhost:9090")
39func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) {
40 return NewTSSLSocketTimeout(hostPort, cfg, 0)
41}
42
43// NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port
44// it also accepts a tls Configuration and a timeout as a time.Duration
45func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) {
46 //conn, err := net.DialTimeout(network, address, timeout)
47 addr, err := net.ResolveTCPAddr("tcp", hostPort)
48 if err != nil {
49 return nil, err
50 }
51 return NewTSSLSocketFromAddrTimeout(addr, cfg, timeout), nil
52}
53
54// Creates a TSSLSocket from a net.Addr
55func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
56 return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg}
57}
58
59// Creates a TSSLSocket from an existing net.Conn
60func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket {
61 return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg}
62}
63
64// Sets the socket timeout
65func (p *TSSLSocket) SetTimeout(timeout time.Duration) error {
66 p.timeout = timeout
67 return nil
68}
69
70func (p *TSSLSocket) pushDeadline(read, write bool) {
71 var t time.Time
72 if p.timeout > 0 {
73 t = time.Now().Add(time.Duration(p.timeout))
74 }
75 if read && write {
76 p.conn.SetDeadline(t)
77 } else if read {
78 p.conn.SetReadDeadline(t)
79 } else if write {
80 p.conn.SetWriteDeadline(t)
81 }
82}
83
84// Connects the socket, creating a new socket object if necessary.
85func (p *TSSLSocket) Open() error {
86 if p.IsOpen() {
87 return NewTTransportException(ALREADY_OPEN, "Socket already connected.")
88 }
89 if p.addr == nil {
90 return NewTTransportException(NOT_OPEN, "Cannot open nil address.")
91 }
92 if len(p.addr.Network()) == 0 {
93 return NewTTransportException(NOT_OPEN, "Cannot open bad network name.")
94 }
95 if len(p.addr.String()) == 0 {
96 return NewTTransportException(NOT_OPEN, "Cannot open bad address.")
97 }
98 var err error
99 if p.conn, err = tls.Dial(p.addr.Network(), p.addr.String(), p.cfg); err != nil {
100 return NewTTransportException(NOT_OPEN, err.Error())
101 }
102 return nil
103}
104
105// Retreive the underlying net.Conn
106func (p *TSSLSocket) Conn() net.Conn {
107 return p.conn
108}
109
110// Returns true if the connection is open
111func (p *TSSLSocket) IsOpen() bool {
112 if p.conn == nil {
113 return false
114 }
115 return true
116}
117
118// Closes the socket.
119func (p *TSSLSocket) Close() error {
120 // Close the socket
121 if p.conn != nil {
122 err := p.conn.Close()
123 if err != nil {
124 return err
125 }
126 p.conn = nil
127 }
128 return nil
129}
130
131func (p *TSSLSocket) Read(buf []byte) (int, error) {
132 if !p.IsOpen() {
133 return 0, NewTTransportException(NOT_OPEN, "Connection not open")
134 }
135 p.pushDeadline(true, false)
136 n, err := p.conn.Read(buf)
137 return n, NewTTransportExceptionFromError(err)
138}
139
140func (p *TSSLSocket) Write(buf []byte) (int, error) {
141 if !p.IsOpen() {
142 return 0, NewTTransportException(NOT_OPEN, "Connection not open")
143 }
144 p.pushDeadline(false, true)
145 return p.conn.Write(buf)
146}
147
148func (p *TSSLSocket) Peek() bool {
149 return p.IsOpen()
150}
151
152func (p *TSSLSocket) Flush() error {
153 return nil
154}
155
156func (p *TSSLSocket) Interrupt() error {
157 if !p.IsOpen() {
158 return nil
159 }
160 return p.conn.Close()
161}