Python rich comparisons
Python 3 has switch to Rich Comparisons.
With Python 2 is was enough to implement a single __cmp__(self, other)
method, now you have to implement
__lt__(self, other)
__le__(self, other)
__eq__(self, other)
__ne__(self, other)
__ge__(self, other)
__gt__(self, other)
__hash__(self)
@functools.total_ordering helps with this, but this is still painful.
I often have classes, which work like namedtuple. Previously I was done with
class Base(object):
def __init__(self, a, b):
self.a = a
self.b = b
def __hash__(self):
return hash((self.a, self.b))
class Old(Base):
def __cmp__(self, other):
if not isinstance(other, Base):
return NotImplemented
return cmp((self.a, self.b), (other.a, other.b))
Now I have to implement it like this:
class New(Base):
def __lt__(self, other):
if not isinstance(other, Base):
return NotImplemented
return (self.a, self.b) < (other.a, other.b)
def __le__(self, other):
if not isinstance(other, Base):
return NotImplemented
return (self.a, self.b) <= (other.a, other.b)
def __eq__(self, other) -> bool:
return isinstance(other, Base) and (self.a, self.b) == (other.a, other.b)
def __ne__(self, other) -> bool:
return not isinstance(other, Base) or (self.a, self.b) != (other.a, other.b)
def __ge__(self, other):
if not isinstance(other, Base):
return NotImplemented
return (self.a, self.b) >= (other.a, other.b)
def __gt__(self, other):
if not isinstance(other, Base):
return NotImplemented
return (self.a, self.b) > (other.a, other.b)
This is a lot more work and also error prone due the repetition of 6 times nearly the same code with subtile differences. I have seen wrong implementation doing it like this:
class Wrong(Base):
def __lt__(self, other):
if not isinstance(other, Base):
return NotImplemented
if self.a < other.a:
return True
if self.b < other.b
return True
return False
This is wrong as the test for b
must also check for self.a == other.a
:
class StillIncomplete(Base):
def __lt__(self, other):
if not isinstance(other, Base):
return NotImplemented
if self.a < other.a:
return True
if self.a == other.a and self.b < other.b
return True
return False
This still lacks support for NotImplemented, which is required for handling the case, where a sub-class wants to implement special handling when an instance of that class is compared to a super- (or distinct) class. (Do not confuse this with NotImplementedError, which is an exception for abstract methods.)
You can write it more compact like
class Ugly(Base):
def __lt__(self, other):
return (self.a < other.a) or ((self.a == self.a) and (self.b < other.b)) \
if isinstance(self, Base) else NotImplemented
but this does not get prettier when you compare more then two values.
Re-using the idea of total_ordering
you can write your own decorator for classes, which adds the missing functions based on some attribute:
from operator import attrgetter
def ordering(getter):
def decorator(cls):
ops = {
'__lt__': lambda self, other: getter(self) < getter(other) if isinstance(other, cls) else NotImplemented,
'__le__': lambda self, other: getter(self) <= getter(other) if isinstance(other, cls) else NotImplemented,
'__eq__': lambda self, other: isinstance(other, cls) and getter(self) == getter(other),
'__ne__': lambda self, other: not isinstance(other, cls) or getter(self) != getter(other),
'__ge__': lambda self, other: getter(self) >= getter(other) if isinstance(other, cls) else NotImplemented,
'__gt__': lambda self, other: getter(self) > getter(other) if isinstance(other, cls) else NotImplemented,
'__hash__': lambda self: hash(getter(self)),
}
root = set(dir(cls))
for opname, opfunc in ops.items():
if opname not in root or getattr(cls, opname, None) is getattr(object, opname, None):
opfunc.__name__ = opname
setattr(cls, opname, opfunc)
return cls
return decorator
This can be used like this:
@ordering(attrgetter("val"))
class foo(object):
def __init__(self, val):
self.val = val
x, y = foo(1), foo(2)
assert x == x
assert not x == y
assert x != y
assert x <= x
assert x < y
assert x in {x}
class bar(foo): pass
z = bar(1)
assert x == z
assert z == x
assert z != y
assert z in {x}
This also work for more complex structures like tuples:
x, y = foo((1, 1)), foo((1, 2))
z = bar((1, 1))
If you have a class with multiple components, you can implement __iter__()
to return those parts and use them for comparison:
@ordering(lambda obj: tuple(iter(obj)))
class foo(object):
def __init__(self, a, b):
self.a = a
self.b = b
def __iter__(self):
yield self.a
yield self.b
x, y = foo(0, "x"), foo(0, "y")
This works with both Python 2 and 3.