From: Bryan Duxbury Date: Mon, 29 Mar 2010 23:57:09 +0000 (+0000) Subject: THRIFT-723: java: Thrift buffers with set and map types in Java should implement... X-Git-Tag: 0.3.0~48 X-Git-Url: https://source.supwisdom.com/gerrit/gitweb?a=commitdiff_plain;h=5557beffaecbd7b97a90ed38afc49c2a091aadba;p=common%2Fthrift.git THRIFT-723: java: Thrift buffers with set and map types in Java should implement Comparable This makes structs that contain sets and maps in their hierarchy Comparable. git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@928944 13f79535-47bb-0310-9956-ffa450edef68 --- diff --git a/compiler/cpp/src/generate/t_java_generator.cc b/compiler/cpp/src/generate/t_java_generator.cc index c2479213..bc2ac494 100644 --- a/compiler/cpp/src/generate/t_java_generator.cc +++ b/compiler/cpp/src/generate/t_java_generator.cc @@ -199,9 +199,6 @@ class t_java_generator : public t_oop_generator { void generate_deep_copy_container(std::ofstream& out, std::string source_name_p1, std::string source_name_p2, std::string result_name, t_type* type); void generate_deep_copy_non_container(std::ofstream& out, std::string source_name, std::string dest_name, t_type* type); - bool is_comparable(t_struct* tstruct); - bool is_comparable(t_type* type); - bool has_bit_vector(t_struct* tstruct); /** @@ -703,9 +700,7 @@ void t_java_generator::generate_java_union(t_struct* tstruct) { "public " << (is_final ? "final " : "") << "class " << tstruct->get_name() << " extends TUnion<" << tstruct->get_name() << "._Fields> "; - if (is_comparable(tstruct)) { - f_struct << "implements Comparable<" << type_name(tstruct) << "> "; - } + f_struct << "implements Comparable<" << type_name(tstruct) << "> "; scope_up(f_struct); @@ -1002,22 +997,20 @@ void t_java_generator::generate_union_comparisons(ofstream& out, t_struct* tstru indent(out) << "}" << endl; out << endl; - if (is_comparable(tstruct)) { - indent(out) << "@Override" << endl; - indent(out) << "public int compareTo(" << type_name(tstruct) << " other) {" << endl; - indent(out) << " int lastComparison = TBaseHelper.compareTo(getSetField(), other.getSetField());" << endl; - indent(out) << " if (lastComparison == 0) {" << endl; - indent(out) << " Object myValue = getFieldValue();" << endl; - indent(out) << " if (myValue instanceof byte[]) {" << endl; - indent(out) << " return TBaseHelper.compareTo((byte[])myValue, (byte[])other.getFieldValue());" << endl; - indent(out) << " } else {" << endl; - indent(out) << " return TBaseHelper.compareTo((Comparable)myValue, (Comparable)other.getFieldValue());" << endl; - indent(out) << " }" << endl; - indent(out) << " }" << endl; - indent(out) << " return lastComparison;" << endl; - indent(out) << "}" << endl; - out << endl; - } + indent(out) << "@Override" << endl; + indent(out) << "public int compareTo(" << type_name(tstruct) << " other) {" << endl; + indent(out) << " int lastComparison = TBaseHelper.compareTo(getSetField(), other.getSetField());" << endl; + indent(out) << " if (lastComparison == 0) {" << endl; + indent(out) << " Object myValue = getFieldValue();" << endl; + indent(out) << " if (myValue instanceof byte[]) {" << endl; + indent(out) << " return TBaseHelper.compareTo((byte[])myValue, (byte[])other.getFieldValue());" << endl; + indent(out) << " } else {" << endl; + indent(out) << " return TBaseHelper.compareTo((Comparable)myValue, (Comparable)other.getFieldValue());" << endl; + indent(out) << " }" << endl; + indent(out) << " }" << endl; + indent(out) << " return lastComparison;" << endl; + indent(out) << "}" << endl; + out << endl; } void t_java_generator::generate_union_hashcode(ofstream& out, t_struct* tstruct) { @@ -1077,9 +1070,7 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, } out << "implements TBase<" << tstruct->get_name() << "._Fields>, java.io.Serializable, Cloneable"; - if (is_comparable(tstruct)) { - out << ", Comparable<" << type_name(tstruct) << ">"; - } + out << ", Comparable<" << type_name(tstruct) << ">"; out << " "; @@ -1241,9 +1232,7 @@ void t_java_generator::generate_java_struct_definition(ofstream &out, generate_generic_isset_method(out, tstruct); generate_java_struct_equality(out, tstruct); - if (is_comparable(tstruct)) { - generate_java_struct_compare_to(out, tstruct); - } + generate_java_struct_compare_to(out, tstruct); generate_java_struct_reader(out, tstruct); if (is_result) { @@ -3606,32 +3595,6 @@ void t_java_generator::generate_field_name_constants(ofstream& out, t_struct* ts indent(out) << "}" << endl; } -bool t_java_generator::is_comparable(t_struct* tstruct) { - const vector& members = tstruct->get_members(); - vector::const_iterator m_iter; - - for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) { - if (!is_comparable(get_true_type((*m_iter)->get_type()))) { - return false; - } - } - return true; -} - -bool t_java_generator::is_comparable(t_type* type) { - if (type->is_container()) { - if (type->is_list()) { - return is_comparable(get_true_type(((t_list*)type)->get_elem_type())); - } else { - return false; - } - } else if (type->is_struct() || type->is_xception()) { - return is_comparable((t_struct*)type); - } else { - return true; - } -} - bool t_java_generator::has_bit_vector(t_struct* tstruct) { const vector& members = tstruct->get_members(); vector::const_iterator m_iter; diff --git a/lib/java/src/org/apache/thrift/TBaseHelper.java b/lib/java/src/org/apache/thrift/TBaseHelper.java index b41daae5..fccece83 100644 --- a/lib/java/src/org/apache/thrift/TBaseHelper.java +++ b/lib/java/src/org/apache/thrift/TBaseHelper.java @@ -17,14 +17,24 @@ */ package org.apache.thrift; +import java.util.Comparator; +import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; public class TBaseHelper { - + + private static final Comparator comparator = new NestedStructureComparator(); + public static int compareTo(boolean a, boolean b) { return Boolean.valueOf(a).compareTo(b); } - + public static int compareTo(byte a, byte b) { if (a < b) { return -1; @@ -44,7 +54,7 @@ public class TBaseHelper { return 0; } } - + public static int compareTo(int a, int b) { if (a < b) { return -1; @@ -54,7 +64,7 @@ public class TBaseHelper { return 0; } } - + public static int compareTo(long a, long b) { if (a < b) { return -1; @@ -64,7 +74,7 @@ public class TBaseHelper { return 0; } } - + public static int compareTo(double a, double b) { if (a < b) { return -1; @@ -74,11 +84,11 @@ public class TBaseHelper { return 0; } } - + public static int compareTo(String a, String b) { return a.compareTo(b); } - + public static int compareTo(byte[] a, byte[] b) { int sizeCompare = compareTo(a.length, b.length); if (sizeCompare != 0) { @@ -92,28 +102,103 @@ public class TBaseHelper { } return 0; } - + public static int compareTo(Comparable a, Comparable b) { return a.compareTo(b); } - + public static int compareTo(List a, List b) { int lastComparison = compareTo(a.size(), b.size()); if (lastComparison != 0) { return lastComparison; } for (int i = 0; i < a.size(); i++) { - Object oA = a.get(i); - Object oB = b.get(i); - if (oA instanceof List) { - lastComparison = compareTo((List)oA, (List)oB); - } else { - lastComparison = compareTo((Comparable)oA, (Comparable)oB); + lastComparison = comparator.compare(a.get(i), b.get(i)); + if (lastComparison != 0) { + return lastComparison; } + } + return 0; + } + + public static int compareTo(Set a, Set b) { + int lastComparison = compareTo(a.size(), b.size()); + if (lastComparison != 0) { + return lastComparison; + } + SortedSet sortedA = new TreeSet(comparator); + sortedA.addAll(a); + SortedSet sortedB = new TreeSet(comparator); + sortedB.addAll(b); + + Iterator iterA = sortedA.iterator(); + Iterator iterB = sortedB.iterator(); + + // Compare each item. + while (iterA.hasNext() && iterB.hasNext()) { + lastComparison = comparator.compare(iterA.next(), iterB.next()); if (lastComparison != 0) { return lastComparison; } } + return 0; } + + public static int compareTo(Map a, Map b) { + int lastComparison = compareTo(a.size(), b.size()); + if (lastComparison != 0) { + return lastComparison; + } + + // Sort a and b so we can compare them. + SortedMap sortedA = new TreeMap(comparator); + sortedA.putAll(a); + Iterator iterA = sortedA.entrySet().iterator(); + SortedMap sortedB = new TreeMap(comparator); + sortedB.putAll(b); + Iterator iterB = sortedB.entrySet().iterator(); + + // Compare each item. + while (iterA.hasNext() && iterB.hasNext()) { + Map.Entry entryA = iterA.next(); + Map.Entry entryB = iterB.next(); + lastComparison = comparator.compare(entryA.getKey(), entryB.getKey()); + if (lastComparison != 0) { + return lastComparison; + } + lastComparison = comparator.compare(entryA.getValue(), entryB.getValue()); + if (lastComparison != 0) { + return lastComparison; + } + } + + return 0; + } + + /** + * Comparator to compare items inside a structure (e.g. a list, set, or map). + */ + private static class NestedStructureComparator implements Comparator { + public int compare(Object oA, Object oB) { + if (oA == null && oB == null) { + return 0; + } else if (oA == null) { + return -1; + } else if (oB == null) { + return 1; + } else if (oA instanceof List) { + return compareTo((List)oA, (List)oB); + } else if (oA instanceof Set) { + return compareTo((Set)oA, (Set)oB); + } else if (oA instanceof Map) { + return compareTo((Map)oA, (Map)oB); + } else if (oA instanceof byte[]) { + return compareTo((byte[])oA, (byte[])oB); + } else { + return compareTo((Comparable)oA, (Comparable)oB); + } + } + } + } diff --git a/lib/java/test/org/apache/thrift/TestStruct.java b/lib/java/test/org/apache/thrift/TestStruct.java index 94650902..6ba48a47 100644 --- a/lib/java/test/org/apache/thrift/TestStruct.java +++ b/lib/java/test/org/apache/thrift/TestStruct.java @@ -4,6 +4,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.HashMap; import junit.framework.TestCase; @@ -11,7 +12,9 @@ import org.apache.thrift.protocol.TBinaryProtocol; import thrift.test.Bonk; import thrift.test.HolyMoley; +import thrift.test.Insanity; import thrift.test.Nesting; +import thrift.test.Numberz; import thrift.test.OneOfEach; public class TestStruct extends TestCase { @@ -135,4 +138,52 @@ public class TestStruct extends TestCase { bonk2.setMessage("m"); assertEquals(0, bonk1.compareTo(bonk2)); } + + public void testCompareToWithDataStructures() { + Insanity insanity1 = new Insanity(); + Insanity insanity2 = new Insanity(); + + // Both empty. + expectEquals(insanity1, insanity2); + + insanity1.setUserMap(new HashMap()); + // insanity1.map = {}, insanity2.map = null + expectGreaterThan(insanity1, insanity2); + + // insanity1.map = {2:1}, insanity2.map = null + insanity1.getUserMap().put(Numberz.TWO, 1l); + expectGreaterThan(insanity1, insanity2); + + // insanity1.map = {2:1}, insanity2.map = {} + insanity2.setUserMap(new HashMap()); + expectGreaterThan(insanity1, insanity2); + + // insanity1.map = {2:1}, insanity2.map = {2:2} + insanity2.getUserMap().put(Numberz.TWO, 2l); + expectLessThan(insanity1, insanity2); + + // insanity1.map = {2:1, 3:5}, insanity2.map = {2:2} + insanity1.getUserMap().put(Numberz.THREE, 5l); + expectGreaterThan(insanity1, insanity2); + + // insanity1.map = {2:1, 3:5}, insanity2.map = {2:1, 4:5} + insanity2.getUserMap().put(Numberz.TWO, 1l); + insanity2.getUserMap().put(Numberz.FIVE, 5l); + expectLessThan(insanity1, insanity2); + } + + private void expectLessThan(Insanity insanity1, Insanity insanity2) { + int compareTo = insanity1.compareTo(insanity2); + assertTrue(insanity1 + " should be less than " + insanity2 + ", but is: " + compareTo, compareTo < 0); + } + + private void expectGreaterThan(Insanity insanity1, Insanity insanity2) { + int compareTo = insanity1.compareTo(insanity2); + assertTrue(insanity1 + " should be greater than " + insanity2 + ", but is: " + compareTo, compareTo > 0); + } + + private void expectEquals(Insanity insanity1, Insanity insanity2) { + int compareTo = insanity1.compareTo(insanity2); + assertEquals(insanity1 + " should be equal to " + insanity2 + ", but is: " + compareTo, 0, compareTo); + } } diff --git a/lib/java/test/org/apache/thrift/TestTBaseHelper.java b/lib/java/test/org/apache/thrift/TestTBaseHelper.java new file mode 100644 index 00000000..e2d78690 --- /dev/null +++ b/lib/java/test/org/apache/thrift/TestTBaseHelper.java @@ -0,0 +1,124 @@ +package org.apache.thrift; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import junit.framework.TestCase; + +public class TestTBaseHelper extends TestCase { + public void testByteArrayComparison() { + assertTrue(TBaseHelper.compareTo(new byte[]{'a','b'}, new byte[]{'a','c'}) < 0); + } + + public void testSets() { + Set a = new HashSet(); + Set b = new HashSet(); + + assertTrue(TBaseHelper.compareTo(a, b) == 0); + + a.add("test"); + + assertTrue(TBaseHelper.compareTo(a, b) > 0); + + b.add("test"); + + assertTrue(TBaseHelper.compareTo(a, b) == 0); + + b.add("aardvark"); + + assertTrue(TBaseHelper.compareTo(a, b) < 0); + + a.add("test2"); + + assertTrue(TBaseHelper.compareTo(a, b) > 0); + } + + public void testNestedStructures() { + Set> a = new HashSet>(); + Set> b = new HashSet>(); + + a.add(Arrays.asList(new String[] {"a","b"})); + b.add(Arrays.asList(new String[] {"a","b", "c"})); + a.add(Arrays.asList(new String[] {"a","b"})); + b.add(Arrays.asList(new String[] {"a","b", "c"})); + + assertTrue(TBaseHelper.compareTo(a, b) < 0); + } + + public void testMapsInSets() { + Set> a = new HashSet>(); + Set> b = new HashSet>(); + + assertTrue(TBaseHelper.compareTo(a, b) == 0); + + Map innerA = new HashMap(); + Map innerB = new HashMap(); + a.add(innerA); + b.add(innerB); + + innerA.put("a", 1l); + innerB.put("a", 2l); + + assertTrue(TBaseHelper.compareTo(a, b) < 0); + } + + public void testByteArraysInMaps() { + Map a = new HashMap(); + Map b = new HashMap(); + + assertTrue(TBaseHelper.compareTo(a, b) == 0); + + a.put(new byte[]{'a','b'}, 1000L); + b.put(new byte[]{'a','b'}, 1000L); + a.put(new byte[]{'a','b', 'd'}, 1000L); + b.put(new byte[]{'a','b', 'a'}, 1000L); + assertTrue(TBaseHelper.compareTo(a, b) > 0); + } + + public void testMapsWithNulls() { + Map a = new HashMap(); + Map b = new HashMap(); + a.put("a", null); + a.put("b", null); + b.put("a", null); + b.put("b", null); + + assertTrue(TBaseHelper.compareTo(a, b) == 0); + } + + public void testMapKeyComparison() { + Map a = new HashMap(); + Map b = new HashMap(); + a.put("a", "a"); + b.put("b", "a"); + + assertTrue(TBaseHelper.compareTo(a, b) < 0); + } + + public void testMapValueComparison() { + Map a = new HashMap(); + Map b = new HashMap(); + a.put("a", "b"); + b.put("a", "a"); + + assertTrue(TBaseHelper.compareTo(a, b) > 0); + } + + public void testByteArraysInSets() { + Set a = new HashSet(); + Set b = new HashSet(); + + if (TBaseHelper.compareTo(a, b) != 0) + throw new RuntimeException("Set compare failed:" + a + " vs. " + b); + + a.add(new byte[]{'a','b'}); + b.add(new byte[]{'a','b'}); + a.add(new byte[]{'a','b', 'd'}); + b.add(new byte[]{'a','b', 'a'}); + assertTrue(TBaseHelper.compareTo(a, b) > 0); + } +}