from typing import Any, BinaryIO, Optional as PyOptional, TypeVar, Type, cast from remerkleable.basic import uint256 from remerkleable.complex import MonoSubtreeView from remerkleable.core import BasicView, View, ViewHook from remerkleable.tree import Gindex, Node, PairNode, get_depth, subtree_fill_to_contents, zero_node T = TypeVar('T', bound="Optional") class Optional(MonoSubtreeView): __slots__ = () def __new__(cls, value: PyOptional[Type[T]] = None, backing: PyOptional[Node] = None, hook: PyOptional[ViewHook] = None, **kwargs): if backing is not None: if value is not None: raise Exception("cannot have both a backing and a value to init Optional") return super().__new__(cls, backing=backing, hook=hook, **kwargs) elem_cls = cls.element_cls() assert cls.limit() == 1 input_views = [] if value is not None: if isinstance(value, View): input_views.append(value) else: input_views.append(elem_cls.coerce_view(value)) input_nodes = cls.views_into_chunks(input_views) contents = subtree_fill_to_contents(input_nodes, cls.contents_depth()) backing = PairNode(contents, uint256(len(input_views)).get_backing()) return super().__new__(cls, backing=backing, hook=hook, **kwargs) def __class_getitem__(cls, element_type) -> Type["Optional"]: if element_type.min_byte_length() == 0: raise Exception(f"Invalid option type: ${element_type}") limit = 1 contents_depth = get_depth(limit) packed = isinstance(element_type, BasicView) class SpecialOptionView(Optional): @classmethod def is_packed(cls) -> bool: return packed @classmethod def contents_depth(cls) -> int: return contents_depth @classmethod def element_cls(cls) -> Type[View]: return element_type @classmethod def limit(cls) -> int: return limit SpecialOptionView.__name__ = SpecialOptionView.type_repr() return SpecialOptionView def length(self) -> int: ll_node = super().get_backing().get_right() ll = cast(uint256, uint256.view_from_backing(node=ll_node, hook=None)) return int(ll) def value_byte_length(self) -> int: if self.length() == 0: return 0 else: elem_cls = self.__class__.element_cls() if elem_cls.is_fixed_byte_length(): return elem_cls.type_byte_length() else: return cast(View, el).value_byte_length() def get(self) -> PyOptional[View]: if self.length() == 0: return None else: return super().get(0) def set(self, v: PyOptional[View]) -> None: if v is None: if self.length() == 0: return i = 0 target = to_gindex(i, self.__class__.tree_depth()) set_last = self.get_backing().setter(target) next_backing = set_last(zero_node(0)) can_summarize = (target & 1) == 0 if can_summarize: while (target & 1) == 0 and target != 0b10: target >>= 1 summary_fn = next_backing.summarize_into(target) next_backing = summary_fn() set_length = next_backing.rebind_right new_length = uint256(i).get_backing() next_backing = set_length(new_length) self.set_backing(next_backing) else: if self.length() == 1: super().set(0, v) return i = 0 elem_type: Type[View] = self.__class__.element_cls() if not isinstance(v, elem_type): v = elem_type.coerce_view(v) target = to_gindex(i, self.__class__.tree_depth()) set_last = self.get_backing().setter(target, expand=True) next_backing = set_last(v.get_backing()) set_length = next_backing.rebind_right new_length = uint256(i + 1).get_backing() next_backing = set_length(new_length) self.set_backing(next_backing) def __repr__(self): value = self.get() if value is None: return f"{self.type_repr()}(None)" else: return f"{self.type_repr()}(Some({repr(value)}))" @classmethod def type_repr(cls) -> str: return f"Optional[{cls.element_cls().__name__}]" @classmethod def is_packed(cls) -> bool: raise NotImplementedError @classmethod def contents_depth(cls) -> int: raise NotImplementedError @classmethod def tree_depth(cls) -> int: return cls.contents_depth() + 1 # 1 extra for length mix-in @classmethod def limit(cls) -> int: raise NotImplementedError @classmethod def deserialize(cls: Type[T], stream: BinaryIO, scope: int) -> Type[T]: if scope == 0: return cls() else: return cls(cls.element_cls().deserialize(stream, scope)) def serialize(self, stream: BinaryIO) -> int: v = self.get() if v is None: return 0 else: return v.serialize(stream) @classmethod def navigate_type(cls, key: Any) -> Type[View]: if key >= cls.limit(): raise KeyError return super().navigate_type(key) @classmethod def key_to_static_gindex(cls, key: Any) -> Gindex: if key == '__is_some__': return RIGHT_GINDEX if key >= cls.limit(): raise KeyError return super().key_to_static_gindex(key) @classmethod def default_node(cls) -> Node: return PairNode(zero_node(cls.contents_depth()), zero_node(0)) # mix-in 0 as list length @classmethod def is_fixed_byte_length(cls) -> bool: return False @classmethod def min_byte_length(cls) -> int: return 0 @classmethod def max_byte_length(cls) -> int: elem_cls = cls.element_cls() bytes_per_elem = elem_cls.max_byte_length() return bytes_per_elem