from typing import Dict
import torch
from algorithms.airl import AIRL
from algorithms.bc import BC
from algorithms.iiq import IIQ
from algorithms.iqvdn import IQVDN
from algorithms.sqil import SQIL
from algorithms.gail import GAIL
from algorithms.qmix import QMIX
from algorithms.mifq import MIFQ, SoftMIFQ
from trainer.base import DQN_SMAC


class QMIX_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = QMIX(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = pi_buffer.sample(self.batch_size)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals
    

class MIFQ_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args, activation="elu"):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = MIFQ(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim, activation).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = ex_buffer.sample_with(self.batch_size, pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals


class SoftMIFQ_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args, activation="elu"):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = SoftMIFQ(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim, activation).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = ex_buffer.sample_with(self.batch_size, pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals


class IQVDN_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = IQVDN(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = ex_buffer.sample_with(self.batch_size, pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals
    

class SQIL_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = SQIL(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = ex_buffer.sample_with(self.batch_size, pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals


class GAIL_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = GAIL(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)
        self.disc_optimizer = torch.optim.Adam(self.qmix.disc.parameters(), lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, _, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        ex_obs, ex_states, ex_avails, ex_actions, _, ex_dones, ex_actives = ex_buffer.sample(self.batch_size)
        pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives = map(lambda x: x.to(self.device), (pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives))
        mb_obs = torch.cat((pi_obs, ex_obs))
        mb_states = torch.cat((pi_states, ex_states))
        mb_avails = torch.cat((pi_avails, ex_avails))
        mb_actions = torch.cat((pi_actions, ex_actions))
        mb_dones = torch.cat((pi_dones, ex_dones))
        mb_actives = torch.cat((pi_actives, ex_actives))
        mb_rewards = self.qmix.disc.calculate_reward(mb_states, mb_actions)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.disc_optimizer.zero_grad()
        self.qmix.disc.compute_loss(pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives)
        self.disc_optimizer.step()
        self.update_targets()
        return loss_vals


class AIRL_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = AIRL(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)
        self.disc_optimizer = torch.optim.Adam(self.qmix.disc.parameters(), lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, _, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        ex_obs, ex_states, ex_avails, ex_actions, _, ex_dones, ex_actives = ex_buffer.sample(self.batch_size)
        pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives = map(lambda x: x.to(self.device), (pi_states, pi_actions, ex_states, ex_actions, pi_actives, ex_actives))
        mb_obs = torch.cat((pi_obs, ex_obs))
        mb_states = torch.cat((pi_states, ex_states))
        mb_avails = torch.cat((pi_avails, ex_avails))
        mb_actions = torch.cat((pi_actions, ex_actions))
        mb_dones = torch.cat((pi_dones, ex_dones))
        mb_actives = torch.cat((pi_actives, ex_actives))
        with torch.no_grad():
            pi_log_pis = self.qmix.evaluate_log_pi(pi_obs, pi_actions)
            ex_log_pis = self.qmix.evaluate_log_pi(ex_obs, ex_actions)
        mb_log_pis = torch.cat((pi_log_pis, ex_log_pis))
        mb_rewards = self.qmix.disc.calculate_reward(mb_states[:, :-1], mb_dones.float(), mb_log_pis, mb_states[:, 1:])
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.disc_optimizer.zero_grad()
        self.qmix.disc.compute_loss(
            pi_states[:, :-1], pi_dones.float(), pi_log_pis, pi_states[:, 1:],
            ex_states[:, :-1], ex_dones.float(), ex_log_pis, ex_states[:, 1:],
            pi_actives, ex_actives
        )
        self.disc_optimizer.step()
        self.update_targets()
        return loss_vals
    

class IIQ_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = IIQ(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_mix_net.parameters()) + list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr)

    def train(self, pi_buffer, ex_buffer):
        self.train_step += 1
        pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives = pi_buffer.sample(self.batch_size)
        mb_obs, mb_states, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives = ex_buffer.sample_with(self.batch_size, pi_obs, pi_states, pi_avails, pi_actions, pi_rewards, pi_dones, pi_actives)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_avails, mb_actions, mb_rewards, mb_dones, mb_actives, self.gamma)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        self.update_targets()
        return loss_vals
    

class BC_SMAC(DQN_SMAC):

    def __init__(self, n_agents, ob_dim, st_dim, ac_dim, args):
        super().__init__(n_agents, ob_dim, st_dim, ac_dim, args)
        self.qmix = BC(self.input_dim, self.st_dim, self.ac_dim, self.n_agents, self.h_dim).to(self.device)
        self.eval_parameters = list(self.qmix.eval_Q_net.parameters())
        self.optimizer = torch.optim.Adam(self.eval_parameters, lr=min(self.lr, 2e-5))

    def train(self, ex_buffer):
        self.train_step += 1
        mb_obs, _, mb_avails, mb_actions, _, _, mb_actives = ex_buffer.sample(self.batch_size)
        self.optimizer.zero_grad()
        loss_vals = self.qmix.compute_loss(mb_obs, mb_avails, mb_actions, mb_actives)
        torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.max_grad_norm)
        self.optimizer.step()
        return loss_vals


ALGOS: Dict[str, SoftMIFQ_SMAC] = {
    "QMIX": QMIX_SMAC,
    "MIFQ-DQN": MIFQ_SMAC,
    "MIFQ-SOFT": SoftMIFQ_SMAC,
    "MIFQ-DQN-TANH": lambda *x: MIFQ_SMAC(*x, activation="tanh"),
    "MIFQ-DQN-SIGMOID": lambda *x: MIFQ_SMAC(*x, activation="sigmoid"),
    "MIFQ-SOFT-TANH": lambda *x: SoftMIFQ_SMAC(*x, activation="tanh"),
    "MIFQ-SOFT-SIGMOID": lambda *x: SoftMIFQ_SMAC(*x, activation="sigmoid"),
    "IQVDN": IQVDN_SMAC,
    "SQIL": SQIL_SMAC,
    "GAIL": GAIL_SMAC,
    "AIRL": AIRL_SMAC,
    "IIQ": IIQ_SMAC,
    "BC": BC_SMAC,
}