THRIFT-2388 GoLang - Fix data races in simple_server and server_socket
authorJens Geyer <jensg@apache.org>
Thu, 6 Mar 2014 20:11:46 +0000 (21:11 +0100)
committerJens Geyer <jensg@apache.org>
Thu, 6 Mar 2014 20:11:46 +0000 (21:11 +0100)
Patch: Chris Bannister

lib/go/thrift/server_socket.go
lib/go/thrift/simple_server.go

index 1a01095..4c80714 100644 (file)
@@ -21,6 +21,7 @@ package thrift
 
 import (
        "net"
+       "sync"
        "time"
 )
 
@@ -28,7 +29,10 @@ type TServerSocket struct {
        listener      net.Listener
        addr          net.Addr
        clientTimeout time.Duration
-       interrupted   bool
+
+       // Protects the interrupted value to make it thread safe.
+       mu          sync.RWMutex
+       interrupted bool
 }
 
 func NewTServerSocket(listenAddr string) (*TServerSocket, error) {
@@ -56,7 +60,11 @@ func (p *TServerSocket) Listen() error {
 }
 
 func (p *TServerSocket) Accept() (TTransport, error) {
-       if p.interrupted {
+       p.mu.RLock()
+       interrupted := p.interrupted
+       p.mu.RUnlock()
+
+       if interrupted {
                return nil, errTransportInterrupted
        }
        if p.listener == nil {
@@ -102,6 +110,9 @@ func (p *TServerSocket) Close() error {
 }
 
 func (p *TServerSocket) Interrupt() error {
+       p.mu.Lock()
        p.interrupted = true
+       p.mu.Unlock()
+
        return nil
 }
index b5cb0e1..521394c 100644 (file)
@@ -25,7 +25,7 @@ import (
 
 // Simple, non-concurrent server for testing.
 type TSimpleServer struct {
-       stopped bool
+       quit chan struct{}
 
        processorFactory       TProcessorFactory
        serverTransport        TServerTransport
@@ -78,12 +78,14 @@ func NewTSimpleServerFactory4(processorFactory TProcessorFactory, serverTranspor
 }
 
 func NewTSimpleServerFactory6(processorFactory TProcessorFactory, serverTransport TServerTransport, inputTransportFactory TTransportFactory, outputTransportFactory TTransportFactory, inputProtocolFactory TProtocolFactory, outputProtocolFactory TProtocolFactory) *TSimpleServer {
-       return &TSimpleServer{processorFactory: processorFactory,
+       return &TSimpleServer{
+               processorFactory:       processorFactory,
                serverTransport:        serverTransport,
                inputTransportFactory:  inputTransportFactory,
                outputTransportFactory: outputTransportFactory,
                inputProtocolFactory:   inputProtocolFactory,
                outputProtocolFactory:  outputProtocolFactory,
+               quit: make(chan struct{}, 1),
        }
 }
 
@@ -112,12 +114,19 @@ func (p *TSimpleServer) OutputProtocolFactory() TProtocolFactory {
 }
 
 func (p *TSimpleServer) Serve() error {
-       p.stopped = false
        err := p.serverTransport.Listen()
        if err != nil {
                return err
        }
-       for !p.stopped {
+
+loop:
+       for {
+               select {
+               case <-p.quit:
+                       break loop
+               default:
+               }
+
                client, err := p.serverTransport.Accept()
                if err != nil {
                        log.Println("Accept err: ", err)
@@ -134,7 +143,7 @@ func (p *TSimpleServer) Serve() error {
 }
 
 func (p *TSimpleServer) Stop() error {
-       p.stopped = true
+       p.quit <- struct{}{}
        p.serverTransport.Interrupt()
        return nil
 }