"""
ELBERT-PO implementation with PyTorch.

Code based on:
https://github.com/umd-huang-lab/ELBERT
"""

from typing import Union
import numpy.typing as npt
import copy

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

from agents import AbstractAgent
from agents.ppo import layer_init
from utils.rollout_buffer import ELBERTRolloutBuffer

from fair_gym import AcceptRejectAction


def get_supply_demand(group, action, success):
    r_U_0 = 0
    r_B_0 = 0
    r_U_1 = 0
    r_B_1 = 0
    
    assert type(success) == bool, 'will_pay should be a boolean'
    group_id = np.argmax(group)
    
    if group_id == 0:
        if success:
            r_B_0 = 1
            if action == AcceptRejectAction.ACCEPT.value:
                r_U_0 = 1
    elif group_id == 1:
        if success:
            r_B_1 = 1
            if action == AcceptRejectAction.ACCEPT.value:
                r_U_1 = 1
    else:
        raise ValueError('invalid group_id')
    return [r_U_0, r_U_1], [r_B_0, r_B_1]


def smooth_max(x,beta):
    '''
    log sum trick
    when beta<0, this is smooth_min
    x: (num_groups,), the R_U/R_B values of each group
    '''
    assert isinstance(x,torch.Tensor), 'x should be a torch.Tensor'
    assert len(x.size()) == 1, 'x should be flat'
    y = x * beta
    y = torch.logsumexp(y,dim=0)
    y = y / beta
    
    return y

def soft_bias_value_and_gradient(x, beta):
    '''
    soft_bias = smooth_max(x, beta) - smooth_max(x, -beta)
    compute the value of soft_bias and the gradient of soft_bias w.r.t. the input x

    If num_group == 2, use hard bias instead
    '''
    assert beta is not None, 'beta for computing the soft bias is None. Please specify it'
    assert isinstance(x, torch.Tensor), 'the type of input should be torch.Tensor'
    num_groups = x.size(1)
    assert num_groups > 1, 'There should be at least two groups in the environment'
    
    if num_groups == 2:
        bias = x.max() - x.min()
        bias_grad = torch.ones(2, device = x.device)
        bias_grad[torch.argmin(x)] = -1
        return bias, bias_grad

    # following: num_groups > 2
    assert num_groups > 2, 'When there are only two groups, do not need to use soft_bias!'

    x.requires_grad = True

    soft_max = smooth_max(x,beta)
    soft_min = smooth_max(x,-beta)

    soft_bias = soft_max - soft_min
    soft_bias_grad = torch.autograd.grad(outputs=soft_bias, inputs=x)[0].detach()
    x.requires_grad = False

    return (soft_bias).detach(), soft_bias_grad


class Agent(nn.Module):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        group_dim: int,
        hidden_width: int,
        device: torch.device = torch.device("cpu"),
    ):
        super().__init__()

        self.device = device
        self.n_actions = n_actions
        self.group_dim = group_dim
        self.feature_extractor = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_width)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_width, hidden_width)),
            nn.Tanh(),
        ).to(device)
        
        self.critic = nn.Sequential(            
            layer_init(nn.Linear(hidden_width, 1), std=1.0),
        ).to(device)
        
        supply_critic = []
        demand_critic = []
        for _ in range(self.group_dim):
            supply_critic.append(nn.Sequential(
                layer_init(nn.Linear(hidden_width, 1), std=1.0),
            ).to(device))
            demand_critic.append(nn.Sequential(
                layer_init(nn.Linear(hidden_width, 1), std=1.0),
            ).to(device))
        self.supply_critic = nn.ModuleList(supply_critic)
        self.demand_critic = nn.ModuleList(demand_critic)        
        
        self.actor = nn.Sequential(
            layer_init(nn.Linear(state_dim, hidden_width)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_width, hidden_width)),
            nn.Tanh(),
            layer_init(nn.Linear(hidden_width, n_actions), std=0.01),
        ).to(device)

    def get_value(self, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        state = state.to(self.device)
        features = self.feature_extractor(state)
        
        value = self.critic(features)
        supply_values = []
        demand_values = []
        for i in range(self.group_dim):
            supply_values.append(self.supply_critic[i](features))
            demand_values.append(self.demand_critic[i](features)) 
        supply_values = torch.cat(supply_values, dim=-1)
        demand_values = torch.cat(demand_values, dim=-1)           
        return value, supply_values, demand_values

    def act(self, state: Union[torch.Tensor, npt.NDArray]) -> npt.NDArray:
        state = torch.Tensor(state).to(self.device)
        logits = self.actor(state)
        action = torch.argmax(logits)
        return action.detach().cpu().numpy()

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        state = state.to(self.device)
        
        logits = self.actor(state)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        value, supply_values, demand_values = self.get_value(state)
        return action, probs.log_prob(action), probs.entropy(), value, supply_values, demand_values


class ELBERT(AbstractAgent):
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        group_dim: int, 
        continuous_actions: bool = False,
        hidden_width: int = 256,
        learning_rate: float = 2.5e-4,
        final_learning_rate: float = 1e-4,
        batch_size: int = 128,
        mini_batch_size: int = 32,
        update_epochs: int = 4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_coef: float = 0.2,
        norm_adv: bool = True,
        ent_coef: float = 0.01,
        vf_coef: float = 0.5,
        max_grad_norm: float = 0.5,
        target_kl: Union[float, None] = None,
        use_anneal_lr: bool = True,
        bias_coef: float = 200000,
        beta_smooth: float = 20,
        device: torch.device = torch.device("cpu"),
    ):
        self.state_dim = state_dim
        self.group_dim = group_dim
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.continuous_actions = continuous_actions
        self.batch_size = batch_size
        self.mini_batch_size = mini_batch_size
        self.update_epochs = update_epochs
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.ent_coef = ent_coef
        self.clip_coef = clip_coef
        self.norm_adv = norm_adv
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.use_anneal_lr = use_anneal_lr
        self.bias_coef = bias_coef
        self.beta_smooth = beta_smooth
        self.device = device

        self.agent = Agent(state_dim, n_actions, group_dim, hidden_width, device).to(device)
        self.optimizer = optim.Adam(self.agent.parameters(), lr=learning_rate, eps=1e-5)

    def get_action_and_value(
        self, state: torch.Tensor, action: Union[torch.Tensor, None] = None
    ):
        return self.agent.get_action_and_value(state, action)

    def act(self, state: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
        return self.agent.act(state)

    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.optimizer.param_groups[0]["lr"] = lr
            
    def update(
        self, buffer: ELBERTRolloutBuffer, last_state: torch.Tensor, last_done: torch.Tensor
    ) -> dict[str:float]:
        states, actions, logprobs, rewards, dones, values, supply_rewards, demand_rewards, supply_values, demand_values = buffer.get_data()

        # bootstrap value if not done        
        with torch.no_grad():
            next_value, next_supply_value, next_demand_value = self.agent.get_value(last_state)
            next_value = next_value.reshape(-1, 1)
            next_supply_value = next_supply_value.reshape(-1, self.group_dim)
            next_demand_value = next_demand_value.reshape(-1, self.group_dim)
            
            advantages = torch.zeros_like(rewards).to(self.device)
            supply_advantages = torch.zeros_like(supply_rewards).to(self.device)
            demand_advantages = torch.zeros_like(demand_rewards).to(self.device)
            
            lastgaelam = 0
            last_supply_gaelam = 0
            last_demand_gaelam = 0
            
            for t in reversed(range(self.batch_size)):
                if t == self.batch_size - 1:
                    nextnonterminal = 1.0 - last_done
                    nextvalues = next_value
                    next_supply_values = next_value
                    next_demand_values = next_value
                else:
                    nextnonterminal = 1.0 - dones[t + 1]
                    nextvalues = values[t + 1]
                    next_supply_values = supply_values[t + 1]
                    next_demand_values = demand_values[t + 1]
                delta = (
                    rewards[t] + self.gamma * nextvalues * nextnonterminal - values[t]
                )
                supply_delta = (
                    supply_rewards[t] + self.gamma * next_supply_values * nextnonterminal - supply_values[t]
                )
                demand_delta = (
                    demand_rewards[t] + self.gamma * next_demand_values * nextnonterminal - demand_values[t]
                )
                advantages[t] = lastgaelam = (
                    delta + self.gamma * self.gae_lambda * nextnonterminal * lastgaelam
                )
                supply_advantages[t] = last_supply_gaelam = (
                    supply_delta + self.gamma * self.gae_lambda * nextnonterminal * last_supply_gaelam
                )
                demand_advantages[t] = last_demand_gaelam = (
                    demand_delta + self.gamma * self.gae_lambda * nextnonterminal * last_demand_gaelam
                )
            
            returns = advantages + values
            supply_returns = supply_advantages + supply_values
            demand_returns = demand_advantages + demand_values

        # flatten the batch
        b_states = states.reshape(-1, self.state_dim)
        b_logprobs = logprobs.reshape(-1)        
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_supply_returns = supply_returns.reshape(-1, self.group_dim)
        b_demand_returns = demand_returns.reshape(-1, self.group_dim)
        b_values = values.reshape(-1)
        b_actions = actions.reshape(-1)
        
        # Get fairness ratio        
        supply_value_estimate = torch.sum(supply_rewards, dim=0)
        demand_value_estimate = torch.sum(demand_rewards, dim=0)
        fairness_ratio = supply_value_estimate / demand_value_estimate
        fairness_ratio[torch.isnan(fairness_ratio)] = 0
        
        # soft_bias_grad: gradient of soft bias w.r.t the ratio 
        soft_bias, soft_bias_grad = soft_bias_value_and_gradient(copy.deepcopy(fairness_ratio), self.beta_smooth)
        # In the paper, h = soft_bias**2, so partial_h/partial_z = 2 * soft_bias * soft_bias_grad 
        grad_h = 2 * soft_bias * soft_bias_grad
        
        # advantage version of gradient of U/B (using chain rule formula of grad_U/B)
        advantages_grad_ratio_U_B = (1 / demand_value_estimate) * supply_advantages - (supply_value_estimate / demand_value_estimate**2) * demand_advantages
        advantages_grad_ratio_U_B = advantages_grad_ratio_U_B.reshape(-1, self.group_dim)
        advantages_grad_ratio_U_B[torch.isnan(advantages_grad_ratio_U_B)] = 0
        
        advantages_fair = b_advantages + torch.matmul(advantages_grad_ratio_U_B, grad_h) * (self.bias_coef)
        advantages_fair[torch.isnan(advantages_fair)] = 0
        
        # Optimizing the policy and value network
        b_inds = np.arange(self.batch_size)
        clipfracs = []
        for epoch in range(self.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, self.batch_size, self.mini_batch_size):
                end = start + self.mini_batch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue, new_supply_value, new_demand_value = self.agent.get_action_and_value(
                    b_states[mb_inds], b_actions.long()[mb_inds]
                )
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [
                        ((ratio - 1.0).abs() > self.clip_coef).float().mean().item()
                    ]

                mb_advantages_fair = advantages_fair[mb_inds]
                if self.norm_adv:
                    mb_advantages_fair = (mb_advantages_fair - mb_advantages_fair.mean()) / (
                        mb_advantages_fair.std() + 1e-8
                    )

                # Policy loss
                pg_loss1 = -mb_advantages_fair * ratio
                pg_loss2 = -mb_advantages_fair * torch.clamp(
                    ratio, 1 - self.clip_coef, 1 + self.clip_coef
                )
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                new_supply_value = new_supply_value.view(-1, self.group_dim)
                new_demand_value = new_demand_value.view(-1, self.group_dim)
                
                v_loss = 0
                v_loss += 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
                v_loss += 0.5 * ((new_supply_value - b_supply_returns[mb_inds]) ** 2).mean()
                v_loss += 0.5 * ((new_demand_value - b_demand_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.vf_coef

                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.agent.parameters(), self.max_grad_norm)
                self.optimizer.step()

            if self.target_kl is not None:
                if approx_kl > self.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        metrics = {
            "charts/learning_rate": self.optimizer.param_groups[0]["lr"],
            "losses/value_loss": v_loss.item(),
            "losses/policy_loss": pg_loss.item(),
            "losses/entropy": entropy_loss.item(),
            "losses/old_approx_kl": old_approx_kl.item(),
            "losses/approx_kl": approx_kl.item(),
            "losses/clipfrac": np.mean(clipfracs),
            "losses/explained_variance": explained_var,
        }
        return metrics

    def save(self, save_path: str) -> None:
        os.makedirs(save_path, exist_ok=True)
        torch.save(self.agent.state_dict(), f"{save_path}/ppo.pt")
    
    def load(self, load_path: str) -> None:
        self.agent.load_state_dict(torch.load(f"{load_path}/ppo.pt"))
