import copy
from modules.mixers.lmix import LMixer
from modules.mixers.qattn import QattnMixer, MtQattnMixer
from modules.critics.mlp import MLPCritic
from modules.encoders.transition_encoder import TransformerTransitionEncoder, TransformerPriorRoleEncoder, TransformerTransitionRoleEncoder
from modules.encoders.temporal_encoder import TransformerTemporalEncoder, TransformerTemporalRoleEncoder
from modules.encoders.local_encoder import LocalEncoder, LocalRoleEncoder
from modules.encoders.global_encoder import GlobalEncoder
from modules.critics.transformer import TransformerCritic, MtTransformerCritic, MtTransformerCriticCont
import torch as th
from torch.distributions import Categorical
from torch.optim import RMSprop, Adam, AdamW
from components.standarize_stream import RunningMeanStd
import random

class MtOMIGADiffusionLearner:
    def __init__(self, mac, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger
        self.task2decomposer = mac.task2decomposer
        self.task2n_agents = mac.task2n_agents

        self.agent_params = list(mac.parameters())
        task2input_shape_info = mac._get_input_shape()
        task_0 = list(self.task2decomposer.keys())[0]

        self.is_vqvae = self.mac.is_vqvae
        if not self.is_vqvae:
            self.v_critic = MtTransformerCriticCont(task2input_shape_info, self.task2decomposer, self.task2n_agents, True, args)
            self.q_critic = MtTransformerCriticCont(task2input_shape_info, self.task2decomposer, self.task2n_agents, False, args)
        else:
            self.v_critic = MtTransformerCritic(task2input_shape_info, self.task2decomposer, self.task2n_agents, 1, args)
            self.q_critic = MtTransformerCritic(task2input_shape_info, self.task2decomposer, self.task2n_agents, self.mac.vae.num_embeddings, args, is_vqvae=True)
        self.mixer = MtQattnMixer(self.task2decomposer[task_0], args)

        self.temporal_encoder = TransformerTemporalEncoder(self.task2decomposer, args)
        self.local_encoder = LocalEncoder(args)
        self.global_encoder = GlobalEncoder(args)
        self.load_encoders(args.encoder_path_ls[args.encoder_id])

        self.temporal_role_encoder = TransformerTemporalRoleEncoder(self.task2decomposer, args)
        self.local_role_encoder = LocalRoleEncoder(args)
        if self.args.use_role_encoder:
            self.load_role_encoder(args.role_encoder_path_ls[args.role_encoder_id])

        self.prior_role_encoder = TransformerPriorRoleEncoder(self.task2decomposer, args)
        if self.args.use_role_encoder:
            self.load_prior_role_encoder(args.prior_role_encoder_path_ls[args.prior_role_encoder_id])

        self.v_params = list(self.v_critic.parameters())  
        self.q_params = list(self.q_critic.parameters()) + list(self.mixer.parameters()) + list(self.global_encoder.parameters())
        
        match self.args.optim_type.lower():
            case "rmsprop":
                self.actor_optimiser = RMSprop(params=self.agent_params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
                self.q_optimiser = RMSprop(params=self.q_params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
                self.v_optimiser = RMSprop(params=self.v_params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
            case "adam":
                self.actor_optimiser = Adam(params=self.agent_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
                self.q_optimiser = Adam(params=self.q_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
                self.v_optimiser = Adam(params=self.v_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case "adamw":
                self.actor_optimiser = AdamW(params=self.agent_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
                self.q_optimiser = AdamW(params=self.q_params, lr=self.args.critic_lr, weight_decay=self.args.weight_decay)
                self.v_optimiser = AdamW(params=self.v_params, lr=self.args.critic_lr, weight_decay=self.args.weight_decay)
            case _:
                raise ValueError("Invalid optimiser type", self.args.optim_type)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)
        self.target_v_critic = copy.deepcopy(self.v_critic)
        self.target_q_critic = copy.deepcopy(self.q_critic)
        self.target_mixer = copy.deepcopy(self.mixer)

        self.vae = self.mac.vae

        self.log_stats_t = {}
        for task in args.train_task_ls:
            self.log_stats_t[task] = -args.learner_log_interval - 1
        
        # self.log_stats_t = -self.args.learner_log_interval - 1
        self.training_steps = 0
        self.last_target_update_step = 0
        self.last_target_update_episode = 0

        device = "cuda" if args.use_cuda else "cpu"
    
    def get_task_encoding(self, batch, task):
        rewards = batch["reward"][:, :-1]
        actions_one_hot = batch["actions_onehot"][:, :-1]
        obs = batch["obs"][:, :-1]
        next_obs = batch["obs"][:, 1:]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])

        temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task).detach()
        local_encoding = self.local_encoder(temporal_encoding).detach()
        global_encoding = self.global_encoder(temporal_encoding)
        return global_encoding, local_encoding
    
    def get_role_encoding(self, batch, task_encoding, task):
        actions_one_hot = batch["actions_onehot"][:, :-1]
        obs = batch["obs"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
        role_encoding = self.local_role_encoder(temporal_role_encoding, task_encoding).detach()
        return role_encoding

    def train(self, batch, batch_encoding, t_env: int, episode_num: int, task: int):
        train_with_prior_encoding = False

        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        states = batch["state"][:, :-1]
        actions_one_hot = batch["actions_onehot"][:, :-1]
        obs = batch["obs"][:, :-1]
        next_obs = batch["obs"][:, 1:]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        # termianted point 1 -> 0
        avail_actions = batch["avail_actions"]

        global_encoding, local_encoding = self.get_task_encoding(batch_encoding, task)
        if self.args.perturb_local_encoding:
            noise = th.randn_like(local_encoding) * self.args.perturb_noise_scale
            local_encoding = local_encoding + noise

        if self.args.use_role_encoder:
            p = random.random()
            if p < self.args.train_with_prior_encoding_p:
                train_with_prior_encoding = True
            
            if not train_with_prior_encoding:
                role_encoding = self.get_role_encoding(batch, local_encoding, task)
                if not self.args.only_role_encoding:
                    local_encoding = th.cat([local_encoding, role_encoding], dim=-1).detach()
                else:
                    local_encoding = role_encoding.detach()
            else:
                prior_role_encoding = self.prior_role_encoder(obs, local_encoding, task)
                bs, max_t, n_agents, _ = obs.shape
                local_encoding = local_encoding.unsqueeze(1).repeat(1, max_t, 1, 1)
                if not self.args.only_role_encoding:
                    local_encoding = th.cat([local_encoding, prior_role_encoding], dim=-1).detach()
                else:
                    local_encoding = prior_role_encoding.detach()
    
        critic_inputs = self._build_critic_inputs(batch, task)

        if not self.is_vqvae:
            bs, max_t, n_agents, _ = actions.shape
            repeat_states = states.reshape(bs*max_t, -1)
            repeat_actions = actions.reshape(bs * max_t, n_agents, -1)
            # repeat_states = repeat_states.reshape(bs * max_t * n_agents, -1)
            # repeat_actions = actions.reshape(bs * max_t * n_agents, -1)
            latent_actions = self.vae.get_encoding(repeat_states, task, repeat_actions).reshape(bs, max_t, n_agents, -1).detach()
            
            cur_q_vals = self.q_critic(critic_inputs[:, :-1], latent_actions, task, local_encoding)
            cur_chosen_q_tot = self.mixer(cur_q_vals, batch["state"][:, :-1], global_encoding, task_decomposer=self.task2decomposer[task])
            
            next_v_vals = self.target_v_critic(critic_inputs[:, 1:], latent_actions, task, local_encoding) # (b, T, n_agents, 1)
            next_w, next_b = self.target_mixer.w_and_b(batch["state"][:, 1:], global_encoding, task_decomposer=self.task2decomposer[task]) # (b, T, n_agents, 1). (b, T, 1)
            next_v_tot = (next_w * next_v_vals).sum(dim=-2) + next_b
            
            q_target = rewards + self.args.gamma * (1 - terminated) * next_v_tot.detach()
            q_error = (cur_chosen_q_tot - q_target) # (bs, T, 1)
            
            mask_q = mask.expand_as(q_error)
            
            q_loss = ((q_error * mask_q) ** 2).sum() / mask_q.sum()
            
            
            target_chosen_q_vals = self.target_q_critic(critic_inputs[:, :-1], latent_actions, task, local_encoding)
            target_w, _ = self.target_mixer.w_and_b(batch["state"][:, :-1], global_encoding, task_decomposer=self.task2decomposer[task])
            cur_v = self.v_critic(critic_inputs[:, :-1], latent_actions, task, local_encoding) # (b, T, n_agents, 1)
            
            z = 1 / self.args.alpha_temp * (target_w.detach() * target_chosen_q_vals.detach() - target_w.detach() * cur_v)
            z = th.clamp(z, min=-10.0, max=10.0)
            max_z = th.max(z)
            max_z = th.where(max_z < -1.0, th.tensor(-1.0).to(self.args.device), max_z)
            max_z = max_z.detach()
            
            v_error = th.exp(z - max_z) + th.exp(-max_z) * target_w.detach() * cur_v / self.args.alpha_temp
            mask_v = mask_q.unsqueeze(-1).expand_as(v_error)
        
            v_loss = (v_error * mask_v).sum() / mask_v.sum()
        else:
            bs, max_t, n_agents, _ = actions.shape
            repeat_states = states.reshape(bs*max_t, -1)
            repeat_actions = actions.reshape(bs * max_t, n_agents, -1)
            # repeat_states = repeat_states.reshape(bs * max_t * n_agents, -1)
            # repeat_actions = actions.reshape(bs * max_t * n_agents, -1)
            latent_actions = self.vae.get_encoding_id_wo_onehot(repeat_states, task, repeat_actions).reshape(bs, max_t, n_agents, -1).detach()
            latent_actions = latent_actions.to(actions.dtype).to(actions.device)
            
            cur_q_vals = self.q_critic(critic_inputs[:, :-1], task, local_encoding)
            cur_chosen_q_vals = th.gather(cur_q_vals, dim=3, index=latent_actions)
            # print(cur_chosen_q_vals.shape, batch["state"].shape)
            cur_chosen_q_tot = self.mixer(cur_chosen_q_vals, batch["state"][:, :-1], global_encoding, task_decomposer=self.task2decomposer[task])
            
            next_v_vals = self.target_v_critic(critic_inputs[:, 1:], task, local_encoding) # (b, T, n_agents, 1)
            next_w, next_b = self.target_mixer.w_and_b(batch["state"][:, 1:], global_encoding, task_decomposer=self.task2decomposer[task]) # (b, T, n_agents, 1). (b, T, 1)
            next_v_tot = (next_w * next_v_vals).sum(dim=-2) + next_b
            
            q_target = rewards + self.args.gamma * (1 - terminated) * next_v_tot.detach()
            q_error = (cur_chosen_q_tot - q_target) # (bs, T, 1)
            
            mask_q = mask.expand_as(q_error)
            
            q_loss = ((q_error * mask_q) ** 2).sum() / mask_q.sum()
            
            
            target_q_vals = self.target_q_critic(critic_inputs[:, :-1], task, local_encoding)
            targe_chosen_q_vals = th.gather(target_q_vals, dim=3, index=latent_actions)
            target_w, _ = self.target_mixer.w_and_b(batch["state"][:, :-1], global_encoding, task_decomposer=self.task2decomposer[task])
            cur_v = self.v_critic(critic_inputs[:, :-1], task, local_encoding) # (b, T, n_agents, 1)
            
            z = 1 / self.args.alpha_temp * (target_w.detach() * targe_chosen_q_vals.detach() - target_w.detach() * cur_v)
            z = th.clamp(z, min=-10.0, max=10.0)
            max_z = th.max(z)
            max_z = th.where(max_z < -1.0, th.tensor(-1.0).to(self.args.device), max_z)
            max_z = max_z.detach()
            
            v_error = th.exp(z - max_z) + th.exp(-max_z) * target_w.detach() * cur_v / self.args.alpha_temp
            mask_v = mask_q.unsqueeze(-1).expand_as(v_error)
        
            v_loss = (v_error * mask_v).sum() / mask_v.sum()

        
        exp_a = th.exp(z).detach().squeeze(-1)
        mac_out = []
        if not train_with_prior_encoding:
            self.mac.set_task_encoding(local_encoding, task)
        self.mac.init_hidden(batch.batch_size, task)
        for t in range(batch.max_seq_length):
            if train_with_prior_encoding:
                if t == batch.max_seq_length - 1:
                    self.mac.set_task_encoding(local_encoding[:,-1,:,:], task)
                else:
                    self.mac.set_task_encoding(local_encoding[:,t,:,:], task)
            # agent_outs = self.mac.forward(batch, t=t, task=task)
            tmp_loss = self.mac.get_diffusion_loss(batch, weights=exp_a, t=t, task=task)
            mac_out.append(tmp_loss)
        mac_out = th.stack(mac_out, dim=1)
        actor_loss = mac_out[:, :-1].squeeze(-1)
        if len(actor_loss.shape)==3:
            actor_loss = actor_loss.unsqueeze(-1)
        
        mask_a = mask_q.unsqueeze(-1).repeat(1,1,self.task2n_agents[task],actor_loss.shape[-1])
        actor_loss = (actor_loss * mask_a).sum() / mask_a.sum()
                
        self.actor_optimiser.zero_grad()
        actor_loss.backward()
        th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip)
        self.actor_optimiser.step()
        
        self.q_optimiser.zero_grad()
        q_loss.backward()
        th.nn.utils.clip_grad_norm_(self.q_params, self.args.grad_norm_clip)
        self.q_optimiser.step()
        
        self.v_optimiser.zero_grad()
        v_loss.backward()
        th.nn.utils.clip_grad_norm_(self.v_params, self.args.grad_norm_clip)
        self.v_optimiser.step()
        
        self.training_steps += 1
        if self.args.target_update_interval_or_tau > 1 and (episode_num - self.last_target_update_episode) / self.args.target_update_interval_or_tau >= 1.0:
            self._update_targets_hard()
            self.last_target_update_episode = episode_num
        elif self.args.target_update_interval_or_tau <= 1.0:
            self._update_targets_soft(self.args.target_update_interval_or_tau)
        
        if t_env - self.log_stats_t[task] >= self.args.learner_log_interval:
            log_prefix = f"{task}/"
            self.logger.log_stat(log_prefix+"q_loss", q_loss.item(), t_env)
            self.logger.log_stat(log_prefix+"v_loss", v_loss.item(), t_env)
            self.logger.log_stat(log_prefix+"actor_loss", actor_loss.item(), t_env)
            #self.logger.log_stat("alpha_temp", self.args.alpha_temp, t_env)
            self.log_stats_t[task] = t_env

    def _build_critic_inputs(self, batch, task):
        inputs  = []
        bs, max_t = batch.batch_size, batch.max_seq_length

        inputs.append(batch["obs"])
        assert batch.max_seq_length == batch["state"].shape[1]
        if self.args.obs_last_action:
            inputs.append(th.cat([th.zeros_like(batch["actions_onehot"][:, :1]), batch["actions_onehot"][:, :-1]], dim=1))
        if self.args.obs_agent_id:
            inputs.append(th.eye(self.task2n_agents[task]).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1).to(self.args.device))
        inputs = th.cat([x.reshape(bs, max_t, self.task2n_agents[task], -1) for x in inputs], dim=-1)
        return inputs

    
    def _update_targets_hard(self):
        self.target_mac.load_state(self.mac)
        self.target_mixer.load_state_dict(self.mixer.state_dict())
        self.target_q_critic.load_state_dict(self.q_critic.state_dict())
        self.target_v_critic.load_state_dict(self.v_critic.state_dict())
        
    def _update_targets_soft(self, tau):
        for target_param, param in zip(self.target_mac.parameters(), self.mac.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
       
        for target_param, param in zip(self.target_mixer.parameters(), self.mixer.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
            
        for target_param, param in zip(self.target_q_critic.parameters(), self.q_critic.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
        
        for target_param, param in zip(self.target_v_critic.parameters(), self.v_critic.parameters()):
            target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
        
    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        self.q_critic.cuda()
        self.target_q_critic.cuda()
        self.v_critic.cuda()
        self.target_v_critic.cuda()
        self.mixer.cuda()
        self.target_mixer.cuda()
        self.temporal_encoder.cuda()
        self.local_encoder.cuda()
        self.global_encoder.cuda()
        self.temporal_role_encoder.cuda()
        self.local_role_encoder.cuda()
        self.prior_role_encoder.cuda()
    
    def save_models(self, path):
        self.mac.save_models(path)
        th.save(self.q_critic.state_dict(), "{}/q_critic.th".format(path))  
        th.save(self.v_critic.state_dict(), "{}/v_critic.th".format(path))
        th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.global_encoder.state_dict(), "{}/global_encoder.th".format(path))
        th.save(self.actor_optimiser.state_dict(), "{}/actor_opt.th".format(path))
        th.save(self.q_optimiser.state_dict(), "{}/q_opt.th".format(path))
        th.save(self.v_optimiser.state_dict(), "{}/v_opt.th".format(path))
    
    def load_encoders(self, path):
        self.temporal_encoder.load_state_dict(th.load("{}/temporal_encoder.th".format(path)))
        self.local_encoder.load_state_dict(th.load("{}/local_encoder.th".format(path)))
        self.global_encoder.load_state_dict(th.load("{}/global_encoder.th".format(path)))
    
    def load_role_encoder(self, path):
        self.temporal_role_encoder.load_state_dict(th.load("{}/temporal_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_role_encoder.load_state_dict(th.load("{}/local_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
    
    def load_prior_role_encoder(self, path):
        self.prior_role_encoder.load_state_dict(th.load("{}/prior_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
    
    def load_models(self, path):
        self.mac.load_models(path)
        self.target_mac.load_models(path)
        self.mixer.load_state_dict(th.load("{}/mixer.th".format(path)))
        self.target_mixer.load_state_dict(th.load("{}/mixer.th".format(path)))
        self.q_critic.load_state_dict(th.load("{}/q_critic.th".format(path)))
        self.target_q_critic.load_state_dict(th.load("{}/q_critic.th".format(path)))
        self.v_critic.load_state_dict(th.load("{}/v_critic.th".format(path)))
        self.target_v_critic.load_state_dict(th.load("{}/v_critic.th".format(path)))
        self.global_encoder.load_state_dict(th.load("{}/global_encoder.th".format(path)))
        self.actor_optimiser.load_state_dict(th.load("{}/actor_opt.th".format(path)))
        self.q_optimiser.load_state_dict(th.load("{}/q_opt.th".format(path)))
        self.v_optimiser.load_state_dict(th.load("{}/v_opt.th".format(path)))
                                            
    