import torch

import torch.nn.functional as F

from expground.types import DataArray, Dict, Any
from expground.logger import Log
from expground.utils import data
from expground.utils.data import EpisodeKeys
from expground.algorithms import misc
from expground.algorithms.loss_func import LossFunc


class DDPGLoss(LossFunc):
    def __init__(self, mute_critic_loss: bool = False):
        super(DDPGLoss, self).__init__(mute_critic_loss)

    def zero_grad(self):
        assert len(self.optimizers) > 0
        _ = [p.zero_grad() for p in self.optimizers.values()]

    def step(self):
        pass

    def setup_optimizers(self, *args, **kwargs):
        """Accept training configuration and setup optimizers"""

        if self.optimizers is None:
            optim_cls = getattr(torch.optim, self._params["optimizer"])
            self.optimizers = {
                "actor": optim_cls(
                    self.policy._actor.parameters(), lr=self._params["actor_lr"]
                ),
                "critic": optim_cls(
                    self.policy._critic.parameters(), lr=self._params["critic_lr"]
                ),
            }
        else:
            self.optimizers["actor"].param_groups = []
            self.optimizers["actor"].add_param_group(
                {"params": self.policy.actor.parameters()}
            )
            self.optimizers["critic"].param_groups = []
            self.optimizers["critic"].add_param_group(
                {"params": self.policy.critic.parameters()}
            )

    @data.tensor_cast(callback=lambda x: Log.debug(f"Training info: {x}"))
    def __call__(self, batch: Dict[str, DataArray]) -> Dict[str, Any]:
        self.loss = []
        cliprange = self._params["grad_norm_clipping"]
        gamma = self._params["gamma"]

        if not self.mute_critic:
            # ---------------- critic loss compute and step -------------------
            # self.optimizers["critic"].zero_grad()
            next_obs = batch[EpisodeKeys.NEXT_OBSERVATION.value]
            next_action, _, _ = self.policy.compute_actions(
                next_obs,
                use_target=True,
                action_mask=batch.get("next_action_mask", None),
            )
            target_vf_in = torch.cat(
                [next_obs, next_action],
                dim=-1,
            )
            next_value = self.policy.target_critic(target_vf_in).view(-1)
            reward = batch[EpisodeKeys.REWARD.value]
            target_value = reward + gamma * next_value * (
                1.0 - batch[EpisodeKeys.DONE.value]
            )

            cur_obs = batch[EpisodeKeys.OBSERVATION.value]
            actions = batch[EpisodeKeys.ACTION_DIST.value]

            vf_in = torch.cat([cur_obs, actions], dim=-1)
            actual_value = self.policy.critic(vf_in).view(-1)

            value_loss = F.mse_loss(actual_value, target_value.detach())
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.policy.critic.parameters(), cliprange)
            self.optimizers["critic"].step()

            # ========== check nan ==============
            # if torch.isnan(value_loss).max():
            #     check_table = {
            #         "target_value": target_value,
            #         "actual_value": actual_value,
            #         "next_value": next_value,
            #         "target_vf_in": target_vf_in,
            #     }
            #     print("find nan in v loss: ", end="")
            #     self._output_nan(check_table)
            # ========== check nan ==============

            # --------------------------------------

        # ------------------ actor loss compute and step --------------------
        # self.optimizers["actor"].zero_grad()
        if not self.mute_critic:
            actions, _, _ = self.policy.compute_actions(
                cur_obs, action_mask=batch.get(EpisodeKeys.ACTION_MASK.value, None)
            )
            vf_in = torch.cat([cur_obs, actions], dim=-1)
            policy_loss = (
                -self.policy.critic(vf_in).view(-1).mean()
            )  # need add regularization?
            # policy_loss += 0.001 * torch.mean(torch.square(actions))
            policy_loss.backward()
            loss_names = ["policy_loss"]

            stats_list = [
                policy_loss.detach().item(),
            ]

            # ========== check nan ==============
            # if torch.isnan(policy_loss).max():
            #     print(
            #         "find nan in p loss:",
            #         torch.isnan(policy_loss),
            #     )
            # ========== check nan ==============

            if not self.mute_critic:
                loss_names += [
                    "value_loss",
                    "target_value_est",
                    "eval_value_est",
                    "target_value_max",
                    "eval_value_max",
                    "target_value_min",
                    "eval_value_min",
                    "reward_max",
                    "reward_mean",
                    "reward_min",
                ]
                stats_list += [
                    value_loss.detach().item(),
                    target_value.mean().item(),
                    actual_value.mean().item(),
                    target_value.max().item(),
                    actual_value.max().item(),
                    target_value.min().item(),
                    actual_value.min().item(),
                    reward.max().item(),
                    reward.mean().item(),
                    reward.min().item(),
                ]

            torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), cliprange)
            self.optimizers["actor"].step()

            return dict(zip(loss_names, stats_list))
        else:
            torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), cliprange)
            self.optimizers["actor"].step()
            return {}
        # --------------------------------------
