from typing import Optional
import torch
import torch.nn as nn

from pado.core.memo import PadoMemo, MEMO_DTYPE

__all__ = ["PadoModule", "PadoModuleMixin"]


class PadoModuleMixin(object):
    """
    Mixin adds additional functionalities for nn.Module.
    * name (for distinguish) and memo (for memorize).
    * freeze/unfreeze, counting and more...

    We separate PadoModule to PadoModuleMixin for cases that we should inherit nn.Module class itself.
    For such cases, don't forget to call nn.Module.__init__(self, ...) and PadoModuleMixin.__init__(self, ...)
    """

    def __init__(self):
        self._name: Optional[str] = None
        self._memo: Optional[PadoMemo] = None

    @property
    def name(self) -> Optional[str]:
        return self._name

    @property
    def device(self) -> Optional[torch.device]:
        d = None
        self: nn.Module
        for p in self.parameters():
            d = p.device
            if d is not None:
                break
        return d

    def set_name(self, prefix: str = "") -> None:
        self: nn.Module
        if not prefix.endswith("."):
            prefix = prefix + "."
        for module_name, module in self.named_modules():
            module._name = prefix + module_name
        for param_name, param in self.named_parameters():
            param._name = param_name
        for buf_name, buf in self.named_buffers():
            buf._name = buf_name

    def num_parameters(self, recurse: bool = True) -> int:
        self: nn.Module
        num: int = 0
        for p in self.parameters(recurse=recurse):
            if p.requires_grad:
                num += p.numel()
        return num

    def num_buffers(self, recurse: bool = True) -> int:
        self: nn.Module
        num: int = 0
        for b in self.buffers(recurse=recurse):
            if b is not None:
                num += b.numel()
        return num

    def freeze(self, recurse: bool = True) -> None:
        self: nn.Module
        for p in self.parameters(recurse=recurse):
            p.requires_grad = False
        self.eval()

    def unfreeze(self, recurse: bool = True) -> None:
        self: nn.Module
        for p in self.parameters(recurse=recurse):
            p.requires_grad = True
        self.train()

    def set_memo(self, memo: PadoMemo) -> None:
        self: nn.Module
        for module_name, module in self.named_modules():
            module._memo = memo

    def memorize(self, key: str, value: MEMO_DTYPE) -> None:
        if self._memo is None:
            raise ValueError(f"Module {self._name} called memorize({key}), but memo is None.")
        self._memo[key] = value

    @property
    def memo(self):
        return self._memo


class PadoModule(nn.Module, PadoModuleMixin):
    """
    Wrapper of nn.Module for Pado framework.
    Sometimes it's easy to implement from PadoModule, or sometimes just inherit and add Mixin.
    """

    def __init__(self):
        nn.Module.__init__(self)
        PadoModuleMixin.__init__(self)

    def forward(self, *args, **kwargs):
        raise NotImplementedError
