import torch
import torch.jit as jit
from algorithms.base import DQN


class IQVDN(DQN):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim)
        
    @jit.script_method
    def phi(self, rewards, masks, is_experts):
        ex_rewards = rewards[is_experts]
        ex_masks = masks[is_experts]
        ex_rewards = ex_rewards[ex_masks]
        rewards = rewards[masks]
        return 1/2 * rewards.pow(2).mean() - ex_rewards.mean()
    
    @jit.script_method
    def delta(self, rewards, masks):
        return rewards[masks].mean()
    
    @jit.script_method
    def compute_loss(self, mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_states = mb_states.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_rewards = mb_rewards.to(device)
        mb_dones = mb_dones.to(device).float()
        mb_actives = mb_actives.to(device)
        self.eval_Q_net.reset()
        self.target_Q_net.reset()
        batch_size, seq_length, n_agents, _ = mb_obs.shape
        my_obs = mb_obs.transpose(1, 2).flatten(0, 1)
        q_all = self.eval_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
        curr_q = q_all[:, :-1]
        next_q = q_all[:, 1:]
        q_evals_local = torch.gather(curr_q, -1, mb_actions).squeeze(-1)
        q_evals_global = torch.sum(q_evals_local, -1, True)
        a_argmax_eval = torch.argmax(curr_q + mb_avails[:, :-1], -1, True)
        v_evals_local = torch.gather(curr_q, -1, a_argmax_eval).squeeze(-1)
        v_evals_global = torch.sum(v_evals_local, -1, True)
        with torch.no_grad():
            a_argmax_target = torch.argmax(next_q + mb_avails[:, 1:], -1, True)
            q_targets_all = self.target_Q_net.forward(my_obs[:, 1:]).reshape(batch_size, n_agents, seq_length-1, -1).transpose(1, 2)
            q_targets_local = torch.gather(q_targets_all, -1, a_argmax_target).squeeze(-1)
            q_targets_global = torch.sum(q_targets_local, -1, True)
            y_global = gamma * (1 - mb_dones) * q_targets_global
        masks = mb_actives.squeeze(-1)
        is_experts = mb_rewards.squeeze(-1)
        loss_1 = self.delta(v_evals_global - y_global, masks)
        loss_2 = self.phi(q_evals_global - y_global, masks, is_experts)
        loss = loss_1 + loss_2
        loss.backward()
        return loss_1.item(), loss_2.item()