import torch
import torch.jit as jit
from algorithms.base import DQN, SoftQ


class MIFQ(DQN):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu", value_expert=False, chi_expert=False):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim, activation)
        self.value_expert = value_expert
        self.chi_expert = chi_expert
        
    @jit.script_method
    def phi(self, rewards, masks, is_experts):
        ex_rewards = rewards[is_experts]
        ex_masks = masks[is_experts]
        ex_rewards = ex_rewards[ex_masks]
        if self.chi_expert:
            return 1/2 * ex_rewards.pow(2).mean() - ex_rewards.mean()
        else:
            rewards = rewards[masks]
            return 1/2 * rewards.pow(2).mean() - ex_rewards.mean()

    @jit.script_method
    def delta(self, rewards, masks, is_experts):
        if self.value_expert:
            ex_rewards = rewards[is_experts]
            ex_masks = masks[is_experts]
            ex_rewards = ex_rewards[ex_masks]
            return ex_rewards.mean()
        else:
            rewards = rewards[masks]
            return rewards.mean()

    @jit.script_method
    def compute_loss(self, mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_states = mb_states.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_rewards = mb_rewards.to(device)
        mb_dones = mb_dones.to(device).float()
        mb_actives = mb_actives.to(device)
        self.eval_Q_net.reset()
        self.target_Q_net.reset()
        batch_size, seq_length, n_agents, _ = mb_obs.shape
        my_obs = mb_obs.transpose(1, 2).flatten(0, 1)
        q_all = self.eval_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
        curr_q = q_all[:, :-1]
        next_q = q_all[:, 1:]
        curr_s = mb_states[:, :-1]
        next_s = mb_states[:, 1:]
        q_evals_local = torch.gather(curr_q, -1, mb_actions).squeeze(-1)
        q_evals_global = self.eval_mix_net.forward(q_evals_local, curr_s)
        a_argmax_eval = torch.argmax(curr_q + mb_avails[:, :-1], -1, True)
        v_evals_local = torch.gather(curr_q, -1, a_argmax_eval).squeeze(-1)
        v_evals_global = self.eval_mix_net.forward(v_evals_local, curr_s)
        with torch.no_grad():
            a_argmax_target = torch.argmax(next_q + mb_avails[:, 1:], -1, True)
            q_targets_all = self.target_Q_net.forward(my_obs[:, 1:]).reshape(batch_size, n_agents, seq_length-1, -1).transpose(1, 2)
            q_targets_local = torch.gather(q_targets_all, -1, a_argmax_target).squeeze(-1)
            q_targets_global = self.target_mix_net.forward(q_targets_local, next_s)
            y_global = gamma * (1 - mb_dones) * q_targets_global
        masks = mb_actives.squeeze(-1)
        is_experts = mb_rewards.squeeze(-1)
        loss_1 = self.delta(v_evals_global - y_global, masks, is_experts)
        loss_2 = self.phi(q_evals_global - y_global, masks, is_experts)
        loss = loss_1 + loss_2
        loss.backward()
        return loss_1.item(), loss_2.item()


class SoftMIFQ(SoftQ):

    def __init__(self, ob_dim, st_dim, ac_dim, n_agents, h_dim, activation="elu", value_expert=True, chi_expert=True):
        super().__init__(ob_dim, st_dim, ac_dim, n_agents, h_dim, activation)
        self.value_expert = value_expert
        self.chi_expert = chi_expert
        
    @jit.script_method
    def phi(self, rewards, masks, is_experts):
        ex_rewards = rewards[is_experts]
        ex_masks = masks[is_experts]
        ex_rewards = ex_rewards[ex_masks]
        if self.chi_expert:
            return 1/2 * ex_rewards.pow(2).mean() - ex_rewards.mean()
        else:
            rewards = rewards[masks]
            return 1/2 * rewards.pow(2).mean() - ex_rewards.mean()

    @jit.script_method
    def delta(self, rewards, masks, is_experts):
        if self.value_expert:
            ex_rewards = rewards[is_experts]
            ex_masks = masks[is_experts]
            ex_rewards = ex_rewards[ex_masks]
            return ex_rewards.mean()
        else:
            rewards = rewards[masks]
            return rewards.mean()

    @jit.script_method
    def compute_loss(self, mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, gamma: float):
        device = self.eval_Q_net.fc1.weight.device
        mb_obs = mb_obs.to(device)
        mb_states = mb_states.to(device)
        mb_avails = mb_avails.to(device).log()
        mb_actions = mb_actions.to(device).unsqueeze(-1)
        mb_rewards = mb_rewards.to(device)
        mb_dones = mb_dones.to(device).float()
        mb_actives = mb_actives.to(device)
        self.eval_Q_net.reset()
        self.target_Q_net.reset()

        curr_s = mb_states[:, :-1]
        next_s = mb_states[:, 1:]

        batch_size, seq_length, n_agents, _ = mb_obs.shape
        my_obs = mb_obs.transpose(1, 2).flatten(0, 1)

        q_all = self.eval_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
        curr_q = q_all[:, :-1]

        q_evals_local = torch.gather(curr_q, -1, mb_actions).squeeze(-1)
        q_evals_global = self.eval_mix_net.forward(q_evals_local, curr_s)

        v_evals_local = self.alpha * torch.logsumexp(curr_q / self.alpha + mb_avails[:, :-1], -1)
        v_evals_global = self.eval_mix_net.forward(v_evals_local, curr_s)

        with torch.no_grad():
            q_targets_all = self.target_Q_net.forward(my_obs).reshape(batch_size, n_agents, seq_length, -1).transpose(1, 2)
            curr_q_targets = q_targets_all[:, 1:]

            v_targets_local = self.alpha * torch.logsumexp(curr_q_targets / self.alpha + mb_avails[:, 1:], -1)
            v_targets_global = self.target_mix_net.forward(v_targets_local, next_s)

            y_global = gamma * (1 - mb_dones) * v_targets_global

        masks = mb_actives.squeeze(-1)
        is_experts = mb_rewards.squeeze(-1)
        loss_1 = self.delta(v_evals_global - y_global, masks, is_experts)
        loss_2 = self.phi(q_evals_global - y_global, masks, is_experts)
        loss = loss_1 + loss_2
        loss.backward()
        return loss_1.item(), loss_2.item()
