trie in python
This commit is contained in:
parent
17bbd628fa
commit
e1e0673c89
170
python/35_trie/trie_.py
Normal file
170
python/35_trie/trie_.py
Normal file
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from queue import Queue
|
||||
import pygraphviz as pgv
|
||||
|
||||
OUTPUT_PATH = 'E:/'
|
||||
|
||||
|
||||
class Node:
|
||||
def __init__(self, c):
|
||||
self.data = c
|
||||
self.is_ending_char = False
|
||||
# 使用有序数组,降低空间消耗,支持更多字符
|
||||
self.children = []
|
||||
|
||||
def insert_child(self, c):
|
||||
"""
|
||||
插入一个子节点
|
||||
:param c:
|
||||
:return:
|
||||
"""
|
||||
v = ord(c)
|
||||
idx = self._find_insert_idx(v)
|
||||
length = len(self.children)
|
||||
|
||||
node = Node(c)
|
||||
if idx == length:
|
||||
self.children.append(node)
|
||||
else:
|
||||
self.children.append(None)
|
||||
for i in range(length, idx, -1):
|
||||
self.children[i] = self.children[i-1]
|
||||
self.children[idx] = node
|
||||
|
||||
def get_child(self, c):
|
||||
"""
|
||||
搜索子节点并返回
|
||||
:param c:
|
||||
:return:
|
||||
"""
|
||||
start = 0
|
||||
end = len(self.children) - 1
|
||||
v = ord(c)
|
||||
|
||||
while start <= end:
|
||||
mid = (start + end)//2
|
||||
if v == ord(self.children[mid].data):
|
||||
return self.children[mid]
|
||||
elif v < ord(self.children[mid].data):
|
||||
end = mid - 1
|
||||
else:
|
||||
start = mid + 1
|
||||
# 找不到返回None
|
||||
return None
|
||||
|
||||
def _find_insert_idx(self, v):
|
||||
"""
|
||||
二分查找,找到有序数组的插入位置
|
||||
:param v:
|
||||
:return:
|
||||
"""
|
||||
start = 0
|
||||
end = len(self.children) - 1
|
||||
|
||||
while start <= end:
|
||||
mid = (start + end)//2
|
||||
if v < ord(self.children[mid].data):
|
||||
end = mid - 1
|
||||
else:
|
||||
if mid + 1 == len(self.children) or v < ord(self.children[mid+1].data):
|
||||
return mid + 1
|
||||
else:
|
||||
start = mid + 1
|
||||
# v < self.children[0]
|
||||
return 0
|
||||
|
||||
def __repr__(self):
|
||||
return 'node value: {}'.format(self.data) + '\n' \
|
||||
+ 'children:{}'.format([n.data for n in self.children])
|
||||
|
||||
|
||||
class Trie:
|
||||
def __init__(self):
|
||||
self.root = Node(None)
|
||||
|
||||
def gen_tree(self, string_list):
|
||||
"""
|
||||
创建trie树
|
||||
|
||||
1. 遍历每个字符串的字符,从根节点开始,如果没有对应子节点,则创建
|
||||
2. 每一个串的末尾节点标注为红色(is_ending_char)
|
||||
:param string_list:
|
||||
:return:
|
||||
"""
|
||||
for string in string_list:
|
||||
n = self.root
|
||||
for c in string:
|
||||
if n.get_child(c) is None:
|
||||
n.insert_child(c)
|
||||
n = n.get_child(c)
|
||||
n.is_ending_char = True
|
||||
|
||||
def search(self, pattern):
|
||||
"""
|
||||
搜索
|
||||
|
||||
1. 遍历模式串的字符,从根节点开始搜索,如果途中子节点不存在,返回False
|
||||
2. 遍历完模式串,则说明模式串存在,再检查树中最后一个节点是否为红色,是
|
||||
则返回True,否则False
|
||||
:param pattern:
|
||||
:return:
|
||||
"""
|
||||
assert type(pattern) is str and len(pattern) > 0
|
||||
|
||||
n = self.root
|
||||
for c in pattern:
|
||||
if n.get_child(c) is None:
|
||||
return False
|
||||
n = n.get_child(c)
|
||||
|
||||
return True if n.is_ending_char is True else False
|
||||
|
||||
def draw_img(self, img_name='Trie.png'):
|
||||
"""
|
||||
画出trie树
|
||||
:param img_name:
|
||||
:return:
|
||||
"""
|
||||
if self.root is None:
|
||||
return
|
||||
|
||||
tree = pgv.AGraph('graph foo {}', strict=False, directed=False)
|
||||
|
||||
# root
|
||||
nid = 0
|
||||
color = 'black'
|
||||
tree.add_node(nid, color=color, label='None')
|
||||
|
||||
q = Queue()
|
||||
q.put((self.root, nid))
|
||||
while not q.empty():
|
||||
n, pid = q.get()
|
||||
for c in n.children:
|
||||
nid += 1
|
||||
q.put((c, nid))
|
||||
color = 'red' if c.is_ending_char is True else 'black'
|
||||
tree.add_node(nid, color=color, label=c.data)
|
||||
tree.add_edge(pid, nid)
|
||||
|
||||
tree.graph_attr['epsilon'] = '0.01'
|
||||
tree.layout('dot')
|
||||
tree.draw(OUTPUT_PATH + img_name)
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
string_list = ['abc', 'abd', 'abcc', 'accd', 'acml', 'P@trick', 'data', 'structure', 'algorithm']
|
||||
|
||||
print('--- gen trie ---')
|
||||
print(string_list)
|
||||
trie = Trie()
|
||||
trie.gen_tree(string_list)
|
||||
# trie.draw_img()
|
||||
|
||||
print('\n')
|
||||
print('--- search result ---')
|
||||
search_string = ['a', 'ab', 'abc', 'abcc', 'abe', 'P@trick', 'P@tric', 'Patrick']
|
||||
for ss in search_string:
|
||||
print('[pattern]: {}'.format(ss), '[result]: {}'.format(trie.search(ss)))
|
Loading…
Reference in New Issue
Block a user