import torch
import torch.nn as nn
import torch.jit as jit
import torch.nn.functional as F


class Linear(nn.Linear):
    
    def __init__(self, in_features, out_features, gain=1.0, init_bias=0.0, bias=True, device=None, dtype=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        nn.init.orthogonal_(self.weight, gain=gain)
        self.bias.data.fill_(init_bias)


class Discriminator(jit.ScriptModule):

    def __init__(self, st_dim, ac_dim, n_agents, h_dim):
        super().__init__()
        self.net = nn.Sequential(Linear(st_dim + ac_dim*n_agents, h_dim), nn.ReLU(), Linear(h_dim, 1))
        self.ac_dim = ac_dim

    @jit.script_method
    def forward(self, states, actions):
        n_batch, n_steps, _ = actions.shape
        actions = F.one_hot(actions, self.ac_dim)
        actions = actions.reshape(n_batch, n_steps, -1)
        return self.net.forward(torch.cat((states[:, :-1], actions), -1))

    @jit.script_method
    def calculate_reward(self, states, actions):
        with torch.no_grad():
            return -F.logsigmoid(-self.forward(states, actions)).detach().cpu()
    
    @jit.script_method
    def compute_loss(self, pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives):
        logits_pi = self.forward(pi_states, pi_actions)[pi_actives.squeeze(-1)]
        logits_exp = self.forward(ex_states, ex_actions)[ex_actives.squeeze(-1)]
        loss_pi = -F.logsigmoid(-logits_pi).mean()
        loss_exp = -F.logsigmoid(logits_exp).mean()
        loss_disc = loss_pi + loss_exp
        loss_disc.backward()
        return loss_pi.item(), loss_exp.item()


class AIRLDiscriminator(jit.ScriptModule):

    def __init__(self, st_dim, h_dim, gamma=0.995):
        super().__init__()
        self.g = nn.Sequential(Linear(st_dim, h_dim), nn.ReLU(), Linear(h_dim, 1))
        self.h = nn.Sequential(Linear(st_dim, h_dim), nn.ReLU(), Linear(h_dim, 1))
        self.norm = nn.LayerNorm(h_dim)
        self.gamma = gamma

    @jit.script_method
    def forward(self, states, dones, log_pis, next_states):
        # print("states:", states.shape)
        # print("dones:", dones.shape)
        # print("log_pis:", log_pis.shape)
        # print("next_states:", next_states.shape)
        device = self.norm.weight.device
        states = states.to(device)
        dones = dones.to(device)
        log_pis = log_pis.to(device)
        next_states = next_states.to(device)
        rs = self.g.forward(states)
        vs = self.h.forward(states)
        next_vs = self.h.forward(next_states)
        return rs + self.gamma * (1 - dones) * next_vs - vs - log_pis

    @jit.script_method
    def calculate_reward(self, states, dones, log_pis, next_states):
        with torch.no_grad():
            return -F.logsigmoid(-self.forward(states, dones, log_pis, next_states)).detach().cpu()
    
    @jit.script_method
    def compute_loss(self, pi_states, pi_dones, pi_log_pis, pi_next_states, ex_states, ex_dones, ex_log_exs, ex_next_states, pi_actives, ex_actives):
        device = self.norm.weight.device
        pi_actives = pi_actives.to(device)
        ex_actives = ex_actives.to(device)
        logits_pi = self.forward(pi_states, pi_dones, pi_log_pis, pi_next_states)[pi_actives.squeeze(-1)]
        logits_exp = self.forward(ex_states, ex_dones, ex_log_exs, ex_next_states)[ex_actives.squeeze(-1)]
        loss_pi = -F.logsigmoid(-logits_pi).mean()
        loss_exp = -F.logsigmoid(logits_exp).mean()
        loss_disc = loss_pi + loss_exp
        loss_disc.backward()
        return loss_pi.item(), loss_exp.item()