DCIPs/assets/eip-6404/tests/proof_helpers.py

95 lines
3.5 KiB
Python
Raw Normal View History

from typing import Any, Sequence, Set, Union as PyUnion
from hashlib import sha256
from remerkleable.basic import uint64
from remerkleable.byte_arrays import Bytes32
from remerkleable.core import Path
from remerkleable.tree import Gindex as GeneralizedIndex, gindex_bit_iter
SSZVariableName = str
class Root(Bytes32):
pass
def floorlog2(x: int) -> uint64:
if x < 1:
raise ValueError(f"floorlog2 accepts only positive values, x={x}")
return uint64(x.bit_length() - 1)
def get_generalized_index(ssz_class: Any, *path: Sequence[PyUnion[int, SSZVariableName]]) -> GeneralizedIndex:
ssz_path = Path(ssz_class)
for item in path:
ssz_path = ssz_path / item
return GeneralizedIndex(ssz_path.gindex())
def build_proof(anchor, leaf_index):
if leaf_index <= 1:
return [] # Nothing to prove / invalid index
node = anchor
proof = []
# Walk down, top to bottom to the leaf
bit_iter, _ = gindex_bit_iter(leaf_index)
for bit in bit_iter:
# Always take the opposite hand for the proof.
# 1 = right as leaf, thus get left
if bit:
proof.append(node.get_left().merkle_root())
node = node.get_right()
else:
proof.append(node.get_right().merkle_root())
node = node.get_left()
return list(reversed(proof))
def hash(x: PyUnion[bytes, bytearray, memoryview]) -> Bytes32:
return Bytes32(sha256(x).digest())
def generalized_index_parent(index: GeneralizedIndex) -> GeneralizedIndex:
return GeneralizedIndex(index // 2)
def generalized_index_sibling(index: GeneralizedIndex) -> GeneralizedIndex:
return GeneralizedIndex(index ^ 1)
def get_branch_indices(tree_index: GeneralizedIndex) -> Sequence[GeneralizedIndex]:
o = [generalized_index_sibling(tree_index)]
while o[-1] > 1:
o.append(generalized_index_sibling(generalized_index_parent(o[-1])))
return o[:-1]
def get_path_indices(tree_index: GeneralizedIndex) -> Sequence[GeneralizedIndex]:
o = [tree_index]
while o[-1] > 1:
o.append(generalized_index_parent(o[-1]))
return o[:-1]
def get_helper_indices(indices: Sequence[GeneralizedIndex]) -> Sequence[GeneralizedIndex]:
all_helper_indices: Set[GeneralizedIndex] = set()
all_path_indices: Set[GeneralizedIndex] = set()
for index in indices:
all_helper_indices = all_helper_indices.union(set(get_branch_indices(index)))
all_path_indices = all_path_indices.union(set(get_path_indices(index)))
return sorted(all_helper_indices.difference(all_path_indices), reverse=True)
def calculate_multi_merkle_root(leaves: Sequence[Bytes32],
proof: Sequence[Bytes32],
indices: Sequence[GeneralizedIndex],
helper_indices: Sequence[GeneralizedIndex]) -> Root:
assert len(leaves) == len(indices)
assert len(proof) == len(helper_indices)
objects = {
**{index: Bytes32(node) for index, node in zip(indices, leaves)},
**{index: Bytes32(node) for index, node in zip(helper_indices, proof)}
}
keys = sorted(objects.keys(), reverse=True)
pos = 0
while pos < len(keys):
k = keys[pos]
if k in objects and k ^ 1 in objects and k // 2 not in objects:
objects[GeneralizedIndex(k // 2)] = hash(
objects[GeneralizedIndex((k | 1) ^ 1)] +
objects[GeneralizedIndex(k | 1)]
)
keys.append(GeneralizedIndex(k // 2))
pos += 1
return objects[GeneralizedIndex(1)]