
"""Implementation of CSPO: Constraint-Sentitive Policy Optimization"""

import torch
from torch.nn.utils.clip_grad import clip_grad_norm_

from omnisafe.algorithms import registry
from omnisafe.algorithms.on_policy.base.ppo import PPO
from omnisafe.utils import distributed
from omnisafe.common.lagrange import Lagrange
from torch.utils.data import DataLoader, TensorDataset
from rich.progress import track


@registry.register
class CSPO(PPO):

    def _init(self) -> None:
        super()._init()
        self._lagrange: Lagrange = Lagrange(**self._cfgs.lagrange_cfgs)
        self.alpha = self._cfgs.algo_cfgs.alpha
        self._w_ema = torch.tensor(1.0, device=self._cfgs.train_cfgs.device)

    def _init_log(self) -> None:
        """Log the CSPO specific information.

        +-------------------+-----------------------------------+
        | Things to log     | Description                       |
        +===================+===================================+
        | Loss/Loss_pi_cost | The loss of the cost performance. |
        +-------------------+-----------------------------------+
        """
        super()._init_log()
        self._logger.register_key('Loss/penalty', delta=True)
        self._logger.register_key('Loss/correction', delta=True)
        self._logger.register_key("Geom/w_raw", delta=True)
        self._logger.register_key("Geom/w_ema", delta=True)
        self._logger.register_key("Geom/grad_norm_sq", delta=True)
        self._logger.register_key("Geom/phi", delta=True)
        self._logger.register_key("Misc/LagrangeMultiplier", delta=True)
        self._logger.register_key("Misc/alpha", delta=True)

    def _ema(self, prev, new, beta):
        return beta * prev + (1 - beta) * new

    def _flatten_grads(self, grads):
        parts = []
        for gi, p in zip(grads, self._actor_critic.actor.parameters()):
            if gi is None:
                continue
            parts.append(gi.detach().reshape(-1))
        if len(parts) == 0:
            return torch.tensor(0.0, device=self.device)
        return torch.cat(parts)

    def _loss_pi_cost(
        self,
        obs: torch.Tensor,
        act: torch.Tensor,
        logp: torch.Tensor,
        adv_c: torch.Tensor,
    ) -> torch.Tensor:

        self._actor_critic.actor(obs)
        logp_ = self._actor_critic.actor.log_prob(act)

        ratio = torch.exp(logp_ - logp)
        ratio_cliped = torch.clamp(
            ratio,
            1 - self._cfgs.algo_cfgs.clip,
            1 + self._cfgs.algo_cfgs.clip,
        )

        surr_cadv_unclipped = (ratio * adv_c).mean()
        # Grad of the constraint function
        grads = torch.autograd.grad(
            -surr_cadv_unclipped,
            self._actor_critic.actor.parameters(),
            retain_graph=True,
            create_graph=False,
            allow_unused=True
        )
        delta_g = self._flatten_grads(grads)
        g_norm = (delta_g ** 2).sum()
        # Sensitivity weight w
        w_raw = 1.0 / (g_norm + self._cfgs.algo_cfgs.geo_eps)
        # clamp then EMA
        w_raw = torch.clamp(w_raw, self._cfgs.algo_cfgs.geo_w_clip_min, self._cfgs.algo_cfgs.geo_w_clip_max)
        self._w_ema = self._ema(self._w_ema, w_raw, self._cfgs.algo_cfgs.geo_w_ema)
        w = self._w_ema.detach()
        phi = self._logger.get_stats('Metrics/EpCost')[0] - self._cfgs.algo_cfgs.cost_limit
        multiplier = self._lagrange.lagrangian_multiplier

        if phi > 0.0:
            correction = self.alpha * w * phi
        else:
            correction = 0.0

        surr_cadv = torch.max(ratio * adv_c, ratio_cliped * adv_c).mean()

        loss_cost = (multiplier + correction) * surr_cadv

        self._logger.store({
            'Loss/penalty': loss_cost.mean().detach(),
            'Loss/correction': correction,
            'Geom/w_raw': w_raw.detach(),
            'Geom/w_ema': w,
            'Geom/grad_norm_sq': g_norm.detach(),
            'Geom/phi': phi,
            # 'Loss/lambda': multiplier.detach()
        })
        return loss_cost, correction

    def _update_actor(
        self,
        obs: torch.Tensor,
        act: torch.Tensor,
        logp: torch.Tensor,
        adv_r: torch.Tensor,
        adv_c: torch.Tensor,
    ) -> None:
        """
        """
        loss_reward = self._loss_pi(obs, act, logp, adv_r)
        loss_cost, correction = self._loss_pi_cost(obs, act, logp, adv_c)

        loss = (loss_reward + loss_cost) / (1 + correction)

        self._actor_critic.actor_optimizer.zero_grad()
        loss.backward()
        if self._cfgs.algo_cfgs.use_max_grad_norm:
            clip_grad_norm_(
                self._actor_critic.actor.parameters(),
                self._cfgs.algo_cfgs.max_grad_norm,
            )
        distributed.avg_grads(self._actor_critic.actor)
        self._actor_critic.actor_optimizer.step()

        # TODO: update penalty factor (sigma) based on the violation degree

    def _update(self) -> None:
        """Update actor, critic.
        """
        data = self._buf.get()
        obs, act, logp, target_value_r, target_value_c, adv_r, adv_c = (
            data['obs'],
            data['act'],
            data['logp'],
            data['target_value_r'],
            data['target_value_c'],
            data['adv_r'],
            data['adv_c'],
        )

        original_obs = obs
        old_distribution = self._actor_critic.actor(obs)

        dataloader = DataLoader(
            dataset=TensorDataset(obs, act, logp, target_value_r, target_value_c, adv_r, adv_c),
            batch_size=self._cfgs.algo_cfgs.batch_size,
            shuffle=True,
        )

        update_counts = 0
        final_kl = torch.ones_like(old_distribution.loc)

        # Update lagrangian multiplier
        ep_cost = self._logger.get_stats('Metrics/EpCost')[0]
        self._lagrange.update_lagrange_multiplier(ep_cost)

        # Update actor/critic
        for i in track(range(self._cfgs.algo_cfgs.update_iters), description='Updating...'):
            for (
                obs,
                act,
                logp,
                target_value_r,
                target_value_c,
                adv_r,
                adv_c,
            ) in dataloader:
                self._update_reward_critic(obs, target_value_r)
                if self._cfgs.algo_cfgs.use_cost:
                    self._update_cost_critic(obs, target_value_c)
                self._update_actor(obs, act, logp, adv_r, adv_c)

            new_distribution = self._actor_critic.actor(original_obs)

            kl = (
                torch.distributions.kl.kl_divergence(old_distribution, new_distribution)
                .sum(-1, keepdim=True)
                .mean()
                .item()
            )
            kl = distributed.dist_avg(kl)

            final_kl = kl
            update_counts += 1

            if self._cfgs.algo_cfgs.kl_early_stop and kl > self._cfgs.algo_cfgs.target_kl:
                self._logger.log(f'Early stopping at iter {i + 1} due to reaching max kl')
                break


        self._logger.store(
            {
                'Train/StopIter': update_counts,  # pylint: disable=undefined-loop-variable
                'Value/Adv': adv_r.mean().item(),
                'Train/KL': final_kl,
                'Misc/LagrangeMultiplier': self._lagrange.lagrangian_multiplier.item(),
                'Misc/alpha': self.alpha,
            },
        )
