from typing import Any, Dict, Optional
import logging
import torch
from pado.core import PadoModule
from pado.core.memo import PadoMemo

__all__ = ["PadoWrapper"]


class PadoWrapper(PadoModule):
    """
    Simple wrapper that wraps the model in the highest level.
    Should return dictionary that contains:
    """

    def __init__(self, network: PadoModule, logger: Optional = None) -> None:
        super().__init__()
        self.network = network
        self.network.set_name()

        memo = PadoMemo()
        self.set_memo(memo=memo)

        if logger is None:
            logger = logging.getLogger("pado")
        self.logger = logger

    def __call__(self, *args, **kwargs) -> Dict[str, Any]:
        self.memo.clear()
        results = self.forward(*args, **kwargs)

        device = self.network.device
        for k, v in results.items():
            if isinstance(v, float):  # wrap to 1-D tensor
                results[k] = torch.tensor([v], dtype=torch.float32, device=device)
            elif isinstance(v, int):  # wrap to 1-D tensor
                results[k] = torch.tensor([v], dtype=torch.long, device=device)
            elif isinstance(v, torch.Tensor) and (v.ndim == 0):
                results[k] = v.view(1)  # wrap to 1-D tensor
        return results

    def forward(self, *args, **kwargs) -> Dict[str, Any]:
        """Output of model.
        * result should be dictionary of {str: torch.Tensor}"""
        raise NotImplementedError

    def state_dict(self, destination=None, prefix: str = "", keep_vars: bool = False):
        return self.network.state_dict(destination, prefix=prefix, keep_vars=keep_vars)

    def load_state_dict(self, state_dict, strict: bool = True):
        return self.network.load_state_dict(state_dict, strict=strict)
