algo/python/26_red_black_tree/red_black_tree.py

450 lines
12 KiB
Python
Raw Normal View History

#!/usr/bin/python
# -*- coding: UTF-8 -*-
from queue import Queue
import pygraphviz as pgv
import random
OUTPUT_PATH = 'E:/'
class TreeNode:
def __init__(self, val=None, color=None):
self.val = val
assert color in ['r', 'b']
self.color = 'red' if color == 'r' else 'black'
self.left = None
self.right = None
self.parent = None
def is_black(self):
return self.color == 'black'
def set_black(self):
self.color = 'black'
return
def set_red(self):
self.color = 'red'
class RedBlackTree:
"""
红黑树实现
参考资料
1. 算法导论
第13章 红黑树
13.3 插入 p178
13.4 删除 p183
2. 红黑树删除
https://zhuanlan.zhihu.com/p/25402654
"""
def __init__(self, val_list=None):
self.root = None
self.black_leaf = TreeNode(color='b') # 共用的黑色叶子节点
# 可用数组初始化
if type(val_list) is list:
for n in val_list:
assert type(n) is int
self.insert(n)
def search(self, val):
"""
搜索
:param val:
:return:
"""
if self.root is None:
return None
n = self.root
while n != self.black_leaf:
if val < n.val:
n = n.left
elif val > n.val:
n = n.right
else:
return n
return None
def insert(self, val):
"""
插入
:param val:
:return:
"""
assert type(val) is int
new_node = TreeNode(val, 'r') # 新插入的节点为红色
# 根节点
if self.root is None:
self.root = new_node
else:
n = self.root
while n != self.black_leaf: # 黑色叶子节点
p = n
if val < n.val:
n = n.left
elif val > n.val:
n = n.right
else:
raise KeyError('val:{} already exists') # 该值已存在,插入失败
if val < p.val:
p.left = new_node
else:
p.right = new_node
new_node.parent = p
new_node.left = new_node.right = self.black_leaf
# 插入后调整
self._insert_fixup(new_node)
def _insert_fixup(self, node):
"""
插入调整
参考资料算法导论 13.3 p178-179
:param node:
:return:
"""
n = node
while n is not self.root and not n.parent.is_black():
# 父p 叔u 祖父g
p = self.parent(n)
u = self.bro(p)
g = self.parent(p)
if not u.is_black(): # case 1
p.set_black() # case 1
u.set_black() # case 1
g.set_red() # case 1
n = g # case 1
continue
if p == g.left: # p为左结点
if n == p.right: # case 2
self.rotate_l(p) # case 2
n, p = p, n # case 2
p.set_black() # case 3
g.set_red() # case 3
self.rotate_r(g) # case 3
else: # p为右节点
if n == p.left: # case 2
self.rotate_r(p) # case 2
n, p = p, n # case 2
p.set_black() # case 3
g.set_red() # case 3
self.rotate_l(g) # case 3
# 根节点强制置黑,有两种情况根节点是红色:
# 1. 新插入时是红色
# 2. 经过case 1调整过后变红色
self.root.color = 'black'
def delete(self, val):
"""
删除
:param val:
:return:
"""
assert type(val) is int
n = self.search(val)
if n is None:
print('can not find any nodes with value: {}'.format(val))
return
self._delete_node(n)
def _delete_node(self, node):
"""
删除节点内部实现
参考资料算法导论 13.4 p183-184
实现方式有微调当n有2个子节点时将s拷贝至n转为删除s(s最多有一个子节点)
:param node:
:return:
"""
n = node
# n的子节点个数等于2
if self.children_count(n) == 2:
# 寻找n的后继s
s = n.right
while s.left != self.black_leaf:
s = s.left
n.val = s.val
# 将删除n转化为删除s
n = s
# n的子节点个数小于2
if n.left == self.black_leaf:
c = n.right
else:
c = n.left
self._transplant(n, c)
# 删除的节点是黑色,需要调整
if n.is_black():
self._delete_fixup(c)
return
def _delete_fixup(self, node):
"""
删除调整
参考资料算法导论 13.4 p185-187
:param node:
:return:
"""
n = node
while n != self.root and n.is_black():
p = self.parent(n)
b = self.bro(n)
# 左右节点对称
if p.left == n:
if not b.is_black():
b.set_black() # case 1
p.set_red() # case 1
self.rotate_l(p) # case 1
# new bro after rotate
b = self.bro(n) # case 1
if b.left.is_black() and b.right.is_black():
b.set_red() # case 2
n = p # case 2
else:
if b.right.is_black():
b.left.set_black() # case 3
b.set_red() # case 3
self.rotate_r(b) # case 3
# new bro after rotate
b = self.bro(n) # case 3
# 注意因为p可能是红或黑所以不能直接赋值颜色只能copy
b.color = p.color # case 4
p.set_black() # case 4
b.right.set_black() # case 4
self.rotate_l(p) # case 4
# trick, 调整结束跳出while
n = self.root # case 4
else:
if not b.is_black():
b.set_black() # case 1
p.set_red() # case 1
self.rotate_r(p) # case 1
# new bro after rotate
b = self.bro(n) # case 1
if b.left.is_black() and b.right.is_black():
b.set_red() # case 2
n = p # case 2
else:
if b.left.is_black():
b.right.set_black() # case 3
b.set_red() # case 3
self.rotate_l(b) # case 3
# new bro after rotate
b = self.bro(n) # case 3
# 注意因为p可能是红或黑所以不能直接赋值颜色只能copy
b.color = p.color # case 4
p.set_black() # case 4
2018-11-25 16:18:57 +08:00
b.left.set_black() # case 4
self.rotate_r(p) # case 4
# trick, 调整结束跳出while
n = self.root # case 4
# 将n设为黑色从上面while循环跳出情况有两种
# 1. n是根节点直接无视附加的黑色
# 2. n是红色的节点则染黑
n.set_black()
def _transplant(self, n1, n2):
"""
节点移植 n2 -> n1
:param n1: 原节点
:param n2: 移植节点
:return:
"""
if n1 == self.root:
if n2 != self.black_leaf:
self.root = n2
n2.parent = None
else:
self.root = None # 只有删除根节点时会进来
else:
p = self.parent(n1)
if p.left == n1:
p.left = n2
else:
p.right = n2
n2.parent = p
def rotate_l(self, node):
"""
左旋
:param node:
:return:
"""
if node is None:
return
if node.right is self.black_leaf:
return
# raise Exception('try rotate left , but the node "{}" has no right child'.format(node.val))
p = self.parent(node)
x = node
y = node.right
# node为根节点时p为None旋转后要更新根节点指向
if p is not None:
if x == p.left:
p.left = y
else:
p.right = y
else:
self.root = y
x.parent, y.parent = y, p
if y.left != self.black_leaf:
y.left.parent = x
x.right, y.left = y.left, x
def rotate_r(self, node):
"""
右旋
:param node:
:return:
"""
if node is None:
return
if node.left is self.black_leaf:
return
2018-11-25 16:18:57 +08:00
# raise Exception('try rotate right , but the node "{}" has no left child'.format(node.val))
p = self.parent(node)
x = node
y = node.left
# 同左旋
if p is not None:
if x == p.left:
p.left = y
else:
p.right = y
else:
self.root = y
x.parent, y.parent = y, p
if y.right is not None:
y.right.parent = x
x.left, y.right = y.right, x
@staticmethod
def bro(node):
"""
获取兄弟节点
:param node:
:return:
"""
if node is None or node.parent is None:
return None
else:
p = node.parent
if node == p.left:
return p.right
else:
return p.left
@staticmethod
def parent(node):
"""
获取父节点
:param node:
:return:
"""
if node is None:
return None
else:
return node.parent
def children_count(self, node):
"""
获取子节点个数
:param node:
:return:
"""
return 2 - [node.left, node.right].count(self.black_leaf)
def draw_img(self, img_name='Red_Black_Tree.png'):
"""
画图
用pygraphviz画出节点和箭头
箭头的红色和黑色分别代表左和右
:param img_name:
:return:
"""
if self.root is None:
return
tree = pgv.AGraph(directed=True, strict=True)
q = Queue()
q.put(self.root)
while not q.empty():
n = q.get()
if n != self.black_leaf: # 黑色叶子的连线由各个节点自己画
tree.add_node(n.val, color=n.color)
# 画父节点箭头
# if n.parent is not None:
# tree.add_edge(n.val, n.parent.val)
for c in [n.left, n.right]:
q.put(c)
color = 'red' if c == n.left else 'black'
if c != self.black_leaf:
tree.add_edge(n.val, c.val, color=color)
else:
tree.add_edge(n.val, 'None', color=color)
tree.graph_attr['epsilon'] = '0.01'
tree.layout('dot')
tree.draw(OUTPUT_PATH + img_name)
return True
if __name__ == '__main__':
rbt = RedBlackTree()
# insert
nums = list(range(1, 25))
# random.shuffle(nums)
for num in nums:
rbt.insert(num)
# search
search_num = 23
n = rbt.search(search_num)
if n is not None:
print(n)
else:
print('node {} not found'.format(search_num))
# delete
rbt.delete(4)
# draw image
rbt.draw_img('rbt.png')