algo/python/23_binarytree/binary_search_tree.py
2018-11-16 12:20:46 +08:00

293 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/python
# -*- coding: UTF-8 -*-
from queue import Queue
import math
class TreeNode:
def __init__(self, val=None):
self.val = val
self.left = None
self.right = None
self.parent = None
class BinarySearchTree:
def __init__(self, val_list=[]):
self.root = None
for n in val_list:
self.insert(n)
def insert(self, data):
"""
插入
:param data:
:return:
"""
assert(isinstance(data, int))
if self.root is None:
self.root = TreeNode(data)
else:
n = self.root
while n:
p = n
if data < n.val:
n = n.left
else:
n = n.right
new_node = TreeNode(data)
new_node.parent = p
if data < p.val:
p.left = new_node
else:
p.right = new_node
return True
def search(self, data):
"""
搜索
返回bst中所有值为data的节点列表
:param data:
:return:
"""
assert(isinstance(data, int))
# 所有搜索到的节点
ret = []
n = self.root
while n:
if data < n.val:
n = n.left
else:
if data == n.val:
ret.append(n)
n = n.right
return ret
def delete(self, data):
"""
删除
:param data:
:return:
"""
assert (isinstance(data, int))
# 通过搜索得到需要删除的节点
del_list = self.search(data)
for n in del_list:
# 父节点为空,又不是根节点,已经不在树上,不用再删除
if n.parent is None and n != self.root:
continue
else:
self._del(n)
def _del(self, node):
"""
删除
所删除的节点N存在以下情况
1. 没有子节点直接删除N的父节点指针
2. 有一个子节点将N父节点指针指向N的子节点
3. 有两个子节点找到右子树的最小节点M将值赋给N然后删除M
:param data:
:return:
"""
# 1
if node.left is None and node.right is None:
# 情况1和2根节点和普通节点的处理方式不同
if node == self.root:
self.root = None
else:
if node.val < node.parent.val:
node.parent.left = None
else:
node.parent.right = None
node.parent = None
# 2
elif node.left is None and node.right is not None:
if node == self.root:
self.root = node.right
self.root.parent = None
node.right = None
else:
if node.val < node.parent.val:
node.parent.left = node.right
else:
node.parent.right = node.right
node.right.parent = node.parent
node.parent = None
node.right = None
elif node.left is not None and node.right is None:
if node == self.root:
self.root = node.left
self.root.parent = None
node.left = None
else:
if node.val < node.parent.val:
node.parent.left = node.left
else:
node.parent.right = node.left
node.left.parent = node.parent
node.parent = None
node.left = None
# 3
else:
min_node = node.right
# 找到右子树的最小值节点
if min_node.left:
min_node = min_node.left
if node.val != min_node.val:
node.val = min_node.val
self._del(min_node)
# 右子树的最小值节点与被删除节点的值相等,再次删除原节点
else:
self._del(min_node)
self._del(node)
def get_min(self):
"""
返回最小值节点
:return:
"""
if self.root is None:
return None
n = self.root
while n.left:
n = n.left
return n.val
def get_max(self):
"""
返回最大值节点
:return:
"""
if self.root is None:
return None
n = self.root
while n.right:
n = n.right
return n.val
def in_order(self):
"""
中序遍历
:return:
"""
if self.root is None:
return []
return self._in_order(self.root)
def _in_order(self, node):
if node is None:
return []
ret = []
n = node
ret.extend(self._in_order(n.left))
ret.append(n.val)
ret.extend(self._in_order(n.right))
return ret
def __repr__(self):
# return str(self.in_order())
print(str(self.in_order()))
return self._draw_tree()
def _bfs(self):
"""
bfs
通过父子关系记录节点编号
:return:
"""
if self.root is None:
return []
ret = []
q = Queue()
# 队列[节点,编号]
q.put((self.root, 1))
while not q.empty():
n = q.get()
if n[0] is not None:
ret.append((n[0].val, n[1]))
q.put((n[0].left, n[1]*2))
q.put((n[0].right, n[1]*2+1))
return ret
def _draw_tree(self):
"""
可视化
:return:
"""
nodes = self._bfs()
if not nodes:
print('This tree has no nodes.')
return
layer_num = int(math.log(nodes[-1][1], 2)) + 1
prt_nums = []
for i in range(layer_num):
prt_nums.append([None]*2**i)
for v, p in nodes:
row = int(math.log(p ,2))
col = p % 2**row
prt_nums[row][col] = v
prt_str = ''
for l in prt_nums:
prt_str += str(l)[1:-1] + '\n'
return prt_str
if __name__ == '__main__':
nums = [4, 2, 5, 6, 1, 7, 3]
bst = BinarySearchTree(nums)
print(bst)
# 插入
bst.insert(1)
bst.insert(4)
print(bst)
# 搜索
for n in bst.search(2):
print(n.parent.val, n.val)
# 删除
bst.insert(6)
bst.insert(7)
print(bst)
bst.delete(7)
print(bst)
bst.delete(6)
print(bst)
bst.delete(4)
print(bst)
# min max
print(bst.get_max())
print(bst.get_min())