1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
| from functools import total_ordering
@total_ordering class Node:
def __init__(self, key, value, n): self.key = key self.value = value self.left = None self.right = None self.parent = None self.n = n
def __gt__(self, node): return self.key > node.key
def __eq__(self, node): return self.key == node.key
def __repr__(self): return 'Node<{}:{}>'.format(self.key, self.value)
class BinarySearchTree:
def __init__(self): self._root = None
def __repr__(self): key = self._root.key if self._root else None value = self._root.value if self._root else None return 'Node<{}:{}>'.format(key, value)
def size(self): return self._size(self._root)
def _size(self, node): if node is None: return 0 return node.n
def search(self, key): return self._search(self._root, key)
def _search(self, node, key): if node is None: return None if node.key > key: self._search(node.left, key) elif node.key < key: self._search(node.right, key) else: return node.value
def search_and_set(self, key, default): pass
def _search_and_set(self, node, key, default): pass
def insert(self, key, value): self._root = self._insert(self._root, key, value)
def _insert(self, node, key, value): if node is None: return Node(key, value, 1) if node.key > key: node.left = self._insert(node.left, key, value) elif node.key < key: node.right = self._insert(node.right, key, value) else: node.value = value node.n = self._size(node.left) + self._size(node.right) + 1 return node
def min(self): return self._min(self._root).key
def _min(self, node): if self._root is None: return None if node.left is None: return node return self._min(node.left)
def max(self): return self._max(self._root).key
def _max(self, node): if node is None: return None if node.right is None: return node return self._max(node.right)
def delete_max(self): self._root = self._delete_max(self._root)
def _delete_max(self, node): if node.right is None: return node.left node.left = self._delete_max(node.left) node.n = self._size(node.left) + self._size(node.right) + 1 return node
def delete_min(self): self._root = self._delete(self._root)
def _delete_min(self, node): if node.left is None: return node.right node.left = self._delete_min(node.left) node.n = self._size(node.left) + self._size(node.right) + 1 return node
def delete(self, key): self._root = self._delete(self._root, key)
def _delete(self, node, key): if node is None: return None if node.key > key: node.left = self._delete(node.left, key) elif node.key < key: node.right = self._delete(node.right, key) else: if node.right is None: return node.left if node.left is None: return node.right t = node node = self._min(t.right) node.right = self._delete_min(t.right) node.left = t.left node.n = self._size(node.left) + self._size(node.right) + 1 return node
def delete_return(self, key): self._root, target = self._delete_return(self._root, key) return target.value
def _delete_return(self, node, key): if node is None: return None, node if node.key > key: node.left = self._delete(node.left, key) elif node.key < key: node.right = self._delete(node.right, key) else: if node.right is None: return node.left, node if node.left is None: return node.right, node t = node node = self._min(t.right) node.right = self._delete_min(t.right) node.left = t.left node.n = self._size(node.left) + self._size(node.right) + 1 return node, t
|