"""
Loss function for PPO

Ref 1: https://github.com/ray-project/ray/blob/master/rllib/agents/ppo/ppo_torch_policy.py
Ref 2: https://github.com/astooke/rlpyt/blob/master/rlpyt/algos/pg/ppo.py
Ref 3: https://github.com/openai/baselines/blob/master/baselines/ppo2/model.py
"""

from os import stat
from turtle import pd
from numpy import testing
from expground.algorithms.ppo.policy import PPO
import torch

import torch.nn.functional as F

from expground.types import DataArray, Dict, Any, Tensor, Tuple
from expground.algorithms.ppo.config import DEFAULT_CONFIG
from expground.logger import Log
from expground.utils import data
from expground.utils.data import EpisodeKeys
from expground.algorithms import misc
from expground.algorithms.loss_func import LossFunc


class PPOLoss(LossFunc):
    def __init__(self, mute_critic_loss: bool = False):
        super(PPOLoss, self).__init__(mute_critic_loss=mute_critic_loss)
        self._params.update(DEFAULT_CONFIG["training_config"])

    def zero_grad(self):
        assert len(self.optimizers) > 0
        _ = [p.zero_grad() for p in self.optimizers.values()]

    def step(self) -> Any:
        pass
        # assert len(self.optimizers) > 0
        # _ = [p.step() for p in self.optimizers.values()]

    def setup_extras(self):
        pass

    def setup_optimizers(self, *args, **kwargs):
        # if self.optimizers is None:
        optim_cls = getattr(torch.optim, self._params["optimizer"])
        self.optimizers = {
            "total": optim_cls(
                self.policy.parameters(),
                lr=self._params["learning_rate"],
            ),
        }

    def _update_ops(self, batch: Dict[str, Tensor]) -> Dict[str, float]:
        """Define critic loss computation and update operations.

        Args:
            cur_critic_state (Tensor): The batched critic states.
            batch (Dict[str, Tensor]): A dict of training batches.
        """
        ent_coef = self._params["ent_coef"]
        vf_coef = self._params["vf_coef"]
        gamma = self._params["gamma"]
        ratio_clip = self._params["clip_param"]

        device = "cuda" if self.policy.use_cuda else "cpu"
        state = batch[EpisodeKeys.OBSERVATION.value].to(device)
        next_state = batch[EpisodeKeys.NEXT_OBSERVATION.value].to(device)
        done = batch[EpisodeKeys.DONE.value].reshape(-1).to(device)
        reward = batch[EpisodeKeys.REWARD.value].reshape(-1).to(device)
        old_pi = batch[EpisodeKeys.ACTION_DIST.value].to(device)
        action = (
            batch[EpisodeKeys.ACTION.value].reshape(-1).to(dtype=torch.int64).to(device)
        )
        action = F.one_hot(action, num_classes=self.policy._action_space.n)
        action_mask = batch[EpisodeKeys.ACTION_MASK.value]

        # print("shapes:", state.shape, next_state.shape, done.shape, reward.shape, old_pi.shape, action.shape)

        _, log_pi, entropy = self.policy.compute_actions(
            state, action_mask=action_mask, with_log_pis=True
        )

        # policy loss
        value = self.policy.value_function(state)
        with torch.no_grad():
            next_value = self.policy.value_function(next_state)
            assert reward.shape == next_value.shape == done.shape == value.shape, (
                reward.shape,
                next_value.shape,
                done.shape,
                value.shape,
                next_state.shape,
            )
            adv = reward + next_value * (1.0 - done) * gamma - value

        # Normalize the advantages (in openai/baseline)
        # Trick1: advs = (advs - advs.mean()) / (advs.std() + 1e-8)

        ratio = ((log_pi.exp() / old_pi) * action).sum(dim=-1, keepdims=True)
        ratio2 = torch.clamp(ratio, 1.0 - ratio_clip, 1.0 + ratio_clip)
        policy_loss = -torch.min(ratio * adv, ratio2 * adv)
        policy_loss = policy_loss.mean()

        likely = -(log_pi * action).sum(dim=-1, keepdims=True).mean()
        # policy_loss = -adv*(log_pi*action).sum(dim=-1, keepdims=True)
        # policy_loss = policy_loss.mean()

        # value loss
        # Trick2: clipped value loss
        # prev_value_fn_out = train_batch[SampleBatch.VF_PREDS]
        # vf_clipped = prev_value_fn_out + torch.clamp(
        #                 value_fn_out - prev_value_fn_out,
        #                 -self.config["vf_clip_param"], self.config["vf_clip_param"])
        # vf_loss = torch.max(vf_loss1, vf_loss2)

        with torch.no_grad():
            target_v1 = self.policy.target_value_function(next_state)
        # value = self.policy.value_function(state)
        assert reward.shape == target_v1.shape == done.shape == value.shape, (
            reward.shape,
            target_v1.shape,
            done.shape,
            value.shape,
        )
        value_loss = F.mse_loss(reward + target_v1 * (1.0 - done) * gamma, value)

        # Trick3: loss += self.kl_coeff * action_kl (ray rllib)
        entropy = entropy.mean()

        check_table = {
            "ratio0": (log_pi.exp()),
            # "ratio": ratio,
            # "action": action,
            "old_pi": old_pi,
            # "p": policy_loss,
            # "v": value_loss,
            # "ent": entropy,
        }
        # print(check_table)
        # self._output_nan(check_table)

        loss = policy_loss + vf_coef * value_loss - ent_coef * entropy

        self.optimizers["total"].zero_grad()
        loss.backward()
        self.optimizers["total"].step()

        diagnos = {
            "value_loss": value_loss.detach().cpu().item(),
            "target_value_min": target_v1.min().cpu().item(),
            "target_value_mean": target_v1.mean().cpu().item(),
            "target_value_max": target_v1.max().cpu().item(),
            "reward_min": reward.min().cpu().item(),
            "reward_mean": reward.mean().cpu().item(),
            "reward_max": reward.max().cpu().item(),
            "value_min": value.min().cpu().item(),
            "value_mean": value.mean().cpu().item(),
            "value_max": value.max().cpu().item(),
            "entropy": entropy.cpu().item(),
            "adv": adv.mean().cpu().item(),
            "likely": likely.cpu().item(),
            "policy_loss": policy_loss.detach().cpu().item(),
            "log_prob": (log_pi * action).sum(dim=-1).mean().cpu().item(),
        }
        return diagnos

    @data.tensor_cast(callback=None)  # lambda x: Log.debug(f"Training info: {x}"))
    def __call__(self, batch: Dict[str, DataArray]) -> Dict[str, Any]:
        self.loss = []
        # policy evaluation
        diag = self._update_ops(batch)

        info = {}
        info.update(diag)

        return info
