import torch
import torch.nn as nn
import torch.jit as jit


class Net(jit.ScriptModule):

    def __init__(self, ob_dim, ac_dim, h_dim):
        super().__init__()
        self.fc1 = nn.Linear(ob_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, ac_dim)

    @jit.script_method
    def reset(self):
        return None
    
    @jit.script_method
    def forward(self, inputs) -> torch.Tensor:
        h = torch.tanh(self.fc1(inputs))
        Q = self.fc2(h)
        return Q


class BC(jit.ScriptModule):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim):
        super().__init__()
        self.eval_Q_net = Net(ob_dim, ac_dim, h_dim)
    
    @jit.script_method
    def mode(self, obs, avails):
        with torch.no_grad():
            device = self.eval_Q_net.fc1.weight.device
            q_value = self.eval_Q_net.forward(obs.to(device).unsqueeze(-2)).squeeze(-2)
            q_value = q_value + torch.log(avails.to(device))
            actions = q_value.argmax(-1).cpu()
            return actions
    
    @jit.script_method
    def sample(self, obs, avails):
        return self.mode(obs, avails)
    
    @jit.script_method
    def compute_loss(self, mb_obs, mb_avails, mb_actions, mb_actives):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_actives = mb_actives.to(device).squeeze(-1)
        batch_size, seq_length, n_agents, _ = mb_obs.shape
        my_obs = mb_obs.transpose(1, 2).flatten(0, 1)
        q_all = self.eval_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
        log_probs = q_all[:, :-1].log_softmax(-1)
        log_probs = log_probs.gather(-1, mb_actions)
        loss = - log_probs.mean()
        loss.backward()
        return loss.item(), 0.0