"""
PPO implementation with PyTorch.

Code based on:
https://github.com/vwxyzjn/cleanrl
"""

from typing import Union
import numpy.typing as npt

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch.distributions import Categorical, Normal

from agents.base import AbstractAgent
from utils.rollout_buffer import RolloutBuffer


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Agent(nn.Module):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        hidden_width: int,
        continuous_actions: bool = False,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()

        self.device = device
        self.continuous_actions = continuous_actions
        self.n_actions = n_actions
        self.critic = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_width)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_width, hidden_width)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_width, 1), std=1.0),
        ).to(device)
        if continuous_actions:
            self.actor = nn.Sequential(
                layer_init(nn.Linear(state_dim, hidden_width)),
                nn.Tanh(),
                layer_init(nn.Linear(hidden_width, hidden_width)),
                nn.Tanh(),
                layer_init(nn.Linear(hidden_width, n_actions), std=0.01),
                nn.Sigmoid(),
            ).to(device)
        else:
            self.actor = nn.Sequential(
                layer_init(nn.Linear(state_dim, hidden_width)),
                nn.Tanh(),
                layer_init(nn.Linear(hidden_width, hidden_width)),
                nn.Tanh(),
                layer_init(nn.Linear(hidden_width, n_actions), std=0.01),
            ).to(device)
        
        if continuous_actions:
            self.actor_logstd = nn.Parameter(torch.zeros(1, n_actions).to(device))


    def get_value(self, state: torch.Tensor) -> torch.Tensor:
        state = state.to(self.device)
        return self.critic(state)

    def act(self, state: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        state = torch.Tensor(state).to(self.device)
                
        if self.continuous_actions:
            # stochastic action
            action_mean = self.actor(state).view(-1, self.n_actions)
            action_logstd = self.actor_logstd.expand_as(action_mean)
            action_std = torch.exp(action_logstd)
            probs = Normal(action_mean, action_std)
            action = probs.sample().view(-1)
        else:
            # Lending environment
            # deterministic action
            logits = self.actor(state)
            action = torch.argmax(logits)
        return action.detach().cpu().numpy()

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        state = state.to(self.device)
        
        if self.continuous_actions:
            action_mean = self.actor(state).view(-1, self.n_actions)            
            action_logstd = self.actor_logstd.expand_as(action_mean)            
            action_std = torch.exp(action_logstd)
            probs = Normal(action_mean, action_std)
            if action is None:
                action = probs.sample()
            return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(state)
        else:
            # Lenting environment
            logits = self.actor(state)
            probs = Categorical(logits=logits)
            if action is None:
                action = probs.sample()
            return action, probs.log_prob(action), probs.entropy(), self.critic(state)


class PPO(AbstractAgent):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        continuous_actions: bool = False,
        hidden_width: int = 256,
        learning_rate: float = 2.5e-4,
        final_learning_rate: float = 1e-4,
        batch_size: int = 128,
        mini_batch_size: int = 32,
        update_epochs: int = 4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_coef: float = 0.2,
        norm_adv: bool = True,
        clip_vloss: bool = True,
        ent_coef: float = 0.01,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        target_kl: Union[float, None] = None,
        use_anneal_lr: bool = True,
        device: torch.device = torch.device("cpu"),
    ):
        self.state_dim = state_dim
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.continuous_actions = continuous_actions
        self.batch_size = batch_size
        self.mini_batch_size = mini_batch_size
        self.update_epochs = update_epochs
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.ent_coef = ent_coef
        self.clip_coef = clip_coef
        self.norm_adv = norm_adv
        self.clip_vloss = clip_vloss
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.use_anneal_lr = use_anneal_lr
        self.device = device

        self.agent = Agent(state_dim, n_actions, hidden_width, continuous_actions, device).to(device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=learning_rate, eps=1e-5)

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        return self.agent.get_action_and_value(state, action)

    def act(self, state: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
        return self.agent.act(state)

    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.optimizer.param_groups[0]["lr"] = lr

    def update(
        self, buffer: RolloutBuffer, last_state: torch.Tensor, last_done: torch.Tensor
    ) -> dict[str:float]:
        states, actions, logprobs, rewards, dones, values = buffer.get_data()

        # bootstrap value if not done
        with torch.no_grad():
            next_value = self.agent.get_value(last_state).reshape(1, -1)
            advantages = torch.zeros_like(rewards).to(self.device)
            lastgaelam = 0
            for t in reversed(range(self.batch_size)):
                if t == self.batch_size - 1:
                    nextnonterminal = 1.0 - last_done
                    nextvalues = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                delta = (
                    rewards[t] + self.gamma * nextvalues * nextnonterminal - values[t]
                )
                advantages[t] = lastgaelam = (
                    delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
                )
            returns = advantages + values

        # flatten the batch
        b_states = states.reshape(-1, self.state_dim)
        b_logprobs = logprobs.reshape(-1)        
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        
        if self.continuous_actions:
            b_actions = actions.reshape(-1, self.agent.n_actions)
        else:
            b_actions = actions.reshape(-1)

        # Optimizing the policy and value network
        b_inds = np.arange(self.batch_size)
        clipfracs = []
        for epoch in range(self.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, self.batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_inds = b_inds[start:end]

                if self.continuous_actions:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                    b_states[mb_inds], b_actions[mb_inds]
                    )
                else:
                    _, newlogprob, entropy, newvalue = self.agent.get_action_and_value(
                        b_states[mb_inds], b_actions.long()[mb_inds]
                    )
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > self.clip_coef).float().mean().item()
                    ]

                mb_advantages = b_advantages[mb_inds]
                if self.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (
                        mb_advantages.std() + 1e-8
                    )

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(
                    ratio, 1 - self.clip_coef, 1 + self.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if self.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -self.clip_coef,
                        self.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()

            if self.target_kl is not None:
                if approx_kl > self.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        metrics = {
            "charts/learning_rate": self.optimizer.param_groups[0]["lr"],
            "losses/value_loss": v_loss.item(),
            "losses/policy_loss": pg_loss.item(),
            "losses/entropy": entropy_loss.item(),
            "losses/old_approx_kl": old_approx_kl.item(),
            "losses/approx_kl": approx_kl.item(),
            "losses/clipfrac": np.mean(clipfracs),
            "losses/explained_variance": explained_var,
        }
        return metrics

    def save(self, save_path: str) -> None:
        os.makedirs(save_path, exist_ok=True)
        torch.save(self.agent.state_dict(), f"{save_path}/ppo.pt")
    
    def load(self, load_path: str) -> None:
        self.agent.load_state_dict(torch.load(f"{load_path}/ppo.pt"))
