From d4837129729c85d49476cae08fb0c5ec84a2a811 Mon Sep 17 00:00:00 2001 From: Bryan Duxbury Date: Wed, 8 Sep 2010 00:06:35 +0000 Subject: [PATCH] THRIFT-876. java: Add SASL support This patch adds support for a SASL-secured transport to the Java library. In its current form, it only works for the blocking-IO servers. Patch: Aaron T Meyers git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@993563 13f79535-47bb-0310-9956-ffa450edef68 --- doc/thrift-sasl-spec.txt | 107 ++++ .../src/org/apache/thrift/EncodingUtils.java | 85 ++++ .../transport/TSaslClientTransport.java | 108 ++++ .../transport/TSaslServerTransport.java | 233 +++++++++ .../thrift/transport/TSaslTransport.java | 470 ++++++++++++++++++ .../apache/thrift/server/ServerTestBase.java | 8 +- .../thrift/server/TestNonblockingServer.java | 7 +- .../thrift/transport/TestTSaslTransports.java | 225 +++++++++ 8 files changed, 1232 insertions(+), 11 deletions(-) create mode 100644 doc/thrift-sasl-spec.txt create mode 100644 lib/java/src/org/apache/thrift/EncodingUtils.java create mode 100644 lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java create mode 100644 lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java create mode 100644 lib/java/src/org/apache/thrift/transport/TSaslTransport.java create mode 100644 lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java diff --git a/doc/thrift-sasl-spec.txt b/doc/thrift-sasl-spec.txt new file mode 100644 index 00000000..59bfcf98 --- /dev/null +++ b/doc/thrift-sasl-spec.txt @@ -0,0 +1,107 @@ +A Thrift SASL message shall be a byte array of one of the following forms: + +| 1-byte START status code | 1-byte mechanism name length | variable length mechanism name | 4-byte payload length | variable-length payload | +| 1-byte status code | 4-byte payload length | variable-length payload | + +The length fields shall be interpreted as integers, with the high byte sent +first. This indicates the length of the field immediately following it, not +including the status code or the length bytes. + +The possible status codes are: + +0x01 - START - Hello, let's go on a date. +0x02 - OK - Everything's been going alright so far, let's see each other again. +0x03 - BAD - I understand what you're saying. I really do. I just don't like it. We have to break up. +0x04 - ERROR - We can't go on like this. It's like you're speaking another language. +0x05 - COMPLETE - Will you marry me? + +The Thrift SASL communication will proceed as follows: + +1. The client is configured at instantiation of the transport with a single +underlying SASL security mechanism that it supports. + +2. The server is configured with a mapping of underlying security mechanism +name -> mechanism options. + +3. At connection time, the client will initiate communication by sending the +server a START byte, followed by a 1-byte field indicating the length in bytes +of the underlying security mechanism name that the client would like to use. +This mechanism name shall be 1-20 characters in length, and follow the +specifications for SASL mechanism names specified in RFC 2222. This mechanism +name shall be followed by a 4-byte, potentially zero-value message length word, +followed by a potentially zero-length payload. The payload is determined by the +output byte array of the underlying actual security mechanism, and will be +empty except for those underlying security protocols which implement the +optional SASL initial response. + +4. The server receives this message and, if the mechanism name provided is +among the set of mechanisms this server transport is configured to accept, +appropriate initialization of the underlying security mechanism may take place. +If the mechanism name is not one which the server is configured to support, the +server shall return the BAD byte, followed by a 4-byte, potentially zero-value +message length, followed by the potentially zero-length payload which may be a +status code or message indicating failure. No further communication may take +place via this transport. If the mechanism name is one which the server +supports, then proceed to step 5. + +5. The server then provides the byte array of the payload received to its +underlying security mechanism. A challenge is generated by the underlying +security mechanism on the server, and this is used as the payload for a message +sent to the client. This message shall consist of an OK byte, followed by the +non-zero message length word, followed by the payload. + +6. The client receives this message from the server and passes the payload to +its underlying security mechanism to generate a response. The client then sends +the server an OK byte, followed by the non-zero-value length of the response, +followed by the bytes of the response as the payload. + +7. Steps 5 and 6 are repeated until both security mechanisms are satisfied with +the challenge/response exchange. When either side has completed its security +protocol, its next message shall be the COMPLETE byte, followed by a 4-byte +potentially zero-value length word, followed by a potentially zero-length +payload. This payload will be empty except for those underlying security +mechanisms which provide additional data with success. + +If at any point in time either side is able to interpret the challenge or +response sent by the other, but is dissatisfied with the contents thereof, this +side should send the other a BAD byte, followed by a 4-byte potentially +zero-value length word, followed by an optional, potentially zero-length +message encoded in UTF-8 indicating failure. This message should be passed to +the protocol above the thrift transport by whatever mechanism is appropriate +and idiomatic for the particular language these thrift bindings are for. + +If at any point in time either side fails to interpret the challenge or +response sent by the other, this side should send the other an ERROR byte, +followed by a 4-byte potentially zero-value length word, followed by an +optional, potentially zero-length message encoded in UTF-8. This message should +be passed to the protocol above the thrift transport by whatever mechanism is +appropriate and idiomatic for the particular language these thrift bindings are +for. + +If step 7 completes successfully, then the communication is considered +authenticated and subsequent communication may commence. + +If step 7 fails to complete successfully, then no further communication may +take place via this transport. + +8. All writes to the underlying transport must be prefixed by the 4-byte length +of the payload data, followed by the payload. All reads from this transport +should read the 4-byte length word, then read the full quantity of bytes +specified by this length word. + +If no SASL QOP (quality of protection) is negotiated during steps 5 and 6, then +all subsequent writes to/reads from this transport are written/read unaltered, +save for the length prefix, to the underlying transport. + +If a SASL QOP is negotiated, then this must be used by the Thrift transport for +all subsequent communication. This is done by wrapping subsequent writes to the +transport using the underlying security mechanism, and unwrapping subsequent +reads from the underlying transport. Note that in this case, the length prefix +of the write to the underlying transport is the length of the data after it has +been wrapped by the underlying security mechanism. Note that the complete +message must be read before giving this data to the underlying security +mechanism for unwrapping. + +If at any point in time reading of a message fails either because of a +malformed length word or failure to unwrap by the underlying security +mechanism, then all further communication on this transport must cease. diff --git a/lib/java/src/org/apache/thrift/EncodingUtils.java b/lib/java/src/org/apache/thrift/EncodingUtils.java new file mode 100644 index 00000000..072de93c --- /dev/null +++ b/lib/java/src/org/apache/thrift/EncodingUtils.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift; + +/** + * Utility methods for use when encoding/decoding raw data as byte arrays. + */ +public class EncodingUtils { + + /** + * Encode integer as a series of 4 bytes into buf + * starting at position 0 within that buffer. + * + * @param integer + * The integer to encode. + * @param buf + * The buffer to write to. + */ + public static final void encodeBigEndian(final int integer, final byte[] buf) { + encodeBigEndian(integer, buf, 0); + } + + /** + * Encode integer as a series of 4 bytes into buf + * starting at position offset. + * + * @param integer + * The integer to encode. + * @param buf + * The buffer to write to. + * @param offset + * The offset within buf to start the encoding. + */ + public static final void encodeBigEndian(final int integer, final byte[] buf, int offset) { + buf[offset] = (byte) (0xff & (integer >> 24)); + buf[offset + 1] = (byte) (0xff & (integer >> 16)); + buf[offset + 2] = (byte) (0xff & (integer >> 8)); + buf[offset + 3] = (byte) (0xff & (integer)); + } + + /** + * Decode a series of 4 bytes from buf, starting at position 0, + * and interpret them as an integer. + * + * @param buf + * The buffer to read from. + * @return An integer, as read from the buffer. + */ + public static final int decodeBigEndian(final byte[] buf) { + return decodeBigEndian(buf, 0); + } + + /** + * Decode a series of 4 bytes from buf, start at + * offset, and interpret them as an integer. + * + * @param buf + * The buffer to read from. + * @param offset + * The offset with buf to start the decoding. + * @return An integer, as read from the buffer. + */ + public static final int decodeBigEndian(final byte[] buf, int offset) { + return ((buf[offset] & 0xff) << 24) | ((buf[offset + 1] & 0xff) << 16) + | ((buf[offset + 2] & 0xff) << 8) | ((buf[offset + 3] & 0xff)); + } + +} diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java new file mode 100644 index 00000000..fc8a3ea2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; + +import org.apache.thrift.EncodingUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wraps another Thrift TTransport, but performs SASL client + * negotiation on the call to open(). This class will wrap ensuing + * communication over it, if a SASL QOP is negotiated with the other party. + */ +public class TSaslClientTransport extends TSaslTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslClientTransport.class); + + /** + * The name of the mechanism this client supports. + */ + private final String mechanism; + + /** + * Uses the given SaslClient. + * + * @param saslClient + * The SaslClient to use for the subsequent SASL + * negotiation. + * @param transport + * Transport underlying this one. + */ + public TSaslClientTransport(SaslClient saslClient, TTransport transport) { + super(saslClient, transport); + mechanism = saslClient.getMechanismName(); + } + + /** + * Creates a SaslClient using the given SASL-specific parameters. + * See the Java documentation for Sasl.createSaslClient for the + * details of the parameters. + * + * @param transport + * The underlying Thrift transport. + * @throws SaslException + */ + public TSaslClientTransport(String mechanism, String authorizationId, String protocol, + String serverName, Map props, CallbackHandler cbh, TTransport transport) + throws SaslException { + super(Sasl.createSaslClient(new String[] { mechanism }, authorizationId, protocol, serverName, + props, cbh), transport); + this.mechanism = mechanism; + } + + /** + * Performs the client side of the initial portion of the Thrift SASL + * protocol. Generates and sends the initial response to the server, including + * which mechanism this client wants to use. + */ + @Override + protected void handleSaslStartMessage() throws TTransportException, SaslException { + SaslClient saslClient = getSaslClient(); + + byte[] initialResponse = new byte[0]; + if (saslClient.hasInitialResponse()) + initialResponse = saslClient.evaluateChallenge(initialResponse); + + byte[] mechanismBytes = mechanism.getBytes(); + byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES + mechanismBytes.length + + PAYLOAD_LENGTH_BYTES]; + + messageHeader[0] = START; + messageHeader[1] = (byte) (0xff & mechanismBytes.length); + System.arraycopy(mechanismBytes, 0, messageHeader, STATUS_BYTES + MECHANISM_NAME_BYTES, + mechanismBytes.length); + EncodingUtils.encodeBigEndian(initialResponse.length, messageHeader, STATUS_BYTES + + MECHANISM_NAME_BYTES + mechanismBytes.length); + + LOGGER.debug("Sending mechanism name {} and initial response of length {}", mechanism, + initialResponse.length); + underlyingTransport.write(messageHeader); + underlyingTransport.write(initialResponse); + underlyingTransport.flush(); + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java new file mode 100644 index 00000000..b07e5972 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.WeakHashMap; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Wraps another Thrift TTransport, but performs SASL server + * negotiation on the call to open(). This class will wrap ensuing + * communication over it, if a SASL QOP is negotiated with the other party. + */ +public class TSaslServerTransport extends TSaslTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslServerTransport.class); + + /** + * Mapping from SASL mechanism name -> all the parameters required to + * instantiate a SASL server. + */ + private Map serverDefinitionMap = new HashMap(); + + /** + * Contains all the parameters used to define a SASL server implementation. + */ + private static class TSaslServerDefinition { + public String mechanism; + public String protocol; + public String serverName; + public Map props; + public CallbackHandler cbh; + + public TSaslServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + this.mechanism = mechanism; + this.protocol = protocol; + this.serverName = serverName; + this.props = props; + this.cbh = cbh; + } + } + + /** + * Uses the given underlying transport. Assumes that addServerDefinition is + * called later. + * + * @param transport + * Transport underlying this one. + */ + public TSaslServerTransport(TTransport transport) { + super(transport); + } + + /** + * Creates a SaslServer using the given SASL-specific parameters. + * See the Java documentation for Sasl.createSaslServer for the + * details of the parameters. + * + * @param transport + * The underlying Thrift transport. + */ + public TSaslServerTransport(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh, TTransport transport) { + super(transport); + addServerDefinition(mechanism, protocol, serverName, props, cbh); + } + + private TSaslServerTransport(Map serverDefinitionMap, TTransport transport) { + super(transport); + this.serverDefinitionMap.putAll(serverDefinitionMap); + } + + /** + * Add a supported server definition to this transport. See the Java + * documentation for Sasl.createSaslServer for the details of the + * parameters. + */ + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + /** + * Performs the server side of the initial portion of the Thrift SASL protocol. + * Receives the initial response from the client, creates a SASL server using + * the mechanism requested by the client (if this server supports it), and + * sends the first challenge back to the client. + */ + @Override + protected void handleSaslStartMessage() throws TTransportException, SaslException { + // Get the status byte and length of the mechanism name. + byte[] messageHeader = new byte[STATUS_BYTES + MECHANISM_NAME_BYTES]; + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + LOGGER.debug("Received status {} and mechanism name length {}", messageHeader[0], + messageHeader[1]); + if (messageHeader[0] != START) { + sendAndThrowMessage(ERROR, "Expecting START status, received " + messageHeader[0]); + } + + // Get the mechanism name. + byte[] mechanismBytes = new byte[messageHeader[1]]; + underlyingTransport.readAll(mechanismBytes, 0, mechanismBytes.length); + + String mechanismName = new String(mechanismBytes); + TSaslServerDefinition serverDefinition = serverDefinitionMap.get(new String(mechanismBytes)); + LOGGER.debug("Received mechanism name '{}'", mechanismName); + + if (serverDefinition == null) { + sendAndThrowMessage(BAD, "Unsupported mechanism type " + mechanismName); + } + SaslServer saslServer = Sasl.createSaslServer(serverDefinition.mechanism, + serverDefinition.protocol, serverDefinition.serverName, serverDefinition.props, + serverDefinition.cbh); + + // Evaluate the initial response and send the first challenge. + byte[] initialResponse = new byte[readLength()]; + sendSaslMessage(saslServer.isComplete() ? COMPLETE : OK, saslServer + .evaluateResponse(initialResponse)); + + setSaslServer(saslServer); + } + + /** + * TTransportFactory to create + * TSaslServerTransports. Ensures that a given + * underlying TTransport instance receives the same + * TSaslServerTransport. This is kind of an awful hack to work + * around the fact that Thrift is designed assuming that + * TTransport instances are stateless, and thus the existing + * TServers use different TTransport instances for + * input and output. + */ + public static class Factory extends TTransportFactory { + + /** + * This is the implementation of the awful hack described above. + * WeakHashMap is used to ensure that we don't leak memory. + */ + private static Map transportMap = + Collections.synchronizedMap(new WeakHashMap()); + + /** + * Mapping from SASL mechanism name -> all the parameters required to + * instantiate a SASL server. + */ + private Map serverDefinitionMap = new HashMap(); + + /** + * Create a new Factory. Assumes that addServerDefinition will + * be called later. + */ + public Factory() { + super(); + } + + /** + * Create a new Factory, initially with the single server + * definition given. You may still call addServerDefinition + * later. See the Java documentation for Sasl.createSaslServer + * for the details of the parameters. + */ + public Factory(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + super(); + addServerDefinition(mechanism, protocol, serverName, props, cbh); + } + + /** + * Add a supported server definition to the transports created by this + * factory. See the Java documentation for + * Sasl.createSaslServer for the details of the parameters. + */ + public void addServerDefinition(String mechanism, String protocol, String serverName, + Map props, CallbackHandler cbh) { + serverDefinitionMap.put(mechanism, new TSaslServerDefinition(mechanism, protocol, serverName, + props, cbh)); + } + + /** + * Get a new TSaslServerTransport instance, or reuse the + * existing one if a TSaslServerTransport has already been + * created before using the given TTransport as an underlying + * transport. This ensures that a given underlying transport instance + * receives the same TSaslServerTransport. + */ + @Override + public TTransport getTransport(TTransport base) { + TSaslServerTransport ret = transportMap.get(base); + if (ret == null) { + LOGGER.debug("transport map does not contain key", base); + ret = new TSaslServerTransport(serverDefinitionMap, base); + try { + ret.open(); + } catch (TTransportException e) { + LOGGER.debug("failed to open server transport", e); + return null; + } + transportMap.put(base, ret); + } else { + LOGGER.debug("transport map does contain key {}", base); + } + return ret; + } + } +} diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java new file mode 100644 index 00000000..b5eadb74 --- /dev/null +++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.UnsupportedEncodingException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; + +import org.apache.thrift.EncodingUtils; +import org.apache.thrift.TByteArrayOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A superclass for SASL client/server thrift transports. A subclass need only + * implement the open method. + */ +abstract class TSaslTransport extends TTransport { + + private static final Logger LOGGER = LoggerFactory.getLogger(TSaslTransport.class); + + protected static final int DEFAULT_MAX_LENGTH = 0x7FFFFFFF; + + protected static final int MECHANISM_NAME_BYTES = 1; + protected static final int STATUS_BYTES = 1; + protected static final int PAYLOAD_LENGTH_BYTES = 4; + + /** + * Status bytes used during the initial Thrift SASL handshake. + */ + protected static final byte START = 0x01; + protected static final byte OK = 0x02; + protected static final byte BAD = 0x03; + protected static final byte ERROR = 0x04; + protected static final byte COMPLETE = 0x05; + + protected static final Set VALID_STATUSES = new HashSet(Arrays.asList(START, OK, BAD, ERROR, COMPLETE)); + + /** + * Transport underlying this one. + */ + protected TTransport underlyingTransport; + + /** + * Either a SASL client or a SASL server. + */ + private SaslParticipant sasl; + + /** + * Whether or not we should wrap/unwrap reads/writes. Determined by whether or + * not a QOP is negotiated during the SASL handshake. + */ + private boolean shouldWrap = false; + + /** + * Buffer for input. + */ + private TMemoryInputTransport readBuffer = new TMemoryInputTransport(); + + /** + * Buffer for output. + */ + private final TByteArrayOutputStream writeBuffer = new TByteArrayOutputStream(1024); + + /** + * Create a TSaslTransport. It's assumed that setSaslServer will be called + * later to initialize the SASL endpoint underlying this transport. + * + * @param underlyingTransport + * The thrift transport which this transport is wrapping. + */ + protected TSaslTransport(TTransport underlyingTransport) { + this.underlyingTransport = underlyingTransport; + } + + /** + * Create a TSaslTransport which acts as a client. + * + * @param saslClient + * The SaslClient which this transport will use for SASL + * negotiation. + * @param underlyingTransport + * The thrift transport which this transport is wrapping. + */ + protected TSaslTransport(SaslClient saslClient, TTransport underlyingTransport) { + sasl = new SaslParticipant(saslClient); + this.underlyingTransport = underlyingTransport; + } + + protected void setSaslServer(SaslServer saslServer) { + sasl = new SaslParticipant(saslServer); + } + + // Used to read the status byte and payload length. + private final byte[] messageHeader = new byte[STATUS_BYTES + PAYLOAD_LENGTH_BYTES]; + + /** + * Send a complete Thrift SASL message. + * + * @param status + * The status to send. + * @param payload + * The data to send as the payload of this message. + * @throws TTransportException + */ + protected void sendSaslMessage(byte status, byte[] payload) throws TTransportException { + if (payload == null) + payload = new byte[0]; + + messageHeader[0] = status; + EncodingUtils.encodeBigEndian(payload.length, messageHeader, STATUS_BYTES); + + LOGGER.debug("Writing message with status {} and payload length {}", status, payload.length); + underlyingTransport.write(messageHeader); + underlyingTransport.write(payload); + underlyingTransport.flush(); + } + + /** + * Read a complete Thrift SASL message. + * + * @return The SASL status and payload from this message. + * @throws TTransportException + * Thrown if there is a failure reading from the underlying + * transport, or if a status code of BAD or ERROR is encountered. + */ + protected SaslResponse receiveSaslMessage() throws TTransportException { + underlyingTransport.readAll(messageHeader, 0, messageHeader.length); + + byte status = messageHeader[0]; + byte[] payload = new byte[EncodingUtils.decodeBigEndian(messageHeader, STATUS_BYTES)]; + underlyingTransport.readAll(payload, 0, payload.length); + + if (!VALID_STATUSES.contains(status)) + sendAndThrowMessage(ERROR, "Invalid status " + status); + else if (status == BAD || status == ERROR) { + try { + throw new TTransportException(new String(payload, "UTF-8")); + } catch (UnsupportedEncodingException e) { + throw new TTransportException(e); + } + } + + LOGGER.debug("Received message with status {} and payload length {}", status, payload.length); + return new SaslResponse(status, payload); + } + + /** + * Send a Thrift SASL message with the given status (usaully BAD or ERROR) and + * string message, and then throw a TTransportException with the given + * message. + * + * @param status + * The Thrift SASL status code to send. Usually BAD or ERROR. + * @param message + * The optional message to send to the other side. + * @throws TTransportException + * Always thrown with the message provided. + */ + protected void sendAndThrowMessage(byte status, String message) throws TTransportException { + sendSaslMessage(status, message.getBytes()); + throw new TTransportException(message); + } + + /** + * Implemented by subclasses to start the Thrift SASL handshake process. When + * this method completes, the SaslParticipant in this class is + * assumed to be initialized. + * + * @throws TTransportException + * @throws SaslException + */ + abstract protected void handleSaslStartMessage() throws TTransportException, SaslException; + + /** + * Opens the underlying transport if it's not already open and then performs + * SASL negotiation. If a QOP is negoiated during this SASL handshake, it used + * for all communication on this transport after this call is complete. + */ + @Override + public void open() throws TTransportException { + LOGGER.debug("opening transport {}", this); + if (sasl != null && sasl.isComplete()) + throw new TTransportException("SASL transport already open"); + + if (!underlyingTransport.isOpen()) + underlyingTransport.open(); + + try { + handleSaslStartMessage(); + + SaslResponse message; + do { + message = receiveSaslMessage(); + if (message.status != COMPLETE && message.status != OK) { + throw new TTransportException("Expected COMPLETE or OK, got " + message.status); + } + + if (sasl.isComplete() && message.status == COMPLETE) + break; + + byte[] challenge = sasl.evaluateChallengeOrResponse(message.payload); + sendSaslMessage(sasl.isComplete() ? COMPLETE : OK, challenge); + } while (!(sasl.isComplete() && message.status == COMPLETE)); + } catch (SaslException e) { + underlyingTransport.close(); + sendAndThrowMessage(BAD, e.getMessage()); + } + + String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP); + if (qop != null && !qop.equalsIgnoreCase("auth")) + shouldWrap = true; + } + + /** + * Get the underlying SaslClient. + * + * @return The SaslClient, or null if this transport + * is backed by a SaslServer. + */ + protected SaslClient getSaslClient() { + return sasl.saslClient; + } + + /** + * Get the underlying SaslServer. + * + * @return The SaslServer, or null if this transport + * is backed by a SaslClient. + */ + protected SaslServer getSaslServer() { + return sasl.saslServer; + } + + /** + * Read a 4-byte word from the underlying transport and interpret it as an + * integer. + * + * @return The length prefix of the next SASL message to read. + * @throws TTransportException + * Thrown if reading from the underlying transport fails. + */ + protected int readLength() throws TTransportException { + byte[] lenBuf = new byte[4]; + underlyingTransport.readAll(lenBuf, 0, lenBuf.length); + return EncodingUtils.decodeBigEndian(lenBuf); + } + + /** + * Write the given integer as 4 bytes to the underlying transport. + * + * @param length + * The length prefix of the next SASL message to write. + * @throws TTransportException + * Thrown if writing to the underlying transport fails. + */ + protected void writeLength(int length) throws TTransportException { + byte[] lenBuf = new byte[4]; + TFramedTransport.encodeFrameSize(length, lenBuf); + underlyingTransport.write(lenBuf); + } + + // Below is the SASL implementation of the TTransport interface. + + /** + * Closes the underlying transport and disposes of the SASL implementation + * underlying this transport. + */ + @Override + public void close() { + underlyingTransport.close(); + try { + sasl.dispose(); + } catch (SaslException e) { + // Not much we can do here. + } + } + + /** + * True if the underlying transport is open and the SASL handshake is + * complete. + */ + @Override + public boolean isOpen() { + return underlyingTransport.isOpen() && sasl != null && sasl.isComplete(); + } + + /** + * Read from the underlying transport. Unwraps the contents if a QOP was + * negotiated during the SASL handshake. + */ + @Override + public int read(byte[] buf, int off, int len) throws TTransportException { + if (!isOpen()) + throw new TTransportException("SASL authentication not complete"); + + int got = readBuffer.read(buf, off, len); + if (got > 0) { + return got; + } + + // Read another frame of data + try { + readFrame(); + } catch (SaslException e) { + throw new TTransportException(e); + } + + return readBuffer.read(buf, off, len); + } + + /** + * Read a single frame of data from the underlying transport, unwrapping if + * necessary. + * + * @throws TTransportException + * Thrown if there's an error reading from the underlying transport. + * @throws SaslException + * Thrown if there's an error unwrapping the data. + */ + private void readFrame() throws TTransportException, SaslException { + int dataLength = readLength(); + + if (dataLength < 0) + throw new TTransportException("Read a negative frame size (" + dataLength + ")!"); + + byte[] buff = new byte[dataLength]; + LOGGER.debug("reading data length: {}", dataLength); + underlyingTransport.readAll(buff, 0, dataLength); + if (shouldWrap) { + buff = sasl.unwrap(buff, 0, buff.length); + LOGGER.debug("data length after unwrap: {}", buff.length); + } + readBuffer.reset(buff); + } + + /** + * Write to the underlying transport. + */ + @Override + public void write(byte[] buf, int off, int len) throws TTransportException { + if (!isOpen()) + throw new TTransportException("SASL authentication not complete"); + + writeBuffer.write(buf, off, len); + } + + /** + * Flushes to the underlying transport. Wraps the contents if a QOP was + * negotiated during the SASL handshake. + */ + @Override + public void flush() throws TTransportException { + byte[] buf = writeBuffer.get(); + int dataLength = writeBuffer.len(); + writeBuffer.reset(); + + if (shouldWrap) { + LOGGER.debug("data length before wrap: {}", dataLength); + try { + buf = sasl.wrap(buf, 0, dataLength); + } catch (SaslException e) { + throw new TTransportException(e); + } + dataLength = buf.length; + } + LOGGER.debug("writing data length: {}", dataLength); + writeLength(dataLength); + underlyingTransport.write(buf, 0, dataLength); + underlyingTransport.flush(); + } + + /** + * Used exclusively by readSaslMessage to return both a status and data. + */ + private static class SaslResponse { + public byte status; + public byte[] payload; + + public SaslResponse(byte status, byte[] payload) { + this.status = status; + this.payload = payload; + } + } + + /** + * Used to abstract over the SaslServer and + * SaslClient classes, which share a lot of their interface, but + * unfortunately don't share a common superclass. + */ + private static class SaslParticipant { + // One of these will always be null. + public SaslServer saslServer; + public SaslClient saslClient; + + public SaslParticipant(SaslServer saslServer) { + this.saslServer = saslServer; + } + + public SaslParticipant(SaslClient saslClient) { + this.saslClient = saslClient; + } + + public byte[] evaluateChallengeOrResponse(byte[] challengeOrResponse) throws SaslException { + if (saslClient != null) { + return saslClient.evaluateChallenge(challengeOrResponse); + } else { + return saslServer.evaluateResponse(challengeOrResponse); + } + } + + public boolean isComplete() { + if (saslClient != null) + return saslClient.isComplete(); + else + return saslServer.isComplete(); + } + + public void dispose() throws SaslException { + if (saslClient != null) + saslClient.dispose(); + else + saslServer.dispose(); + } + + public byte[] unwrap(byte[] buf, int off, int len) throws SaslException { + if (saslClient != null) + return saslClient.unwrap(buf, off, len); + else + return saslServer.unwrap(buf, off, len); + } + + public byte[] wrap(byte[] buf, int off, int len) throws SaslException { + if (saslClient != null) + return saslClient.wrap(buf, off, len); + else + return saslServer.wrap(buf, off, len); + } + + public Object getNegotiatedProperty(String propName) { + if (saslClient != null) + return saslClient.getNegotiatedProperty(propName); + else + return saslServer.getNegotiatedProperty(propName); + } + } +} diff --git a/lib/java/test/org/apache/thrift/server/ServerTestBase.java b/lib/java/test/org/apache/thrift/server/ServerTestBase.java index 88430e6c..3bfc8d7e 100644 --- a/lib/java/test/org/apache/thrift/server/ServerTestBase.java +++ b/lib/java/test/org/apache/thrift/server/ServerTestBase.java @@ -34,7 +34,6 @@ import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TCompactProtocol; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; -import org.apache.thrift.transport.TFramedTransport; import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; @@ -286,7 +285,7 @@ public abstract class ServerTestBase extends TestCase { public abstract void stopServer() throws Exception; - public abstract TTransport getTransport() throws Exception; + public abstract TTransport getClientTransport(TTransport underlyingTransport) throws Exception; private void testByte(ThriftTest.Client testClient) throws TException { byte i8 = testClient.testByte((byte)1); @@ -374,12 +373,9 @@ public abstract class ServerTestBase extends TestCase { startServer(processor, protoFactory); - TTransport transport; - TSocket socket = new TSocket(HOST, PORT); socket.setTimeout(SOCKET_TIMEOUT); - transport = socket; - transport = new TFramedTransport(transport); + TTransport transport = getClientTransport(socket); TProtocol protocol = protoFactory.getProtocol(transport); ThriftTest.Client testClient = new ThriftTest.Client(protocol); diff --git a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java index c43b4731..e2024351 100644 --- a/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java +++ b/lib/java/test/org/apache/thrift/server/TestNonblockingServer.java @@ -23,7 +23,6 @@ import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.transport.TFramedTransport; import org.apache.thrift.transport.TNonblockingServerSocket; -import org.apache.thrift.transport.TSocket; import org.apache.thrift.transport.TTransport; public class TestNonblockingServer extends ServerTestBase { @@ -68,9 +67,7 @@ public class TestNonblockingServer extends ServerTestBase { } @Override - public TTransport getTransport() throws Exception { - TSocket socket = new TSocket(HOST, PORT); - socket.setTimeout(SOCKET_TIMEOUT); - return new TFramedTransport(socket); + public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { + return new TFramedTransport(underlyingTransport); } } diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java new file mode 100644 index 00000000..812028d1 --- /dev/null +++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.transport; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; + +import org.apache.thrift.TProcessor; +import org.apache.thrift.protocol.TProtocolFactory; +import org.apache.thrift.server.ServerTestBase; +import org.apache.thrift.server.TServer; +import org.apache.thrift.server.TSimpleServer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import junit.framework.TestCase; + +public class TestTSaslTransports extends TestCase { + + private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); + + private static final String HOST = "localhost"; + private static final String SERVICE = "thrift-test"; + private static final String PRINCIPAL = "thrift-test-principal"; + private static final String PASSWORD = "super secret password"; + private static final String REALM = "thrift-test-realm"; + + private static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; + private static final Map UNWRAPPED_PROPS = null; + + private static final String WRAPPED_MECHANISM = "DIGEST-MD5"; + private static final Map WRAPPED_PROPS = new HashMap(); + + static { + WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); + WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); + } + + private static final String testMessage1 = "Hello, world! Also, four " + + "score and seven years ago our fathers brought forth on this " + + "continent a new nation, conceived in liberty, and dedicated to the " + + "proposition that all men are created equal."; + + private static final String testMessage2 = "I have a dream that one day " + + "this nation will rise up and live out the true meaning of its creed: " + + "'We hold these truths to be self-evident, that all men are created equal.'"; + + + private static class TestSaslCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback c : callbacks) { + if (c instanceof NameCallback) { + ((NameCallback) c).setName(PRINCIPAL); + } else if (c instanceof PasswordCallback) { + ((PasswordCallback) c).setPassword(PASSWORD.toCharArray()); + } else if (c instanceof AuthorizeCallback) { + ((AuthorizeCallback) c).setAuthorized(true); + } else if (c instanceof RealmCallback) { + ((RealmCallback) c).setText(REALM); + } else { + throw new UnsupportedCallbackException(c); + } + } + } + } + + private void testSaslOpen(final String mechanism, final Map props) + throws SaslException, TTransportException { + Thread serverThread = new Thread() { + public void run() { + try { + TServerSocket serverSocket = new TServerSocket(ServerTestBase.PORT); + TTransport serverTransport = serverSocket.accept(); + TTransport saslServerTransport = new TSaslServerTransport(mechanism, SERVICE, HOST, + props, new TestSaslCallbackHandler(), serverTransport); + + saslServerTransport.open(); + + byte[] inBuf = new byte[testMessage1.getBytes().length]; + // Deliberately read less than the full buffer to ensure + // that TSaslTransport is correctly buffering reads. This + // will fail for the WRAPPED test, if it doesn't work. + saslServerTransport.readAll(inBuf, 0, 5); + saslServerTransport.readAll(inBuf, 5, 10); + saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); + LOGGER.debug("server got: {}", new String(inBuf)); + assertEquals(new String(inBuf), testMessage1); + + LOGGER.debug("server writing: {}", testMessage2); + saslServerTransport.write(testMessage2.getBytes()); + saslServerTransport.flush(); + + serverSocket.close(); + saslServerTransport.close(); + } catch (TTransportException e) { + fail(e.toString()); + } + } + }; + serverThread.start(); + + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // Ah well. + } + + TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); + TTransport saslClientTransport = new TSaslClientTransport(mechanism, + PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(), clientSocket); + saslClientTransport.open(); + LOGGER.debug("client writing: {}", testMessage1); + saslClientTransport.write(testMessage1.getBytes()); + saslClientTransport.flush(); + + byte[] inBuf = new byte[testMessage2.getBytes().length]; + saslClientTransport.readAll(inBuf, 0, inBuf.length); + LOGGER.debug("client got: {}", new String(inBuf)); + assertEquals(new String(inBuf), testMessage2); + + TTransportException expectedException = null; + try { + saslClientTransport.open(); + } catch (TTransportException e) { + expectedException = e; + } + assertNotNull(expectedException); + + saslClientTransport.close(); + + try { + serverThread.join(); + } catch (InterruptedException e) { + // Ah well. + } + } + + public void testUnwrappedOpen() throws SaslException, TTransportException { + testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); + } + + public void testWrappedOpen() throws SaslException, TTransportException { + testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); + } + + public void testWithServer() throws Exception { + new TestTSaslTransportsWithServer().testIt(); + } + + private static class TestTSaslTransportsWithServer extends ServerTestBase { + + private Thread serverThread; + private TServer server; + + @Override + public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { + return new TSaslClientTransport(WRAPPED_MECHANISM, + PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler(), underlyingTransport); + } + + @Override + public void startServer(final TProcessor processor, final TProtocolFactory protoFactory) throws Exception { + serverThread = new Thread() { + public void run() { + try { + // Transport + TServerSocket socket = new TServerSocket(PORT); + + TTransportFactory factory = new TSaslServerTransport.Factory(WRAPPED_MECHANISM, + SERVICE, HOST, WRAPPED_PROPS, new TestSaslCallbackHandler()); + server = new TSimpleServer(processor, socket, factory, protoFactory); + + // Run it + LOGGER.debug("Starting the server on port {}", PORT); + server.serve(); + } catch (Exception e) { + e.printStackTrace(); + fail(); + } + } + }; + serverThread.start(); + Thread.sleep(1000); + } + + @Override + public void stopServer() throws Exception { + server.stop(); + try { + serverThread.join(); + } catch (InterruptedException e) {} + } + + } + +} -- 2.17.1