import copy
from modules.mixers.lmix import LMixer
from modules.mixers.qattn import QattnMixer, MtQattnMixer
from modules.critics.mlp import MLPCritic
from modules.encoders.transition_encoder import TransformerTransitionEncoder, TransformerPriorRoleEncoder, TransformerTransitionRoleEncoder
from modules.encoders.temporal_encoder import TransformerTemporalEncoder, TransformerTemporalRoleEncoder
from modules.encoders.local_encoder import LocalEncoder, LocalRoleEncoder
from modules.encoders.global_encoder import GlobalEncoder
from modules.critics.transformer import TransformerCritic, MtTransformerCritic
import torch as th
from torch.distributions import Categorical
from torch.optim import RMSprop, Adam, AdamW
import torch.nn.functional as F
from components.standarize_stream import RunningMeanStd
import random

class MtVAELearner:
    def __init__(self, mac, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger
        self.task2decomposer = mac.task2decomposer
        self.task2n_agents = mac.task2n_agents

        self.agent_params = list(mac.parameters())
        
        match self.args.optim_type.lower():
            case "rmsprop":
                self.actor_optimiser = RMSprop(params=self.agent_params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
            case "adam":
                self.actor_optimiser = Adam(params=self.agent_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case "adamw":
                self.actor_optimiser = AdamW(params=self.agent_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case _:
                raise ValueError("Invalid optimiser type", self.args.optim_type)

        self.log_stats_t = {}
        for task in args.train_task_ls:
            self.log_stats_t[task] = -args.learner_log_interval - 1
        
        # self.log_stats_t = -self.args.learner_log_interval - 1
        self.training_steps = 0
        self.last_target_update_step = 0
        self.last_target_update_episode = 0

        self.is_vqvae = self.mac.is_vqvae

        device = "cuda" if args.use_cuda else "cpu"

    def train(self, batch, t_env: int, task: int):
        # Get the relevant quantities
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        # termianted point 1 -> 0
        
        if not self.is_vqvae:
            mac_out = []
            mu_ls = []
            logvar_ls = []
            self.mac.init_hidden(batch.batch_size, task)
            for t in range(batch.max_seq_length):
                agent_outs, mu, logvar = self.mac.forward(batch, t=t, task=task)
                mac_out.append(agent_outs)
                mu_ls.append(mu)
                logvar_ls.append(logvar)
            mac_out = th.stack(mac_out, dim=1)[:, :-1] # (bs, t, n_agent, x)
            bs, max_t, n_agent, _ = mac_out.shape
            mu_ls = th.stack(mu_ls, dim=1)[:, :-1].reshape(bs * max_t * n_agent, -1)
            logvar_ls = th.stack(logvar_ls, dim=1)[:, :-1].reshape(bs * max_t * n_agent, -1)
            mask_a = mask.reshape(bs, max_t, -1).repeat(1, 1, n_agent)
            
            mac_out = mac_out.reshape(bs * max_t * n_agent, -1)
            target_action = actions.reshape(-1)
            mask_a = mask_a.reshape(-1)
            ce_loss = F.cross_entropy(mac_out, target_action, reduction='none')
            dec_loss = (ce_loss * mask_a).sum() / mask_a.sum()

            kl = -0.5 * (1 + logvar_ls - mu_ls.pow(2) - logvar_ls.exp())
            kl = kl.sum(dim=-1)
            enc_loss = (kl * mask_a).sum() / mask_a.sum()

            vae_loss = dec_loss + self.args.vae_beta * enc_loss
        else:
            mac_out = []
            vq_loss_ls = []
            perplexity_ls = []
            self.mac.init_hidden(batch.batch_size, task)
            for t in range(batch.max_seq_length):
                agent_outs, z_e, vq_loss, perplexity = self.mac.forward(batch, t=t, task=task)
                mac_out.append(agent_outs)
                vq_loss_ls.append(vq_loss)
                perplexity_ls.append(perplexity.item())
            mac_out = th.stack(mac_out, dim=1)[:, :-1] # (bs, t, n_agent, x)
            bs, max_t, n_agent, _ = mac_out.shape
            vq_loss_ls = th.stack(vq_loss_ls, dim=1)[:, :-1].reshape(bs * max_t * n_agent, -1)
            mac_out = mac_out.reshape(bs * max_t * n_agent, -1)
            mask_a = mask.reshape(bs, max_t, -1).repeat(1, 1, n_agent)
            target_action = actions.reshape(-1)
            mask_a = mask_a.reshape(-1)
            ce_loss = F.cross_entropy(mac_out, target_action, reduction='none')
            dec_loss = (ce_loss * mask_a).sum() / mask_a.sum()

            mask_vq = mask_a.reshape(-1, 1).repeat(1, vq_loss_ls.shape[1])
            enc_loss = (vq_loss_ls * mask_vq).sum() / mask_vq.sum()

            vae_loss = dec_loss + enc_loss

            mean_perplexity = sum(perplexity_ls)/len(perplexity_ls)

        self.actor_optimiser.zero_grad()
        vae_loss.backward()
        self.actor_optimiser.step()
        
        self.training_steps += 1
        
        if t_env - self.log_stats_t[task] >= self.args.learner_log_interval:
            log_prefix = f"{task}/"
            self.logger.log_stat(log_prefix+"dec_loss", dec_loss.item(), t_env)
            self.logger.log_stat(log_prefix+"enc_loss", enc_loss.item(), t_env)
            self.logger.log_stat(log_prefix+"total_loss", vae_loss.item(), t_env)
            if self.is_vqvae:
                self.logger.log_stat(log_prefix+"perplexity", mean_perplexity, t_env)
            #self.logger.log_stat("alpha_temp", self.args.alpha_temp, t_env)
            self.log_stats_t[task] = t_env

    def cuda(self):
        self.mac.cuda()
    
    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.actor_optimiser.state_dict(), "{}/actor_opt.th".format(path))
    
    def load_models(self, path):
        self.mac.load_models(path)
        self.actor_optimiser.load_state_dict(th.load("{}/actor_opt.th".format(path)))
                                            
    