This commit is contained in:
commit
aee7f91373
@ -79,6 +79,8 @@ void test_checkCircle() {
|
||||
SinglyLinkedNode* node = malloc(sizeof(SinglyLinkedNode));
|
||||
node->data = i;
|
||||
current->next = node;
|
||||
//reset current node
|
||||
current = node;
|
||||
}
|
||||
current->next = h;
|
||||
|
||||
|
@ -1,133 +1,70 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
#
|
||||
# 1) Insertion, deletion and random access of array
|
||||
# 2) Assumes int for element type
|
||||
#
|
||||
# Author: Wenru
|
||||
#
|
||||
|
||||
|
||||
class MyArray:
|
||||
"""A simple wrapper around List.
|
||||
You cannot have -1 in the array.
|
||||
"""
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
|
||||
self._data = []
|
||||
self._count = 0
|
||||
self._capacity = capacity
|
||||
|
||||
def __getitem__(self, position: int) -> int:
|
||||
|
||||
"""Support for subscript.
|
||||
Perhaps better than the find() method below.
|
||||
"""
|
||||
def __getitem__(self, position: int) -> object:
|
||||
return self._data[position]
|
||||
|
||||
def find(self, index: int) -> Optional[int]:
|
||||
def __setitem__(self, index: int, value: object):
|
||||
self._data[index] = value
|
||||
|
||||
if index >= self._count or index <= -self._count:
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __iter__(self):
|
||||
for item in self._data:
|
||||
yield item
|
||||
|
||||
def find(self, index: int) -> object:
|
||||
try:
|
||||
return self._data[index]
|
||||
except IndexError:
|
||||
return None
|
||||
return self._data[index]
|
||||
|
||||
def delete(self, index: int) -> bool:
|
||||
|
||||
if index >= self._count or index <= -self._count:
|
||||
try:
|
||||
self._data.pop(index)
|
||||
return True
|
||||
except IndexError:
|
||||
return False
|
||||
|
||||
self._data[index:-1] = self._data[index + 1:]
|
||||
self._count -= 1
|
||||
# 真正将数据删除并覆盖原来的数据 ,这个需要增加
|
||||
self._data = self._data[0:self._count]
|
||||
print('delete function', self._data)
|
||||
return True
|
||||
|
||||
def insert(self, index: int, value: int) -> bool:
|
||||
|
||||
# if index >= self._count or index <= -self._count: return False
|
||||
if self._capacity == self._count:
|
||||
if len(self) >= self._capacity:
|
||||
return False
|
||||
# 如果还有空间,那么插入位置大于当前的元素个数,可以插入最后的位置
|
||||
if index >= self._count:
|
||||
self._data.append(value)
|
||||
# 同上,如果位置小于0 可以插入第0个位置.
|
||||
if index < 0:
|
||||
print(index)
|
||||
self._data.insert(0, value)
|
||||
|
||||
self._count += 1
|
||||
return True
|
||||
|
||||
def insert_v2(self, index: int, value: int) -> bool:
|
||||
"""
|
||||
支持任意位置插入
|
||||
:param index:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
# 数组空间已满
|
||||
if self._capacity == self._count:
|
||||
return False
|
||||
|
||||
# 插入位置大于当前的元素个数,可以插入最后的位置
|
||||
if index >= self._count:
|
||||
self._data.append(value)
|
||||
elif index < 0:
|
||||
# 位置小于 0 可以插入第 0 个位置
|
||||
self._data.insert(0, value)
|
||||
else:
|
||||
# 挪动 index 至 _count 位到 index+1 至 _count+1 位
|
||||
# 插入第 index
|
||||
self._data[index+1:self._count+1] = self._data[index:self._count]
|
||||
self._data[index] = value
|
||||
|
||||
self._count += 1
|
||||
return True
|
||||
|
||||
def insert_to_tail(self, value: int) -> bool:
|
||||
|
||||
if self._count == self._capacity:
|
||||
return False
|
||||
if self._count == len(self._data):
|
||||
self._data.append(value)
|
||||
else:
|
||||
self._data[self._count] = value
|
||||
self._count += 1
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
return " ".join(str(num) for num in self._data[:self._count])
|
||||
return self._data.insert(index, value)
|
||||
|
||||
def print_all(self):
|
||||
|
||||
for num in self._data[:self._count]:
|
||||
print(f"{num}", end=" ")
|
||||
print("\n", flush=True)
|
||||
for item in self:
|
||||
print(item)
|
||||
|
||||
|
||||
def test_myarray():
|
||||
array_a = MyArray(6)
|
||||
for num in range(6):
|
||||
array_a.insert_to_tail(num)
|
||||
assert array_a.find(0) == 0
|
||||
assert array_a[0] == 0
|
||||
array_a.delete(0)
|
||||
assert array_a[0] == 1
|
||||
array = MyArray(5)
|
||||
array.insert(0, 3)
|
||||
array.insert(0, 4)
|
||||
array.insert(1, 5)
|
||||
array.insert(3, 9)
|
||||
array.insert(3, 10)
|
||||
assert array.insert(0, 100) is False
|
||||
assert len(array) == 5
|
||||
assert array.find(1) == 5
|
||||
assert array.delete(4) is True
|
||||
array.print_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = MyArray(6)
|
||||
for i in range(6):
|
||||
a.insert_to_tail(i)
|
||||
a.delete(2)
|
||||
print(a)
|
||||
a.insert_to_tail(7)
|
||||
print(a)
|
||||
print('origin', a)
|
||||
a.delete(4)
|
||||
print('delete ', a)
|
||||
|
||||
a.insert(100, 10000)
|
||||
print(a)
|
||||
test_myarray()
|
||||
|
@ -23,7 +23,7 @@ def counting_sort(a: List[int]):
|
||||
a_sorted[index] = num
|
||||
counts[num] -= 1
|
||||
|
||||
a = a_sorted
|
||||
a[:] = a_sorted
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -37,4 +37,4 @@ if __name__ == "__main__":
|
||||
|
||||
a3 = [4, 5, 0, 9, 3, 3, 1, 9, 8, 7]
|
||||
counting_sort(a3)
|
||||
print(a3)
|
||||
print(a3)
|
||||
|
51
python/38_divide_and_conquer/merge_sort_counting.py
Normal file
51
python/38_divide_and_conquer/merge_sort_counting.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
inversion_num = 0
|
||||
|
||||
|
||||
def merge_sort_counting(nums, start, end):
|
||||
if start >= end:
|
||||
return
|
||||
|
||||
mid = (start + end)//2
|
||||
merge_sort_counting(nums, start, mid)
|
||||
merge_sort_counting(nums, mid+1, end)
|
||||
merge(nums, start, mid, end)
|
||||
|
||||
|
||||
def merge(nums, start, mid, end):
|
||||
global inversion_num
|
||||
i = start
|
||||
j = mid+1
|
||||
tmp = []
|
||||
while i <= mid and j <= end:
|
||||
if nums[i] <= nums[j]:
|
||||
inversion_num += j - mid - 1
|
||||
tmp.append(nums[i])
|
||||
i += 1
|
||||
else:
|
||||
tmp.append(nums[j])
|
||||
j += 1
|
||||
|
||||
while i <= mid:
|
||||
# 这时nums[i]的逆序数是整个nums[mid+1: end+1]的长度
|
||||
inversion_num += end - mid
|
||||
tmp.append(nums[i])
|
||||
i += 1
|
||||
|
||||
while j <= end:
|
||||
tmp.append(nums[j])
|
||||
j += 1
|
||||
|
||||
nums[start: end+1] = tmp
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('--- count inversion number using merge sort ---')
|
||||
# nums = [5, 0, 4, 2, 3, 1, 6, 8, 7]
|
||||
nums = [5, 0, 4, 2, 3, 1, 3, 3, 3, 6, 8, 7]
|
||||
print('nums : {}'.format(nums))
|
||||
merge_sort_counting(nums, 0, len(nums)-1)
|
||||
print('sorted: {}'.format(nums))
|
||||
print('inversion number: {}'.format(inversion_num))
|
59
python/39_back_track/01_bag.py
Normal file
59
python/39_back_track/01_bag.py
Normal file
@ -0,0 +1,59 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
# 背包选取的物品列表
|
||||
picks = []
|
||||
picks_with_max_value = []
|
||||
|
||||
|
||||
def bag(capacity: int, cur_weight: int, items_info: List, pick_idx: int):
|
||||
"""
|
||||
回溯法解01背包,穷举
|
||||
:param capacity: 背包容量
|
||||
:param cur_weight: 背包当前重量
|
||||
:param items_info: 物品的重量和价值信息
|
||||
:param pick_idx: 当前物品的索引
|
||||
:return:
|
||||
"""
|
||||
# 考察完所有物品,或者在中途已经装满
|
||||
if pick_idx >= len(items_info) or cur_weight == capacity:
|
||||
global picks_with_max_value
|
||||
if get_value(items_info, picks) > \
|
||||
get_value(items_info, picks_with_max_value):
|
||||
picks_with_max_value = picks.copy()
|
||||
else:
|
||||
item_weight = items_info[pick_idx][0]
|
||||
if cur_weight + item_weight <= capacity: # 选
|
||||
picks[pick_idx] = 1
|
||||
bag(capacity, cur_weight + item_weight, items_info, pick_idx + 1)
|
||||
|
||||
picks[pick_idx] = 0 # 不选
|
||||
bag(capacity, cur_weight, items_info, pick_idx + 1)
|
||||
|
||||
|
||||
def get_value(items_info: List, pick_items: List):
|
||||
values = [_[1] for _ in items_info]
|
||||
return sum([a*b for a, b in zip(values, pick_items)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# [(weight, value), ...]
|
||||
items_info = [(3, 5), (2, 2), (1, 4), (1, 2), (4, 10)]
|
||||
capacity = 8
|
||||
|
||||
print('--- items info ---')
|
||||
print(items_info)
|
||||
|
||||
print('\n--- capacity ---')
|
||||
print(capacity)
|
||||
|
||||
picks = [0] * len(items_info)
|
||||
bag(capacity, 0, items_info, 0)
|
||||
|
||||
print('\n--- picks ---')
|
||||
print(picks_with_max_value)
|
||||
|
||||
print('\n--- value ---')
|
||||
print(get_value(items_info, picks_with_max_value))
|
57
python/39_back_track/eight_queens.py
Normal file
57
python/39_back_track/eight_queens.py
Normal file
@ -0,0 +1,57 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
# 棋盘尺寸
|
||||
BOARD_SIZE = 8
|
||||
|
||||
solution_count = 0
|
||||
queen_list = [0] * BOARD_SIZE
|
||||
|
||||
|
||||
def eight_queens(cur_column: int):
|
||||
"""
|
||||
输出所有符合要求的八皇后序列
|
||||
用一个长度为8的数组代表棋盘的列,数组的数字则为当前列上皇后所在的行数
|
||||
:return:
|
||||
"""
|
||||
if cur_column >= BOARD_SIZE:
|
||||
global solution_count
|
||||
solution_count += 1
|
||||
# 解
|
||||
print(queen_list)
|
||||
else:
|
||||
for i in range(BOARD_SIZE):
|
||||
if is_valid_pos(cur_column, i):
|
||||
queen_list[cur_column] = i
|
||||
eight_queens(cur_column + 1)
|
||||
|
||||
|
||||
def is_valid_pos(cur_column: int, pos: int) -> bool:
|
||||
"""
|
||||
因为采取的是每列放置1个皇后的做法
|
||||
所以检查的时候不必检查列的合法性,只需要检查行和对角
|
||||
1. 行:检查数组在下标为cur_column之前的元素是否已存在pos
|
||||
2. 对角:检查数组在下标为cur_column之前的元素,其行的间距pos - QUEEN_LIST[i]
|
||||
和列的间距cur_column - i是否一致
|
||||
:param cur_column:
|
||||
:param pos:
|
||||
:return:
|
||||
"""
|
||||
i = 0
|
||||
while i < cur_column:
|
||||
# 同行
|
||||
if queen_list[i] == pos:
|
||||
return False
|
||||
# 对角线
|
||||
if cur_column - i == abs(pos - queen_list[i]):
|
||||
return False
|
||||
i += 1
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('--- eight queens sequence ---')
|
||||
eight_queens(0)
|
||||
|
||||
print('\n--- solution count ---')
|
||||
print(solution_count)
|
42
python/39_back_track/permutations.py
Normal file
42
python/39_back_track/permutations.py
Normal file
@ -0,0 +1,42 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
permutations_list = [] # 全局变量,用于记录每个输出
|
||||
|
||||
|
||||
def permutations(nums: List, n: int, pick_count: int):
|
||||
"""
|
||||
从nums选取n个数的全排列
|
||||
|
||||
回溯法,用一个栈记录当前路径信息
|
||||
当满足n==0时,说明栈中的数已足够,输出并终止遍历
|
||||
:param nums:
|
||||
:param n:
|
||||
:param pick_count:
|
||||
:return:
|
||||
"""
|
||||
if n == 0:
|
||||
print(permutations_list)
|
||||
else:
|
||||
for i in range(len(nums) - pick_count):
|
||||
permutations_list[pick_count] = nums[i]
|
||||
nums[i], nums[len(nums) - pick_count - 1] = nums[len(nums) - pick_count - 1], nums[i]
|
||||
permutations(nums, n-1, pick_count+1)
|
||||
nums[i], nums[len(nums) - pick_count - 1] = nums[len(nums) - pick_count - 1], nums[i]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nums = [1, 2, 3, 4]
|
||||
n = 3
|
||||
print('--- list ---')
|
||||
print(nums)
|
||||
|
||||
print('\n--- pick num ---')
|
||||
print(n)
|
||||
|
||||
print('\n--- permutation list ---')
|
||||
permutations_list = [0] * n
|
||||
permutations(nums, n, 0)
|
||||
|
35
python/39_back_track/regex.py
Normal file
35
python/39_back_track/regex.py
Normal file
@ -0,0 +1,35 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
is_match = False
|
||||
|
||||
|
||||
def rmatch(r_idx: int, m_idx: int, regex: str, main: str):
|
||||
global is_match
|
||||
if is_match:
|
||||
return
|
||||
|
||||
if r_idx >= len(regex): # 正则串全部匹配好了
|
||||
is_match = True
|
||||
return
|
||||
|
||||
if m_idx >= len(main) and r_idx < len(regex): # 正则串没匹配完,但是主串已经没得匹配了
|
||||
is_match = False
|
||||
return
|
||||
|
||||
if regex[r_idx] == '*': # * 匹配1个或多个任意字符,递归搜索每一种情况
|
||||
for i in range(m_idx, len(main)):
|
||||
rmatch(r_idx+1, i+1, regex, main)
|
||||
elif regex[r_idx] == '?': # ? 匹配0个或1个任意字符,两种情况
|
||||
rmatch(r_idx+1, m_idx+1, regex, main)
|
||||
rmatch(r_idx+1, m_idx, regex, main)
|
||||
else: # 非特殊字符需要精确匹配
|
||||
if regex[r_idx] == main[m_idx]:
|
||||
rmatch(r_idx+1, m_idx+1, regex, main)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
regex = 'ab*eee?d'
|
||||
main = 'abcdsadfkjlekjoiwjiojieeecd'
|
||||
rmatch(0, 0, regex, main)
|
||||
print(is_match)
|
25
python/39_backtracking/backtracking.py
Normal file
25
python/39_backtracking/backtracking.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
def eight_queens() -> None:
|
||||
solutions = []
|
||||
|
||||
def backtracking(queens_at_column: List[int], index_sums: List[int], index_diffs: List[int]) -> None:
|
||||
row = len(queens_at_column)
|
||||
if row == 8:
|
||||
solutions.append(queens_at_column)
|
||||
return
|
||||
for col in range(8):
|
||||
if col in queens_at_column or row + col in index_sums or row - col in index_diffs: continue
|
||||
backtracking(queens_at_column + [col], index_sums + [row + col], index_diffs + [row - col])
|
||||
|
||||
backtracking([], [], [])
|
||||
print(*(" " + " ".join("*" * i + "Q" + "*" * (8 - i - 1) + "\n" for i in solution) for solution in solutions), sep="\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
eight_queens()
|
66
python/40_dynamic_programming/01_bag.py
Normal file
66
python/40_dynamic_programming/01_bag.py
Normal file
@ -0,0 +1,66 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def bag(items_info: List[int], capacity: int) -> int:
|
||||
"""
|
||||
固定容量的背包,计算能装进背包的物品组合的最大重量
|
||||
|
||||
:param items_info: 每个物品的重量
|
||||
:param capacity: 背包容量
|
||||
:return: 最大装载重量
|
||||
"""
|
||||
n = len(items_info)
|
||||
memo = [[-1]*(capacity+1) for i in range(n)]
|
||||
memo[0][0] = 1
|
||||
if items_info[0] <= capacity:
|
||||
memo[0][items_info[0]] = 1
|
||||
|
||||
for i in range(1, n):
|
||||
for cur_weight in range(capacity+1):
|
||||
if memo[i-1][cur_weight] != -1:
|
||||
memo[i][cur_weight] = memo[i-1][cur_weight] # 不选
|
||||
if cur_weight + items_info[i] <= capacity: # 选
|
||||
memo[i][cur_weight + items_info[i]] = 1
|
||||
|
||||
for w in range(capacity, -1, -1):
|
||||
if memo[-1][w] != -1:
|
||||
return w
|
||||
|
||||
|
||||
def bag_with_max_value(items_info: List[Tuple[int, int]], capacity: int) -> int:
|
||||
"""
|
||||
固定容量的背包,计算能装进背包的物品组合的最大价值
|
||||
|
||||
:param items_info: 物品的重量和价值
|
||||
:param capacity: 背包容量
|
||||
:return: 最大装载价值
|
||||
"""
|
||||
n = len(items_info)
|
||||
memo = [[-1]*(capacity+1) for i in range(n)]
|
||||
memo[0][0] = 0
|
||||
if items_info[0][0] <= capacity:
|
||||
memo[0][items_info[0][0]] = items_info[0][1]
|
||||
|
||||
for i in range(1, n):
|
||||
for cur_weight in range(capacity+1):
|
||||
if memo[i-1][cur_weight] != -1:
|
||||
memo[i][cur_weight] = memo[i-1][cur_weight]
|
||||
if cur_weight + items_info[i][0] <= capacity:
|
||||
memo[i][cur_weight + items_info[i][0]] = max(memo[i][cur_weight + items_info[i][0]],
|
||||
memo[i-1][cur_weight] + items_info[i][1])
|
||||
return max(memo[-1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# [weight, ...]
|
||||
items_info = [2, 2, 4, 6, 3]
|
||||
capacity = 9
|
||||
print(bag(items_info, capacity))
|
||||
|
||||
# [(weight, value), ...]
|
||||
items_info = [(3, 5), (2, 2), (1, 4), (1, 2), (4, 10)]
|
||||
capacity = 8
|
||||
print(bag_with_max_value(items_info, capacity))
|
30
python/40_dynamic_programming/knapsack.py
Normal file
30
python/40_dynamic_programming/knapsack.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
def knapsack01(weights: List[int], values: List[int], capacity: int) -> int:
|
||||
# Denote the state as (i, c), where i is the stage number,
|
||||
# and c is the capacity available. Denote f(i, c) to be the
|
||||
# maximum value when the capacity available is c, and Item 0
|
||||
# to Item i-1 are to be packed.
|
||||
# The goal is to find f(n-1, W), where W is the total capacity.
|
||||
# Then the DP functional equation is:
|
||||
# f(i, c) = max(xᵢvᵢ + f(i-1, c-xᵢwᵢ)), xᵢ ∈ D, i ≥ 0,
|
||||
# f(-1, c) = 0, 0 ≤ c ≤ W,
|
||||
# where
|
||||
# / {0}, if wᵢ > c
|
||||
# D = D(i, c) =
|
||||
# \ {0, 1}, if wᵢ ≤ c
|
||||
|
||||
prev = [0] * (capacity + 1)
|
||||
for w, v in zip(weights, values):
|
||||
prev = [c >= w and max(prev[c], prev[c-w] + v) for c in range(capacity + 1)]
|
||||
return prev[-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To find the maximum weight that can be packed,
|
||||
# set values equal to the weights
|
||||
print(knapsack01([2, 2, 4, 6, 3], [2, 2, 4, 6, 3], 9))
|
64
python/40_dynamic_programming/yh_triangle.py
Normal file
64
python/40_dynamic_programming/yh_triangle.py
Normal file
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
Layer_nums = List[int]
|
||||
|
||||
|
||||
def yh_triangle(nums: List[Layer_nums]) -> int:
|
||||
"""
|
||||
从根节点开始向下走,过程中经过的节点,只需存储经过它时最小的路径和
|
||||
:param nums:
|
||||
:return:
|
||||
"""
|
||||
assert len(nums) > 0
|
||||
n = len(nums) # 层数
|
||||
memo = [[0]*n for i in range(n)]
|
||||
memo[0][0] = nums[0][0]
|
||||
|
||||
for i in range(1, n):
|
||||
for j in range(i+1):
|
||||
# 每一层首尾两个数字,只有一条路径可以到达
|
||||
if j == 0:
|
||||
memo[i][j] = memo[i-1][j] + nums[i][j]
|
||||
elif j == i:
|
||||
memo[i][j] = memo[i-1][j-1] + nums[i][j]
|
||||
else:
|
||||
memo[i][j] = min(memo[i-1][j-1] + nums[i][j], memo[i-1][j] + nums[i][j])
|
||||
return min(memo[n-1])
|
||||
|
||||
|
||||
def yh_triangle_space_optimization(nums: List[Layer_nums]) -> int:
|
||||
assert len(nums) > 0
|
||||
n = len(nums)
|
||||
memo = [0] * n
|
||||
memo[0] = nums[0][0]
|
||||
|
||||
for i in range(1, n):
|
||||
for j in range(i, -1, -1):
|
||||
if j == i:
|
||||
memo[j] = memo[j-1] + nums[i][j]
|
||||
elif j == 0:
|
||||
memo[j] = memo[j] + nums[i][j]
|
||||
else:
|
||||
memo[j] = min(memo[j-1] + nums[i][j], memo[j] + nums[i][j])
|
||||
return min(memo)
|
||||
|
||||
|
||||
def yh_triangle_bottom_up(nums: List[Layer_nums]) -> int:
|
||||
assert len(nums) > 0
|
||||
n = len(nums)
|
||||
memo = nums[-1].copy()
|
||||
|
||||
for i in range(n-1, 0, -1):
|
||||
for j in range(i):
|
||||
memo[j] = min(memo[j] + nums[i-1][j], memo[j+1] + nums[i-1][j])
|
||||
return memo[0]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
nums = [[3], [2, 6], [5, 4, 2], [6, 0, 3, 2]]
|
||||
print(yh_triangle(nums))
|
||||
print(yh_triangle_space_optimization(nums))
|
||||
print(yh_triangle_bottom_up(nums))
|
45
python/41_dynamic_programming/coins_problem.py
Normal file
45
python/41_dynamic_programming/coins_problem.py
Normal file
@ -0,0 +1,45 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def coins_dp(values: List[int], target: int) -> int:
|
||||
# memo[i]表示target为i的时候,所需的最少硬币数
|
||||
memo = [0] * (target+1)
|
||||
# 0元的时候为0个
|
||||
memo[0] = 0
|
||||
|
||||
for i in range(1, target+1):
|
||||
min_num = 999999
|
||||
# 对于values中的所有n
|
||||
# memo[i]为 min(memo[i-n1], memo[i-n2], ...) + 1
|
||||
for n in values:
|
||||
if i >= n:
|
||||
min_num = min(min_num, 1 + memo[i-n])
|
||||
else: # values中的数值要从小到大排序
|
||||
break
|
||||
memo[i] = min_num
|
||||
|
||||
# print(memo)
|
||||
return memo[-1]
|
||||
|
||||
|
||||
min_num = 999999
|
||||
def coins_backtracking(values: List[int], target: int, cur_value: int, coins_count: int):
|
||||
if cur_value == target:
|
||||
global min_num
|
||||
min_num = min(coins_count, min_num)
|
||||
else:
|
||||
for n in values:
|
||||
if cur_value + n <= target:
|
||||
coins_backtracking(values, target, cur_value+n, coins_count+1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
values = [1, 3, 5]
|
||||
target = 23
|
||||
print(coins_dp(values, target))
|
||||
coins_backtracking(values, target, 0, 0)
|
||||
print(min_num)
|
||||
|
39
python/41_dynamic_programming/min_dist.py
Normal file
39
python/41_dynamic_programming/min_dist.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from itertools import accumulate
|
||||
|
||||
def min_dist(weights: List[List[int]]) -> int:
|
||||
"""Find the minimum weight path from the weights matrix."""
|
||||
m, n = len(weights), len(weights[0])
|
||||
table = [[0] * n for _ in range(m)]
|
||||
# table[i][j] is the minimum distance (weight) when
|
||||
# there are i vertical moves and j horizontal moves
|
||||
# left.
|
||||
table[0] = list(accumulate(reversed(weights[-1])))
|
||||
for i, v in enumerate(accumulate(row[-1] for row in reversed(weights))):
|
||||
table[i][0] = v
|
||||
for i in range(1, m):
|
||||
for j in range(1, n):
|
||||
table[i][j] = weights[~i][~j] + min(table[i - 1][j], table[i][j - 1])
|
||||
return table[-1][-1]
|
||||
|
||||
|
||||
def min_dist_recur(weights: List[List[int]]) -> int:
|
||||
m, n = len(weights), len(weights[0])
|
||||
table = [[0] * n for _ in range(m)]
|
||||
def min_dist_to(i: int, j: int) -> int:
|
||||
if i == j == 0: return weights[0][0]
|
||||
if table[i][j]: return table[i][j]
|
||||
min_left = float("inf") if j - 1 < 0 else min_dist_to(i, j - 1)
|
||||
min_up = float("inf") if i - 1 < 0 else min_dist_to(i - 1, j)
|
||||
return weights[i][j] + min(min_left, min_up)
|
||||
return min_dist_to(m - 1, n - 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
weights = [[1, 3, 5, 9], [2, 1, 3, 4], [5, 2, 6, 7], [6, 8, 4, 3]]
|
||||
print(min_dist(weights))
|
||||
print(min_dist_recur(weights))
|
@ -0,0 +1,58 @@
|
||||
#!/usr/bin/python
|
||||
# -*- coding: UTF-8 -*-
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def longest_increasing_subsequence(nums: List[int]) -> int:
|
||||
"""
|
||||
最长子上升序列的一种DP解法,从回溯解法转化,思路类似于有限物品的背包问题
|
||||
每一次决策都算出当前可能的lis的长度,重复子问题合并,合并策略是lis的末尾元素最小
|
||||
时间复杂度:O(n^2)
|
||||
空间复杂度:O(n^2),可优化至O(n)
|
||||
|
||||
没leetcode上的参考答案高效,提供另一种思路作为参考
|
||||
https://leetcode.com/problems/longest-increasing-subsequence/solution/
|
||||
:param nums:
|
||||
:return:
|
||||
"""
|
||||
if not nums:
|
||||
return 0
|
||||
|
||||
n = len(nums)
|
||||
# memo[i][j] 表示第i次决策,长度为j的lis的 最小的 末尾元素数值
|
||||
# 每次决策都根据上次决策的所有可能转化,空间上可以类似背包优化为O(n)
|
||||
memo = [[-1] * (n+1) for _ in range(n)]
|
||||
|
||||
# 第一列全赋值为0,表示每次决策都不选任何数
|
||||
for i in range(n):
|
||||
memo[i][0] = 0
|
||||
# 第一次决策选数组中的第一个数
|
||||
memo[0][1] = nums[0]
|
||||
|
||||
for i in range(1, n):
|
||||
for j in range(1, n+1):
|
||||
# case 1: 长度为j的lis在上次决策后存在,nums[i]比长度为j-1的lis末尾元素大
|
||||
if memo[i-1][j] != -1 and nums[i] > memo[i-1][j-1]:
|
||||
memo[i][j] = min(nums[i], memo[i-1][j])
|
||||
|
||||
# case 2: 长度为j的lis在上次决策后存在,nums[i]比长度为j-1的lis末尾元素小/等
|
||||
if memo[i-1][j] != -1 and nums[i] <= memo[i-1][j-1]:
|
||||
memo[i][j] = memo[i-1][j]
|
||||
|
||||
if memo[i-1][j] == -1:
|
||||
# case 3: 长度为j的lis不存在,nums[i]比长度为j-1的lis末尾元素大
|
||||
if nums[i] > memo[i-1][j-1]:
|
||||
memo[i][j] = nums[i]
|
||||
# case 4: 长度为j的lis不存在,nums[i]比长度为j-1的lis末尾元素小/等
|
||||
break
|
||||
|
||||
for i in range(n, -1, -1):
|
||||
if memo[-1][i] != -1:
|
||||
return i
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 要求输入的都是大于0的正整数(可优化至支持任意整数)
|
||||
nums = [2, 9, 3, 6, 5, 1, 7]
|
||||
print(longest_increasing_subsequence(nums))
|
31
python/42_dynamic_programming/min_edit_dist.py
Normal file
31
python/42_dynamic_programming/min_edit_dist.py
Normal file
@ -0,0 +1,31 @@
|
||||
"""
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
def levenshtein_dp(s: str, t: str) -> int:
|
||||
m, n = len(s), len(t)
|
||||
table = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
table[0] = [j for j in range(m + 1)]
|
||||
for i in range(m + 1):
|
||||
table[i][0] = i
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
table[i][j] = min(1 + table[i - 1][j], 1 + table[i][j - 1], int(s[i - 1] != t[j - 1]) + table[i - 1][j - 1])
|
||||
return table[-1][-1]
|
||||
|
||||
|
||||
def common_substring_dp(s: str, t: str) -> int:
|
||||
m, n = len(s), len(t)
|
||||
table = [[0] * (n + 1) for _ in range(m + 1)]
|
||||
for i in range(1, m + 1):
|
||||
for j in range(1, n + 1):
|
||||
table[i][j] = max(table[i - 1][j], table[i][j - 1], int(s[i - 1] == t[j - 1]) + table[i - 1][j - 1])
|
||||
return table[-1][-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
s = "mitcmu"
|
||||
t = "mtacnu"
|
||||
|
||||
print(levenshtein_dp(s, t))
|
||||
print(common_substring_dp(s, t))
|
63
python/43_topological_sorting/topological_sorting.py
Normal file
63
python/43_topological_sorting/topological_sorting.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from itertools import filterfalse
|
||||
|
||||
class Graph:
|
||||
def __init__(self, num_vertices: int):
|
||||
self._num_vertices = num_vertices
|
||||
self._adjacency = [[] for _ in range(num_vertices)]
|
||||
|
||||
def add_edge(self, s: int, t: int) -> None:
|
||||
self._adjacency[s].append(t)
|
||||
|
||||
def tsort_by_kahn(self) -> None:
|
||||
in_degree = [0] * self._num_vertices
|
||||
for v in range(self._num_vertices):
|
||||
if len(self._adjacency[v]):
|
||||
for neighbour in self._adjacency[v]:
|
||||
in_degree[neighbour] += 1
|
||||
q = deque(filterfalse(lambda x: in_degree[x], range(self._num_vertices)))
|
||||
while q:
|
||||
v = q.popleft()
|
||||
print(f"{v} -> ", end="")
|
||||
for neighbour in self._adjacency[v]:
|
||||
in_degree[neighbour] -= 1
|
||||
if not in_degree[neighbour]:
|
||||
q.append(neighbour)
|
||||
print("\b\b\b ")
|
||||
|
||||
def tsort_by_dfs(self) -> None:
|
||||
inverse_adjacency = [[] for _ in range(self._num_vertices)]
|
||||
for v in range(self._num_vertices):
|
||||
if len(self._adjacency[v]):
|
||||
for neighbour in self._adjacency[v]:
|
||||
inverse_adjacency[neighbour].append(v)
|
||||
visited = [False] * self._num_vertices
|
||||
|
||||
def dfs(vertex: int) -> None:
|
||||
if len(inverse_adjacency[vertex]):
|
||||
for v in inverse_adjacency[vertex]:
|
||||
if not visited[v]:
|
||||
visited[v] = True
|
||||
dfs(v)
|
||||
print(f"{vertex} -> ", end="")
|
||||
|
||||
for v in range(self._num_vertices):
|
||||
if not visited[v]:
|
||||
visited[v] = True
|
||||
dfs(v)
|
||||
|
||||
print("\b\b\b ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
dag = Graph(4)
|
||||
dag.add_edge(1, 0)
|
||||
dag.add_edge(2, 1)
|
||||
dag.add_edge(1, 3)
|
||||
dag.tsort_by_kahn()
|
||||
dag.tsort_by_dfs()
|
64
python/44_shortest_path/shortest_path.py
Normal file
64
python/44_shortest_path/shortest_path.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
Dijkstra algorithm
|
||||
|
||||
Author: Wenru Dong
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from queue import PriorityQueue
|
||||
|
||||
@dataclass
|
||||
class Edge:
|
||||
start_id: int
|
||||
end_id: int
|
||||
weight: int
|
||||
|
||||
@dataclass(order=True)
|
||||
class Vertex:
|
||||
distance_to_start = float("inf")
|
||||
vertex_id: int
|
||||
|
||||
class Graph:
|
||||
def __init__(self, num_vertices: int):
|
||||
self._num_vertices = num_vertices
|
||||
self._adjacency = [[] for _ in range(num_vertices)]
|
||||
|
||||
def add_edge(self, from_vertex: int, to_vertex: int, weight: int) -> None:
|
||||
self._adjacency[from_vertex].append(Edge(from_vertex, to_vertex, weight))
|
||||
|
||||
def dijkstra(self, from_vertex: int, to_vertex: int) -> None:
|
||||
vertices = [Vertex(i) for i in range(self._num_vertices)]
|
||||
vertices[from_vertex].distance_to_start = 0
|
||||
visited = [False] * self._num_vertices
|
||||
predecessor = [-1] * self._num_vertices
|
||||
q = PriorityQueue()
|
||||
q.put(vertices[from_vertex])
|
||||
visited[from_vertex] = True
|
||||
while not q.empty():
|
||||
min_vertex = q.get()
|
||||
if min_vertex.vertex_id == to_vertex:
|
||||
break
|
||||
for edge in self._adjacency[min_vertex.vertex_id]:
|
||||
next_vertex = vertices[edge.end_id]
|
||||
if min_vertex.distance_to_start + edge.weight < next_vertex.distance_to_start:
|
||||
next_vertex.distance_to_start = min_vertex.distance_to_start + edge.weight
|
||||
predecessor[next_vertex.vertex_id] = min_vertex.vertex_id
|
||||
if not visited[next_vertex.vertex_id]:
|
||||
q.put(next_vertex)
|
||||
visited[next_vertex.vertex_id] = True
|
||||
|
||||
path = lambda x: path(predecessor[x]) + [str(x)] if from_vertex != x else [str(from_vertex)]
|
||||
print("->".join(path(to_vertex)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
graph = Graph(6)
|
||||
graph.add_edge(0, 1, 10)
|
||||
graph.add_edge(0, 4, 15)
|
||||
graph.add_edge(1, 2, 15)
|
||||
graph.add_edge(1, 3, 2)
|
||||
graph.add_edge(2, 5, 5)
|
||||
graph.add_edge(3, 2, 1)
|
||||
graph.add_edge(3, 5, 12)
|
||||
graph.add_edge(4, 5, 10)
|
||||
graph.dijkstra(0, 5)
|
@ -8,63 +8,63 @@ import scala.util.control.Breaks.{break, breakable}
|
||||
* Author: yangchuz
|
||||
*/
|
||||
object Sorts {
|
||||
def main(args: Array[String]): Unit ={
|
||||
// println(bubbleSort(Array(0, 6, 2, 3, 8, 5, 6, 7), 8).mkString(", "))
|
||||
// println(insertSort(Array(0, 6, 2, 3, 8, 5, 6, 7), 8).mkString(", "))
|
||||
println(selectionSort(Array(0, 6, 2, 3, 8, 5, 6, 7), 8).mkString(", "))
|
||||
}
|
||||
|
||||
def bubbleSort(arr: Array[Int], n:Int): Array[Int] = {
|
||||
val n = arr.length
|
||||
def bubbleSort(items: Array[Int]): Array[Int] = {
|
||||
val length = items.length
|
||||
breakable {
|
||||
for(i <- (n-1) to (1, -1)){
|
||||
var flag = false
|
||||
for(j <- 0 until i){
|
||||
if(arr(j) > arr(j+1)){
|
||||
val tmp = arr(j)
|
||||
arr(j) = arr(j+1)
|
||||
arr(j+1) = tmp
|
||||
flag = true
|
||||
for (i <- Range(0, length)) {
|
||||
var exit = true
|
||||
for (j <- Range(0, length - i - 1)) {
|
||||
if (items(j + 1) < items(j)) {
|
||||
val temp = items(j + 1)
|
||||
items(j + 1) = items(j)
|
||||
items(j) = temp
|
||||
exit = false
|
||||
}
|
||||
}
|
||||
if(!flag){
|
||||
if (exit) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
arr
|
||||
items
|
||||
}
|
||||
|
||||
def insertSort(arr: Array[Int], n:Int): Array[Int] = {
|
||||
for(i <- 1 until n){
|
||||
val tmp = arr(i)
|
||||
breakable{
|
||||
for(j <- (i-1) to (0, -1)){
|
||||
if(tmp < arr(j)){
|
||||
arr(j+1) = arr(j)
|
||||
}else{
|
||||
arr(j+1) = tmp
|
||||
def insertSort(items: Array[Int]): Array[Int] = {
|
||||
val length = items.length
|
||||
for (i <- Range(1, length)) {
|
||||
val value = items(i)
|
||||
var j = i - 1
|
||||
breakable {
|
||||
while (j >= 0) {
|
||||
if (items(j) > value) {
|
||||
items(j + 1) = items(j)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
j -= 1
|
||||
}
|
||||
}
|
||||
items(j + 1) = value
|
||||
}
|
||||
arr
|
||||
items
|
||||
}
|
||||
|
||||
def selectionSort(arr: Array[Int], n:Int): Array[Int] = {
|
||||
for(i <- 0 until n){
|
||||
var min = i
|
||||
for(j <- (i + 1) until n){
|
||||
if(arr(j) < arr(min)){
|
||||
min = j
|
||||
def selectionSort(items: Array[Int]): Array[Int] = {
|
||||
val length = items.length
|
||||
for (i <- Range(0, length)) {
|
||||
var minIndex = i
|
||||
for (j <- Range(i + 1, length)) {
|
||||
if (items(j) < items(minIndex)) {
|
||||
minIndex = j
|
||||
}
|
||||
}
|
||||
|
||||
val tmp = arr(i)
|
||||
arr(i) = arr(min)
|
||||
arr(min) = tmp
|
||||
//put the min value to the front
|
||||
val temp = items(i)
|
||||
items(i) = items(minIndex)
|
||||
items(minIndex) = temp
|
||||
}
|
||||
arr
|
||||
items
|
||||
}
|
||||
}
|
||||
|
60
scala/src/main/scala/ch12_sorts/MergeSort.scala
Normal file
60
scala/src/main/scala/ch12_sorts/MergeSort.scala
Normal file
@ -0,0 +1,60 @@
|
||||
package ch12_sorts
|
||||
|
||||
object MergeSort {
|
||||
|
||||
def mergeSort(items: Array[Int]): Array[Int] = {
|
||||
_mergeSort(items, 0, items.length - 1)
|
||||
items
|
||||
}
|
||||
|
||||
|
||||
private[this] def _mergeSort(items: Array[Int], p: Int, r: Int): Unit = {
|
||||
if (p >= r) {
|
||||
return
|
||||
}
|
||||
|
||||
val q = p + (r - p) / 2
|
||||
_mergeSort(items, p, q)
|
||||
_mergeSort(items, q + 1, r)
|
||||
_merge(items, p, q, r)
|
||||
|
||||
}
|
||||
|
||||
private[this] def _merge(items: Array[Int], p: Int, q: Int, r: Int): Unit = {
|
||||
//start of first half
|
||||
var i = p
|
||||
//start of second half
|
||||
var j = q + 1
|
||||
var k = 0
|
||||
//temp array to hold the data
|
||||
val tempArray = new Array[Int](r - p + 1)
|
||||
while (i <= q && j <= r) {
|
||||
if (items(i) <= items(j)) {
|
||||
tempArray(k) = items(i)
|
||||
i += 1
|
||||
} else {
|
||||
tempArray(k) = items(j)
|
||||
j += 1
|
||||
}
|
||||
k += 1
|
||||
}
|
||||
|
||||
var start = i
|
||||
var end = q
|
||||
|
||||
if (j <= r) {
|
||||
start = j
|
||||
end = r
|
||||
}
|
||||
|
||||
for (n <- start to end) {
|
||||
tempArray(k) = items(n)
|
||||
k += 1
|
||||
}
|
||||
|
||||
//copy tempArray back to items
|
||||
for (n <- 0 to r - p) {
|
||||
items(p + n) = tempArray(n)
|
||||
}
|
||||
}
|
||||
}
|
54
scala/src/main/scala/ch12_sorts/QuickSort.scala
Normal file
54
scala/src/main/scala/ch12_sorts/QuickSort.scala
Normal file
@ -0,0 +1,54 @@
|
||||
package ch12_sorts
|
||||
|
||||
object QuickSort {
|
||||
|
||||
//find the K th smallest element int the array
|
||||
def findKthElement(items: Array[Int], k: Int): Int = {
|
||||
_findKthElement(items, k, 0, items.length - 1)
|
||||
}
|
||||
|
||||
private[this] def _findKthElement(items: Array[Int], k: Int, p: Int, r: Int): Int = {
|
||||
val q = _partition(items, p, r)
|
||||
|
||||
if (k == q + 1) {
|
||||
items(q)
|
||||
} else if (k < q + 1) {
|
||||
_findKthElement(items, k, p, q - 1)
|
||||
} else {
|
||||
_findKthElement(items, k, q + 1, r)
|
||||
}
|
||||
}
|
||||
|
||||
def quickSort(items: Array[Int]): Array[Int] = {
|
||||
_quickSort(items, 0, items.length - 1)
|
||||
items
|
||||
}
|
||||
|
||||
private[this] def _quickSort(items: Array[Int], p: Int, r: Int): Unit = {
|
||||
if (p >= r) {
|
||||
return
|
||||
}
|
||||
val q = _partition(items, p, r)
|
||||
_quickSort(items, p, q - 1)
|
||||
_quickSort(items, q + 1, r)
|
||||
}
|
||||
|
||||
private[this] def _partition(items: Array[Int], p: Int, r: Int): Int = {
|
||||
val pivot = items(r)
|
||||
var i = p
|
||||
for (j <- Range(p, r)) {
|
||||
if (items(j) < pivot) {
|
||||
val temp = items(i)
|
||||
items(i) = items(j)
|
||||
items(j) = temp
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
|
||||
val temp = items(i)
|
||||
items(i) = items(r)
|
||||
items(r) = temp
|
||||
|
||||
i
|
||||
}
|
||||
}
|
@ -1,20 +1,50 @@
|
||||
package ch15_bsearch
|
||||
|
||||
object BSearch {
|
||||
def search(nums: Array[Int], target: Int): Int = {
|
||||
var low = 0
|
||||
var high = nums.length - 1
|
||||
while(low <= high){
|
||||
val mid = low + ((high - low) >> 2)
|
||||
if(nums(mid) > target){
|
||||
high = mid - 1
|
||||
} else if (nums(mid) < target){
|
||||
low = mid + 1
|
||||
} else {
|
||||
return mid
|
||||
}
|
||||
}
|
||||
import scala.math.abs
|
||||
|
||||
return -1
|
||||
object BSearch {
|
||||
|
||||
def search(items: Array[Int], target: Int): Int = {
|
||||
var low = 0
|
||||
var high = items.length - 1
|
||||
while (low <= high) {
|
||||
val mid = low + (high - low) / 2
|
||||
if (items(mid) == target) {
|
||||
return mid
|
||||
} else if (items(mid) > target) {
|
||||
high = mid - 1
|
||||
} else {
|
||||
low = mid + 1
|
||||
}
|
||||
}
|
||||
|
||||
-1
|
||||
}
|
||||
|
||||
def sqrt(x: Double, precision: Double): Double = {
|
||||
|
||||
require(precision > 0, "precision must > 0")
|
||||
require(x > 0, "input value for sqrt must > 0")
|
||||
var low = 0.0
|
||||
var high = x
|
||||
val actualPrecision = precision / 10
|
||||
|
||||
if (x > 0 && x < 1) {
|
||||
low = x
|
||||
high = 1
|
||||
}
|
||||
while (high - low > actualPrecision) {
|
||||
val mid = low + (high - low) / 2
|
||||
if (abs(mid * mid - x) < actualPrecision) {
|
||||
//find it
|
||||
return mid
|
||||
} else if (mid * mid > x) {
|
||||
high = mid
|
||||
} else {
|
||||
low = mid
|
||||
}
|
||||
}
|
||||
throw new IllegalStateException("could not determine the sqrt value for " + x)
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,23 @@
|
||||
package ch15_bsearch
|
||||
|
||||
object BSearchRecursive {
|
||||
def search(nums: Array[Int], target: Int): Int = {
|
||||
return searchInternal(nums, target, 0, nums.length - 1)
|
||||
|
||||
def search(items: Array[Int], target: Int): Int = {
|
||||
_search(items, target, 0, items.length - 1)
|
||||
}
|
||||
|
||||
private[this] def _search(items: Array[Int], target: Int, low: Int, high: Int): Int = {
|
||||
if (low > high) {
|
||||
return -1
|
||||
}
|
||||
|
||||
def searchInternal(nums:Array[Int], target: Int, low: Int, high: Int): Int = {
|
||||
if(low <= high){
|
||||
val mid = low + ((high - low) >> 2)
|
||||
if(nums(mid) > target){
|
||||
searchInternal(nums, target, low, mid - 1)
|
||||
} else if (nums(mid) < target){
|
||||
searchInternal(nums, target, mid + 1, high)
|
||||
} else {
|
||||
return mid
|
||||
}
|
||||
}else{
|
||||
return -1
|
||||
}
|
||||
val mid = low + (high - low) / 2
|
||||
if (items(mid) == target) {
|
||||
mid
|
||||
} else if (items(mid) > target) {
|
||||
_search(items, target, low, mid - 1)
|
||||
} else {
|
||||
_search(items, target, mid + 1, high)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
87
scala/src/main/scala/ch16_bsearch/BSearch.scala
Normal file
87
scala/src/main/scala/ch16_bsearch/BSearch.scala
Normal file
@ -0,0 +1,87 @@
|
||||
package ch16_bsearch
|
||||
|
||||
object BSearch {
|
||||
|
||||
//find the first index of given value
|
||||
//-1 if not found
|
||||
def findFirstValue(items: Array[Int], target: Int): Int = {
|
||||
require(items.length > 0, "given array is empty")
|
||||
var low = 0
|
||||
var high = items.length - 1
|
||||
while (low <= high) {
|
||||
val mid = low + (high - low) / 2
|
||||
if (items(mid) > target) {
|
||||
high = mid - 1
|
||||
} else if (items(mid) < target) {
|
||||
low = mid + 1
|
||||
} else {
|
||||
//find the value in the array
|
||||
if (mid == 0 || items(mid - 1) != target) {
|
||||
return mid
|
||||
} else {
|
||||
high = mid - 1
|
||||
}
|
||||
}
|
||||
}
|
||||
-1
|
||||
}
|
||||
|
||||
def findLastValue(items: Array[Int], target: Int): Int = {
|
||||
var low = 0
|
||||
var high = items.length - 1
|
||||
while (low <= high) {
|
||||
val mid = low + (high - low) / 2
|
||||
if (items(mid) > target) {
|
||||
high = mid - 1
|
||||
} else if (items(mid) < target) {
|
||||
low = mid + 1
|
||||
} else {
|
||||
//find the target value
|
||||
if (mid == items.length - 1 || items(mid + 1) != target) {
|
||||
return mid
|
||||
} else {
|
||||
low = mid + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
-1
|
||||
}
|
||||
|
||||
def findFirstGreaterThan(items: Array[Int], target: Int): Int = {
|
||||
var low = 0
|
||||
var high = items.length
|
||||
while (low <= high) {
|
||||
val mid = low + (high - low) / 2
|
||||
if (items(mid) >= target) {
|
||||
//find the range
|
||||
if (mid == 0 || items(mid - 1) < target) {
|
||||
return mid
|
||||
} else {
|
||||
high = mid - 1
|
||||
}
|
||||
} else {
|
||||
low = mid + 1
|
||||
}
|
||||
}
|
||||
-1
|
||||
}
|
||||
|
||||
def findLastSmallerThan(items: Array[Int], target: Int): Int = {
|
||||
var low = 0
|
||||
var high = items.length - 1
|
||||
while (low <= high) {
|
||||
var mid = low + (high - low) / 2
|
||||
if (items(mid) <= target) {
|
||||
//find the range
|
||||
if (mid == items.length - 1 || items(mid + 1) > target) {
|
||||
return mid
|
||||
} else {
|
||||
low = mid + 1
|
||||
}
|
||||
} else {
|
||||
high = mid - 1
|
||||
}
|
||||
}
|
||||
-1
|
||||
}
|
||||
}
|
99
scala/src/main/scala/ch17_skip_list/SkipList.scala
Normal file
99
scala/src/main/scala/ch17_skip_list/SkipList.scala
Normal file
@ -0,0 +1,99 @@
|
||||
package ch17_skip_list
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class Node(var data: Int, var forwards: Array[Node], var maxLevel: Int)
|
||||
|
||||
class SkipList(var head: Node, var skipListLevel: Int) {
|
||||
|
||||
def this() = this(new Node(-1, new Array[Node](16), 0), 1)
|
||||
|
||||
val MAX_LEVEL = 16
|
||||
val random = new Random()
|
||||
|
||||
def find(value: Int): Option[Node] = {
|
||||
var p = head
|
||||
for (i <- skipListLevel - 1 to 0 by -1) {
|
||||
while (p.forwards(i) != null && p.forwards(i).data < value) {
|
||||
p = p.forwards(i)
|
||||
}
|
||||
}
|
||||
if (p.forwards(0) != null && p.forwards(0).data == value) {
|
||||
Some(p.forwards(0))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
def insert(value: Int): Unit = {
|
||||
//init the new node
|
||||
val level = randomLevel()
|
||||
val newNode = new Node(value, new Array[Node](level), level)
|
||||
|
||||
//use updtes array to record all nodes in all level before the inserted node
|
||||
val updates: Array[Node] = new Array[Node](level)
|
||||
var p = head
|
||||
for (i <- level - 1 to 0 by -1) {
|
||||
while (p.forwards(i) != null && p.forwards(i).data < value) {
|
||||
p = p.forwards(i)
|
||||
}
|
||||
updates(i) = p
|
||||
}
|
||||
|
||||
for (i <- Range(0, level)) {
|
||||
newNode.forwards(i) = updates(i).forwards(i)
|
||||
updates(i).forwards(i) = newNode
|
||||
}
|
||||
|
||||
if (level > skipListLevel) {
|
||||
skipListLevel = level
|
||||
}
|
||||
}
|
||||
|
||||
def delete(value: Int): Unit = {
|
||||
var p = head
|
||||
val updates: Array[Node] = new Array[Node](skipListLevel)
|
||||
|
||||
//try to locate the given node with the value
|
||||
for (i <- skipListLevel - 1 to 0 by -1) {
|
||||
while (p.forwards(i) != null && p.forwards(i).data < value) {
|
||||
p = p.forwards(i)
|
||||
}
|
||||
updates(i) = p
|
||||
}
|
||||
|
||||
if (p.forwards(0) != null && p.forwards(0).data == value) {
|
||||
//find the value node, start to delete the node from the skip list
|
||||
for (i <- skipListLevel - 1 to 0 by -1) {
|
||||
if (updates(i).forwards(i) != null && updates(i).forwards(i).data == value) {
|
||||
updates(i).forwards(i) = updates(i).forwards(i).forwards(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
def randomLevel(): Int = {
|
||||
var level = 1
|
||||
for (i <- Range(1, MAX_LEVEL)) {
|
||||
if (random.nextInt() % 2 == 1) {
|
||||
level += 1
|
||||
}
|
||||
}
|
||||
|
||||
level
|
||||
}
|
||||
|
||||
def mkString(): String = {
|
||||
val builder = new StringBuilder
|
||||
var p = head
|
||||
while (p.forwards(0) != null) {
|
||||
p = p.forwards(0)
|
||||
builder.append(p.data)
|
||||
}
|
||||
|
||||
builder.mkString
|
||||
}
|
||||
}
|
||||
|
||||
|
149
scala/src/main/scala/ch20_linked_hash_map/LRUCache.scala
Normal file
149
scala/src/main/scala/ch20_linked_hash_map/LRUCache.scala
Normal file
@ -0,0 +1,149 @@
|
||||
package ch20_linked_hash_map
|
||||
|
||||
class Node[K, V](var key: Option[K], var data: Option[V], var prev: Option[Node[K, V]], var next: Option[Node[K, V]],
|
||||
var hNext: Option[Node[K, V]]) {
|
||||
|
||||
def this(key: Option[K], data: Option[V]) = this(key, data, None, None, None)
|
||||
}
|
||||
|
||||
/**
|
||||
* LRU cache - https://leetcode.com/problems/lru-cache/ see unit test from LRUCacheTest
|
||||
*
|
||||
* @author email2liyang@gmail.com
|
||||
*/
|
||||
class LRUCache[K, V](var head: Node[K, V], var tail: Node[K, V], var table: Array[Node[K, V]],
|
||||
capacity: Int = 1000, var elementCount: Int = 0) {
|
||||
|
||||
head.next = Some(tail)
|
||||
tail.prev = Some(head)
|
||||
|
||||
def this(capacity: Int) = this(new Node(None, None), new Node(None, None), new Array[Node[K, V]](capacity), capacity)
|
||||
|
||||
def get(key: K): Option[V] = {
|
||||
val index = indexFor(key.hashCode())
|
||||
var hNode = table(index)
|
||||
if (hNode == null) {
|
||||
None
|
||||
} else {
|
||||
while (!hNode.key.get.equals(key) && hNode.hNext.isDefined) {
|
||||
hNode = hNode.hNext.get
|
||||
}
|
||||
if (hNode.key.get.equals(key)) {
|
||||
//move this to the end of the linked list
|
||||
moveHNodeToTail(hNode)
|
||||
hNode.data
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//put data into the linked hash map
|
||||
//1: check if the data exist in the linked list
|
||||
//2: if it's not exist , append it in the linked list
|
||||
//3: if it's exist in the list, move it to the tail of the linked list
|
||||
//4: return old value if it's exist
|
||||
def put(key: K, value: V): Option[V] = {
|
||||
|
||||
if (elementCount == capacity) {
|
||||
deleteLRUElement()
|
||||
}
|
||||
|
||||
val node = new Node(Some(key), Some(value))
|
||||
val index = indexFor(key.hashCode())
|
||||
var hNode = table(index)
|
||||
var result: Option[V] = None
|
||||
if (hNode == null) {
|
||||
//if it's not exist , append it in the linked list
|
||||
node.prev = tail.prev
|
||||
node.next = Some(tail)
|
||||
tail.prev.get.next = Some(node)
|
||||
tail.prev = Some(node)
|
||||
table(index) = node
|
||||
elementCount += 1
|
||||
} else {
|
||||
//we find a key conflict in the hash table
|
||||
//start to loop the hNode to match the key
|
||||
while (!hNode.key.get.equals(key) && hNode.hNext.isDefined) {
|
||||
hNode = hNode.hNext.get
|
||||
}
|
||||
|
||||
if (hNode.key.get.equals(key)) {
|
||||
//find the old data from the hash table
|
||||
result = hNode.data
|
||||
hNode.data = Some(value)
|
||||
//move the node to the tail of the linked list
|
||||
moveHNodeToTail(hNode)
|
||||
//hNext pointer stay untouched
|
||||
} else {
|
||||
//could not find the old data
|
||||
//put the new node into the tail of the linked list
|
||||
node.prev = tail.prev
|
||||
node.next = Some(tail)
|
||||
tail.prev.get.next = Some(node)
|
||||
tail.prev = Some(node)
|
||||
|
||||
//put it the tail of the hash table's list
|
||||
//iterator to the end of hNode
|
||||
while (hNode.hNext.isDefined) {
|
||||
hNode = hNode.hNext.get
|
||||
}
|
||||
hNode.hNext = Some(node)
|
||||
elementCount += 1
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
private[this] def moveHNodeToTail(hNode: Node[K, V]) = {
|
||||
hNode.prev.get.next = hNode.next
|
||||
hNode.next.get.prev = hNode.prev
|
||||
hNode.prev = tail.prev
|
||||
hNode.next = Some(tail)
|
||||
tail.prev.get.next = Some(hNode)
|
||||
tail.prev = Some(hNode)
|
||||
}
|
||||
|
||||
private[this] def deleteLRUElement(): Unit = {
|
||||
//cache is full, start to delete element from the head
|
||||
val node = head.next.get
|
||||
|
||||
//delete it from head
|
||||
node.next.get.prev = Some(head)
|
||||
head.next = node.next
|
||||
|
||||
//deal with hNext in the table
|
||||
val index = indexFor(node.key.get.hashCode())
|
||||
var hNode = table(index)
|
||||
//deal with first element in the hash table
|
||||
if (hNode.key.get.equals(node.key.get)) {
|
||||
hNode.hNext match {
|
||||
case Some(n) => table(index) = n
|
||||
case None => table(index) = null
|
||||
}
|
||||
} else {
|
||||
//deal with not first element in the hash table
|
||||
var hNodePrev = hNode
|
||||
hNode = hNode.next.get
|
||||
while (!hNode.key.get.equals(node.key.get)) {
|
||||
hNode = hNode.next.get
|
||||
hNodePrev = hNodePrev.next.get
|
||||
}
|
||||
//now hNodePrev is the previous hNode in the hashtable
|
||||
//remove the hNode
|
||||
hNodePrev.next = hNode.next
|
||||
|
||||
hNode.next match {
|
||||
case Some(n) => n.prev = Some(hNodePrev)
|
||||
case None =>
|
||||
}
|
||||
}
|
||||
|
||||
elementCount -= 1
|
||||
}
|
||||
|
||||
private[this] def indexFor(hash: Int): Int = {
|
||||
hash % table.length
|
||||
}
|
||||
}
|
64
scala/src/test/scala/ch11_sorts/SortsTest.scala
Normal file
64
scala/src/test/scala/ch11_sorts/SortsTest.scala
Normal file
@ -0,0 +1,64 @@
|
||||
package ch11_sorts
|
||||
|
||||
import ch12_sorts.{MergeSort, QuickSort}
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class SortsTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "SortsTest in ch11"
|
||||
|
||||
it should "bubbleSort int arrays" in {
|
||||
var array = Array(4, 5, 6, 3, 2, 1)
|
||||
array = Sorts.bubbleSort(array)
|
||||
array.mkString("") should equal("123456")
|
||||
|
||||
array = Array(4)
|
||||
array = Sorts.bubbleSort(array)
|
||||
array.mkString("") should equal("4")
|
||||
}
|
||||
|
||||
it should "insertSort int arrays" in {
|
||||
var array = Array(4, 5, 6, 1, 3, 2)
|
||||
array = Sorts.insertSort(array)
|
||||
array.mkString("") should equal("123456")
|
||||
|
||||
array = Array(4)
|
||||
array = Sorts.insertSort(array)
|
||||
array.mkString("") should equal("4")
|
||||
}
|
||||
|
||||
it should "selectionSort int arrays" in {
|
||||
var array = Array(4, 5, 6, 1, 3, 2)
|
||||
array = Sorts.insertSort(array)
|
||||
array.mkString("") should equal("123456")
|
||||
|
||||
array = Array(4)
|
||||
array = Sorts.insertSort(array)
|
||||
array.mkString("") should equal("4")
|
||||
}
|
||||
|
||||
|
||||
it should "compare the sort algo" in {
|
||||
val length = 50000
|
||||
val array = new Array[Int](length)
|
||||
val rnd = new Random()
|
||||
for (i <- Range(0, length)) {
|
||||
array(i) = rnd.nextInt()
|
||||
}
|
||||
|
||||
timed("bubbleSort", Sorts.bubbleSort, array.clone())
|
||||
timed("insertSort", Sorts.insertSort, array.clone())
|
||||
timed("selectionSort", Sorts.selectionSort, array.clone())
|
||||
timed("mergeSort", MergeSort.mergeSort, array.clone())
|
||||
timed("quickSort", QuickSort.quickSort, array.clone())
|
||||
}
|
||||
|
||||
def reportElapsed(name: String, time: Long): Unit = println(name + " takes in " + time + "ms")
|
||||
|
||||
def timed(name: String, f: (Array[Int]) => Unit, array: Array[Int]): Unit = {
|
||||
val start = System.currentTimeMillis()
|
||||
try f(array) finally reportElapsed(name, System.currentTimeMillis - start)
|
||||
}
|
||||
}
|
22
scala/src/test/scala/ch12_sorts/MergeSortTest.scala
Normal file
22
scala/src/test/scala/ch12_sorts/MergeSortTest.scala
Normal file
@ -0,0 +1,22 @@
|
||||
package ch12_sorts
|
||||
|
||||
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
class MergeSortTest extends FlatSpec with Matchers {
|
||||
behavior of "SortsTest in ch12"
|
||||
|
||||
it should "mergeSort int arrays" in {
|
||||
var array = Array(4, 5, 6, 3, 2, 1)
|
||||
array = MergeSort.mergeSort(array)
|
||||
array.mkString("") should equal("123456")
|
||||
|
||||
array = Array(4)
|
||||
array = MergeSort.mergeSort(array)
|
||||
array.mkString("") should equal("4")
|
||||
|
||||
array = Array(4, 2)
|
||||
array = MergeSort.mergeSort(array)
|
||||
array.mkString("") should equal("24")
|
||||
}
|
||||
}
|
31
scala/src/test/scala/ch12_sorts/QuickSortTest.scala
Normal file
31
scala/src/test/scala/ch12_sorts/QuickSortTest.scala
Normal file
@ -0,0 +1,31 @@
|
||||
package ch12_sorts
|
||||
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
class QuickSortTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "QuickSortTest"
|
||||
|
||||
it should "quickSort" in {
|
||||
var array = Array(4, 5, 6, 3, 2, 1)
|
||||
array = QuickSort.quickSort(array)
|
||||
array.mkString("") should equal("123456")
|
||||
|
||||
array = Array(4)
|
||||
array = QuickSort.quickSort(array)
|
||||
array.mkString("") should equal("4")
|
||||
|
||||
array = Array(4, 2)
|
||||
array = QuickSort.quickSort(array)
|
||||
array.mkString("") should equal("24")
|
||||
}
|
||||
|
||||
it should "find the Kth element in the array" in {
|
||||
val array = Array(4, 2, 5, 12, 3)
|
||||
|
||||
QuickSort.findKthElement(array, 3) should equal(4)
|
||||
QuickSort.findKthElement(array, 5) should equal(12)
|
||||
QuickSort.findKthElement(array, 1) should equal(2)
|
||||
}
|
||||
|
||||
}
|
25
scala/src/test/scala/ch15_bsearch/BSearchRecursiveTest.scala
Normal file
25
scala/src/test/scala/ch15_bsearch/BSearchRecursiveTest.scala
Normal file
@ -0,0 +1,25 @@
|
||||
package ch15_bsearch
|
||||
|
||||
import ch12_sorts.QuickSort
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class BSearchRecursiveTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "BSearchRecursiveTest"
|
||||
|
||||
it should "search with exist value" in {
|
||||
val length = 50000
|
||||
val array = new Array[Int](length)
|
||||
val rnd = new Random()
|
||||
for (i <- Range(0, length)) {
|
||||
array(i) = rnd.nextInt()
|
||||
}
|
||||
|
||||
val target = array(2698)
|
||||
|
||||
BSearchRecursive.search(QuickSort.quickSort(array), target) should be > -1
|
||||
}
|
||||
|
||||
}
|
37
scala/src/test/scala/ch15_bsearch/BSearchTest.scala
Normal file
37
scala/src/test/scala/ch15_bsearch/BSearchTest.scala
Normal file
@ -0,0 +1,37 @@
|
||||
package ch15_bsearch
|
||||
|
||||
import ch12_sorts.QuickSort
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class BSearchTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "BSearchTest"
|
||||
|
||||
it should "search with exist value" in {
|
||||
val length = 50000
|
||||
val array = new Array[Int](length)
|
||||
val rnd = new Random()
|
||||
for (i <- Range(0, length)) {
|
||||
array(i) = rnd.nextInt()
|
||||
}
|
||||
|
||||
val target = array(2698)
|
||||
|
||||
BSearch.search(QuickSort.quickSort(array), target) should be > -1
|
||||
}
|
||||
|
||||
it should "calculate sqrt value -1 " in {
|
||||
val x = 4
|
||||
val precision = 0.000001
|
||||
BSearch.sqrt(x, precision) should equal(2.0)
|
||||
}
|
||||
|
||||
it should "calculate sqrt value -2 " in {
|
||||
val x = 0.04
|
||||
val precision = 0.000001
|
||||
BSearch.sqrt(x, precision) should equal(0.2 +- precision)
|
||||
}
|
||||
|
||||
}
|
31
scala/src/test/scala/ch16_bsearch/BSearchTest.scala
Normal file
31
scala/src/test/scala/ch16_bsearch/BSearchTest.scala
Normal file
@ -0,0 +1,31 @@
|
||||
package ch16_bsearch
|
||||
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
class BSearchTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "BSearchTest"
|
||||
|
||||
it should "findFirstValue" in {
|
||||
val items = Array(1, 3, 4, 5, 6, 8, 8, 8, 11, 18)
|
||||
BSearch.findFirstValue(items, 8) should equal(5)
|
||||
}
|
||||
|
||||
it should "findLastValue" in {
|
||||
val items = Array(1, 3, 4, 5, 6, 8, 8, 8, 11, 18)
|
||||
BSearch.findLastValue(items, 8) should equal(7)
|
||||
}
|
||||
|
||||
it should "findFirstGreaterThan" in {
|
||||
val items = Array(1, 3, 4, 5, 6, 8, 8, 8, 11, 18)
|
||||
BSearch.findFirstGreaterThan(items, 2) should equal(1)
|
||||
BSearch.findFirstGreaterThan(items, 8) should equal(5)
|
||||
}
|
||||
|
||||
it should "findLastSmallerThan" in {
|
||||
val items = Array(1, 3, 4, 5, 6, 8, 8, 8, 11, 18)
|
||||
BSearch.findLastSmallerThan(items, 2) should equal(0)
|
||||
BSearch.findLastSmallerThan(items, 8) should equal(7)
|
||||
}
|
||||
|
||||
}
|
44
scala/src/test/scala/ch17_skip_list/SkipListTest.scala
Normal file
44
scala/src/test/scala/ch17_skip_list/SkipListTest.scala
Normal file
@ -0,0 +1,44 @@
|
||||
package ch17_skip_list
|
||||
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
import scala.util.Random
|
||||
|
||||
class SkipListTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "SkipListTest"
|
||||
|
||||
it should "insert skip list" in {
|
||||
val list = new SkipList()
|
||||
for (i <- Range(0, 10)) {
|
||||
list.insert(i)
|
||||
}
|
||||
|
||||
list.mkString() should equal("0123456789")
|
||||
}
|
||||
|
||||
it should "delete skip list" in {
|
||||
val list = new SkipList()
|
||||
for (i <- Range(0, 10)) {
|
||||
list.insert(i)
|
||||
}
|
||||
|
||||
list.delete(5)
|
||||
list.mkString() should equal("012346789")
|
||||
}
|
||||
|
||||
it should "find value in skip list" in {
|
||||
val list = new SkipList()
|
||||
val length = 5000
|
||||
val array = new Array[Int](length)
|
||||
val rnd = new Random()
|
||||
for (i <- Range(0, length)) {
|
||||
array(i) = rnd.nextInt(length)
|
||||
list.insert(array(i))
|
||||
}
|
||||
|
||||
assert(list.find(array(rnd.nextInt(length - 1))).isDefined)
|
||||
assert(list.find(array(rnd.nextInt(length - 1)) + length + 1).isEmpty)
|
||||
|
||||
}
|
||||
}
|
22
scala/src/test/scala/ch20_linked_hash_map/LRUCacheTest.scala
Normal file
22
scala/src/test/scala/ch20_linked_hash_map/LRUCacheTest.scala
Normal file
@ -0,0 +1,22 @@
|
||||
package ch20_linked_hash_map
|
||||
|
||||
import org.scalatest.{FlatSpec, Matchers}
|
||||
|
||||
class LRUCacheTest extends FlatSpec with Matchers {
|
||||
|
||||
behavior of "LRUCacheTest"
|
||||
|
||||
it should "put data and get back" in {
|
||||
val cache = new LRUCache[Int, Int](2)
|
||||
cache.put(1, 1)
|
||||
cache.put(2, 2)
|
||||
cache.get(1) should equal(Some(1)) // returns 1
|
||||
cache.put(3, 3) // evicts key 2
|
||||
cache.get(2) should equal(None) //should not find
|
||||
cache.put(4, 4) // evicts key 1
|
||||
cache.get(1) should equal(None) //should not find
|
||||
cache.get(3) should equal(Some(3)) // returns 3
|
||||
cache.get(4) should equal(Some(4)) // returns 4
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user