diff --git a/python/23_binarytree/binary_search_tree.py b/python/23_binarytree/binary_search_tree.py new file mode 100644 index 0000000..5500c5c --- /dev/null +++ b/python/23_binarytree/binary_search_tree.py @@ -0,0 +1,292 @@ +#!/usr/bin/python +# -*- coding: UTF-8 -*- + +from queue import Queue +import math + + +class TreeNode: + def __init__(self, val=None): + self.val = val + self.left = None + self.right = None + self.parent = None + + +class BinarySearchTree: + def __init__(self, val_list=[]): + self.root = None + for n in val_list: + self.insert(n) + + def insert(self, data): + """ + 插入 + :param data: + :return: + """ + assert(isinstance(data, int)) + + if self.root is None: + self.root = TreeNode(data) + else: + n = self.root + while n: + p = n + if data < n.val: + n = n.left + else: + n = n.right + + new_node = TreeNode(data) + new_node.parent = p + + if data < p.val: + p.left = new_node + else: + p.right = new_node + + return True + + def search(self, data): + """ + 搜索 + 返回bst中所有值为data的节点列表 + :param data: + :return: + """ + assert(isinstance(data, int)) + + # 所有搜索到的节点 + ret = [] + + n = self.root + while n: + if data < n.val: + n = n.left + else: + if data == n.val: + ret.append(n) + n = n.right + + return ret + + def delete(self, data): + """ + 删除 + :param data: + :return: + """ + assert (isinstance(data, int)) + + # 通过搜索得到需要删除的节点 + del_list = self.search(data) + + for n in del_list: + # 父节点为空,又不是根节点,已经不在树上,不用再删除 + if n.parent is None and n != self.root: + continue + else: + self._del(n) + + def _del(self, node): + """ + 删除 + 所删除的节点N存在以下情况: + 1. 没有子节点:直接删除N的父节点指针 + 2. 有一个子节点:将N父节点指针指向N的子节点 + 3. 有两个子节点:找到右子树的最小节点M,将值赋给N,然后删除M + :param data: + :return: + """ + # 1 + if node.left is None and node.right is None: + # 情况1和2,根节点和普通节点的处理方式不同 + if node == self.root: + self.root = None + else: + if node.val < node.parent.val: + node.parent.left = None + else: + node.parent.right = None + + node.parent = None + # 2 + elif node.left is None and node.right is not None: + if node == self.root: + self.root = node.right + self.root.parent = None + node.right = None + else: + if node.val < node.parent.val: + node.parent.left = node.right + else: + node.parent.right = node.right + + node.right.parent = node.parent + node.parent = None + node.right = None + elif node.left is not None and node.right is None: + if node == self.root: + self.root = node.left + self.root.parent = None + node.left = None + else: + if node.val < node.parent.val: + node.parent.left = node.left + else: + node.parent.right = node.left + + node.left.parent = node.parent + node.parent = None + node.left = None + # 3 + else: + min_node = node.right + # 找到右子树的最小值节点 + if min_node.left: + min_node = min_node.left + + if node.val != min_node.val: + node.val = min_node.val + self._del(min_node) + # 右子树的最小值节点与被删除节点的值相等,再次删除原节点 + else: + self._del(min_node) + self._del(node) + + def get_min(self): + """ + 返回最小值节点 + :return: + """ + if self.root is None: + return None + + n = self.root + while n.left: + n = n.left + return n.val + + def get_max(self): + """ + 返回最大值节点 + :return: + """ + if self.root is None: + return None + + n = self.root + while n.right: + n = n.right + return n.val + + def in_order(self): + """ + 中序遍历 + :return: + """ + if self.root is None: + return [] + + return self._in_order(self.root) + + def _in_order(self, node): + if node is None: + return [] + + ret = [] + n = node + ret.extend(self._in_order(n.left)) + ret.append(n.val) + ret.extend(self._in_order(n.right)) + + return ret + + def __repr__(self): + # return str(self.in_order()) + print(str(self.in_order())) + return self._draw_tree() + + def _bfs(self): + """ + bfs + 通过父子关系记录节点编号 + :return: + """ + if self.root is None: + return [] + + ret = [] + q = Queue() + # 队列[节点,编号] + q.put((self.root, 1)) + + while not q.empty(): + n = q.get() + + if n[0] is not None: + ret.append((n[0].val, n[1])) + q.put((n[0].left, n[1]*2)) + q.put((n[0].right, n[1]*2+1)) + + return ret + + def _draw_tree(self): + """ + 可视化 + :return: + """ + nodes = self._bfs() + + if not nodes: + print('This tree has no nodes.') + return + + layer_num = int(math.log(nodes[-1][1], 2)) + 1 + + prt_nums = [] + + for i in range(layer_num): + prt_nums.append([None]*2**i) + + for v, p in nodes: + row = int(math.log(p ,2)) + col = p % 2**row + prt_nums[row][col] = v + + prt_str = '' + for l in prt_nums: + prt_str += str(l)[1:-1] + '\n' + + return prt_str + + +if __name__ == '__main__': + nums = [4, 2, 5, 6, 1, 7, 3] + bst = BinarySearchTree(nums) + print(bst) + + # 插入 + bst.insert(1) + bst.insert(4) + print(bst) + + # 搜索 + for n in bst.search(2): + print(n.parent.val, n.val) + + # 删除 + bst.insert(6) + bst.insert(7) + print(bst) + bst.delete(7) + print(bst) + bst.delete(6) + print(bst) + bst.delete(4) + print(bst) + + # min max + print(bst.get_max()) + print(bst.get_min())