# @file tree.007.py
# @ingroup experimental
# Recursive red & black tree.
# @date 01/05/2023
class Color:
RED = 0
BLACK = 1
DOUBLE_BLACK = 2
class Node:
def __init__(self, data):
self.child = [None] * 2
self.color = Color.RED
self.data = data
def __setitem__(self, i, v):
self.child[i] = v
def __getitem__(self, i):
return self.child[i]
def is_red(x):
return x and x.color == Color.RED
def is_double_black(x):
# Context dependent.
return not x or x.color == Color.DOUBLE_BLACK
def try_setcolor(x, v):
if not x:
return False
x.color = v
return True
def rotate(x, R):
L = 1-R
y = x[L]
x[L] = y[R]
y[R] = x
return y
class Tree:
def __init__(self):
self.root = None
## Insert. ##
def insert(self, key):
self.root = Tree._insert(self.root, key)
self.root.color = Color.BLACK
def _insert(root, key):
if not root:
return Node(key)
elif key < root.data:
root[0] = Tree._insert(root[0], key)
elif key > root.data:
root[1] = Tree._insert(root[1], key)
else:
raise KeyError('Key already exists!')
return Tree._insert_balance(root)
def _insert_balance(root):
if Node.is_red(root):
pass
elif Node.is_red(root[0]):
root = Tree._insert_balance_lower(root, 0)
elif Node.is_red(root[1]):
root = Tree._insert_balance_lower(root, 1)
return root
def _insert_balance_lower(root, R):
L = 1-R
if Node.is_red(root[R][L]):
# 4-node normalize (1).
root[R] = Node.rotate(root[R], R)
if Node.is_red(root[R][R]):
# 4-node normalize (2).
root[R].color = Color.BLACK
root.color = Color.RED
root = Node.rotate(root, L)
if Node.is_red(root[L]):
# 4-node split.
root[L].color = Color.BLACK
root[R].color = Color.BLACK
root.color = Color.RED
return root
## Remove. ##
def remove(self, key):
self.root = Tree._remove(self.root, key)
Node.try_setcolor(self.root, Color.BLACK)
def _remove(root, key):
if not root:
raise KeyError('Key does not exist!')
elif key < root.data:
root[0] = Tree._remove(root[0], key)
elif key > root.data:
root[1] = Tree._remove(root[1], key)
elif root[1]:
# Interior node; delete inorder successor.
root[1] = Tree._delete_minimum(root[1], root)
else:
return Tree._delete_node(root)
return Tree._remove_balance(root)
def _delete_minimum(root, interior):
if root[0]:
root[0] = Tree._delete_minimum(root[0], interior)
else:
interior.data = root.data
return Tree._delete_node(root)
return Tree._remove_balance(root)
def _delete_node(root):
if Node.try_setcolor(root[0], Color.BLACK):
return root[0]
if Node.try_setcolor(root[1], Color.BLACK):
return root[1]
return None
def _remove_balance(root):
is_short = False
if not (root[0] or root[1]):
# Root is terminal node.
pass
elif Node.is_double_black(root[0]):
Node.try_setcolor(root[0], Color.BLACK)
root, is_short = Tree._remove_balance_lower(root, 0)
elif Node.is_double_black(root[1]):
Node.try_setcolor(root[1], Color.BLACK)
root, is_short = Tree._remove_balance_lower(root, 1)
if is_short:
# Signal parent.
root.color = Color.DOUBLE_BLACK
return root
def _remove_balance_lower(root, R):
L = 1-R
if Node.is_red(root[L]):
# 3-node parent; sibling is 'far' middle.
root[L].color = Color.BLACK
root.color = Color.RED
root = Node.rotate(root, R)
root[R], _ = Tree._remove_balance_close(root[R], R, L)
return root, False
return Tree._remove_balance_close(root, R, L)
def _remove_balance_close(root, R, L):
is_short = False
if Node.is_red(root[L][R]):
# Borrow sibling (1).
root[L][R].color = Color.BLACK
root[L].color = Color.RED
root[L] = Node.rotate(root[L], L)
if Node.is_red(root[L][L]):
# Borrow sibling (2).
root[L][L].color = Color.BLACK
root[L].color = root.color
root.color = Color.BLACK
root = Node.rotate(root, R)
else:
# Borrow parent; shorten sibling.
is_short = not Node.is_red(root)
root[L].color = Color.RED
root.color = Color.BLACK
return root, is_short
## Utility. ##
def inorder(self):
return Tree._inorder(self.root)
def _inorder(root):
if root:
yield from Tree._inorder(root[0])
yield root.data
yield from Tree._inorder(root[1])
def levelorder(self):
q = [self.root]
while q.count(None) != len(q):
nodes = q
q = []
r = []
for node in nodes:
assert node, 'Tree unbalanced!'
assert node.color == Color.BLACK, 'Color violation!'
if Node.is_red(node[0]):
# 3-node; left-leaning.
r.append([node[0].data, node.data])
q.extend((node[0][0], node[0][1], node[1]))
elif Node.is_red(node[1]):
# 3-node; right-leaning.
r.append([node.data, node[1].data])
q.extend((node[0], node[1][0], node[1][1]))
else:
# 2-node.
r.append([node.data])
q.extend((node[0], node[1]))
yield r
### Testing. ###
import random
def pretty_print_tree(t):
lines = map(lambda x: ''.join(map(str, x)), t.levelorder())
print('>', next(lines, None))
for line in lines:
print(' ', line)
n = 2**3-1
a = list(range(1, 1+n))
t = Tree()
b = random.sample(a, k=n)
print('Insert:', b)
for i in b:
t.insert(i)
pretty_print_tree(t)
b = list(t.inorder())
assert a == b, 'Bad order!'
b = random.sample(a, k=n)
print('Remove:', b)
for i in b:
t.remove(i)
pretty_print_tree(t)