"""
Loss function for SAC
Reference: https://github.com/haarnoja/sac/blob/master/sac/algos/sac.py

Ref 1: https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac_torch_policy.py
Ref 2: https://github.com/rail-berkeley/rlkit/blob/master/rlkit/torch/sac/sac.py
"""

from os import stat
from numpy import testing
from expground.algorithms.sac.policy import SAC
import torch

import torch.nn.functional as F

from expground.types import DataArray, Dict, Any, Tensor, Tuple
from expground.algorithms.sac.config import DEFAULT_CONFIG
from expground.logger import Log
from expground.utils import data
from expground.utils.data import EpisodeKeys
from expground.algorithms import misc
from expground.algorithms.loss_func import LossFunc


class SACLoss(LossFunc):
    def __init__(self, mute_critic_loss: bool = False):
        super(SACLoss, self).__init__(mute_critic_loss=mute_critic_loss)
        self._params.update(DEFAULT_CONFIG["training_config"])

    def zero_grad(self):
        assert len(self.optimizers) > 0
        _ = [p.zero_grad() for p in self.optimizers.values()]

    def step(self) -> Any:
        pass
        # assert len(self.optimizers) > 0
        # _ = [p.step() for p in self.optimizers.values()]

    def setup_extras(self):
        pass

    def setup_optimizers(self, *args, **kwargs):
        # if self.optimizers is None:
        optim_cls = getattr(torch.optim, self._params["optimizer"])
        self.optimizers = {
            "actor": optim_cls(
                self.policy._actor.parameters(), lr=self._params["actor_lr"]
            ),
            "critic1": optim_cls(
                self.policy._critic_1.parameters(), lr=self._params["critic_lr"]
            ),
            "critic2": optim_cls(
                self.policy._critic_2.parameters(), lr=self._params["critic_lr"]
            ),
        }

        if self._params.get("auto_alpha", False):
            self.optimizers["alpha"] = optim_cls(
                self.policy._log_alpha.parameters(), lr=self._params["alpha_lr"]
            )
        # else:
        #     self.optimizers["actor"].param_groups = []
        #     self.optimizers["actor"].add_param_group(
        #         {"params": self.policy._actor.parameters()}
        #     )
        #     self.optimizers["critic1"].param_groups = []
        #     self.optimizers["critic1"].add_param_group(
        #         {"params": self.policy._critic_1.parameters()}
        #     )
        #     self.optimizers["critic2"].param_groups = []
        #     self.optimizers["critic2"].add_param_group(
        #         {"params": self.policy._critic_2.parameters()}
        #     )
        #     if self._params.get("auto_alpha", False):
        #         self.optimizers["alpha"].param_groups = []
        #         self.optimizers["alpha"].add_param_group(
        #             {"params": [self.policy._log_alpha]}
        #         )

    def _update_ops(
        self, cur_critic_state: Tensor, batch: Dict[str, Tensor]
    ) -> Dict[str, float]:
        """Define critic loss computation and update operations.

        Args:
            cur_critic_state (Tensor): The batched critic states.
            batch (Dict[str, Tensor]): A dict of training batches.
        """

        assert isinstance(self.policy, SAC)
        q1 = self.policy._critic_1(cur_critic_state)
        q2 = self.policy._critic_2(cur_critic_state)

        next_state = batch[EpisodeKeys.NEXT_OBSERVATION.value]
        done = batch[EpisodeKeys.DONE.value]
        reward = batch[EpisodeKeys.REWARD.value]

        alpha = self.policy._log_alpha().exp().detach()

        with torch.no_grad():
            next_action, next_log_pi = self.policy.compute_actions(
                next_state,
                action_mask=batch.get(EpisodeKeys.ACTION_MASK.value, None),
                with_log_pis=True,
                explore=True,
            )

            if self.policy._is_discrete_action:
                nxt_critic_state = next_state
            else:
                nxt_critic_state = torch.cat(
                    [next_state, next_action],
                    dim=-1,
                )

            q1_next_target, q2_next_target = self.policy.target_value_function(
                nxt_critic_state
            )
            if self.policy._is_discrete_action:
                next_pi = next_log_pi.exp()

                min_q_next_target = (
                    (torch.min(q1_next_target, q2_next_target) - alpha * next_log_pi)
                    * next_pi
                ).sum(dim=-1)
                # print(reward.shape, done.shape, min_q_next_target.shape)
            else:
                min_q_next_target = (
                    torch.min(q1_next_target, q2_next_target)
                    - alpha * next_log_pi.unsqueeze(-1)
                ).squeeze()

            target_q = (
                reward + (1.0 - done) * self._params["gamma"] * min_q_next_target
            ).unsqueeze(-1)

        if self.policy._is_discrete_action:
            action = torch.FloatTensor(
                self._params["batch_size"], self.policy._action_space.n
            ).to(self.policy.device, dtype=torch.float32)
            action.zero_()
            index = torch.reshape(batch[EpisodeKeys.ACTION.value], (-1, 1)).type(
                torch.int64
            )
            action.scatter_(1, index, 1.0)

            # ========== check nan start ==============

            # if torch.isnan(min_q_next_target).max():
            #     check_table = {
            #         "next_pi": next_pi,
            #         "q1_next_target": q1_next_target,
            #         "q2_next_target": q2_next_target,
            #         "nxt_critic_state": nxt_critic_state,
            #         "min_q_next_target": min_q_next_target,
            #         "alpha": alpha,
            #         "next_log_pi": next_log_pi,
            #         # "(torch.min(q1_next_target, q2_next_target)- alpha * next_log_pi)": min_q_next_target_bk,
            #         # "...*next_pi": min_q_next_target_bk2,
            #     }
            #     print(
            #         "find nan in min next q target: ",
            #     )
            #     self._output_nan(check_table)
            #     self.policy.check_nan()

            # if (
            #     torch.isnan(q1).max()
            #     or torch.isnan(action).max()
            #     or torch.isnan(target_q).max()
            # ):
            #     print(
            #         "find nan at td_loss",
            #         torch.isnan(alpha),
            #         torch.isnan(min_q_next_target).max(),
            #         torch.isnan(next_log_pi).max(),
            #         torch.isnan(q1).max(),
            #         torch.isnan(action).max(),
            #         torch.isnan(target_q).max(),
            #     )
            # ========== check nan end ==============

            td_loss1 = F.mse_loss((q1 * action).sum(dim=-1, keepdim=True), target_q)
            td_loss2 = F.mse_loss((q2 * action).sum(dim=-1, keepdim=True), target_q)
        else:
            assert target_q.shape == q1.shape == q2.shape, (
                target_q.shape,
                q1.shape,
                q2.shape,
                done.shape,
                reward.shape,
                q1_next_target.shape,
                q2_next_target.shape,
                next_state.shape,
                reward.shape,
                done.shape,
                min_q_next_target.shape,
            )
            td_loss1 = F.mse_loss(q1, target_q)
            td_loss2 = F.mse_loss(q2, target_q)
        td_loss = td_loss1 + td_loss2

        self.optimizers["critic1"].zero_grad()
        self.optimizers["critic2"].zero_grad()
        td_loss.backward()
        self.optimizers["critic1"].step()
        self.optimizers["critic2"].step()

        diagnos = {
            "td_loss1": td_loss1.detach().item(),
            "td_loss2": td_loss2.detach().item(),
            "td_loss": td_loss.detach().item(),
            "target_q_min": target_q.min().item(),
            "target_q_mean": target_q.mean().item(),
            "target_q_max": target_q.max().item(),
            "reward_min": reward.min().item(),
            "reward_mean": reward.mean().item(),
            "reward_max": reward.max().item(),
            "q1_min": q1.min().item(),
            "q1_mean": q1.mean().item(),
            "q1_max": q1.max().item(),
            "q2_min": q2.min().item(),
            "q2_mean": q2.mean().item(),
            "q2_max": q2.max().item(),
            "q1_q2_diff": F.mse_loss(q1, q2).item(),
        }
        # actor loss

        state = batch[EpisodeKeys.OBSERVATION.value]

        action, log_pi = self.policy.compute_actions(
            state, with_log_pis=True, explore=True
        )
        # stochastic action
        with torch.no_grad():
            if not self.policy._is_discrete_action:
                v_in = torch.cat([state, action], dim=-1)
                q1, q2 = self.policy.value_function(v_in)
            min_q = torch.min(q1, q2).detach()

        if self.policy._is_discrete_action:
            if self._params["auto_alpha"]:
                pi = log_pi.detach().exp()
                # print(log_pi.shape, (log_pi + self.policy._target_entropy).shape)
                alpha_loss = (
                    (
                        -pi
                        * self.policy._log_alpha()
                        * (log_pi + self.policy._target_entropy).detach()
                    )
                    .sum(dim=-1)
                    .mean()
                )

                # print(log_pi[:10] + self.policy._target_entropy, log_pi[:10], self.policy._target_entropy)

            policy_loss = (log_pi.exp() * ((alpha * log_pi) - min_q)).sum(dim=-1)
            policy_loss = policy_loss.mean()
        else:
            if self._params["auto_alpha"]:
                alpha_loss = -(
                    self.policy._log_alpha()
                    * (log_pi + self.policy._target_entropy).detach()
                ).mean()

            policy_loss = ((alpha * log_pi.flatten()) - min_q.flatten()).mean()
        diagnos.update(
            {
                "policy_loss": policy_loss.detach().item(),
                "log_pi": log_pi.detach().mean().item(),
            }
        )

        if self._params["auto_alpha"]:
            diagnos["alpha"] = alpha.detach().item()
            diagnos["alpha_loss"] = alpha_loss.detach().item()

            self.optimizers["alpha"].zero_grad()
            alpha_loss.backward()
            self.optimizers["alpha"].step()
        else:
            diagnos["alpha"] = alpha

        self.optimizers["actor"].zero_grad()
        policy_loss.backward()
        self.optimizers["actor"].step()

        if torch.isnan(policy_loss) or torch.isnan(td_loss):
            print(
                "find nan in loss:",
                torch.isnan(policy_loss),
                torch.isnan(td_loss1),
                torch.isnan(td_loss2),
            )

        return diagnos

    @data.tensor_cast(callback=None)  # lambda x: Log.debug(f"Training info: {x}"))
    def __call__(self, batch: Dict[str, DataArray]) -> Dict[str, Any]:
        self.loss = []

        with torch.no_grad():
            if self.policy.is_discrete_action:
                cur_critic_state = batch[EpisodeKeys.OBSERVATION.value]
            else:
                action = batch[EpisodeKeys.ACTION.value]
                cur_critic_state = torch.cat(
                    [batch[EpisodeKeys.OBSERVATION.value], action], dim=-1
                )

        # policy evaluation
        diag = self._update_ops(cur_critic_state, batch)

        info = {}
        info.update(diag)

        return info
