commit
ddbd68d13a
292
python/23_binarytree/binary_search_tree.py
Normal file
292
python/23_binarytree/binary_search_tree.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
#!/usr/bin/python
|
||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
|
||||||
|
from queue import Queue
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class TreeNode:
|
||||||
|
def __init__(self, val=None):
|
||||||
|
self.val = val
|
||||||
|
self.left = None
|
||||||
|
self.right = None
|
||||||
|
self.parent = None
|
||||||
|
|
||||||
|
|
||||||
|
class BinarySearchTree:
|
||||||
|
def __init__(self, val_list=[]):
|
||||||
|
self.root = None
|
||||||
|
for n in val_list:
|
||||||
|
self.insert(n)
|
||||||
|
|
||||||
|
def insert(self, data):
|
||||||
|
"""
|
||||||
|
插入
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert(isinstance(data, int))
|
||||||
|
|
||||||
|
if self.root is None:
|
||||||
|
self.root = TreeNode(data)
|
||||||
|
else:
|
||||||
|
n = self.root
|
||||||
|
while n:
|
||||||
|
p = n
|
||||||
|
if data < n.val:
|
||||||
|
n = n.left
|
||||||
|
else:
|
||||||
|
n = n.right
|
||||||
|
|
||||||
|
new_node = TreeNode(data)
|
||||||
|
new_node.parent = p
|
||||||
|
|
||||||
|
if data < p.val:
|
||||||
|
p.left = new_node
|
||||||
|
else:
|
||||||
|
p.right = new_node
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def search(self, data):
|
||||||
|
"""
|
||||||
|
搜索
|
||||||
|
返回bst中所有值为data的节点列表
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert(isinstance(data, int))
|
||||||
|
|
||||||
|
# 所有搜索到的节点
|
||||||
|
ret = []
|
||||||
|
|
||||||
|
n = self.root
|
||||||
|
while n:
|
||||||
|
if data < n.val:
|
||||||
|
n = n.left
|
||||||
|
else:
|
||||||
|
if data == n.val:
|
||||||
|
ret.append(n)
|
||||||
|
n = n.right
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def delete(self, data):
|
||||||
|
"""
|
||||||
|
删除
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert (isinstance(data, int))
|
||||||
|
|
||||||
|
# 通过搜索得到需要删除的节点
|
||||||
|
del_list = self.search(data)
|
||||||
|
|
||||||
|
for n in del_list:
|
||||||
|
# 父节点为空,又不是根节点,已经不在树上,不用再删除
|
||||||
|
if n.parent is None and n != self.root:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
self._del(n)
|
||||||
|
|
||||||
|
def _del(self, node):
|
||||||
|
"""
|
||||||
|
删除
|
||||||
|
所删除的节点N存在以下情况:
|
||||||
|
1. 没有子节点:直接删除N的父节点指针
|
||||||
|
2. 有一个子节点:将N父节点指针指向N的子节点
|
||||||
|
3. 有两个子节点:找到右子树的最小节点M,将值赋给N,然后删除M
|
||||||
|
:param data:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# 1
|
||||||
|
if node.left is None and node.right is None:
|
||||||
|
# 情况1和2,根节点和普通节点的处理方式不同
|
||||||
|
if node == self.root:
|
||||||
|
self.root = None
|
||||||
|
else:
|
||||||
|
if node.val < node.parent.val:
|
||||||
|
node.parent.left = None
|
||||||
|
else:
|
||||||
|
node.parent.right = None
|
||||||
|
|
||||||
|
node.parent = None
|
||||||
|
# 2
|
||||||
|
elif node.left is None and node.right is not None:
|
||||||
|
if node == self.root:
|
||||||
|
self.root = node.right
|
||||||
|
self.root.parent = None
|
||||||
|
node.right = None
|
||||||
|
else:
|
||||||
|
if node.val < node.parent.val:
|
||||||
|
node.parent.left = node.right
|
||||||
|
else:
|
||||||
|
node.parent.right = node.right
|
||||||
|
|
||||||
|
node.right.parent = node.parent
|
||||||
|
node.parent = None
|
||||||
|
node.right = None
|
||||||
|
elif node.left is not None and node.right is None:
|
||||||
|
if node == self.root:
|
||||||
|
self.root = node.left
|
||||||
|
self.root.parent = None
|
||||||
|
node.left = None
|
||||||
|
else:
|
||||||
|
if node.val < node.parent.val:
|
||||||
|
node.parent.left = node.left
|
||||||
|
else:
|
||||||
|
node.parent.right = node.left
|
||||||
|
|
||||||
|
node.left.parent = node.parent
|
||||||
|
node.parent = None
|
||||||
|
node.left = None
|
||||||
|
# 3
|
||||||
|
else:
|
||||||
|
min_node = node.right
|
||||||
|
# 找到右子树的最小值节点
|
||||||
|
if min_node.left:
|
||||||
|
min_node = min_node.left
|
||||||
|
|
||||||
|
if node.val != min_node.val:
|
||||||
|
node.val = min_node.val
|
||||||
|
self._del(min_node)
|
||||||
|
# 右子树的最小值节点与被删除节点的值相等,再次删除原节点
|
||||||
|
else:
|
||||||
|
self._del(min_node)
|
||||||
|
self._del(node)
|
||||||
|
|
||||||
|
def get_min(self):
|
||||||
|
"""
|
||||||
|
返回最小值节点
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.root is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
n = self.root
|
||||||
|
while n.left:
|
||||||
|
n = n.left
|
||||||
|
return n.val
|
||||||
|
|
||||||
|
def get_max(self):
|
||||||
|
"""
|
||||||
|
返回最大值节点
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.root is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
n = self.root
|
||||||
|
while n.right:
|
||||||
|
n = n.right
|
||||||
|
return n.val
|
||||||
|
|
||||||
|
def in_order(self):
|
||||||
|
"""
|
||||||
|
中序遍历
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.root is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return self._in_order(self.root)
|
||||||
|
|
||||||
|
def _in_order(self, node):
|
||||||
|
if node is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
n = node
|
||||||
|
ret.extend(self._in_order(n.left))
|
||||||
|
ret.append(n.val)
|
||||||
|
ret.extend(self._in_order(n.right))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
# return str(self.in_order())
|
||||||
|
print(str(self.in_order()))
|
||||||
|
return self._draw_tree()
|
||||||
|
|
||||||
|
def _bfs(self):
|
||||||
|
"""
|
||||||
|
bfs
|
||||||
|
通过父子关系记录节点编号
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if self.root is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
q = Queue()
|
||||||
|
# 队列[节点,编号]
|
||||||
|
q.put((self.root, 1))
|
||||||
|
|
||||||
|
while not q.empty():
|
||||||
|
n = q.get()
|
||||||
|
|
||||||
|
if n[0] is not None:
|
||||||
|
ret.append((n[0].val, n[1]))
|
||||||
|
q.put((n[0].left, n[1]*2))
|
||||||
|
q.put((n[0].right, n[1]*2+1))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _draw_tree(self):
|
||||||
|
"""
|
||||||
|
可视化
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
nodes = self._bfs()
|
||||||
|
|
||||||
|
if not nodes:
|
||||||
|
print('This tree has no nodes.')
|
||||||
|
return
|
||||||
|
|
||||||
|
layer_num = int(math.log(nodes[-1][1], 2)) + 1
|
||||||
|
|
||||||
|
prt_nums = []
|
||||||
|
|
||||||
|
for i in range(layer_num):
|
||||||
|
prt_nums.append([None]*2**i)
|
||||||
|
|
||||||
|
for v, p in nodes:
|
||||||
|
row = int(math.log(p ,2))
|
||||||
|
col = p % 2**row
|
||||||
|
prt_nums[row][col] = v
|
||||||
|
|
||||||
|
prt_str = ''
|
||||||
|
for l in prt_nums:
|
||||||
|
prt_str += str(l)[1:-1] + '\n'
|
||||||
|
|
||||||
|
return prt_str
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
nums = [4, 2, 5, 6, 1, 7, 3]
|
||||||
|
bst = BinarySearchTree(nums)
|
||||||
|
print(bst)
|
||||||
|
|
||||||
|
# 插入
|
||||||
|
bst.insert(1)
|
||||||
|
bst.insert(4)
|
||||||
|
print(bst)
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
for n in bst.search(2):
|
||||||
|
print(n.parent.val, n.val)
|
||||||
|
|
||||||
|
# 删除
|
||||||
|
bst.insert(6)
|
||||||
|
bst.insert(7)
|
||||||
|
print(bst)
|
||||||
|
bst.delete(7)
|
||||||
|
print(bst)
|
||||||
|
bst.delete(6)
|
||||||
|
print(bst)
|
||||||
|
bst.delete(4)
|
||||||
|
print(bst)
|
||||||
|
|
||||||
|
# min max
|
||||||
|
print(bst.get_max())
|
||||||
|
print(bst.get_min())
|
Loading…
Reference in New Issue
Block a user