import math
import torch
import torch.nn as nn
import torch.jit as jit
from typing import List, Tuple
import torch.nn.functional as F


class Linear(nn.Linear):
    
    def __init__(self, in_features, out_features, gain=math.sqrt(2), init_bias=0.0, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        nn.init.orthogonal_(self.weight, gain=gain)
        self.bias.data.fill_(init_bias)


class RNN(jit.ScriptModule):

    def __init__(self, h_dim=128):
        super().__init__()
        self.h_dim = h_dim
        self.rnn = nn.LSTM(h_dim, h_dim)
        self.norm = nn.LayerNorm(h_dim)

        self.rnn.cuda()
        self.rnn.flatten_parameters()

    @jit.script_method
    def forward(self, x, rnn_state: Tuple[torch.Tensor, torch.Tensor]):
        n_envs, n_agents, h_dim = x.shape
        n_batches = n_envs * n_agents
        x = x.view(1, n_batches, h_dim)
        x, rnn_state = self.rnn.forward(x, rnn_state)
        x = self.norm.forward(x)
        x = x.view(n_envs, n_agents, -1)
        return x, rnn_state

    @jit.script_method
    def forward_all(self, x, rnn_state: Tuple[torch.Tensor, torch.Tensor], dones):
        n_steps, n_envs, n_agents, h_dim = x.shape
        n_batches = n_envs * n_agents
        x = x.view(n_steps, n_batches, h_dim)
        outputs = []
        mask_ids: List[int] = dones.sum(-1).nonzero().squeeze(-1).tolist()
        mask_ids.append(n_steps)
        s_id = 0
        for e_id in mask_ids:
            x_slice = x[s_id:e_id+1]
            output, rnn_state = self.rnn.forward(x_slice, rnn_state)
            outputs.append(output)
            s_id = e_id+1
            if s_id >= n_steps:
                break
            masks = dones[e_id]
            h_state, c_state = rnn_state
            h_state = h_state.view(n_envs, n_agents, h_dim)
            c_state = c_state.view(n_envs, n_agents, h_dim)
            h_state[masks] = 0.0
            c_state[masks] = 0.0
            h_state = h_state.view(1, n_batches, h_dim)
            c_state = c_state.view(1, n_batches, h_dim)
            rnn_state = (h_state, c_state)
        x = torch.cat(outputs)
        x = self.norm.forward(x)
        x = x.view(n_steps, n_envs, n_agents, h_dim)
        return x, rnn_state


class Actor(jit.ScriptModule):

    def __init__(self, ob_dim, ac_dim, h_dim=128):
        super().__init__()
        self.h_dim = h_dim
        self.fc_in = nn.Sequential(Linear(ob_dim, h_dim), nn.ReLU())
        self.rnn = RNN(h_dim)
        self.fc_out = Linear(h_dim, ac_dim, gain=1.0)

        self.min_real = torch.finfo(torch.float32).min
        self.clip_ratio = 0.2
        self.entropy_coef = 0.01

    @jit.script_method
    def forward(self, obs, rnn_state: Tuple[torch.Tensor, torch.Tensor]):
        x = self.fc_in.forward(obs)
        x, rnn_state = self.rnn.forward(x, rnn_state)
        x = self.fc_out.forward(x)
        return x, rnn_state

    @jit.script_method
    def forward_all(self, obs, rnn_state: Tuple[torch.Tensor, torch.Tensor], dones):
        x = self.fc_in.forward(obs)
        x, rnn_state = self.rnn.forward_all(x, rnn_state, dones)
        x = self.fc_out.forward(x)
        return x, rnn_state
    
    @jit.script_method
    def sample(self, obs, avails, rnn_state: Tuple[torch.Tensor, torch.Tensor], deterministic:bool=False):
        with torch.no_grad():
            logits, rnn_state = self.forward(obs, rnn_state)
        logits = logits + avails.log()
        logits = logits - logits.logsumexp(-1, True)
        n_envs, n_agents, ac_dim = logits.shape
        logits_2d = logits.view(-1, ac_dim)
        probs_2d = logits_2d.softmax(-1)
        if deterministic:
            actions_2d = probs_2d.argmax(-1, True)
        else:
            actions_2d = probs_2d.multinomial(1)
        actions = actions_2d.view(n_envs, n_agents)
        log_probs_2d = logits_2d.gather(-1, actions_2d)
        log_probs = log_probs_2d.view(n_envs, n_agents)
        return actions, log_probs, rnn_state

    @jit.script_method
    def compute_loss(self, obs, avails, actions, old_log_probs, dones, advantages, rnn_state: Tuple[torch.Tensor, torch.Tensor]):
        # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        logits, rnn_state = self.forward_all(obs, rnn_state, dones)
        logits = logits + avails.log()
        logits = logits - logits.logsumexp(-1, True)
        n_steps, n_envs, n_agents, ac_dim = logits.shape
        logits_2d = logits.view(-1, ac_dim)
        probs_2d = logits_2d.softmax(-1)
        actions_2d = actions.view(-1, 1)
        log_probs_2d = logits_2d.gather(-1, actions_2d)
        log_probs = log_probs_2d.view(n_steps, n_envs, n_agents)

        logits_2d = logits_2d.clamp_min(self.min_real)
        entropy = -(probs_2d * logits_2d).sum(-1)

        ratio = (log_probs - old_log_probs).exp()
        # advantages = advantages.unsqueeze(-1)
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
        pg_loss = -torch.minimum(surr1, surr2).mean()
        entropy_loss = entropy.mean()
        actor_loss = pg_loss - self.entropy_coef * entropy_loss
        actor_loss.backward()
        rnn_state = (rnn_state[0].detach(), rnn_state[1].detach())
        return pg_loss.item(), entropy_loss.item(), rnn_state


class Critic(jit.ScriptModule):

    def __init__(self, st_dim, h_dim=128):
        super().__init__()
        self.h_dim = h_dim
        self.fc_in = nn.Sequential(Linear(st_dim, h_dim), nn.ReLU())
        self.rnn = RNN(h_dim)
        self.fc_out = Linear(h_dim, 1, gain=0.01)

        self.clip_ratio = 0.2

    @jit.script_method
    def forward(self, states, rnn_state: Tuple[torch.Tensor, torch.Tensor]):
        x = self.fc_in.forward(states)
        x = x.unsqueeze(2)
        x, rnn_state = self.rnn.forward(x, rnn_state)
        x = x.squeeze(2)
        x = self.fc_out.forward(x).squeeze(-1)
        return x, rnn_state

    @jit.script_method
    def forward_all(self, states, rnn_state: Tuple[torch.Tensor, torch.Tensor], dones):
        x = self.fc_in.forward(states)
        x = x.unsqueeze(2)
        x, rnn_state = self.rnn.forward_all(x, rnn_state, dones)
        x = x.squeeze(2)
        x = self.fc_out.forward(x).squeeze(-1)
        return x, rnn_state

    @jit.script_method
    def compute_loss(self, states, rnn_state: Tuple[torch.Tensor, torch.Tensor], dones, value_preds, returns):
        values, rnn_state = self.forward_all(states, rnn_state, dones)
        value_pred_clipped = value_preds + (values - value_preds).clamp(-self.clip_ratio, self.clip_ratio)
        vf_loss1 = F.mse_loss(values, returns, reduction='none')
        vf_loss2 = F.mse_loss(value_pred_clipped, returns, reduction='none')
        vf_loss = 0.5 * torch.maximum(vf_loss1, vf_loss2).mean()
        vf_loss.backward()
        rnn_state = (rnn_state[0].detach(), rnn_state[1].detach())
        return vf_loss.item(), rnn_state