import copy
from components.episode_buffer import EpisodeBatch
from modules.mixers.vdn import VDNMixer
from modules.mixers.qmix import QMixer
from modules.mixers.qopt import QoptMixer
import torch as th
from torch.optim import RMSprop
from torch.distributions import Categorical
from torch.distributions import Beta
from torch.distributions import Dirichlet


class QLearner:
    def __init__(self, mac, scheme, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0

        self.mixer = None
        if args.mixer is not None:
            if args.mixer == "vdn":
                self.mixer = VDNMixer()
            elif args.mixer == "qmix":
                self.mixer = QMixer(args)
            else:
                raise ValueError("Mixer {} not recognised.".format(args.mixer))

            self.mixer = QoptMixer(args)
            self.params += list(self.mixer.parameters())
            self.target_mixer = copy.deepcopy(self.mixer)

        self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)

        self.log_stats_t = -self.args.learner_log_interval - 1

    def train(self, batch: EpisodeBatch, t_env: int, episode_num: int):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        mac_out = []
        mac_hidden_states = []

        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
            mac_hidden_states.append(self.mac.hidden_states)


        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        mac_hidden_states = th.stack(mac_hidden_states, dim=1)
        mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3)  # Remove the last dim


        # Calculate the Q-Values necessary for the target
        target_mac_out = []
        target_mac_hidden_states = []


        self.target_mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            target_agent_outs = self.target_mac.forward(batch, t=t)
            target_mac_out.append(target_agent_outs)
            target_mac_hidden_states.append(self.target_mac.hidden_states)


        target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1)
        target_mac_hidden_states = target_mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav

        # We don't need the first timesteps Q-Value estimate for calculating targets
        target_mac_out = th.stack(target_mac_out[1:], dim=1)  # Concat across time
        
        # Mask out unavailable actions
        target_mac_out[avail_actions[:, 1:] == 0] = -9999999


        mac_out_detach = mac_out.clone().detach()
        mac_out_detach[avail_actions == 0] = -9999999
        cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1]


        # Max over target Q-Values
        if self.args.double_q:
            # Get actions that maximise live Q (for double q-learning)
            target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions[:, 1:]).squeeze(3)
        else:
            target_max_qvals = target_mac_out.max(dim=3)[0]
        
        max_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=cur_max_actions[:, :-1]).squeeze(3)

        # Mix
        if self.mixer is not None:
            chosen_action_qvals, chosen_action_qvals2 = self.mixer(chosen_action_qvals, batch["state"][:, :-1])
            max_action_qvals, max_action_qvals2 = self.mixer(max_action_qvals, batch["state"][:, :-1])
            target_max_qvals, _ = self.target_mixer(target_max_qvals, batch["state"][:, 1:])

 
        # v17
        targets = (rewards + self.args.gamma * (1 - terminated) * target_max_qvals).detach()
        targets2 = chosen_action_qvals
        # targets2 = chosen_action_qvals

        targets4 = max_action_qvals.detach()



        # Td-error
        td_error = (chosen_action_qvals - targets)

        td_error2_0 = (chosen_action_qvals2 - targets2)
        td_error2 = th.where(td_error2_0 > 0, td_error2_0 * 0, td_error2_0 * 1)
        td_error2_1 = (chosen_action_qvals2 - targets4)
        td_error2 = th.where(td_error2_1 > 0, td_error2_1 * 1, td_error2)
        td_error2 = th.where(targets2 >= targets4, td_error2_0, td_error2)

        td_error3 = max_action_qvals2 - max_action_qvals


        mask1 = mask.expand_as(td_error)
        mask2 = mask.expand_as(td_error2)
        mask3 = mask.expand_as(td_error3)

        # 0-out the targets that came from padded data
        masked_td_error = td_error * mask1
        masked_td_error2 = td_error2 * mask2
        masked_td_error3 = td_error3 * mask3

        # Normal L2 loss, take mean over actual data
        loss = 1 * (masked_td_error ** 2).sum() / mask1.sum()
        loss += 1 * (masked_td_error2 ** 2).sum() / mask2.sum()
        loss += 2 * (masked_td_error3 ** 2).sum() / mask3.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num

        if t_env - self.log_stats_t >= self.args.learner_log_interval:
            self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            mask_elems1 = mask1.sum().item()
            mask_elems2 = mask2.sum().item()
            mask_elems3 = mask3.sum().item()



            self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems1), t_env)
            self.logger.log_stat("td_error_abs2", (masked_td_error2.abs().sum().item()/mask_elems2), t_env)
            self.logger.log_stat("td_error_abs3", (masked_td_error3.abs().sum().item()/mask_elems3), t_env)

            self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems1), t_env)
            self.logger.log_stat("target_mean", (targets * mask1).sum().item()/(mask_elems1), t_env)



            self.log_stats_t = t_env

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
