二叉搜索树(binary search tree)是一种特殊的二叉树,满足如下性质:
(1)每个结点的关键码大于其左子树结点的关键码(如果有的情况)
(2)每个结点的管家码小于其右子树结点的关键码(如果有的情况)
(3)根据(1)和(2)BST不允许有重复的键值

由于是树形结构,在理性的情况下,搜索BST的时间复杂度为O(lgN)。如果我们从竖直向下的方法投影BST,投影结构是一个递增序列,我们可以使用归纳法的思路证明这一点或者使用递归的思路证明。不难发现,在三个结点两层的BST的时候,竖直投影的递增性是显然。

BST操作

BST的基本操作包括:搜索、插入、删除。在这些基本操作上可以实现满足具体业务场景的需求。我们定义其结点如下:

1
2
3
4
5
6
7
8
class Node:

def __init__(self, key, value):
self.key = key
self.value = value
self.left = None
self.right = None
self.parent = None

搜索

搜索BST和二叉树的遍历并没有差别,只是在搜索BST时利用了BST的有序信息,这样实现快速定位需要查找的关键码。搜索过程:
(1)如果搜索目标和当前结点关键码相等,返回搜索结果
(2)如果搜索目标大于当前结点关键码,搜索当前结点的右子树
(2)如果搜索目标小于当前结点关键码,搜索当前结点的左子树
(4)如果当前搜索结点(包括根结点)为空,则没有搜索到,直接返回

根据上述思路,不难实现递归的BST搜索

1
2
3
4
5
6
7
8
9
def search_rescursive(bst, target):
if not bst:
return
if bst.key == target:
return bst.value
elif bst.key < target:
search(bst.right, target)
else:
search(bst.left, target)

循环的实现如下

1
2
3
4
5
6
7
8
9
10
11
def search(bst, target):
if not bst:
return
p_node = bst
while p_node:
if p_node.key == target:
return p_node.value
elif p_node.key < target:
p_node = p_node.right
else:
p_node = p_node.left

插入

BST插入过程分两部分:(1)搜索(2)插入。搜索过程就是前面讲述的过程,但这里的搜索目的是找到适合插入的位置,该适合的位置要满足BST定义的条件。找到适合的位置后,把该位置的父结点指向插入的结点。

我们依旧给出递归和循环实现方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def insert(bst, node):
if not bst:
return node
p_node = bst
while p_node:
if p_node.key == target.key:
p_node.value = target.value
return bst
elif p_node.key < target.key:
if p_node.right is None:
p_node.right = node
return bst
else:
p_node = p_node.right
else:
if p_node.left is None:
p_node.left = node
return bst
else:
p_node = p_node.left

有两个细节:(1)如果树为空,就直接用插入结点替换空树或者直接弹出错误,看具体实现需求(2)如果插入的结点已经存在,替换关键码对应的值。

上面的实现独立与搜索操作的实现,如果考虑到代码的简洁,应该删除操作应该复用搜索操作的代码

删除

删除操是BST搜索、插入中最难的操作。为了分解难度,考虑不同的情况:
(1)删除的结点拥有左右孩子
(2)删除的结点没有左右孩子
(3)删除的结点有左孩子或右孩子

对于情况(2),把父结点的指针指向NULL,释放删除结点的内存即可
对于情况(3),删除结点后把父结点指针指向删除结点的孩子及其

不难看出(2)是情况(3)的特例,可以合并为一种情况

对于(1),考虑到一个特点:BST在竖直方向下投影为递增序列,我们只要找到要删除的结点的后继,把后继结点替换要删除的结点,那么BST的特性就不变了。

删除的实现如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class BinarySearchTree:

# ...

def delete(self, key):
self._root = self._delete(self._root, key)

def _delete(self, node, key):
if node is None:
return None
if node.key > key:
node.left = self._delete(node.left, key)
elif node.key < key:
node.right = self._delete(node.right, key)
else:
if node.right is None:
return node.left
if node.left is None:
return node.right
t = node
node = self._min(t.right)
node.right = self._delete_min(t.right)
node.left = t.left
node.n = self._size(node.left) + self._size(node.right) + 1
return node

def delete_min(self):
self._root = self._delete(self._root)

def _delete_min(self, node):
if node.left is None:
return node.right
node.left = self._delete_min(node.left)
node.n = self._size(node.left) + self._size(node.right) + 1
return node

综合的实现

根据上面的讨论我们实现完整的BinarySearchTree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from functools import total_ordering

@total_ordering
class Node:

def __init__(self, key, value, n):
self.key = key
self.value = value
self.left = None
self.right = None
self.parent = None
self.n = n

def __gt__(self, node):
return self.key > node.key

def __eq__(self, node):
return self.key == node.key

def __repr__(self):
return 'Node<{}:{}>'.format(self.key, self.value)

class BinarySearchTree:

def __init__(self):
self._root = None

def __repr__(self):
key = self._root.key if self._root else None
value = self._root.value if self._root else None
return 'Node<{}:{}>'.format(key, value)

def size(self):
return self._size(self._root)

def _size(self, node):
if node is None:
return 0
return node.n

def search(self, key):
return self._search(self._root, key)

def _search(self, node, key):
if node is None:
return None
if node.key > key:
self._search(node.left, key)
elif node.key < key:
self._search(node.right, key)
else:
return node.value

def search_and_set(self, key, default):
pass

def _search_and_set(self, node, key, default):
pass

def insert(self, key, value):
self._root = self._insert(self._root, key, value)

def _insert(self, node, key, value):
if node is None:
return Node(key, value, 1)
if node.key > key:
node.left = self._insert(node.left, key, value)
elif node.key < key:
node.right = self._insert(node.right, key, value)
else:
node.value = value
node.n = self._size(node.left) + self._size(node.right) + 1
return node

def min(self):
return self._min(self._root).key

def _min(self, node):
if self._root is None:
return None
if node.left is None:
return node
return self._min(node.left)

def max(self):
return self._max(self._root).key

def _max(self, node):
if node is None:
return None
if node.right is None:
return node
return self._max(node.right)

def delete_max(self):
self._root = self._delete_max(self._root)

def _delete_max(self, node):
if node.right is None:
return node.left
node.left = self._delete_max(node.left)
node.n = self._size(node.left) + self._size(node.right) + 1
return node

def delete_min(self):
self._root = self._delete(self._root)

def _delete_min(self, node):
if node.left is None:
return node.right
node.left = self._delete_min(node.left)
node.n = self._size(node.left) + self._size(node.right) + 1
return node

def delete(self, key):
self._root = self._delete(self._root, key)

def _delete(self, node, key):
if node is None:
return None
if node.key > key:
node.left = self._delete(node.left, key)
elif node.key < key:
node.right = self._delete(node.right, key)
else:
if node.right is None:
return node.left
if node.left is None:
return node.right
t = node
node = self._min(t.right)
node.right = self._delete_min(t.right)
node.left = t.left
node.n = self._size(node.left) + self._size(node.right) + 1
return node

def delete_return(self, key):
self._root, target = self._delete_return(self._root, key)
return target.value

def _delete_return(self, node, key):
if node is None:
return None, node
if node.key > key:
node.left = self._delete(node.left, key)
elif node.key < key:
node.right = self._delete(node.right, key)
else:
if node.right is None:
return node.left, node
if node.left is None:
return node.right, node
t = node
node = self._min(t.right)
node.right = self._delete_min(t.right)
node.left = t.left
node.n = self._size(node.left) + self._size(node.right) + 1
return node, t

应用

使用BST实现字典数据结构和优先队列。下面是字典的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class TreeIterMixin:

def preorder_traverse(self):
for x in self._preorder(self._root):
yield x

def _preorder(self, node):
if node:
yield node
for x in self._preorder(node.left):
yield x
for x in self._preorder(node.right):
yield x

def inorder_traverse(self):
for x in self._inorder(self._root):
yield x

def _inorder(self, node):
if node:
for x in self._preorder(node.left):
yield x
yield node
for x in self._preorder(node.right):
yield x

def postorder_traverse(self):
for x in self._inorder(self._root):
yield x

def _postorder(self, node):
if node:
for x in self._preorder(node.left):
yield x
for x in self._preorder(node.right):
yield x
yield node

def depth_first_order_traverse(self):
if self._root:
stack = Stack()
stack.push(self._root)
while not stack.empty():
node = stack.pop()
if node.left:
stack.push(node.left)
if node.right:
stack.push(node.right)
yield node

def breadth_first_order_traverse(self):
if self._root:
queue = Queue()
queue.push(self._root)
while not queue.empty():
node = queue.pop()
if node.left:
queue.push(node.left)
if node.right:
queue.push(node.right)
yield node

class Dict(TreeIterMixin, BinarySearchTree):

def __init__(self, **kwargs):
super().__init__()
for key, value in kwargs.items():
self.insert(key, value)

def clear(self):
self._root = None

def copy(self):
tree = BinarySearchTree()
for x in self.breadth_first_order_traverse():
tree.insert(x.key, x.value)
return tree

def fromkeys(self, seq, value):
tree = BinarySearchTree()
for key in seq:
tree.insert(key, value)
return tree

def get(self, k, d=None):
value = self.search(k)
return value if value else k

def items(self):
return [(node.key, node.value) for node in self.inorder_traverse()]

def keys(self):
return [node.key for node in self.inorder_traverse()]

def values(self):
return [node.value for node in self.inorder_traverse()]

def pop(self, k, d=None):
value = self.delete_return(d)
return value if value else d

def popitem(self):
if not self._root:
raise KeyError("dict is empty")
key = self._root.key
value = self.delete_return(key)
return (key, value)

def setdefault(self, k, d=None):
value = self.get(k, d)
if value is d and d is not None:
self.insert(k, d)
return value

def update(self, dict, **kwargs):
for key, value in dict.items():
self.insert(key, value)
for key, value in kwargs.items():
self.insert(key, value)

def __len__(self):
return self.size()

def __bool__(self):
return bool(self._root)

def __contains__(self, key):
return bool(self.search(key))

def __delitem__(self, key):
self.delete(key)

def __repr__(self):
return str(dict(self.items()))

def __setitem__(self, key, value):
self.insert(key, value)

def __getitem__(self, key):
value = self.search(key)
if not value:
raise KeyError(key)
return value

def __delitem__(self, key):
value = self.delete_return(key)
if not value:
raise KyeError(key)

优先队列的实现

1
2
3
4
5
6
7
8
9
10
class PriorityQueue(BinarySearchTree):

def push(self, priority, item):
self.insert(priority, item)

def pop(self):
_min = self.min()
if _min:
self.delete_min()
return _min

完。

完。