# # Copyright (c) 2017 Ilya Kostrikov
# #
# # Licensed under the MIT License;
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# #     https://opensource.org/licenses/MIT
# #
# # This file is a modified version of:
# # https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/algo/ppo.py


# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.optim as optim
# import random

# class PPO():
#     """
#     Vanilla PPO
#     """
#     def __init__(self,
#                  actor_critic,
#                  clip_param,
#                  ppo_epoch,
#                  num_mini_batch,
#                  value_loss_coef,
#                  entropy_coef,
#                  kl_loss_coef=0.0,
#                  lr=None,
#                  eps=None,
#                  max_grad_norm=None,
#                  clip_value_loss=True,
#                  log_grad_norm=False):

#         self.actor_critic = actor_critic

#         self.clip_param = clip_param
#         self.ppo_epoch = ppo_epoch
#         self.num_mini_batch = num_mini_batch
#         self.clip_value_loss = clip_value_loss

#         self.value_loss_coef = value_loss_coef
#         self.entropy_coef = entropy_coef
#         self.kl_loss_coef = kl_loss_coef

#         self.max_grad_norm = max_grad_norm

#         self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps)

#         self.log_grad_norm = log_grad_norm

#     def _grad_norm(self):
#         total_norm = 0
#         for p in self.actor_critic.parameters():
#             if p.grad is not None:
#                 param_norm = p.grad.data.norm(2)
#                 total_norm += param_norm.item() ** 2
#         total_norm = total_norm ** (1. / 2)
#         return total_norm

#     def update(self, rollouts, discard_grad=False, kl_dict=None):
#         use_kl_loss = (kl_dict is not None) and (self.kl_loss_coef > 0.0) and (discard_grad is False)

#         if rollouts.use_popart:
#             value_preds = rollouts.denorm_value_preds
#         else:
#             value_preds = rollouts.value_preds

#         advantages = rollouts.returns[:-1] - value_preds[:-1]
#         advantages = (advantages - advantages.mean()) / (
#             advantages.std() + 1e-5)

#         value_loss_epoch = 0
#         action_loss_epoch = 0
#         dist_entropy_epoch = 0
#         if use_kl_loss:
#             kl_loss_epoch = 0

#         if self.log_grad_norm:
#             grad_norms = []

#         for e in range(self.ppo_epoch):
#             if self.actor_critic.is_recurrent:
#                 data_generator = rollouts.recurrent_generator(
#                     advantages, self.num_mini_batch)
#             else:
#                 data_generator = rollouts.feed_forward_generator(
#                     advantages, self.num_mini_batch)

#             for sample in data_generator:
#                 obs_batch, recurrent_hidden_states_batch, actions_batch, \
#                 value_preds_batch, return_batch, masks_batch, old_action_log_probs_batch, \
#                         adv_targ = sample

#                 if use_kl_loss:
#                     values, action_log_probs, dist_entropy, _, dist_protagonist = self.actor_critic.evaluate_actions(
#                         obs_batch, recurrent_hidden_states_batch, masks_batch,
#                         actions_batch, return_policy_logits=True)
#                     with torch.no_grad():
#                         _, _, _, _, dist_antagonist = kl_dict['antagonist_model'].evaluate_actions(
#                         obs_batch, recurrent_hidden_states_batch, masks_batch,
#                         actions_batch, return_policy_logits=True)
#                 else:
#                     values, action_log_probs, dist_entropy, _ = self.actor_critic.evaluate_actions(
#                         obs_batch, recurrent_hidden_states_batch, masks_batch,
#                         actions_batch)

#                 ratio = torch.exp(action_log_probs -
#                                   old_action_log_probs_batch)
#                 surr1 = ratio * adv_targ
#                 surr2 = torch.clamp(ratio, 1.0 - self.clip_param,
#                                     1.0 + self.clip_param) * adv_targ
#                 action_loss = -torch.min(surr1, surr2).mean()

#                 if rollouts.use_popart:
#                     self.actor_critic.popart.update(return_batch)
#                     return_batch = self.actor_critic.popart.normalize(return_batch)

#                 if self.clip_value_loss:
#                     value_pred_clipped = value_preds_batch + \
#                         (values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
#                     value_losses = (values - return_batch).pow(2)
#                     value_losses_clipped = (
#                         value_pred_clipped - return_batch).pow(2)
#                     value_loss = 0.5 * torch.max(value_losses,
#                                                     value_losses_clipped).mean()
#                 else:
#                     value_loss = F.smooth_l1_loss(values, return_batch)

#                 if use_kl_loss:
#                     kl_div = torch.distributions.kl.kl_divergence(dist_antagonist, dist_protagonist)
#                     bs = kl_div.shape[0]
#                     kl_loss = kl_div.sum() / bs

#                 self.optimizer.zero_grad()
#                 loss = (value_loss*self.value_loss_coef + action_loss - dist_entropy*self.entropy_coef)
#                 if use_kl_loss:
#                     loss += (self.kl_loss_coef*kl_loss)

#                 loss.backward()

#                 if self.log_grad_norm:
#                     grad_norms.append(self._grad_norm())

#                 if self.max_grad_norm is not None and self.max_grad_norm > 0:
#                     nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
#                                             self.max_grad_norm)

#                 if not discard_grad:
#                     self.optimizer.step()

#                 value_loss_epoch += value_loss.item()
#                 action_loss_epoch += action_loss.item()
#                 dist_entropy_epoch += dist_entropy.item()
#                 if use_kl_loss:
#                     kl_loss_epoch += kl_loss.item()

#         num_updates = self.ppo_epoch * self.num_mini_batch

#         value_loss_epoch /= num_updates
#         action_loss_epoch /= num_updates
#         dist_entropy_epoch /= num_updates
#         if use_kl_loss:
#             kl_loss_epoch /= num_updates

#         info = {}
#         if self.log_grad_norm:
#             info = {'grad_norms': grad_norms}
#         if use_kl_loss:
#             info['kl_loss'] = kl_loss_epoch

#         return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, info


# Copyright (c) 2017 Ilya Kostrikov
#
# Licensed under the MIT License;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://opensource.org/licenses/MIT
#
# This file is a modified version of:
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/master/a2c_ppo_acktr/algo/ppo.py

# Copyright (c) 2017 Ilya Kostrikov - Modified
# MIT License

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np


def _safe_flatten_grad(params):
    """Concatenate all finite grads into a single 1D float32 tensor (pre-clip). Non-finite set to 0."""
    flats = []
    for p in params:
        if p.grad is None:
            continue
        g = p.grad.detach().to(torch.float32).view(-1)
        if not torch.isfinite(g).all():
            g = g.clone()
            g[~torch.isfinite(g)] = 0.0
        flats.append(g)
    if len(flats) == 0:
        return torch.zeros(0, dtype=torch.float32, device=params[0].device)
    return torch.cat(flats)


def _total_l2_norm(params):
    """L2 norm over all params' grads (assumes pre-clip or post-clip based on call site)."""
    s = 0.0
    for p in params:
        if p.grad is None:
            continue
        g = p.grad.detach().to(torch.float32)
        if not torch.isfinite(g).all():
            g = g.clone()
            g[~torch.isfinite(g)] = 0.0
        s += float((g * g).sum().item())
    return s**0.5


class PPO:
    """
    Vanilla PPO with gradient metrics:
      - grad_var (pooled, per-update)
      - grad_dir_cos (between updates)
      - clip_fraction (+ pre/post clip norm means)
    """

    def __init__(
        self,
        actor_critic,
        clip_param,
        ppo_epoch,
        num_mini_batch,
        value_loss_coef,
        entropy_coef,
        kl_loss_coef=0.0,
        lr=None,
        eps=None,
        max_grad_norm=None,
        clip_value_loss=True,
        log_grad_norm=False,
    ):
        self.actor_critic = actor_critic
        self.clip_param = clip_param
        self.ppo_epoch = ppo_epoch
        self.num_mini_batch = num_mini_batch
        self.clip_value_loss = clip_value_loss
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.kl_loss_coef = kl_loss_coef
        self.max_grad_norm = max_grad_norm
        self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps)
        self.log_grad_norm = log_grad_norm

        # for grad_dir_cos (store previous update's avg grad vector)
        self._prev_update_grad = None  # torch.Tensor(1D) on CPU

    def update(self, rollouts, discard_grad=False, kl_dict=None, compute_grad_metrics=True):
        use_kl_loss = (
            (kl_dict is not None) and (self.kl_loss_coef > 0.0) and (discard_grad is False)
        )

        if rollouts.use_popart:
            value_preds = rollouts.denorm_value_preds
        else:
            value_preds = rollouts.value_preds

        advantages = rollouts.returns[:-1] - value_preds[:-1]
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-5)

        value_loss_epoch = 0.0
        action_loss_epoch = 0.0
        dist_entropy_epoch = 0.0
        if use_kl_loss:
            kl_loss_epoch = 0.0

        # === accumulators for metrics over minibatches (per update) ===
        n_total = 0.0  # total elements across all mb (for pooled var)
        sum_total = 0.0  # sum of grad elements
        sumsq_total = 0.0  # sum of squares of grad elements

        pre_norm_list = []
        post_norm_list = []
        clipped_cnt = 0
        mb_cnt = 0

        # For grad_dir_cos: accumulate the (pre-clip) grad vector across minibatches, then average
        avg_grad_vec_accum = None  # on CPU to reduce GPU pressure

        for e in range(self.ppo_epoch):
            if self.actor_critic.is_recurrent:
                data_generator = rollouts.recurrent_generator(
                    advantages, self.num_mini_batch
                )
            else:
                data_generator = rollouts.feed_forward_generator(
                    advantages, self.num_mini_batch
                )

            for sample in data_generator:
                (
                    obs_batch,
                    recurrent_hidden_states_batch,
                    actions_batch,
                    value_preds_batch,
                    return_batch,
                    masks_batch,
                    old_action_log_probs_batch,
                    adv_targ,
                ) = sample

                if use_kl_loss:
                    values, action_log_probs, dist_entropy, _, dist_protagonist = (
                        self.actor_critic.evaluate_actions(
                            obs_batch,
                            recurrent_hidden_states_batch,
                            masks_batch,
                            actions_batch,
                            return_policy_logits=True,
                        )
                    )
                    with torch.no_grad():
                        _, _, _, _, dist_antagonist = kl_dict[
                            "antagonist_model"
                        ].evaluate_actions(
                            obs_batch,
                            recurrent_hidden_states_batch,
                            masks_batch,
                            actions_batch,
                            return_policy_logits=True,
                        )
                else:
                    values, action_log_probs, dist_entropy, _ = (
                        self.actor_critic.evaluate_actions(
                            obs_batch,
                            recurrent_hidden_states_batch,
                            masks_batch,
                            actions_batch,
                        )
                    )

                ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
                surr1 = ratio * adv_targ
                surr2 = (
                    torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param)
                    * adv_targ
                )
                action_loss = -torch.min(surr1, surr2).mean()

                if rollouts.use_popart:
                    self.actor_critic.popart.update(return_batch)
                    return_batch = self.actor_critic.popart.normalize(return_batch)

                if self.clip_value_loss:
                    value_pred_clipped = value_preds_batch + (
                        values - value_preds_batch
                    ).clamp(-self.clip_param, self.clip_param)
                    value_losses = (values - return_batch).pow(2)
                    value_losses_clipped = (value_pred_clipped - return_batch).pow(2)
                    value_loss = 0.5 * torch.max(value_losses, value_losses_clipped).mean()
                else:
                    value_loss = F.smooth_l1_loss(values, return_batch)

                if use_kl_loss:
                    kl_div = torch.distributions.kl.kl_divergence(
                        dist_antagonist, dist_protagonist
                    )
                    bs = kl_div.shape[0]
                    kl_loss = kl_div.sum() / bs

                self.optimizer.zero_grad(set_to_none=True)
                loss = (
                    value_loss * self.value_loss_coef
                    + action_loss
                    - dist_entropy * self.entropy_coef
                )
                if use_kl_loss:
                    loss += self.kl_loss_coef * kl_loss
                loss.backward()

                # If using AMP with GradScaler as self.scaler, unscale before reading grads
                if hasattr(self, "scaler"):
                    try:
                        self.scaler.unscale_(self.optimizer)
                    except Exception:
                        pass

                # === metrics (pre-clip) ===
                if compute_grad_metrics:
                    # pooled sums for variance
                    for p in self.actor_critic.parameters():
                        if p.grad is None:
                            continue
                        g = p.grad.detach().to(torch.float32).view(-1)
                        if not torch.isfinite(g).all():
                            g = g.clone()
                            g[~torch.isfinite(g)] = 0.0
                        n_total += g.numel()
                        sum_total += float(g.sum().item())
                        sumsq_total += float((g * g).sum().item())

                    # pre/post norms + clip decision
                    pre_norm = _total_l2_norm(self.actor_critic.parameters())
                    pre_norm_list.append(pre_norm)

                # gradient clipping
                did_clip = False
                if (self.max_grad_norm is not None) and (self.max_grad_norm > 0):
                    total_norm = nn.utils.clip_grad_norm_(
                        self.actor_critic.parameters(), self.max_grad_norm
                    )
                    did_clip = float(total_norm) > float(self.max_grad_norm) + 1e-9
                if compute_grad_metrics:
                    post_norm = _total_l2_norm(self.actor_critic.parameters())
                    post_norm_list.append(post_norm)
                    if did_clip:
                        clipped_cnt += 1

                # accumulate grad vector for direction cosine (use pre-clip grads)
                if compute_grad_metrics:
                    full_g = (
                        _safe_flatten_grad(self.actor_critic.parameters())
                        .detach()
                        .to("cpu")
                    )
                    if avg_grad_vec_accum is None:
                        avg_grad_vec_accum = full_g.clone()
                    else:
                        # pad if shapes differ across steps (shouldn't, but safe)
                        if full_g.numel() != avg_grad_vec_accum.numel():
                            # fallback: use element-wise min length
                            m = min(full_g.numel(), avg_grad_vec_accum.numel())
                            avg_grad_vec_accum[:m] += full_g[:m]
                        else:
                            avg_grad_vec_accum += full_g

                if not discard_grad:
                    # If AMP scaler present, step via scaler
                    if hasattr(self, "scaler"):
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                value_loss_epoch += float(value_loss.item())
                action_loss_epoch += float(action_loss.item())
                dist_entropy_epoch += float(dist_entropy.item())
                if use_kl_loss:
                    kl_loss_epoch += float(kl_loss.item())

                mb_cnt += 1

        # average losses across updates
        num_updates = max(1, self.ppo_epoch * self.num_mini_batch)
        value_loss_epoch /= num_updates
        action_loss_epoch /= num_updates
        dist_entropy_epoch /= num_updates
        if use_kl_loss:
            kl_loss_epoch /= num_updates

        info = {}
        if use_kl_loss:
            info["kl_loss"] = kl_loss_epoch

        # === finalize metrics ===
        if compute_grad_metrics and mb_cnt > 0 and n_total > 0:
            mu = sum_total / n_total
            var = max(
                0.0, (sumsq_total / n_total) - mu * mu
            )  # element-wise pooled variance
            pre_mean = float(np.mean(pre_norm_list)) if len(pre_norm_list) else 0.0
            post_mean = float(np.mean(post_norm_list)) if len(post_norm_list) else 0.0
            clip_fraction = float(clipped_cnt) / float(mb_cnt)

            # grad_dir_cos against previous update (use average grad vector over minibatches)
            grad_dir_cos = None
            if avg_grad_vec_accum is not None:
                g_avg = avg_grad_vec_accum / float(mb_cnt)
                # sanitize
                if g_avg.numel() > 0:
                    g_avg_np = g_avg.numpy()
                    if (
                        self._prev_update_grad is not None
                        and self._prev_update_grad.numel() == g_avg.numel()
                    ):
                        a = g_avg_np
                        b = self._prev_update_grad.numpy()
                        denom = np.linalg.norm(a) * np.linalg.norm(b) + 1e-12
                        grad_dir_cos = float(np.dot(a, b) / denom)
                    # store current for next time
                    self._prev_update_grad = g_avg.clone()

            info["grad_stats"] = {
                "grad_var": float(var),
                "pre_clip_norm_mean": pre_mean,
                "post_clip_norm_mean": post_mean,
                "clip_fraction": clip_fraction,
                "grad_dir_cos": (None if grad_dir_cos is None else float(grad_dir_cos)),
                "num_minibatches": int(mb_cnt),
            }

        return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, info
