import torch

from abc import ABCMeta, abstractmethod

from expground.types import Dict, Any, Sequence
from expground.logger import Log
from expground.algorithms.base_policy import Policy
from expground.utils import data


class LossFunc(metaclass=ABCMeta):
    """Define loss function and optimizers

    Flowchart:
        1. create a loss func instance with: loss = LossFunc(policy, **kwargs)
        2. setup optimizers: loss.setup_optimizers(**kwargs)
        3. zero grads: loss.zero_grads()
        4. calculate loss and got returned statistics: statistics = loss(batch)
        5. do optimization (step): loss.step()

    **NOTE**: if you wanna calculate policy for another policy, do reset: loss.reset(policy)
    """

    def __init__(self, mute_critic_loss: bool = False):
        """Create an instance of loss function."""

        self._policy = None
        self._mute_critic_loss = mute_critic_loss
        self.optimizers = None
        self.loss = []
        self._params = {}
        self._gradients = []

    @property
    def mute_critic(self) -> bool:
        """Return whether this loss mutes the update of critic"""

        return self._mute_critic_loss

    @property
    def stacked_gradients(self):
        """Return stacked gradients"""

        return self._gradients

    def push_gradients(self, grad):
        """Push new gradient to gradients"""

        self._gradients.append(grad)

    @property
    def optim_cls(self) -> type:
        """Return default optimizer class. If not specify in params, return Adam as default."""

        return getattr(torch.optim, self._params.get("optimizer", "Adam"))

    @property
    def policy(self):
        return self._policy

    @abstractmethod
    def setup_optimizers(self, *args, **kwargs):
        """Set optimizers and loss function"""

    def setup_extras(self):
        """Implement it if there are other extra operators related to the loss define when policy is reset."""

    @abstractmethod
    @data.tensor_cast
    def __call__(self, *args, **kwargs) -> Dict[str, Any]:
        """Compute loss function here, but not optimize"""

    @abstractmethod
    def step(self) -> Any:
        """Step gradients and other update oprators"""

    def zero_grad(self):
        """Clean stacked gradients and optimizers"""

        self._gradients = []
        if isinstance(self.optimizers, Sequence):
            _ = [p.zero_grad() for p in self.optimizers]
        elif isinstance(self.optimizers, Dict):
            _ = [p.zero_grad() for p in self.optimizers.values()]
        elif isinstance(self.optimizers, torch.optim.Optimizer):
            self.optimizers.zero_grad()
        else:
            raise TypeError(
                f"Unexpected optimizers type: {type(self.optimizers)}, expected are included: Sequence, Dict, and torch.optim.Optimizer"
            )

    def reset(self, policy: Policy, configs: Dict[str, Any]):
        """Return the related policy and hyper-parameters for loss computing.

        Args:
            policy (Policy): A policy instance.
            configs (Dict[str, Any]): The hyper-parameters for loss computing.
        """

        self._params.update(configs or {})
        Log.debug(f"reset loss configs: {self._params}")
        # if policy is not None and self._policy is not policy:
        # print("\t- loss func reset policy:", policy)
        if policy is not None:
            policy.reset()  # reset parameter here
            self._policy = policy
            self.optimizers = None
            self.loss = []
            self.setup_optimizers()
            self.setup_extras()

    def _output_nan(self, check_table):
        for k in check_table:
            if torch.isnan(check_table[k]).max():
                print("{} ".format(k), check_table[k])
        print()
