import torch
import torch.jit as jit
from algorithms.base import DQN
from network.utils import Discriminator


class GAIL(DQN):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim)
        self.disc = Discriminator(st_dim, ac_dim, n_agents, h_dim)
    
    @jit.script_method
    def compute_loss(self, mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_states = mb_states.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_rewards = mb_rewards.to(device)
        mb_dones = mb_dones.to(device).float()
        mb_actives = mb_actives.to(device)
        self.eval_Q_net.reset()
        self.target_Q_net.reset()
        batch_size, seq_length, n_agents, _ = mb_obs.shape
        my_obs = mb_obs.transpose(1, 2).flatten(0, 1)
        q_all = self.eval_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
        curr_q = q_all[:, :-1]
        next_q = q_all[:, 1:]
        curr_s = mb_states[:, :-1]
        next_s = mb_states[:, 1:]
        q_evals_local = torch.gather(curr_q, -1, mb_actions).squeeze(-1)
        q_evals_global = self.eval_mix_net.forward(q_evals_local, curr_s)
        with torch.no_grad():
            a_argmax_target = torch.argmax(next_q + mb_avails[:, 1:], -1, True)
            q_targets_all = self.target_Q_net.forward(my_obs[:, 1:]).reshape(batch_size, n_agents, seq_length-1, -1).transpose(1, 2)
            q_targets_local = torch.gather(q_targets_all, -1, a_argmax_target).squeeze(-1)
            q_targets_global = self.target_mix_net.forward(q_targets_local, next_s)
            y_global = mb_rewards + gamma * (1 - mb_dones) * q_targets_global
        td_error = q_evals_global - y_global
        mask_td_error = td_error * mb_actives
        loss = mask_td_error.pow(2).sum() / mb_actives.sum()
        loss.backward()
        return loss.item(), 0.0