"""
Advantage regularized PPO (A-PPO) agent 
https://arxiv.org/abs/2210.12546

Code based on:
https://github.com/ericyangyu/pocar
"""

from typing import Union

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from agents import AbstractAgent, Agent
from utils.rollout_buffer import APPORolloutBuffer


class APPO(AbstractAgent):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        continuous_actions: bool = False,
        hidden_width: int = 256,
        learning_rate: float = 2.5e-4,
        final_learning_rate: float = 1e-4,
        batch_size: int = 128,
        mini_batch_size: int = 32,
        update_epochs: int = 4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_coef: float = 0.2,
        norm_adv: bool = True,
        clip_vloss: bool = True,
        ent_coef: float = 0.01,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        target_kl: Union[float, None] = None,
        use_anneal_lr: bool = True,
        omega: float = 0.005,
        beta_0: float = 1.0, 
        beta_1: float = 0.25, 
        beta_2: float = 0.25,
        device: torch.device = torch.device("cpu"),
    ):
        self.state_dim = state_dim
        self.continuous_actions = continuous_actions
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.batch_size = batch_size
        self.mini_batch_size = mini_batch_size
        self.update_epochs = update_epochs
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.ent_coef = ent_coef
        self.clip_coef = clip_coef
        self.norm_adv = norm_adv
        self.clip_vloss = clip_vloss
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.use_anneal_lr = use_anneal_lr
        self.omega = omega
        self.beta_0 = beta_0
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.device = device

        self.agent = Agent(state_dim, n_actions, hidden_width, continuous_actions, device).to(device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=learning_rate, eps=1e-5)

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        return self.agent.get_action_and_value(state, action)

    def act(self, state: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
        return self.agent.act(state)

    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.optimizer.param_groups[0]["lr"] = lr

    def update(
        self, buffer: APPORolloutBuffer, last_state: torch.Tensor, last_done: torch.Tensor
    ) -> dict[str:float]:
        states, actions, logprobs, rewards, dones, values, deltas, delta_deltas = buffer.get_data()

        # bootstrap value if not done
        with torch.no_grad():
            next_value = self.agent.get_value(last_state).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(self.device)
            lastgaelam = 0
            for t in reversed(range(self.batch_size)):
                if t == self.batch_size - 1:
                    nextnonterminal = 1.0 - last_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = (
                    rewards[t] + self.gamma * nextvalues * nextnonterminal - values[t]
                )
                advantages[t] = lastgaelam = (
                    delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
                )
            
            # Advantage regularization for fairness here
            # Compute value-thresholding (vt) term as part of Eq. 3 from the paper
            vt_term = torch.min(
                torch.zeros_like(deltas).to(self.device),
                -deltas + torch.tensor(self.omega, dtype=torch.float32)
            )

            # Compute decrease-in-violation (div) term as part of Eq. 3 from the paper
            div_cond = torch.where(deltas > torch.tensor(self.omega, dtype=torch.float32).to(self.device),
                                    torch.tensor(1, dtype=torch.float32).to(self.device),
                                    torch.tensor(0, dtype=torch.float32).to(self.device))
            div_term = torch.min(torch.zeros_like(delta_deltas).to(self.device),
                                    -div_cond * delta_deltas)

            # Bring the 3 terms to scale for numerical stability
            advantages = (advantages - torch.min(advantages)) / (torch.max(advantages) - torch.min(advantages) + 1e-8)
            vt_term = (vt_term - torch.min(vt_term)) / (torch.max(vt_term) - torch.min(vt_term) + 1e-8)
            div_term = (div_term - torch.min(div_term)) / (torch.max(div_term) - torch.min(div_term) + 1e-8)

            # Add terms to advantages
            advantages = (self.beta_0 * advantages + self.beta_1 * vt_term + self.beta_2 * div_term)
            
            returns = advantages + values

        # flatten the batch
        b_states = states.reshape(-1, self.state_dim)
        b_logprobs = logprobs.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        
        if self.continuous_actions:
            b_actions = actions.reshape(-1, self.agent.n_actions)
        else:
            b_actions = actions.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(self.batch_size)
        clipfracs = []
        for epoch in range(self.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, self.batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_inds = b_inds[start:end]

                if self.continuous_actions:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                    b_states[mb_inds], b_actions[mb_inds]
                    )
                else:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                        b_states[mb_inds], b_actions.long()[mb_inds]
                    )
                
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > self.clip_coef).float().mean().item()
                    ]

                mb_advantages = b_advantages[mb_inds]
                if self.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(
                    ratio, 1 - self.clip_coef, 1 + self.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if self.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -self.clip_coef,
                        self.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()

            if self.target_kl is not None:
                if approx_kl > self.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        metrics = {
            "charts/learning_rate": self.optimizer.param_groups[0]["lr"],
            "losses/value_loss": v_loss.item(),
            "losses/policy_loss": pg_loss.item(),
            "losses/entropy": entropy_loss.item(),
            "losses/old_approx_kl": old_approx_kl.item(),
            "losses/approx_kl": approx_kl.item(),
            "losses/clipfrac": np.mean(clipfracs),
            "losses/explained_variance": explained_var,
        }
        return metrics

    def save(self, save_path: str) -> None:
        os.makedirs(save_path, exist_ok=True)
        torch.save(self.agent.state_dict(), f"{save_path}/ppo.pt")
    
    def load(self, load_path: str) -> None:
        self.agent.load_state_dict(torch.load(f"{load_path}/ppo.pt"))
