from typing import Optional, Union
from numbers import Number
from collections import OrderedDict
import torch

__all__ = ["PadoMemo", "MEMO_DTYPE"]

MEMO_DTYPE = Optional[Union[torch.Tensor, Number, str]]


class PadoMemo(OrderedDict):
    """
    Memorize key(str) and value(torch.Tensor / scalar / string / None).
    * Used to gather intermediate values during forward propagation.
    * Will be cleared out for each forward propagation.
    * Does not support inter-GPU communication.
    """

    def __getitem__(self, key: str) -> MEMO_DTYPE:
        try:
            return self.__getitem__(key)
        except KeyError:
            raise KeyError(f"Memo key {key} is not registered.")

    def __setitem__(self, key: str, value: MEMO_DTYPE) -> None:
        # override if key is duplicated
        if (value is not None) and (not isinstance(value, (torch.Tensor, Number, str))):
            raise ValueError(f"Memo key {key} got unsupported value type: {type(value)}.")
        self.__setitem__(key, value)
