"""
Lagrangian PPO agent
https://openreview.net/forum?id=JkIH4MeOc3

Code based on:
https://github.com/hercky/group-fairness-in-RL
"""

from typing import Union

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from agents import AbstractAgent, Agent
from utils.rollout_buffer import RolloutBuffer


class LagrangianPPO(AbstractAgent):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,        
        group_dim: int,
        group: int,
        continuous_actions: bool = False,
        hidden_width: int = 256,
        learning_rate: float = 2.5e-4,
        final_learning_rate: float = 1e-4,
        batch_size: int = 128,
        mini_batch_size: int = 32,
        update_epochs: int = 4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_coef: float = 0.2,
        norm_adv: bool = True,
        clip_vloss: bool = True,
        ent_coef: float = 0.01,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        target_kl: Union[float, None] = None,
        use_anneal_lr: bool = True,
        nu_init: float = 0.0,
        nu_max: float = 2.0,
        nu_lr: float = 0.01,
        epsilon: float = 10,
        device: torch.device = torch.device("cpu"),
    ):
        self.state_dim = state_dim
        self.group = group
        self.continuous_actions = continuous_actions
        self.num_groups = group_dim
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.batch_size = batch_size
        self.mini_batch_size = mini_batch_size
        self.update_epochs = update_epochs
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.ent_coef = ent_coef
        self.clip_coef = clip_coef
        self.norm_adv = norm_adv
        self.clip_vloss = clip_vloss
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.use_anneal_lr = use_anneal_lr
        self.device = device
        
        self.nu = np.ones((group_dim - 1, 2)) * nu_init
        self.nu_max = nu_max
        self.nu_lr = nu_lr
        self.epsilon = epsilon

        self.agent = Agent(state_dim, n_actions, hidden_width, continuous_actions, device).to(device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=learning_rate, eps=1e-5)

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        return self.agent.get_action_and_value(state, action)

    def act(self, state: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
        return self.agent.act(state)

    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.optimizer.param_groups[0]["lr"] = lr
    
    def project_nu(self, nu):
        """
        project nu in [0, nu_max] range

        :param nu:
        :return:
        """
        if nu < 0:
            nu = 0
        elif nu > self.nu_max:
            nu = self.nu_max
        return nu

    def update(
        self, buffer: RolloutBuffer, last_state: torch.Tensor, last_done: torch.Tensor, return_diff: list[torch.Tensor]
    ) -> dict[str:float]:
        states, actions, logprobs, rewards, dones, values = buffer.get_data()

        # bootstrap value if not done
        with torch.no_grad():
            next_value = self.agent.get_value(last_state).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(self.device)
            lastgaelam = 0
            for t in reversed(range(self.batch_size)):
                if t == self.batch_size - 1:
                    nextnonterminal = 1.0 - last_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = (
                    rewards[t] + self.gamma * nextvalues * nextnonterminal - values[t]
                )
                advantages[t] = lastgaelam = (
                    delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
                )
            returns = advantages + values

        # flatten the batch
        b_states = states.reshape(-1, self.state_dim)
        b_logprobs = logprobs.reshape(-1)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        
        if self.continuous_actions:
            b_actions = actions.reshape(-1, self.agent.n_actions)
        else:
            b_actions = actions.reshape(-1)
        
        # Update nu values
        assert len(return_diff) == self.num_groups - 1, "Incorrect number of return difference between subgroups"
        
        for z in range(self.num_groups - 1):
            self.nu[z][0] -= self.nu_lr * (self.epsilon - return_diff[z])
            self.nu[z][0] = self.project_nu(self.nu[z][0])

            self.nu[z][1] -= self.nu_lr * (self.epsilon + return_diff[z])
            self.nu[z][1] = self.project_nu(self.nu[z][1])

        # Optimizing the policy and value network
        b_inds = np.arange(self.batch_size)
        clipfracs = []
        for epoch in range(self.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, self.batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_inds = b_inds[start:end]

                if self.continuous_actions:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                    b_states[mb_inds], b_actions[mb_inds]
                    )
                else:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                        b_states[mb_inds], b_actions.long()[mb_inds]
                    )
                
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > self.clip_coef).float().mean().item()
                    ]

                mb_advantages = b_advantages[mb_inds]
                if self.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )
                
                # advantage multiplier
                adv_coefficient = 1.
                for z in range(self.num_groups - 1):
                    adv_coefficient += (-self.nu[z][0] + self.nu[z][1])

                # Policy loss
                pg_loss1 = -mb_advantages * adv_coefficient * ratio
                pg_loss2 = -mb_advantages * adv_coefficient * torch.clamp(
                    ratio, 1 - self.clip_coef, 1 + self.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if self.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -self.clip_coef,
                        self.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()

            if self.target_kl is not None:
                if approx_kl > self.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        metrics = {
            f"charts/agent_{self.group + 1}/learning_rate": self.optimizer.param_groups[0]["lr"],
            f"losses/agent_{self.group + 1}/value_loss": v_loss.item(),
            f"losses/agent_{self.group + 1}/policy_loss": pg_loss.item(),
            f"losses/agent_{self.group + 1}/entropy": entropy_loss.item(),
            f"losses/agent_{self.group + 1}/old_approx_kl": old_approx_kl.item(),
            f"losses/agent_{self.group + 1}/approx_kl": approx_kl.item(),
            f"losses/agent_{self.group + 1}/clipfrac": np.mean(clipfracs),
            f"losses/agent_{self.group + 1}/explained_variance": explained_var,
        }
        return metrics

    def save(self, save_path: str) -> None:
        os.makedirs(save_path, exist_ok=True)
        torch.save(self.agent.state_dict(), f"{save_path}/ppo.pt")
    
    def load(self, load_path: str) -> None:
        self.agent.load_state_dict(torch.load(f"{load_path}/ppo.pt"))
