95 lines
3.5 KiB
Python
95 lines
3.5 KiB
Python
from typing import Any, Sequence, Set, Union as PyUnion
|
|
from hashlib import sha256
|
|
from remerkleable.basic import uint64
|
|
from remerkleable.byte_arrays import Bytes32
|
|
from remerkleable.core import Path
|
|
from remerkleable.tree import Gindex as GeneralizedIndex, gindex_bit_iter
|
|
|
|
SSZVariableName = str
|
|
|
|
class Root(Bytes32):
|
|
pass
|
|
|
|
def floorlog2(x: int) -> uint64:
|
|
if x < 1:
|
|
raise ValueError(f"floorlog2 accepts only positive values, x={x}")
|
|
return uint64(x.bit_length() - 1)
|
|
|
|
def get_generalized_index(ssz_class: Any, *path: Sequence[PyUnion[int, SSZVariableName]]) -> GeneralizedIndex:
|
|
ssz_path = Path(ssz_class)
|
|
for item in path:
|
|
ssz_path = ssz_path / item
|
|
return GeneralizedIndex(ssz_path.gindex())
|
|
|
|
def build_proof(anchor, leaf_index):
|
|
if leaf_index <= 1:
|
|
return [] # Nothing to prove / invalid index
|
|
node = anchor
|
|
proof = []
|
|
# Walk down, top to bottom to the leaf
|
|
bit_iter, _ = gindex_bit_iter(leaf_index)
|
|
for bit in bit_iter:
|
|
# Always take the opposite hand for the proof.
|
|
# 1 = right as leaf, thus get left
|
|
if bit:
|
|
proof.append(node.get_left().merkle_root())
|
|
node = node.get_right()
|
|
else:
|
|
proof.append(node.get_right().merkle_root())
|
|
node = node.get_left()
|
|
|
|
return list(reversed(proof))
|
|
|
|
def hash(x: PyUnion[bytes, bytearray, memoryview]) -> Bytes32:
|
|
return Bytes32(sha256(x).digest())
|
|
|
|
def generalized_index_parent(index: GeneralizedIndex) -> GeneralizedIndex:
|
|
return GeneralizedIndex(index // 2)
|
|
|
|
def generalized_index_sibling(index: GeneralizedIndex) -> GeneralizedIndex:
|
|
return GeneralizedIndex(index ^ 1)
|
|
|
|
def get_branch_indices(tree_index: GeneralizedIndex) -> Sequence[GeneralizedIndex]:
|
|
o = [generalized_index_sibling(tree_index)]
|
|
while o[-1] > 1:
|
|
o.append(generalized_index_sibling(generalized_index_parent(o[-1])))
|
|
return o[:-1]
|
|
|
|
def get_path_indices(tree_index: GeneralizedIndex) -> Sequence[GeneralizedIndex]:
|
|
o = [tree_index]
|
|
while o[-1] > 1:
|
|
o.append(generalized_index_parent(o[-1]))
|
|
return o[:-1]
|
|
|
|
def get_helper_indices(indices: Sequence[GeneralizedIndex]) -> Sequence[GeneralizedIndex]:
|
|
all_helper_indices: Set[GeneralizedIndex] = set()
|
|
all_path_indices: Set[GeneralizedIndex] = set()
|
|
for index in indices:
|
|
all_helper_indices = all_helper_indices.union(set(get_branch_indices(index)))
|
|
all_path_indices = all_path_indices.union(set(get_path_indices(index)))
|
|
|
|
return sorted(all_helper_indices.difference(all_path_indices), reverse=True)
|
|
|
|
def calculate_multi_merkle_root(leaves: Sequence[Bytes32],
|
|
proof: Sequence[Bytes32],
|
|
indices: Sequence[GeneralizedIndex],
|
|
helper_indices: Sequence[GeneralizedIndex]) -> Root:
|
|
assert len(leaves) == len(indices)
|
|
assert len(proof) == len(helper_indices)
|
|
objects = {
|
|
**{index: Bytes32(node) for index, node in zip(indices, leaves)},
|
|
**{index: Bytes32(node) for index, node in zip(helper_indices, proof)}
|
|
}
|
|
keys = sorted(objects.keys(), reverse=True)
|
|
pos = 0
|
|
while pos < len(keys):
|
|
k = keys[pos]
|
|
if k in objects and k ^ 1 in objects and k // 2 not in objects:
|
|
objects[GeneralizedIndex(k // 2)] = hash(
|
|
objects[GeneralizedIndex((k | 1) ^ 1)] +
|
|
objects[GeneralizedIndex(k | 1)]
|
|
)
|
|
keys.append(GeneralizedIndex(k // 2))
|
|
pos += 1
|
|
return objects[GeneralizedIndex(1)]
|