From 6849f2014d21ca1c49220039453c699eab11fb68 Mon Sep 17 00:00:00 2001 From: Roger Meier Date: Fri, 18 May 2012 07:35:19 +0000 Subject: [PATCH] THRIFT-1598 Update Haskell generated code to use Text, Hash{Map,Set}, Vector Patch: Itai Zukerman git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1340014 13f79535-47bb-0310-9956-ffa450edef68 --- compiler/cpp/src/generate/t_hs_generator.cc | 67 +++++++++++++-------- lib/hs/README | 16 ++++- lib/hs/Thrift.cabal | 26 +++++--- lib/hs/src/Thrift.hs | 6 +- lib/hs/src/Thrift/Protocol.hs | 19 +++--- lib/hs/src/Thrift/Protocol/Binary.hs | 9 ++- lib/hs/src/Thrift/Transport/Handle.hs | 4 +- lib/hs/src/Thrift/Types.hs | 34 +++++++++++ 8 files changed, 129 insertions(+), 52 deletions(-) create mode 100644 lib/hs/src/Thrift/Types.hs diff --git a/compiler/cpp/src/generate/t_hs_generator.cc b/compiler/cpp/src/generate/t_hs_generator.cc index 9fd16862..f2752618 100644 --- a/compiler/cpp/src/generate/t_hs_generator.cc +++ b/compiler/cpp/src/generate/t_hs_generator.cc @@ -168,7 +168,7 @@ class t_hs_generator : public t_oop_generator { string type_to_enum(t_type* ttype); string render_hs_type(t_type* type, - bool needs_parens = true); + bool needs_parens); private: ofstream f_types_; @@ -211,6 +211,7 @@ void t_hs_generator::init_generator() { string t_hs_generator::hs_language_pragma() { return string("{-# LANGUAGE DeriveDataTypeable #-}\n" + "{-# LANGUAGE OverloadedStrings #-}\n" "{-# OPTIONS_GHC -fno-warn-missing-fields #-}\n" "{-# OPTIONS_GHC -fno-warn-missing-signatures #-}\n" "{-# OPTIONS_GHC -fno-warn-name-shadowing #-}\n" @@ -238,16 +239,21 @@ string t_hs_generator::hs_imports() { "import Prelude ( Bool(..), Enum, Double, String, Maybe(..),\n" " Eq, Show, Ord,\n" " return, length, IO, fromIntegral, fromEnum, toEnum,\n" - " (&&), (||), (==), (++), ($), (-) )\n" + " (.), (&&), (||), (==), (++), ($), (-) )\n" "\n" "import Control.Exception\n" "import Data.ByteString.Lazy\n" + "import Data.Hashable\n" "import Data.Int\n" + "import Data.Text.Lazy ( Text )\n" + "import qualified Data.Text.Lazy as TL\n" "import Data.Typeable ( Typeable )\n" - "import qualified Data.Map as Map\n" - "import qualified Data.Set as Set\n" + "import qualified Data.HashMap.Strict as Map\n" + "import qualified Data.HashSet as Set\n" + "import qualified Data.Vector as Vector\n" "\n" "import Thrift\n" + "import Thrift.Types ()\n" "\n"); for (size_t i = 0; i < includes.size(); ++i) @@ -303,22 +309,19 @@ void t_hs_generator::generate_enum(t_enum* tenum) { indent_down(); string ename = capitalize(tenum->get_name()); + indent(f_types_) << "instance Enum " << ename << " where" << endl; indent_up(); - indent(f_types_) << "fromEnum t = case t of" << endl; indent_up(); - for (c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { int value = (*c_iter)->get_value(); string name = capitalize((*c_iter)->get_name()); indent(f_types_) << name << " -> " << value << endl; } indent_down(); - indent(f_types_) << "toEnum t = case t of" << endl; indent_up(); - for(c_iter = constants.begin(); c_iter != constants.end(); ++c_iter) { int value = (*c_iter)->get_value(); string name = capitalize((*c_iter)->get_name()); @@ -327,6 +330,11 @@ void t_hs_generator::generate_enum(t_enum* tenum) { indent(f_types_) << "_ -> throw ThriftException" << endl; indent_down(); indent_down(); + + indent(f_types_) << "instance Hashable " << ename << " where" << endl; + indent_up(); + indent(f_types_) << "hashWithSalt salt = hashWithSalt salt . fromEnum" << endl; + indent_down(); } /** @@ -463,9 +471,9 @@ string t_hs_generator::render_const_value(t_type* type, t_const_value* value) { vector::const_iterator v_iter; if (type->is_set()) - out << "(Set.fromList "; - - out << "["; + out << "(Set.fromList ["; + else + out << "(Vector.fromList "; bool first = true; for (v_iter = val.begin(); v_iter != val.end(); ++v_iter) { @@ -474,10 +482,7 @@ string t_hs_generator::render_const_value(t_type* type, t_const_value* value) { first = false; } - out << "]"; - - if (type->is_set()) - out << ")"; + out << "])"; } else { throw "CANNOT GENERATE CONSTANT FOR TYPE: " + type->get_name(); @@ -535,17 +540,27 @@ void t_hs_generator::generate_hs_struct_definition(ofstream& out, for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { string mname = (*m_iter)->get_name(); out << (first ? "" : ","); - out << "f_" << tname << "_" << mname << " :: Maybe " << render_hs_type((*m_iter)->get_type()); + out << "f_" << tname << "_" << mname << " :: Maybe " << render_hs_type((*m_iter)->get_type(), true); first = false; } out << "}"; } - out << " deriving (Show,Eq,Ord,Typeable)" << endl; + out << " deriving (Show,Eq,Typeable)" << endl; if (is_exception) out << "instance Exception " << tname << endl; + indent(out) << "instance Hashable " << tname << " where" << endl; + indent_up(); + indent(out) << "hashWithSalt salt record = salt"; + for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { + string mname = (*m_iter)->get_name(); + indent(out) << " `hashWithSalt` " << "f_" << tname << "_" << mname << " record"; + } + indent(out) << endl; + indent_down(); + generate_hs_struct_writer(out, tstruct); generate_hs_struct_reader(out, tstruct); } @@ -971,7 +986,7 @@ void t_hs_generator::generate_service_server(t_service* tservice) { indent(f_service_) << "skip iprot T_STRUCT" << endl; indent(f_service_) << "readMessageEnd iprot" << endl; indent(f_service_) << "writeMessageBegin oprot (name,M_EXCEPTION,seqid)" << endl; - indent(f_service_) << "writeAppExn oprot (AppExn AE_UNKNOWN_METHOD (\"Unknown function \" ++ name))" << endl; + indent(f_service_) << "writeAppExn oprot (AppExn AE_UNKNOWN_METHOD (\"Unknown function \" ++ TL.unpack name))" << endl; indent(f_service_) << "writeMessageEnd oprot" << endl; indent(f_service_) << "tFlush (getTransport oprot)" << endl; indent_down(); @@ -1210,9 +1225,9 @@ void t_hs_generator::generate_deserialize_container(ofstream &out, out << ";r <- f (n-1); return $ v:r}} in do {(" << etype << "," << size << ") <- readSetBegin iprot; l <- f " << size << "; return $ Set.fromList l})"; } else if (ttype->is_list()) { - out << "(let {f 0 = return []; f n = do {v <- "; + out << "(let f n = Vector.replicateM (fromIntegral n) ("; generate_deserialize_type(out,((t_map*)ttype)->get_key_type()); - out << ";r <- f (n-1); return $ v:r}} in do {(" << etype << "," << size << ") <- readListBegin iprot; f " << size << "})"; + out << ") in do {(" << etype << "," << size << ") <- readListBegin iprot; f " << size << "})"; } } @@ -1323,9 +1338,9 @@ void t_hs_generator::generate_serialize_container(ofstream &out, } else if (ttype->is_list()) { string v = tmp("_viter"); - out << "(let {f [] = return (); f (" << v << ":t) = do {"; + out << "(let f = Vector.mapM_ (\\" << v << " -> "; generate_serialize_list_element(out, (t_list*)ttype, v); - out << ";f t}} in do {writeListBegin oprot (" << type_to_enum(((t_list*)ttype)->get_elem_type()) << ",fromIntegral $ Prelude.length " << prefix << "); f " << prefix << ";writeListEnd oprot})"; + out << ") in do {writeListBegin oprot (" << type_to_enum(((t_list*)ttype)->get_elem_type()) << ",fromIntegral $ Vector.length " << prefix << "); f " << prefix << ";writeListEnd oprot})"; } } @@ -1450,7 +1465,7 @@ string t_hs_generator::render_hs_type(t_type* type, bool needs_parens) { t_base_type::t_base tbase = ((t_base_type*)type)->get_base(); switch (tbase) { case t_base_type::TYPE_VOID: return "()"; - case t_base_type::TYPE_STRING: return (((t_base_type*)type)->is_binary() ? "ByteString" : "String"); + case t_base_type::TYPE_STRING: return (((t_base_type*)type)->is_binary() ? "ByteString" : "Text"); case t_base_type::TYPE_BOOL: return "Bool"; case t_base_type::TYPE_BYTE: return "Int8"; case t_base_type::TYPE_I16: return "Int16"; @@ -1468,15 +1483,15 @@ string t_hs_generator::render_hs_type(t_type* type, bool needs_parens) { } else if (type->is_map()) { t_type* ktype = ((t_map*)type)->get_key_type(); t_type* vtype = ((t_map*)type)->get_val_type(); - type_repr = "Map.Map " + render_hs_type(ktype, true) + " " + render_hs_type(vtype, true); + type_repr = "Map.HashMap " + render_hs_type(ktype, true) + " " + render_hs_type(vtype, true); } else if (type->is_set()) { t_type* etype = ((t_set*)type)->get_elem_type(); - type_repr = "Set.Set " + render_hs_type(etype, true) ; + type_repr = "Set.HashSet " + render_hs_type(etype, true) ; } else if (type->is_list()) { t_type* etype = ((t_list*)type)->get_elem_type(); - return "[" + render_hs_type(etype, false) + "]"; + type_repr = "Vector.Vector " + render_hs_type(etype, true); } else { throw "INVALID TYPE IN type_to_enum: " + type->get_name(); diff --git a/lib/hs/README b/lib/hs/README index bbfe6997..fe525bd8 100644 --- a/lib/hs/README +++ b/lib/hs/README @@ -43,7 +43,7 @@ The mapping from Thrift types to Haskell's is: * i16 -> Data.Int.Int16 * i32 -> Data.Int.Int32 * i64 -> Data.Int.Int64 - * string -> String + * string -> Text * binary -> Data.ByteString.Lazy * bool -> Boolean @@ -52,6 +52,17 @@ Enums Become Haskell 'data' types. Use fromEnum to get out the int value. +Lists +===== + +Become Data.Vector.Vector from the vector package. + +Maps and Sets +============= + +Become Data.HashMap.Strict.Map and Data.HashSet.Set from the +unordered-containers package. + Structs ======= @@ -61,7 +72,7 @@ fields are Maybe types. Exceptions ========== -Identical to structs. Throw them with throwDyn. Catch them with catchDyn. +Identical to structs. Use them with throw and catch from Control.Exception. Client ====== @@ -86,4 +97,3 @@ Processor Just a function that takes a handler label, protocols. It calls the superclasses process if there is a superclass. - diff --git a/lib/hs/Thrift.cabal b/lib/hs/Thrift.cabal index 393e064f..cf02e123 100644 --- a/lib/hs/Thrift.cabal +++ b/lib/hs/Thrift.cabal @@ -36,13 +36,23 @@ Library Hs-Source-Dirs: src Build-Depends: - base >= 4, base < 5, network, ghc-prim, binary, bytestring, HTTP + base >= 4, base < 5, network, ghc-prim, binary, bytestring, hashable, HTTP, text, unordered-containers, vector Exposed-Modules: - Thrift, Thrift.Protocol, Thrift.Protocol.Binary, Thrift.Transport, - Thrift.Transport.Framed, Thrift.Transport.Handle, - Thrift.Transport.HttpClient, Thrift.Server + Thrift, + Thrift.Protocol, + Thrift.Protocol.Binary, + Thrift.Server, + Thrift.Transport, + Thrift.Transport.Framed, + Thrift.Transport.Handle, + Thrift.Transport.HttpClient, + Thrift.Types Extensions: - DeriveDataTypeable, ExistentialQuantification, FlexibleInstances, - KindSignatures, MagicHash, RankNTypes, - ScopedTypeVariables, TypeSynonymInstances - + DeriveDataTypeable, + ExistentialQuantification, + FlexibleInstances, + KindSignatures, + MagicHash, + RankNTypes, + ScopedTypeVariables, + TypeSynonymInstances diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs index e57cff58..42f5d321 100644 --- a/lib/hs/src/Thrift.hs +++ b/lib/hs/src/Thrift.hs @@ -1,5 +1,6 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} -- -- Licensed to the Apache Software Foundation (ASF) under one @@ -33,6 +34,7 @@ module Thrift import Control.Monad ( when ) import Control.Exception +import Data.Text.Lazy ( pack, unpack ) import Data.Typeable ( Typeable ) import Thrift.Transport @@ -84,7 +86,7 @@ writeAppExn pt ae = do when (ae_message ae /= "") $ do writeFieldBegin pt ("message", T_STRING , 1) - writeString pt (ae_message ae) + writeString pt (pack $ ae_message ae) writeFieldEnd pt writeFieldBegin pt ("type", T_I32, 2); @@ -108,7 +110,7 @@ readAppExnFields pt record = do else case tag of 1 -> if ft == T_STRING then do s <- readString pt - readAppExnFields pt record{ae_message = s} + readAppExnFields pt record{ae_message = unpack s} else do skip pt ft readAppExnFields pt record 2 -> if ft == T_I32 then diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs index 1a319327..f3b342a1 100644 --- a/lib/hs/src/Thrift/Protocol.hs +++ b/lib/hs/src/Thrift/Protocol.hs @@ -29,9 +29,10 @@ module Thrift.Protocol import Control.Monad ( replicateM_, unless ) import Control.Exception +import Data.ByteString.Lazy import Data.Int +import Data.Text.Lazy ( Text ) import Data.Typeable ( Typeable ) -import Data.ByteString.Lazy import Thrift.Transport @@ -102,12 +103,12 @@ instance Enum MessageType where class Protocol a where getTransport :: Transport t => a t -> t - writeMessageBegin :: Transport t => a t -> (String, MessageType, Int32) -> IO () + writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO () writeMessageEnd :: Transport t => a t -> IO () - writeStructBegin :: Transport t => a t -> String -> IO () + writeStructBegin :: Transport t => a t -> Text -> IO () writeStructEnd :: Transport t => a t -> IO () - writeFieldBegin :: Transport t => a t -> (String, ThriftType, Int16) -> IO () + writeFieldBegin :: Transport t => a t -> (Text, ThriftType, Int16) -> IO () writeFieldEnd :: Transport t => a t -> IO () writeFieldStop :: Transport t => a t -> IO () writeMapBegin :: Transport t => a t -> (ThriftType, ThriftType, Int32) -> IO () @@ -123,16 +124,16 @@ class Protocol a where writeI32 :: Transport t => a t -> Int32 -> IO () writeI64 :: Transport t => a t -> Int64 -> IO () writeDouble :: Transport t => a t -> Double -> IO () - writeString :: Transport t => a t -> String -> IO () + writeString :: Transport t => a t -> Text -> IO () writeBinary :: Transport t => a t -> ByteString -> IO () - readMessageBegin :: Transport t => a t -> IO (String, MessageType, Int32) + readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32) readMessageEnd :: Transport t => a t -> IO () - readStructBegin :: Transport t => a t -> IO String + readStructBegin :: Transport t => a t -> IO Text readStructEnd :: Transport t => a t -> IO () - readFieldBegin :: Transport t => a t -> IO (String, ThriftType, Int16) + readFieldBegin :: Transport t => a t -> IO (Text, ThriftType, Int16) readFieldEnd :: Transport t => a t -> IO () readMapBegin :: Transport t => a t -> IO (ThriftType, ThriftType, Int32) readMapEnd :: Transport t => a t -> IO () @@ -147,7 +148,7 @@ class Protocol a where readI32 :: Transport t => a t -> IO Int32 readI64 :: Transport t => a t -> IO Int64 readDouble :: Transport t => a t -> IO Double - readString :: Transport t => a t -> IO String + readString :: Transport t => a t -> IO Text readBinary :: Transport t => a t -> IO ByteString diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs index c55ea5a2..1bc9add4 100644 --- a/lib/hs/src/Thrift/Protocol/Binary.hs +++ b/lib/hs/src/Thrift/Protocol/Binary.hs @@ -1,5 +1,6 @@ {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} -- -- Licensed to the Apache Software Foundation (ASF) under one -- or more contributor license agreements. See the NOTICE file @@ -30,6 +31,7 @@ import Control.Monad ( liftM ) import qualified Data.Binary import Data.Bits import Data.Int +import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 ) import GHC.Exts import GHC.Word @@ -38,7 +40,6 @@ import Thrift.Protocol import Thrift.Transport import qualified Data.ByteString.Lazy as LBS -import qualified Data.ByteString.Lazy.Char8 as LBSChar8 version_mask :: Int32 version_mask = 0xffff0000 @@ -76,7 +77,9 @@ instance Protocol BinaryProtocol where writeI32 p b = tWrite (getTransport p) $ Data.Binary.encode b writeI64 p b = tWrite (getTransport p) $ Data.Binary.encode b writeDouble p d = writeI64 p (fromIntegral $ floatBits d) - writeString p s = writeI32 p (fromIntegral $ length s) >> tWrite (getTransport p) (LBSChar8.pack s) + writeString p s = writeI32 p (fromIntegral $ LBS.length s') >> tWrite (getTransport p) s' + where + s' = encodeUtf8 s writeBinary p s = writeI32 p (fromIntegral $ LBS.length s) >> tWrite (getTransport p) s readMessageBegin p = do @@ -136,7 +139,7 @@ instance Protocol BinaryProtocol where readString p = do i <- readI32 p - LBSChar8.unpack `liftM` tReadAll (getTransport p) (fromIntegral i) + decodeUtf8 `liftM` tReadAll (getTransport p) (fromIntegral i) readBinary p = do i <- readI32 p diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs index 70d39e70..cf4822bd 100644 --- a/lib/hs/src/Thrift/Transport/Handle.hs +++ b/lib/hs/src/Thrift/Transport/Handle.hs @@ -27,7 +27,9 @@ module Thrift.Transport.Handle , HandleSource(..) ) where -import Control.Exception ( throw ) +import Prelude hiding ( catch ) + +import Control.Exception ( catch, throw ) import Control.Monad () import Network diff --git a/lib/hs/src/Thrift/Types.hs b/lib/hs/src/Thrift/Types.hs new file mode 100644 index 00000000..e917e39e --- /dev/null +++ b/lib/hs/src/Thrift/Types.hs @@ -0,0 +1,34 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. +-- + +module Thrift.Types where + +import Data.Foldable (foldl') +import Data.Hashable ( Hashable, hashWithSalt ) +import qualified Data.HashMap.Strict as Map +import qualified Data.HashSet as Set +import qualified Data.Vector as Vector + +instance (Hashable k, Hashable v) => Hashable (Map.HashMap k v) where + hashWithSalt salt = foldl' hashWithSalt salt . Map.toList + +instance (Hashable a) => Hashable (Set.HashSet a) where + hashWithSalt salt = foldl' hashWithSalt salt + +instance (Hashable a) => Hashable (Vector.Vector a) where + hashWithSalt salt = Vector.foldl' hashWithSalt salt -- 2.17.1