From: Bryan Duxbury Date: Tue, 22 Feb 2011 18:12:06 +0000 (+0000) Subject: THRIFT-151. cpp: TSSLServerSocket and TSSLSocket implementation X-Git-Tag: 0.7.0~180 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=cd9aea1136d9a51b2ce53a3de5da09359c9756e2;p=common%2Fthrift.git THRIFT-151. cpp: TSSLServerSocket and TSSLSocket implementation This patch adds an implementation of the above ssl sockets. Patch: Ping Li, Kevin Worth, Rowan Kerr git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1073441 13f79535-47bb-0310-9956-ffa450edef68 --- diff --git a/README.SSL b/README.SSL new file mode 100644 index 00000000..33a4dfec --- /dev/null +++ b/README.SSL @@ -0,0 +1,135 @@ +Notes on Thrift/SSL + +Author: Ping Li + +1. Scope + + This SSL only supports blocking mode socket I/O. It can only be used with + TSimpleServer, TThreadedServer, and TThreadPoolServer. + +2. Implementation + + There're two main classes TSSLSocketFactory and TSSLSocket. Instances of + TSSLSocket are always created from TSSLSocketFactory. + + PosixSSLThreadFactory creates PosixSSLThread. The only difference from the + PthreadThread type is that it cleanups OpenSSL error queue upon exiting + the thread. Ideally, OpenSSL APIs should only be called from PosixSSLThread. + +3. How to use SSL APIs + + // This is for demo. In real code, typically only one TSSLSocketFactory + // instance is needed. + shared_ptr getSSLSocketFactory() { + shared_ptr factory(new TSSLSocketFactory()); + // client: load trusted certificates + factory->loadTrustedCertificates("my-trusted-ca-certificates.pem"); + // client: optionally set your own access manager, otherwise, + // the default client access manager will be loaded. + + factory->loadCertificate("my-certificate-signed-by-ca.pem"); + factory->loadPrivateKey("my-private-key.pem"); + // server: optionally setup access manager + // shared_ptr accessManager(new MyAccessManager); + // factory->access(acessManager); + ... + } + + // client code sample + shared_ptr factory = getSSLScoketFactory(); + shared_ptr socket = factory.createSocket(host, port); + shared_ptr transport(new TBufferedTransport(socket)); + ... + + // server code sample + shared_ptr factory = getSSLSocketFactory(); + shared_ptr socket(new TSSLServerSocket(port, factory)); + shared_ptr transportFactory(new TBufferedTransportFactory)); + ... + +4. AccessManager + + AccessManager defines a callback interface. It has three callback methods: + + (a) Decision verify(const sockaddr_storage& sa); + (b) Decision verify(const string& host, const char* name, int size); + (c) Decision verify(const sockaddr_storage& sa, const char* data, int size); + + After SSL handshake completes, additional checks are conducted. Application + is given the chance to decide whether or not to continue the conversation + with the remote. Application is inqueried through the above three "verify" + method. They are called at different points of the verification process. + + Decisions can be one of ALLOW, DENY, and SKIP. ALLOW and DENY means the + conversation should be continued or disconnected, respectively. ALLOW and + DENY decision stops the verification process. SKIP means there's no decision + based on the given input, continue the verification process. + + First, (a) is called with the remote IP. It is called once at the beginning. + "sa" is the IP address of the remote peer. + + Then, the certificate of remote peer is loaded. SubjectAltName extensions + are extracted and sent to application for verification. When a DNS + subjectAltName field is extracted, (b) is called. When an IP subjectAltName + field is extracted, (c) is called. + + The "host" in (b) is the value from TSocket::getHost() if this is a client + side socket, or TScoket::getPeerHost() if this is a server side socket. The + reason is client side socket initiates the connection. TSocket::getHost() + is the remote host name. On server side, the remote host name is unknown + unless it's retrieved through TSocket::getPeerHost(). Either way, "host" + should be the remote host name. Keep in mind, if TSocket::getPeerHost() + failed, it would return the remote host name in numeric format. + + If all subjectAltName extensions were "skipped", the common name field would + be checked. It is sent to application through (c), where "sa" is the remote + IP address. "data" is the IP address extracted from subjectAltName IP + extension, and "size" is the length of the extension data. + + If any of the above "verify" methods returned a decision ALLOW or DENY, the + verification process would be stopped. + + If any of the above "verify" methods returned SKIP, that decision would be + ignored and the verification process would move on till the last item is + examined. At that point, if there's still no decision, the connection is + terminated. + + Thread safety, an access manager should not store state information if it's + to be used by many SSL sockets. + +5. SIGPIPE signal + + Applications running OpenSSL over network connections may crash if SIGPIPE + is not ignored. This happens when they receive a connection reset by remote + peer exception, which somehow triggers a SIGPIPE signal. If not handled, + this signal would kill the application. + +6. How to run test client/server in SSL mode + + The server expects the followings from the current working directory, + - "server-certificate.pem" + - "server-private-key.pem" + + The client loads "trusted-ca-certificate.pem" from current directory. + + The file names are hard coded in the source code. You need to create these + certificates before you can run the test code in SSL mode. Make sure at least + one of the followings is included in "server-certificate.pem", + - subjectAltName, DNS localhost + - subjectAltName, IP 127.0.0.1 + - common name, localhost + + Run, + - "./test_server --ssl" to start server + - "./test_client --ssl" to run client + + If "-h " is used to run client, the above "localhost" in the above + server-certificate.pem has to be replaced with that host name. + +7. TSSLSocketFactory::randomize() + + The default implementation of OpenSSLSocketFactory::randomize() simply calls + OpenSSL's RAND_poll() when OpenSSL library is first initialized. + + The PRNG seed is key to the application security. This method should be + overriden if it's not strong enough for you. diff --git a/lib/cpp/Makefile.am b/lib/cpp/Makefile.am index 402a294a..b085ac40 100644 --- a/lib/cpp/Makefile.am +++ b/lib/cpp/Makefile.am @@ -60,8 +60,10 @@ libthrift_la_SOURCES = src/Thrift.cpp \ src/transport/THttpClient.cpp \ src/transport/THttpServer.cpp \ src/transport/TSocket.cpp \ + src/transport/TSSLSocket.cpp \ src/transport/TSocketPool.cpp \ src/transport/TServerSocket.cpp \ + src/transport/TSSLServerSocket.cpp \ src/transport/TTransportUtils.cpp \ src/transport/TBufferTransports.cpp \ src/server/TServer.cpp \ @@ -125,11 +127,13 @@ include_transport_HEADERS = \ src/transport/TFileTransport.h \ src/transport/TSimpleFileTransport.h \ src/transport/TServerSocket.h \ + src/transport/TSSLServerSocket.h \ src/transport/TServerTransport.h \ src/transport/THttpTransport.h \ src/transport/THttpClient.h \ src/transport/THttpServer.h \ src/transport/TSocket.h \ + src/transport/TSSLSocket.h \ src/transport/TSocketPool.h \ src/transport/TVirtualTransport.h \ src/transport/TTransport.h \ diff --git a/lib/cpp/src/transport/TSSLServerSocket.cpp b/lib/cpp/src/transport/TSSLServerSocket.cpp new file mode 100644 index 00000000..ed4b648b --- /dev/null +++ b/lib/cpp/src/transport/TSSLServerSocket.cpp @@ -0,0 +1,36 @@ +// Copyright (c) 2009- Facebook +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ + +#include "TSSLServerSocket.h" +#include "TSSLSocket.h" + +namespace apache { namespace thrift { namespace transport { + +using namespace boost; + +/** + * SSL server socket implementation. + * + * @author Ping Li + */ +TSSLServerSocket::TSSLServerSocket(int port, + shared_ptr factory): + TServerSocket(port), factory_(factory) { + factory_->server(true); +} + +TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout, + shared_ptr factory): + TServerSocket(port, sendTimeout, recvTimeout), + factory_(factory) { + factory_->server(true); +} + +shared_ptr TSSLServerSocket::createSocket(int client) { + return factory_->createSocket(client); +} + +}}} diff --git a/lib/cpp/src/transport/TSSLServerSocket.h b/lib/cpp/src/transport/TSSLServerSocket.h new file mode 100644 index 00000000..36f895ca --- /dev/null +++ b/lib/cpp/src/transport/TSSLServerSocket.h @@ -0,0 +1,48 @@ +// Copyright (c) 2009- Facebook +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ + +#ifndef _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ +#define _THRIFT_TRANSPORT_TSSLSERVERSOCKET_H_ 1 + +#include +#include "TServerSocket.h" + +namespace apache { namespace thrift { namespace transport { + +class TSSLSocketFactory; + +/** + * Server socket that accepts SSL connections. + * + * @author Ping Li + */ +class TSSLServerSocket: public TServerSocket { + public: + /** + * Constructor. + * + * @param port Listening port + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(int port, boost::shared_ptr factory); + /** + * Constructor. + * + * @param port Listening port + * @param sendTimeout Socket send timeout + * @param recvTimeout Socket receive timeout + * @param factory SSL socket factory implementation + */ + TSSLServerSocket(int port, int sendTimeout, int recvTimeout, + boost::shared_ptr factory); + protected: + boost::shared_ptr createSocket(int socket); + boost::shared_ptr factory_; +}; + +}}} + +#endif diff --git a/lib/cpp/src/transport/TSSLSocket.cpp b/lib/cpp/src/transport/TSSLSocket.cpp new file mode 100644 index 00000000..f84f8067 --- /dev/null +++ b/lib/cpp/src/transport/TSSLSocket.cpp @@ -0,0 +1,645 @@ +// Copyright (c) 2009- Facebook +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "concurrency/Mutex.h" +#include "TSSLSocket.h" + +#define OPENSSL_VERSION_NO_THREAD_ID 0x10000000L + +using namespace std; +using namespace boost; +using namespace apache::thrift::concurrency; + +struct CRYPTO_dynlock_value { + Mutex mutex; +}; + +namespace apache { namespace thrift { namespace transport { + + +static void buildErrors(string& message, int error = 0); +static bool matchName(const char* host, const char* pattern, int size); +static char uppercase(char c); + +// SSLContext implementation +SSLContext::SSLContext() { + ctx_ = SSL_CTX_new(TLSv1_method()); + if (ctx_ == NULL) { + string errors; + buildErrors(errors); + throw TSSLException("SSL_CTX_new: " + errors); + } + SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY); +} + +SSLContext::~SSLContext() { + if (ctx_ != NULL) { + SSL_CTX_free(ctx_); + ctx_ = NULL; + } +} + +SSL* SSLContext::createSSL() { + SSL* ssl = SSL_new(ctx_); + if (ssl == NULL) { + string errors; + buildErrors(errors); + throw TSSLException("SSL_new: " + errors); + } + return ssl; +} + +// TSSLSocket implementation +TSSLSocket::TSSLSocket(shared_ptr ctx): + TSocket(), server_(false), ssl_(NULL), ctx_(ctx) { +} + +TSSLSocket::TSSLSocket(shared_ptr ctx, int socket): + TSocket(socket), server_(false), ssl_(NULL), ctx_(ctx) { +} + +TSSLSocket::TSSLSocket(shared_ptr ctx, string host, int port): + TSocket(host, port), server_(false), ssl_(NULL), ctx_(ctx) { +} + +TSSLSocket::~TSSLSocket() { + close(); +} + +bool TSSLSocket::isOpen() { + if (ssl_ == NULL || !TSocket::isOpen()) { + return false; + } + int shutdown = SSL_get_shutdown(ssl_); + bool shutdownReceived = (shutdown & SSL_RECEIVED_SHUTDOWN); + bool shutdownSent = (shutdown & SSL_SENT_SHUTDOWN); + if (shutdownReceived && shutdownSent) { + return false; + } + return true; +} + +bool TSSLSocket::peek() { + if (!isOpen()) { + return false; + } + checkHandshake(); + int rc; + uint8_t byte; + rc = SSL_peek(ssl_, &byte, 1); + if (rc < 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_peek: " + errors); + } + if (rc == 0) { + ERR_clear_error(); + } + return (rc > 0); +} + +void TSSLSocket::open() { + if (isOpen() || server()) { + throw TTransportException(TTransportException::BAD_ARGS); + } + TSocket::open(); +} + +void TSSLSocket::close() { + if (ssl_ != NULL) { + int rc = SSL_shutdown(ssl_); + if (rc == 0) { + rc = SSL_shutdown(ssl_); + } + if (rc < 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + GlobalOutput(("SSL_shutdown: " + errors).c_str()); + } + SSL_free(ssl_); + ssl_ = NULL; + ERR_remove_state(0); + } + TSocket::close(); +} + +uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) { + checkHandshake(); + int32_t bytes = 0; + for (int32_t retries = 0; retries < maxRecvRetries_; retries++){ + bytes = SSL_read(ssl_, buf, len); + if (bytes >= 0) + break; + int errno_copy = errno; + if (SSL_get_error(ssl_, bytes) == SSL_ERROR_SYSCALL) { + if (ERR_get_error() == 0 && errno_copy == EINTR) { + continue; + } + } + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_read: " + errors); + } + return bytes; +} + +void TSSLSocket::write(const uint8_t* buf, uint32_t len) { + checkHandshake(); + // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX. + uint32_t written = 0; + while (written < len) { + int32_t bytes = SSL_write(ssl_, &buf[written], len - written); + if (bytes <= 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_write: " + errors); + } + written += bytes; + } +} + +void TSSLSocket::flush() { + // Don't throw exception if not open. Thrift servers close socket twice. + if (ssl_ == NULL) { + return; + } + checkHandshake(); + BIO* bio = SSL_get_wbio(ssl_); + if (bio == NULL) { + throw TSSLException("SSL_get_wbio returns NULL"); + } + if (BIO_flush(bio) != 1) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("BIO_flush: " + errors); + } +} + +void TSSLSocket::checkHandshake() { + if (!TSocket::isOpen()) { + throw TTransportException(TTransportException::NOT_OPEN); + } + if (ssl_ != NULL) { + return; + } + ssl_ = ctx_->createSSL(); + SSL_set_fd(ssl_, socket_); + int rc; + if (server()) { + rc = SSL_accept(ssl_); + } else { + rc = SSL_connect(ssl_); + } + if (rc <= 0) { + int errno_copy = errno; + string fname(server() ? "SSL_accept" : "SSL_connect"); + string errors; + buildErrors(errors, errno_copy); + throw TSSLException(fname + ": " + errors); + } + authorize(); +} + +void TSSLSocket::authorize() { + int rc = SSL_get_verify_result(ssl_); + if (rc != X509_V_OK) { // verify authentication result + throw TSSLException(string("SSL_get_verify_result(), ") + + X509_verify_cert_error_string(rc)); + } + + X509* cert = SSL_get_peer_certificate(ssl_); + if (cert == NULL) { + // certificate is not present + if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) { + throw TSSLException("authorize: required certificate not present"); + } + // certificate was optional: didn't intend to authorize remote + if (server() && access_ != NULL) { + throw TSSLException("authorize: certificate required for authorization"); + } + return; + } + // certificate is present + if (access_ == NULL) { + X509_free(cert); + return; + } + // both certificate and access manager are present + + string host; + sockaddr_storage sa = {}; + socklen_t saLength = sizeof(sa); + + if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) { + sa.ss_family = AF_UNSPEC; + } + + AccessManager::Decision decision = access_->verify(sa); + + if (decision != AccessManager::SKIP) { + X509_free(cert); + if (decision != AccessManager::ALLOW) { + throw TSSLException("authorize: access denied based on remote IP"); + } + return; + } + + // extract subjectAlternativeName + STACK_OF(GENERAL_NAME)* alternatives = (STACK_OF(GENERAL_NAME)*) + X509_get_ext_d2i(cert, NID_subject_alt_name, NULL, NULL); + if (alternatives != NULL) { + const int count = sk_GENERAL_NAME_num(alternatives); + for (int i = 0; decision == AccessManager::SKIP && i < count; i++) { + const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i); + if (name == NULL) { + continue; + } + char* data = (char*)ASN1_STRING_data(name->d.ia5); + int length = ASN1_STRING_length(name->d.ia5); + switch (name->type) { + case GEN_DNS: + if (host.empty()) { + host = (server() ? getPeerHost() : getHost()); + } + decision = access_->verify(host, data, length); + break; + case GEN_IPADD: + decision = access_->verify(sa, data, length); + break; + } + } + sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free); + } + + if (decision != AccessManager::SKIP) { + X509_free(cert); + if (decision != AccessManager::ALLOW) { + throw TSSLException("authorize: access denied"); + } + return; + } + + // extract commonName + X509_NAME* name = X509_get_subject_name(cert); + if (name != NULL) { + X509_NAME_ENTRY* entry; + unsigned char* utf8; + int last = -1; + while (decision == AccessManager::SKIP) { + last = X509_NAME_get_index_by_NID(name, NID_commonName, last); + if (last == -1) + break; + entry = X509_NAME_get_entry(name, last); + if (entry == NULL) + continue; + ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry); + int size = ASN1_STRING_to_UTF8(&utf8, common); + if (host.empty()) { + host = (server() ? getHost() : getHost()); + } + decision = access_->verify(host, (char*)utf8, size); + OPENSSL_free(utf8); + } + } + X509_free(cert); + if (decision != AccessManager::ALLOW) { + throw TSSLException("authorize: cannot authorize peer"); + } +} + +// TSSLSocketFactory implementation +bool TSSLSocketFactory::initialized = false; +uint64_t TSSLSocketFactory::count_ = 0; +Mutex TSSLSocketFactory::mutex_; + +TSSLSocketFactory::TSSLSocketFactory(): server_(false) { + Guard guard(mutex_); + if (count_ == 0) { + initializeOpenSSL(); + randomize(); + } + count_++; + ctx_ = shared_ptr(new SSLContext); +} + +TSSLSocketFactory::~TSSLSocketFactory() { + Guard guard(mutex_); + count_--; + if (count_ == 0) { + cleanupOpenSSL(); + } +} + +shared_ptr TSSLSocketFactory::createSocket() { + shared_ptr ssl(new TSSLSocket(ctx_)); + setup(ssl); + return ssl; +} + +shared_ptr TSSLSocketFactory::createSocket(int socket) { + shared_ptr ssl(new TSSLSocket(ctx_, socket)); + setup(ssl); + return ssl; +} + +shared_ptr TSSLSocketFactory::createSocket(const string& host, + int port) { + shared_ptr ssl(new TSSLSocket(ctx_, host, port)); + setup(ssl); + return ssl; +} + +void TSSLSocketFactory::setup(shared_ptr ssl) { + ssl->server(server()); + if (access_ == NULL && !server()) { + access_ = shared_ptr(new DefaultClientAccessManager); + } + if (access_ != NULL) { + ssl->access(access_); + } +} + +void TSSLSocketFactory::ciphers(const string& enable) { + int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str()); + if (ERR_peek_error() != 0) { + string errors; + buildErrors(errors); + throw TSSLException("SSL_CTX_set_cipher_list: " + errors); + } + if (rc == 0) { + throw TSSLException("None of specified ciphers are supported"); + } +} + +void TSSLSocketFactory::authenticate(bool required) { + int mode; + if (required) { + mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE; + } else { + mode = SSL_VERIFY_NONE; + } + SSL_CTX_set_verify(ctx_->get(), mode, NULL); +} + +void TSSLSocketFactory::loadCertificate(const char* path, const char* format) { + if (path == NULL || format == NULL) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadCertificateChain: either or is NULL"); + } + if (strcmp(format, "PEM") == 0) { + if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors); + } + } else { + throw TSSLException("Unsupported certificate format: " + string(format)); + } +} + +void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) { + if (path == NULL || format == NULL) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadPrivateKey: either or is NULL"); + } + if (strcmp(format, "PEM") == 0) { + if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors); + } + } +} + +void TSSLSocketFactory::loadTrustedCertificates(const char* path) { + if (path == NULL) { + throw TTransportException(TTransportException::BAD_ARGS, + "loadTrustedCertificates: is NULL"); + } + if (SSL_CTX_load_verify_locations(ctx_->get(), path, NULL) == 0) { + int errno_copy = errno; + string errors; + buildErrors(errors, errno_copy); + throw TSSLException("SSL_CTX_load_verify_locations: " + errors); + } +} + +void TSSLSocketFactory::randomize() { + RAND_poll(); +} + +void TSSLSocketFactory::overrideDefaultPasswordCallback() { + SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback); + SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this); +} + +int TSSLSocketFactory::passwordCallback(char* password, + int size, + int, + void* data) { + TSSLSocketFactory* factory = (TSSLSocketFactory*)data; + string userPassword; + factory->getPassword(userPassword, size); + int length = userPassword.size(); + if (length > size) { + length = size; + } + strncpy(password, userPassword.c_str(), length); + return length; +} + +static shared_array mutexes; + +static void callbackLocking(int mode, int n, const char*, int) { + if (mode & CRYPTO_LOCK) { + mutexes[n].lock(); + } else { + mutexes[n].unlock(); + } +} + +#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID) +static unsigned long callbackThreadID() { + return reinterpret_cast(pthread_self()); +} +#endif + +static CRYPTO_dynlock_value* dyn_create(const char*, int) { + return new CRYPTO_dynlock_value; +} + +static void dyn_lock(int mode, + struct CRYPTO_dynlock_value* lock, + const char*, int) { + if (lock != NULL) { + if (mode & CRYPTO_LOCK) { + lock->mutex.lock(); + } else { + lock->mutex.unlock(); + } + } +} + +static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) { + delete lock; +} + +void TSSLSocketFactory::initializeOpenSSL() { + if (initialized) { + return; + } + initialized = true; + SSL_library_init(); + SSL_load_error_strings(); + // static locking + mutexes = shared_array(new Mutex[::CRYPTO_num_locks()]); + if (mutexes == NULL) { + throw TTransportException(TTransportException::INTERNAL_ERROR, + "initializeOpenSSL() failed, " + "out of memory while creating mutex array"); + } +#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID) + CRYPTO_set_id_callback(callbackThreadID); +#endif + CRYPTO_set_locking_callback(callbackLocking); + // dynamic locking + CRYPTO_set_dynlock_create_callback(dyn_create); + CRYPTO_set_dynlock_lock_callback(dyn_lock); + CRYPTO_set_dynlock_destroy_callback(dyn_destroy); +} + +void TSSLSocketFactory::cleanupOpenSSL() { + if (!initialized) { + return; + } + initialized = false; +#if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID) + CRYPTO_set_id_callback(NULL); +#endif + CRYPTO_set_locking_callback(NULL); + CRYPTO_set_dynlock_create_callback(NULL); + CRYPTO_set_dynlock_lock_callback(NULL); + CRYPTO_set_dynlock_destroy_callback(NULL); + CRYPTO_cleanup_all_ex_data(); + ERR_free_strings(); + EVP_cleanup(); + ERR_remove_state(0); + mutexes.reset(); +} + +// extract error messages from error queue +void buildErrors(string& errors, int errno_copy) { + unsigned long errorCode; + char message[256]; + + errors.reserve(512); + while ((errorCode = ERR_get_error()) != 0) { + if (!errors.empty()) { + errors += "; "; + } + const char* reason = ERR_reason_error_string(errorCode); + if (reason == NULL) { + snprintf(message, sizeof(message) - 1, "SSL error # %lu", errorCode); + reason = message; + } + errors += reason; + } + if (errors.empty()) { + if (errno_copy != 0) { + errors += TOutput::strerror_s(errno_copy); + } + } + if (errors.empty()) { + errors = "error code: " + lexical_cast(errno_copy); + } +} + +/** + * Default implementation of AccessManager + */ +Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa) + throw() { return SKIP; } + +Decision DefaultClientAccessManager::verify(const string& host, + const char* name, + int size) throw() { + if (host.empty() || name == NULL || size <= 0) { + return SKIP; + } + return (matchName(host.c_str(), name, size) ? ALLOW : SKIP); +} + +Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa, + const char* data, + int size) throw() { + bool match = false; + if (sa.ss_family == AF_INET && size == sizeof(in_addr)) { + match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0); + } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) { + match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0); + } + return (match ? ALLOW : SKIP); +} + +/** + * Match a name with a pattern. The pattern may include wildcard. A single + * wildcard "*" can match up to one component in the domain name. + * + * @param host Host name, typically the name of the remote host + * @param pattern Name retrieved from certificate + * @param size Size of "pattern" + * @return True, if "host" matches "pattern". False otherwise. + */ +bool matchName(const char* host, const char* pattern, int size) { + bool match = false; + int i = 0, j = 0; + while (i < size && host[j] != '\0') { + if (uppercase(pattern[i]) == uppercase(host[j])) { + i++; + j++; + continue; + } + if (pattern[i] == '*') { + while (host[j] != '.' && host[j] != '\0') { + j++; + } + i++; + continue; + } + break; + } + if (i == size && host[j] == '\0') { + match = true; + } + return match; + +} + +// This is to work around the Turkish locale issue, i.e., +// toupper('i') != toupper('I') if locale is "tr_TR" +char uppercase (char c) { + if ('a' <= c && c <= 'z') { + return c + ('A' - 'a'); + } + return c; +} + +}}} diff --git a/lib/cpp/src/transport/TSSLSocket.h b/lib/cpp/src/transport/TSSLSocket.h new file mode 100644 index 00000000..58e39345 --- /dev/null +++ b/lib/cpp/src/transport/TSSLSocket.h @@ -0,0 +1,304 @@ +// Copyright (c) 2009- Facebook +// Distributed under the Thrift Software License +// +// See accompanying file LICENSE or visit the Thrift site at: +// http://developers.facebook.com/thrift/ + +#ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_ +#define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1 + +#include +#include +#include +#include "concurrency/Mutex.h" +#include "TSocket.h" + +namespace apache { namespace thrift { namespace transport { + +class AccessManager; +class SSLContext; + +/** + * OpenSSL implementation for SSL socket interface. + * + * @author Ping Li + */ +class TSSLSocket: public TSocket { + public: + ~TSSLSocket(); + /** + * TTransport interface. + */ + bool isOpen(); + bool peek(); + void open(); + void close(); + uint32_t read(uint8_t* buf, uint32_t len); + void write(const uint8_t* buf, uint32_t len); + void flush(); + /** + * Set whether to use client or server side SSL handshake protocol. + * + * @param flag Use server side handshake protocol if true. + */ + void server(bool flag) { server_ = flag; } + /** + * Determine whether the SSL socket is server or client mode. + */ + bool server() const { return server_; } + /** + * Set AccessManager. + * + * @param manager Instance of AccessManager + */ + virtual void access(boost::shared_ptr manager) { + access_ = manager; + } +protected: + /** + * Constructor. + */ + TSSLSocket(boost::shared_ptr ctx); + /** + * Constructor, create an instance of TSSLSocket given an existing socket. + * + * @param socket An existing socket + */ + TSSLSocket(boost::shared_ptr ctx, int socket); + /** + * Constructor. + * + * @param host Remote host name + * @param port Remote port number + */ + TSSLSocket(boost::shared_ptr ctx, + std::string host, + int port); + /** + * Authorize peer access after SSL handshake completes. + */ + virtual void authorize(); + /** + * Initiate SSL handshake if not already initiated. + */ + void checkHandshake(); + + bool server_; + SSL* ssl_; + boost::shared_ptr ctx_; + boost::shared_ptr access_; + friend class TSSLSocketFactory; +}; + +/** + * SSL socket factory. SSL sockets should be created via SSL factory. + */ +class TSSLSocketFactory { + public: + /** + * Constructor/Destructor + */ + TSSLSocketFactory(); + virtual ~TSSLSocketFactory(); + /** + * Create an instance of TSSLSocket with a fresh new socket. + */ + virtual boost::shared_ptr createSocket(); + /** + * Create an instance of TSSLSocket with the given socket. + * + * @param socket An existing socket. + */ + virtual boost::shared_ptr createSocket(int socket); + /** + * Create an instance of TSSLSocket. + * + * @param host Remote host to be connected to + * @param port Remote port to be connected to + */ + virtual boost::shared_ptr createSocket(const std::string& host, + int port); + /** + * Set ciphers to be used in SSL handshake process. + * + * @param ciphers A list of ciphers + */ + virtual void ciphers(const std::string& enable); + /** + * Enable/Disable authentication. + * + * @param required Require peer to present valid certificate if true + */ + virtual void authenticate(bool required); + /** + * Load server certificate. + * + * @param path Path to the certificate file + * @param format Certificate file format + */ + virtual void loadCertificate(const char* path, const char* format = "PEM"); + /** + * Load private key. + * + * @param path Path to the private key file + * @param format Private key file format + */ + virtual void loadPrivateKey(const char* path, const char* format = "PEM"); + /** + * Load trusted certificates from specified file. + * + * @param path Path to trusted certificate file + */ + virtual void loadTrustedCertificates(const char* path); + /** + * Default randomize method. + */ + virtual void randomize(); + /** + * Override default OpenSSL password callback with getPassword(). + */ + void overrideDefaultPasswordCallback(); + /** + * Set/Unset server mode. + * + * @param flag Server mode if true + */ + virtual void server(bool flag) { server_ = flag; } + /** + * Determine whether the socket is in server or client mode. + * + * @return true, if server mode, or, false, if client mode + */ + virtual bool server() const { return server_; } + /** + * Set AccessManager. + * + * @param manager The AccessManager instance + */ + virtual void access(boost::shared_ptr manager) { + access_ = manager; + } + protected: + boost::shared_ptr ctx_; + + static void initializeOpenSSL(); + static void cleanupOpenSSL(); + /** + * Override this method for custom password callback. It may be called + * multiple times at any time during a session as necessary. + * + * @param password Pass collected password to OpenSSL + * @param size Maximum length of password including NULL character + */ + virtual void getPassword(std::string& password, int size) { } + private: + bool server_; + boost::shared_ptr access_; + static bool initialized; + static concurrency::Mutex mutex_; + static uint64_t count_; + void setup(boost::shared_ptr ssl); + static int passwordCallback(char* password, int size, int, void* data); +}; + +/** + * SSL exception. + */ +class TSSLException: public TTransportException { + public: + TSSLException(const std::string& message): + TTransportException(TTransportException::INTERNAL_ERROR, message) {} + + virtual const char* what() const throw() { + if (message_.empty()) { + return "TSSLException"; + } else { + return message_.c_str(); + } + } +}; + +/** + * Wrap OpenSSL SSL_CTX into a class. + */ +class SSLContext { + public: + SSLContext(); + virtual ~SSLContext(); + SSL* createSSL(); + SSL_CTX* get() { return ctx_; } + private: + SSL_CTX* ctx_; +}; + +/** + * Callback interface for access control. It's meant to verify the remote host. + * It's constructed when application starts and set to TSSLSocketFactory + * instance. It's passed onto all TSSLSocket instances created by this factory + * object. + */ +class AccessManager { + public: + enum Decision { + DENY = -1, // deny access + SKIP = 0, // cannot make decision, move on to next (if any) + ALLOW = 1, // allow access + }; + /** + * Destructor + */ + virtual ~AccessManager() {} + /** + * Determine whether the peer should be granted access or not. It's called + * once after the SSL handshake completes successfully, before peer certificate + * is examined. + * + * If a valid decision (ALLOW or DENY) is returned, the peer certificate is + * not to be verified. + * + * @param sa Peer IP address + * @return True if the peer is trusted, false otherwise + */ + virtual Decision verify(const sockaddr_storage& sa) throw() { return DENY; } + /** + * Determine whether the peer should be granted access or not. It's called + * every time a DNS subjectAltName/common name is extracted from peer's + * certificate. + * + * @param host Client mode: host name returned by TSocket::getHost() + * Server mode: host name returned by TSocket::getPeerHost() + * @param name SubjectAltName or common name extracted from peer certificate + * @param size Length of name + * @return True if the peer is trusted, false otherwise + * + * Note: The "name" parameter may be UTF8 encoded. + */ + virtual Decision verify(const std::string& host, const char* name, int size) + throw() { return DENY; } + /** + * Determine whether the peer should be granted access or not. It's called + * every time an IP subjectAltName is extracted from peer's certificate. + * + * @param sa Peer IP address retrieved from the underlying socket + * @param data IP address extracted from certificate + * @param size Length of the IP address + * @return True if the peer is trusted, false otherwise + */ + virtual Decision verify(const sockaddr_storage& sa, const char* data, int size) + throw() { return DENY; } +}; + +typedef AccessManager::Decision Decision; + +class DefaultClientAccessManager: public AccessManager { + public: + // AccessManager interface + Decision verify(const sockaddr_storage& sa) throw(); + Decision verify(const std::string& host, const char* name, int size) throw(); + Decision verify(const sockaddr_storage& sa, const char* data, int size) throw(); +}; + + +}}} + +#endif diff --git a/lib/cpp/src/transport/TServerSocket.cpp b/lib/cpp/src/transport/TServerSocket.cpp index 8608898f..276b060b 100644 --- a/lib/cpp/src/transport/TServerSocket.cpp +++ b/lib/cpp/src/transport/TServerSocket.cpp @@ -386,7 +386,7 @@ shared_ptr TServerSocket::acceptImpl() { throw TTransportException(TTransportException::UNKNOWN, "fcntl(F_SETFL)", errno_copy); } - shared_ptr client(new TSocket(clientSocket)); + shared_ptr client = createSocket(clientSocket); if (sendTimeout_ > 0) { client->setSendTimeout(sendTimeout_); } @@ -398,6 +398,10 @@ shared_ptr TServerSocket::acceptImpl() { return client; } +shared_ptr TServerSocket::createSocket(int clientSocket) { + return shared_ptr(new TSocket(clientSocket)); +} + void TServerSocket::interrupt() { if (intSock1_ >= 0) { int8_t byte = 0; diff --git a/lib/cpp/src/transport/TServerSocket.h b/lib/cpp/src/transport/TServerSocket.h index 8cd521fb..40a1148e 100644 --- a/lib/cpp/src/transport/TServerSocket.h +++ b/lib/cpp/src/transport/TServerSocket.h @@ -56,6 +56,7 @@ class TServerSocket : public TServerTransport { protected: boost::shared_ptr acceptImpl(); + virtual boost::shared_ptr createSocket(int client); private: int port_; diff --git a/lib/cpp/src/transport/TSocket.h b/lib/cpp/src/transport/TSocket.h index e89059f3..55214916 100644 --- a/lib/cpp/src/transport/TSocket.h +++ b/lib/cpp/src/transport/TSocket.h @@ -70,12 +70,12 @@ class TSocket : public TVirtualTransport { * * @return Is the socket alive? */ - bool isOpen(); + virtual bool isOpen(); /** * Calls select on the socket to see if there is more data available. */ - bool peek(); + virtual bool peek(); /** * Creates and opens the UNIX socket. @@ -92,12 +92,12 @@ class TSocket : public TVirtualTransport { /** * Reads from the underlying socket. */ - uint32_t read(uint8_t* buf, uint32_t len); + virtual uint32_t read(uint8_t* buf, uint32_t len); /** * Writes to the underlying socket. Loops until done or fail. */ - void write(const uint8_t* buf, uint32_t len); + virtual void write(const uint8_t* buf, uint32_t len); /** * Writes to the underlying socket. Does single send() and returns result. diff --git a/test/cpp/src/TestClient.cpp b/test/cpp/src/TestClient.cpp index 897153ab..7e37e856 100644 --- a/test/cpp/src/TestClient.cpp +++ b/test/cpp/src/TestClient.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "ThriftTest.h" @@ -56,6 +57,7 @@ int main(int argc, char** argv) { int port = 9090; int numTests = 1; bool framed = false; + bool ssl = false; for (int i = 0; i < argc; ++i) { if (strcmp(argv[i], "-h") == 0) { @@ -71,9 +73,22 @@ int main(int argc, char** argv) { numTests = atoi(argv[++i]); } else if (strcmp(argv[i], "-f") == 0) { framed = true; + } else if (strcmp(argv[i], "--ssl") == 0) { + ssl = true; } } + shared_ptr socket; + shared_ptr factory; + if (ssl) { + factory = shared_ptr(new TSSLSocketFactory()); + factory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + factory->loadTrustedCertificates("./trusted-ca-certificate.pem"); + factory->authenticate(true); + socket = factory->createSocket(host, port); + } else { + socket = shared_ptr(new TSocket(host, port)); + } shared_ptr transport; diff --git a/test/cpp/src/TestServer.cpp b/test/cpp/src/TestServer.cpp index d30475b9..047dd331 100644 --- a/test/cpp/src/TestServer.cpp +++ b/test/cpp/src/TestServer.cpp @@ -34,6 +34,7 @@ #define __STDC_FORMAT_MACROS #include +#include using namespace std; using namespace boost; @@ -326,6 +327,7 @@ int main(int argc, char **argv) { string serverType = "simple"; string protocolType = "binary"; size_t workerCount = 4; + bool ssl = false; ostringstream usage; @@ -391,6 +393,11 @@ int main(int argc, char **argv) { cerr << usage; } + if (args["ssl"] == "true") { + ssl = true; + signal(SIGPIPE, SIG_IGN); + } + // Dispatcher shared_ptr protocolFactory( new TBinaryProtocolFactoryT()); @@ -407,8 +414,18 @@ int main(int argc, char **argv) { } // Transport - shared_ptr serverSocket(new TServerSocket(port)); - + shared_ptr sslSocketFactory; + shared_ptr serverSocket; + + if (ssl) { + sslSocketFactory = shared_ptr(new TSSLSocketFactory()); + sslSocketFactory->loadCertificate("./server-certificate.pem"); + sslSocketFactory->loadPrivateKey("./server-private-key.pem"); + sslSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); + serverSocket = shared_ptr(new TSSLServerSocket(port, sslSocketFactory)); + } else { + serverSocket = shared_ptr(new TServerSocket(port)); + } // Factory shared_ptr transportFactory(new TBufferedTransportFactory());