
import torch
from torch.nn.utils.clip_grad import clip_grad_norm_
from rich.progress import track
from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.ppo import PPO
from omnisafe.utils import distributed
from omnisafe.common.lagrange import Lagrange
from torch.utils.data import DataLoader, TensorDataset


@registry.register
class APPO(PPO):
    """The Implementation of the APPO algorithm.

    """

    def _init(self) -> None:
        super()._init()
        self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs)
        self.sigma = self._cfgs.algo_cfgs.penalty_coef
        self.lam = 0.001
        self._w_ema = torch.tensor(1.0, device=self._cfgs.train_cfgs.device)

    def _init_log(self) -> None:
        """Log the APPO specific information.

        +-------------------+-----------------------------------+
        | Things to log     | Description                       |
        +===================+===================================+
        | Loss/Loss_pi_cost | The loss of the cost performance. |
        +-------------------+-----------------------------------+
        """
        super()._init_log()
        self._logger.register_key('Loss/penalty', delta=True)
        self._logger.register_key('Misc/sigma')
        self._logger.register_key('Misc/LagrangeMultiplier')

    def _loss_pi_cost(
        self,
        obs: torch.Tensor,
        act: torch.Tensor,
        logp: torch.Tensor,
        adv_c: torch.Tensor,
    ) -> torch.Tensor:
        r"""
        Compute the cost loss (penalty term) for APPO
        """
        self._actor_critic.actor(obs)
        logp_ = self._actor_critic.actor.log_prob(act)

        ratio = torch.exp(logp_ - logp)
        ratio_clipped = torch.clamp(
            ratio,
            1 - self._cfgs.algo_cfgs.clip,
            1 + self._cfgs.algo_cfgs.clip,
        )
        surr_cadv = torch.max(ratio * adv_c, ratio_clipped * adv_c).mean()
        multiplier = self._lagrange.lagrangian_multiplier
        phi = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit

        # Penalty factor check
        if phi + multiplier / self.sigma > 0:
            factor = multiplier + self.sigma * phi
        else:
            factor = 0.0

        penalty = factor * surr_cadv

        self._logger.store({'Loss/penalty': penalty.mean()})
        return penalty, factor

    def update_penalty_factor(self):
        self.sigma = (1 + self._cfgs.algo_cfgs.rho) * self.sigma

    def _update_actor(
        self,
        obs: torch.Tensor,
        act: torch.Tensor,
        logp: torch.Tensor,
        adv_r: torch.Tensor,
        adv_c: torch.Tensor,
    ) -> None:
        """
        """

        loss_reward = self._loss_pi(obs, act, logp, adv_r)
        cost_penalty, factor = self._loss_pi_cost(obs, act, logp, adv_c)

        loss = (loss_reward + cost_penalty) / (1 + factor)

        self._actor_critic.actor_optimizer.zero_grad()
        loss.backward()
        if self._cfgs.algo_cfgs.use_max_grad_norm:
            clip_grad_norm_(
                self._actor_critic.actor.parameters(),
                self._cfgs.algo_cfgs.max_grad_norm,
            )
        distributed.avg_grads(self._actor_critic.actor)
        self._actor_critic.actor_optimizer.step()

    def _update(self) -> None:
        """Update actor, critic.
        """
        data = self._buf.get()
        obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = (
            data['obs'],
            data['act'],
            data['logp'],
            data['target_value_r'],
            data['target_value_c'],
            data['adv_r'],
            data['adv_c'],
        )

        original_obs = obs
        old_distribution = self._actor_critic.actor(obs)

        dataloader = DataLoader(
            dataset=TensorDataset(obs, act, logp, target_value_r, target_value_c, adv_r, adv_c),
            batch_size=self._cfgs.algo_cfgs.batch_size,
            shuffle=True,
        )

        update_counts = 0
        final_kl = torch.ones_like(old_distribution.loc)

        # Update lagrangian multiplier
        ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
        self._lagrange.update_lagrange_multiplier(ep_cost)

        # Update actor/critic
        for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
            for (
                obs,
                act,
                logp,
                target_value_r,
                target_value_c,
                adv_r,
                adv_c,
            ) in dataloader:
                self._update_reward_critic(obs, target_value_r)
                if self._cfgs.algo_cfgs.use_cost:
                    self._update_cost_critic(obs, target_value_c)
                self._update_actor(obs, act, logp, adv_r, adv_c)

            new_distribution = self._actor_critic.actor(original_obs)

            kl = (
                torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
                .sum(-1, keepdim=True)
                .mean()
                .item()
            )
            kl = distributed.dist_avg(kl)

            final_kl = kl
            update_counts += 1

            if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
                self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
                break

        self._logger.store(
            {
                'Train/StopIter': update_counts,  # pylint: disable=undefined-loop-variable
                'Value/Adv': adv_r.mean().item(),
                'Train/KL': final_kl,
                'Misc/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.item(),
                'Misc/sigma': self.sigma
            },
        )
