binary search tree

This commit is contained in:
树哥 2018-11-16 12:20:46 +08:00
parent 5ebfdc90d8
commit 73f0db5b5e

View File

@ -0,0 +1,292 @@
#!/usr/bin/python
# -*- coding: UTF-8 -*-
from queue import Queue
import math
class TreeNode:
def __init__(self, val=None):
self.val = val
self.left = None
self.right = None
self.parent = None
class BinarySearchTree:
def __init__(self, val_list=[]):
self.root = None
for n in val_list:
self.insert(n)
def insert(self, data):
"""
插入
:param data:
:return:
"""
assert(isinstance(data, int))
if self.root is None:
self.root = TreeNode(data)
else:
n = self.root
while n:
p = n
if data < n.val:
n = n.left
else:
n = n.right
new_node = TreeNode(data)
new_node.parent = p
if data < p.val:
p.left = new_node
else:
p.right = new_node
return True
def search(self, data):
"""
搜索
返回bst中所有值为data的节点列表
:param data:
:return:
"""
assert(isinstance(data, int))
# 所有搜索到的节点
ret = []
n = self.root
while n:
if data < n.val:
n = n.left
else:
if data == n.val:
ret.append(n)
n = n.right
return ret
def delete(self, data):
"""
删除
:param data:
:return:
"""
assert (isinstance(data, int))
# 通过搜索得到需要删除的节点
del_list = self.search(data)
for n in del_list:
# 父节点为空,又不是根节点,已经不在树上,不用再删除
if n.parent is None and n != self.root:
continue
else:
self._del(n)
def _del(self, node):
"""
删除
所删除的节点N存在以下情况
1. 没有子节点直接删除N的父节点指针
2. 有一个子节点将N父节点指针指向N的子节点
3. 有两个子节点找到右子树的最小节点M将值赋给N然后删除M
:param data:
:return:
"""
# 1
if node.left is None and node.right is None:
# 情况1和2根节点和普通节点的处理方式不同
if node == self.root:
self.root = None
else:
if node.val < node.parent.val:
node.parent.left = None
else:
node.parent.right = None
node.parent = None
# 2
elif node.left is None and node.right is not None:
if node == self.root:
self.root = node.right
self.root.parent = None
node.right = None
else:
if node.val < node.parent.val:
node.parent.left = node.right
else:
node.parent.right = node.right
node.right.parent = node.parent
node.parent = None
node.right = None
elif node.left is not None and node.right is None:
if node == self.root:
self.root = node.left
self.root.parent = None
node.left = None
else:
if node.val < node.parent.val:
node.parent.left = node.left
else:
node.parent.right = node.left
node.left.parent = node.parent
node.parent = None
node.left = None
# 3
else:
min_node = node.right
# 找到右子树的最小值节点
if min_node.left:
min_node = min_node.left
if node.val != min_node.val:
node.val = min_node.val
self._del(min_node)
# 右子树的最小值节点与被删除节点的值相等,再次删除原节点
else:
self._del(min_node)
self._del(node)
def get_min(self):
"""
返回最小值节点
:return:
"""
if self.root is None:
return None
n = self.root
while n.left:
n = n.left
return n.val
def get_max(self):
"""
返回最大值节点
:return:
"""
if self.root is None:
return None
n = self.root
while n.right:
n = n.right
return n.val
def in_order(self):
"""
中序遍历
:return:
"""
if self.root is None:
return []
return self._in_order(self.root)
def _in_order(self, node):
if node is None:
return []
ret = []
n = node
ret.extend(self._in_order(n.left))
ret.append(n.val)
ret.extend(self._in_order(n.right))
return ret
def __repr__(self):
# return str(self.in_order())
print(str(self.in_order()))
return self._draw_tree()
def _bfs(self):
"""
bfs
通过父子关系记录节点编号
:return:
"""
if self.root is None:
return []
ret = []
q = Queue()
# 队列[节点,编号]
q.put((self.root, 1))
while not q.empty():
n = q.get()
if n[0] is not None:
ret.append((n[0].val, n[1]))
q.put((n[0].left, n[1]*2))
q.put((n[0].right, n[1]*2+1))
return ret
def _draw_tree(self):
"""
可视化
:return:
"""
nodes = self._bfs()
if not nodes:
print('This tree has no nodes.')
return
layer_num = int(math.log(nodes[-1][1], 2)) + 1
prt_nums = []
for i in range(layer_num):
prt_nums.append([None]*2**i)
for v, p in nodes:
row = int(math.log(p ,2))
col = p % 2**row
prt_nums[row][col] = v
prt_str = ''
for l in prt_nums:
prt_str += str(l)[1:-1] + '\n'
return prt_str
if __name__ == '__main__':
nums = [4, 2, 5, 6, 1, 7, 3]
bst = BinarySearchTree(nums)
print(bst)
# 插入
bst.insert(1)
bst.insert(4)
print(bst)
# 搜索
for n in bst.search(2):
print(n.parent.val, n.val)
# 删除
bst.insert(6)
bst.insert(7)
print(bst)
bst.delete(7)
print(bst)
bst.delete(6)
print(bst)
bst.delete(4)
print(bst)
# min max
print(bst.get_max())
print(bst.get_min())