fork download
  1. # @file tree.007.py
  2. # @ingroup experimental
  3. # Recursive red & black tree.
  4. # @date 01/05/2023
  5.  
  6. class Color:
  7. RED = 0
  8. BLACK = 1
  9. DOUBLE_BLACK = 2
  10.  
  11. class Node:
  12. def __init__(self, data):
  13. self.child = [None] * 2
  14. self.color = Color.RED
  15. self.data = data
  16.  
  17. def __setitem__(self, i, v):
  18. self.child[i] = v
  19.  
  20. def __getitem__(self, i):
  21. return self.child[i]
  22.  
  23. def is_red(x):
  24. return x and x.color == Color.RED
  25.  
  26. def is_double_black(x):
  27. # Context dependent.
  28. return not x or x.color == Color.DOUBLE_BLACK
  29.  
  30. def try_setcolor(x, v):
  31. if not x:
  32. return False
  33. x.color = v
  34. return True
  35.  
  36. def rotate(x, R):
  37. L = 1-R
  38. y = x[L]
  39. x[L] = y[R]
  40. y[R] = x
  41. return y
  42.  
  43. class Tree:
  44. def __init__(self):
  45. self.root = None
  46.  
  47. ## Insert. ##
  48.  
  49. def insert(self, key):
  50. self.root = Tree._insert(self.root, key)
  51. self.root.color = Color.BLACK
  52.  
  53. def _insert(root, key):
  54. if not root:
  55. return Node(key)
  56. elif key < root.data:
  57. root[0] = Tree._insert(root[0], key)
  58. elif key > root.data:
  59. root[1] = Tree._insert(root[1], key)
  60. else:
  61. raise KeyError('Key already exists!')
  62. return Tree._insert_balance(root)
  63.  
  64. def _insert_balance(root):
  65. if Node.is_red(root):
  66. pass
  67. elif Node.is_red(root[0]):
  68. root = Tree._insert_balance_lower(root, 0)
  69. elif Node.is_red(root[1]):
  70. root = Tree._insert_balance_lower(root, 1)
  71. return root
  72.  
  73. def _insert_balance_lower(root, R):
  74. L = 1-R
  75. if Node.is_red(root[R][L]):
  76. # 4-node normalize (1).
  77. root[R] = Node.rotate(root[R], R)
  78. if Node.is_red(root[R][R]):
  79. # 4-node normalize (2).
  80. root[R].color = Color.BLACK
  81. root.color = Color.RED
  82. root = Node.rotate(root, L)
  83. if Node.is_red(root[L]):
  84. # 4-node split.
  85. root[L].color = Color.BLACK
  86. root[R].color = Color.BLACK
  87. root.color = Color.RED
  88. return root
  89.  
  90. ## Remove. ##
  91.  
  92. def remove(self, key):
  93. self.root = Tree._remove(self.root, key)
  94. Node.try_setcolor(self.root, Color.BLACK)
  95.  
  96. def _remove(root, key):
  97. if not root:
  98. raise KeyError('Key does not exist!')
  99. elif key < root.data:
  100. root[0] = Tree._remove(root[0], key)
  101. elif key > root.data:
  102. root[1] = Tree._remove(root[1], key)
  103. elif root[1]:
  104. # Interior node; delete inorder successor.
  105. root[1] = Tree._delete_minimum(root[1], root)
  106. else:
  107. return Tree._delete_node(root)
  108. return Tree._remove_balance(root)
  109.  
  110. def _delete_minimum(root, interior):
  111. if root[0]:
  112. root[0] = Tree._delete_minimum(root[0], interior)
  113. else:
  114. interior.data = root.data
  115. return Tree._delete_node(root)
  116. return Tree._remove_balance(root)
  117.  
  118. def _delete_node(root):
  119. if Node.try_setcolor(root[0], Color.BLACK):
  120. return root[0]
  121. if Node.try_setcolor(root[1], Color.BLACK):
  122. return root[1]
  123. return None
  124.  
  125. def _remove_balance(root):
  126. is_short = False
  127. if not (root[0] or root[1]):
  128. # Root is terminal node.
  129. pass
  130. elif Node.is_double_black(root[0]):
  131. Node.try_setcolor(root[0], Color.BLACK)
  132. root, is_short = Tree._remove_balance_lower(root, 0)
  133. elif Node.is_double_black(root[1]):
  134. Node.try_setcolor(root[1], Color.BLACK)
  135. root, is_short = Tree._remove_balance_lower(root, 1)
  136. if is_short:
  137. # Signal parent.
  138. root.color = Color.DOUBLE_BLACK
  139. return root
  140.  
  141. def _remove_balance_lower(root, R):
  142. L = 1-R
  143. if Node.is_red(root[L]):
  144. # 3-node parent; sibling is 'far' middle.
  145. root[L].color = Color.BLACK
  146. root.color = Color.RED
  147. root = Node.rotate(root, R)
  148. root[R], _ = Tree._remove_balance_close(root[R], R, L)
  149. return root, False
  150. return Tree._remove_balance_close(root, R, L)
  151.  
  152. def _remove_balance_close(root, R, L):
  153. is_short = False
  154. if Node.is_red(root[L][R]):
  155. # Borrow sibling (1).
  156. root[L][R].color = Color.BLACK
  157. root[L].color = Color.RED
  158. root[L] = Node.rotate(root[L], L)
  159. if Node.is_red(root[L][L]):
  160. # Borrow sibling (2).
  161. root[L][L].color = Color.BLACK
  162. root[L].color = root.color
  163. root.color = Color.BLACK
  164. root = Node.rotate(root, R)
  165. else:
  166. # Borrow parent; shorten sibling.
  167. is_short = not Node.is_red(root)
  168. root[L].color = Color.RED
  169. root.color = Color.BLACK
  170. return root, is_short
  171.  
  172. ## Utility. ##
  173.  
  174. def inorder(self):
  175. return Tree._inorder(self.root)
  176.  
  177. def _inorder(root):
  178. if root:
  179. yield from Tree._inorder(root[0])
  180. yield root.data
  181. yield from Tree._inorder(root[1])
  182.  
  183. def levelorder(self):
  184. q = [self.root]
  185. while q.count(None) != len(q):
  186. nodes = q
  187. q = []
  188. r = []
  189. for node in nodes:
  190. assert node, 'Tree unbalanced!'
  191. assert node.color == Color.BLACK, 'Color violation!'
  192. if Node.is_red(node[0]):
  193. # 3-node; left-leaning.
  194. r.append([node[0].data, node.data])
  195. q.extend((node[0][0], node[0][1], node[1]))
  196. elif Node.is_red(node[1]):
  197. # 3-node; right-leaning.
  198. r.append([node.data, node[1].data])
  199. q.extend((node[0], node[1][0], node[1][1]))
  200. else:
  201. # 2-node.
  202. r.append([node.data])
  203. q.extend((node[0], node[1]))
  204. yield r
  205.  
  206. ### Testing. ###
  207.  
  208. import random
  209.  
  210. def pretty_print_tree(t):
  211. lines = map(lambda x: ''.join(map(str, x)), t.levelorder())
  212. print('>', next(lines, None))
  213. for line in lines:
  214. print(' ', line)
  215.  
  216. n = 2**3-1
  217. a = list(range(1, 1+n))
  218. t = Tree()
  219.  
  220. b = random.sample(a, k=n)
  221. print('Insert:', b)
  222. for i in b:
  223. t.insert(i)
  224. pretty_print_tree(t)
  225.  
  226. b = list(t.inorder())
  227. assert a == b, 'Bad order!'
  228.  
  229. b = random.sample(a, k=n)
  230. print('Remove:', b)
  231. for i in b:
  232. t.remove(i)
  233. pretty_print_tree(t)
Success #stdin #stdout 0.11s 14340KB
stdin
Standard input is empty
stdout
Insert: [4, 5, 2, 6, 7, 3, 1]
> [4]
> [4, 5]
> [4]
  [2][5]
> [4]
  [2][5, 6]
> [4, 6]
  [2][5][7]
> [4, 6]
  [2, 3][5][7]
> [4]
  [2][6]
  [1][3][5][7]
Remove: [4, 7, 6, 3, 5, 2, 1]
> [2, 5]
  [1][3][6, 7]
> [2, 5]
  [1][3][6]
> [2]
  [1][3, 5]
> [2]
  [1][5]
> [1, 2]
> [1]
> None