from components.episode_buffer import EpisodeBatch
import os
import torch as th
from torch.optim import RMSprop, Adam, AdamW
import torch.nn.functional as F
import torch.nn as nn
from components.standarize_stream import RunningMeanStd
from modules.encoders.transition_encoder import TransformerTransitionEncoder
from modules.encoders.transition_decoder import TransformerTransitionDecoder, TransformerTransitionRoleDecoder
from modules.encoders.temporal_encoder import TransformerTemporalEncoder, TransformerTemporalRoleEncoder
from modules.encoders.local_encoder import LocalEncoder, LocalRoleEncoder
from modules.encoders.global_encoder import GlobalEncoder
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# from cuml.manifold import TSNE
# import cupy as cp
import numpy as np
import swanlab


class RoleEncoderLearner:
    def __init__(self, mac, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.train_tasks = mac.train_tasks
        self.task2decomposer = mac.task2decomposer

        self.temporal_encoder = TransformerTemporalEncoder(self.task2decomposer, args)
        self.local_encoder = LocalEncoder(args)
        self.global_encoder = GlobalEncoder(args)

        self.load_task_encoder(args.encoder_path_ls[args.encoder_id])

        self.temporal_role_encoder = TransformerTemporalRoleEncoder(self.task2decomposer, args)
        self.local_role_encoder = LocalRoleEncoder(args)
        self.role_decoder = TransformerTransitionRoleDecoder(self.task2decomposer, args)

        self.params = list(self.temporal_role_encoder.parameters()) + list(self.local_role_encoder.parameters()) + list(self.role_decoder.parameters())

        match self.args.optim_type.lower():
            case "rmsprop":
                self.optimiser = RMSprop(params=self.params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
            case "adam":
                self.optimiser = Adam(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case "adamw":
                self.optimiser = AdamW(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case _:
                raise ValueError("Invalid optimiser type", self.args.optim_type)
        
        self.log_stats_t = -self.args.encoder_learner_log_interval - 1
        self.training_steps = 0

    def train(self, train_roles, main_args, task2args, role2episode_sample, role2episode_positive, role2episode_negative, role2encoding_sample, role2encoding_positive, role2encoding_negative, role2task_sample, role2task_positive, role2task_negative, episode):
        
        role2local_encoding = {}
        role2local_encoding_positive = {}
        role2local_encoding_negative = {}
        with th.no_grad():
            for role in train_roles:
                batch = role2encoding_sample[role]
                task = role2task_sample[role]
                rewards = batch["reward"][:, :-1]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                obs = batch["obs"][:, :-1]
                next_obs = batch["obs"][:, 1:]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                local_encoding = self.local_encoder(temporal_encoding)
                bs, n_agents, _ = local_encoding.shape
                local_encoding = local_encoding.reshape(bs*n_agents, -1)
                shuffled_indices = th.randperm(local_encoding.size(0))
                local_encoding = local_encoding[shuffled_indices][:bs]
                role2local_encoding[role] = local_encoding

                batch = role2encoding_positive[role]
                task = role2task_positive[role]
                rewards = batch["reward"][:, :-1]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                obs = batch["obs"][:, :-1]
                next_obs = batch["obs"][:, 1:]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                local_encoding = self.local_encoder(temporal_encoding)
                bs, n_agents, _ = local_encoding.shape
                local_encoding = local_encoding.reshape(bs*n_agents, -1)
                shuffled_indices = th.randperm(local_encoding.size(0))
                local_encoding = local_encoding[shuffled_indices][:bs]
                role2local_encoding_positive[role] = local_encoding

                role2local_encoding_negative[role] = []
                for i in range(len(role2encoding_negative[role])):
                    batch = role2encoding_negative[role][i]
                    task = role2task_negative[role][i]
                    rewards = batch["reward"][:, :-1]
                    actions_one_hot = batch["actions_onehot"][:, :-1]
                    obs = batch["obs"][:, :-1]
                    next_obs = batch["obs"][:, 1:]
                    terminated = batch["terminated"][:, :-1].float()
                    mask = batch["filled"][:, :-1].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                    temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                    local_encoding = self.local_encoder(temporal_encoding)
                    bs, n_agents, _ = local_encoding.shape
                    local_encoding = local_encoding.reshape(bs*n_agents, -1)
                    shuffled_indices = th.randperm(local_encoding.size(0))
                    local_encoding = local_encoding[shuffled_indices][:bs]
                    role2local_encoding_negative[role].append(local_encoding)
        
        role2role_encoding = {}
        role2role_encoding_positive = {}
        role2role_encoding_negative = {}
        for role in train_roles:
            batch = role2episode_sample[role]
            task = role2task_sample[role]
            actions_one_hot = batch["actions_onehot"][:, :-1]
            actions_long = batch["actions"][:, :-1]
            obs = batch["obs"][:, :-1]
            terminated = batch["terminated"][:, :-1].float()
            mask = batch["filled"][:, :-1].float()
            mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
            temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
            local_task_encoding = role2local_encoding[role]
            role_encoding = self.local_role_encoder(temporal_role_encoding, local_task_encoding)
            role2role_encoding[role] = role_encoding

            batch = role2episode_positive[role]
            task = role2task_positive[role]
            actions_one_hot = batch["actions_onehot"][:, :-1]
            actions_long = batch["actions"][:, :-1]
            obs = batch["obs"][:, :-1]
            terminated = batch["terminated"][:, :-1].float()
            mask = batch["filled"][:, :-1].float()
            mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
            temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
            local_task_encoding = role2local_encoding_positive[role]
            role_encoding = self.local_role_encoder(temporal_role_encoding, local_task_encoding)
            role2role_encoding_positive[role] = role_encoding

            role2role_encoding_negative[role] = []
            for i in range(len(role2episode_negative[role])):
                batch = role2episode_negative[role][i]
                task = role2task_negative[role][i]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                actions_long = batch["actions"][:, :-1]
                obs = batch["obs"][:, :-1]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
                local_task_encoding = role2local_encoding_negative[role][i]
                role_encoding = self.local_role_encoder(temporal_role_encoding, local_task_encoding)
                role2role_encoding_negative[role].append(role_encoding)
            role2role_encoding_negative[role] = th.stack(role2role_encoding_negative[role], dim=1)
        q = []
        p = []
        n = []
        for role in train_roles:
            q.append(role2role_encoding[role])
            p.append(role2role_encoding_positive[role])
            n.append(role2role_encoding_negative[role])
        q = th.cat(q, dim=0)
        p = th.cat(p, dim=0)
        n = th.cat(n, dim=0)

        total_loss = 0
        cl_loss = self.contrastive_loss(q, p, n)
        total_loss += cl_loss
        
        if main_args.use_decoder_loss:
            total_decoder_loss = 0
            for role in train_roles:
                batch = role2episode_sample[role]
                task = role2task_sample[role]
                actions_long = batch["actions"][:, :-1]
                obs = batch["obs"][:, :-1]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                bs, t, _ = obs.shape
                obs = obs.reshape(bs*t, -1)
                actions_long = actions_long.reshape(bs*t, 1)
                mask = mask.reshape(bs*t, 1)
                temp_role_encoding = role2role_encoding[role].unsqueeze(1).repeat(1,t,1).reshape(bs*t,-1)
                decoder_loss = self.role_decoder.get_decoding_loss(obs, actions_long, temp_role_encoding, task, mask)
                total_decoder_loss += decoder_loss
                
            total_decoder_loss /= len(train_roles)
            total_loss += total_decoder_loss
        
        self.optimiser.zero_grad()
        total_loss.backward()
        # grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if episode - self.log_stats_t >= self.args.encoder_learner_log_interval:
            self.logger.log_stat(f"loss", total_loss.item(), episode)
            self.logger.log_stat(f"cl_loss", cl_loss.item(), episode)
            # self.logger.log_stat(f"grad_norm", grad_norm.item(), episode)
            if main_args.use_decoder_loss:
                self.logger.log_stat(f"decoder_loss", total_decoder_loss.item(), episode)
                
            self.log_stats_t = episode
    
    def visualize(self, train_roles, main_args, task2args, role2episode_sample_vis, role2encoding_sample_vis, role2task_sample_vis, episode):
        self.logger.console_logger.info("Starting encoder visualization")
        save_path = os.path.join(main_args.vis_save_dir, str(episode))
        role2task_ls = main_args.role2task
        os.makedirs(save_path, exist_ok=True)
        x = []
        y = []
        with th.no_grad():
            for role in train_roles:
                task_ls = role2task_ls[role]
                for i, task in enumerate(task_ls):
                    batch = role2encoding_sample_vis[role][i]
                    rewards = batch["reward"][:, :-1]
                    actions_one_hot = batch["actions_onehot"][:, :-1]
                    obs = batch["obs"][:, :-1]
                    next_obs = batch["obs"][:, 1:]
                    terminated = batch["terminated"][:, :-1].float()
                    mask = batch["filled"][:, :-1].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                    temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                    local_encoding = self.local_encoder(temporal_encoding)
                    bs, n_agents, _ = local_encoding.shape
                    local_encoding = local_encoding.reshape(bs*n_agents, -1)
                    shuffled_indices = th.randperm(local_encoding.size(0))
                    local_encoding = local_encoding[shuffled_indices][:bs]

                    batch = role2episode_sample_vis[role][i]
                    actions_one_hot = batch["actions_onehot"][:, :-1]
                    actions_long = batch["actions"][:, :-1]
                    obs = batch["obs"][:, :-1]
                    terminated = batch["terminated"][:, :-1].float()
                    mask = batch["filled"][:, :-1].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                    temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
                    role_encoding = self.local_role_encoder(temporal_role_encoding, local_encoding).cpu().detach().numpy()

                    for i in range(main_args.vis_batch_size):
                        x.append(role_encoding[i])
                        y.append(role)
        
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        total_x = x
        X_tsne = tsne.fit_transform(np.asarray(total_x))
        x_min, x_max = np.min(X_tsne, 0), np.max(X_tsne, 0)
        X_data = (X_tsne - x_min) / (x_max - x_min)

        colors = plt.cm.rainbow(np.linspace(0,1,len(train_roles)))

        plt.clf()
        for i in range(X_data.shape[0]):
            plt.scatter(X_data[i, 0], X_data[i, 1],
                    color=colors[y[i]])
        plt.xticks([])
        plt.yticks([])
        if main_args.use_swanlab:
            swanlab.log({f"role_vis": swanlab.Image(plt)}, step=episode)
        plt.savefig(save_path+"/role_vis.png")
        self.logger.console_logger.info("Finishing encoder visualization")
    
    def get_visualize_data(self, train_roles, main_args, task2args, role2episode_sample_vis, role2encoding_sample_vis, role2task_sample_vis, episode):
        self.logger.console_logger.info("Starting encoder visualization")
        save_path = os.path.join(main_args.vis_save_dir, str(episode))
        role2task_ls = main_args.role2task
        os.makedirs(save_path, exist_ok=True)
        x = []
        y = []
        task_id = []
        with th.no_grad():
            for role in train_roles:
                task_ls = role2task_ls[role]
                for i, task in enumerate(task_ls):
                    batch = role2encoding_sample_vis[role][i]
                    rewards = batch["reward"][:, :-1]
                    actions_one_hot = batch["actions_onehot"][:, :-1]
                    obs = batch["obs"][:, :-1]
                    next_obs = batch["obs"][:, 1:]
                    terminated = batch["terminated"][:, :-1].float()
                    mask = batch["filled"][:, :-1].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                    temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                    local_encoding = self.local_encoder(temporal_encoding)
                    bs, n_agents, _ = local_encoding.shape
                    local_encoding = local_encoding.reshape(bs*n_agents, -1)
                    shuffled_indices = th.randperm(local_encoding.size(0))
                    local_encoding = local_encoding[shuffled_indices][:bs]

                    batch = role2episode_sample_vis[role][i]
                    actions_one_hot = batch["actions_onehot"][:, :-1]
                    actions_long = batch["actions"][:, :-1]
                    obs = batch["obs"][:, :-1]
                    terminated = batch["terminated"][:, :-1].float()
                    mask = batch["filled"][:, :-1].float()
                    mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                    temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
                    role_encoding = self.local_role_encoder(temporal_role_encoding, local_encoding).cpu().detach().numpy()

                    for j in range(main_args.vis_batch_size):
                        x.append(role_encoding[j])
                        y.append(role)
                        task_id.append(task)
        
        tsne = TSNE(n_components=2, init='pca', random_state=0)
        total_x = x
        X_tsne = tsne.fit_transform(np.asarray(total_x))
        x_min, x_max = np.min(X_tsne, 0), np.max(X_tsne, 0)
        X_data = (X_tsne - x_min) / (x_max - x_min)

        np.savez(save_path+"/role_vis_data.npz", data=X_data, y=y, task_id=task_id)
        self.logger.console_logger.info("Finishing encoder visualization")
        
    
    # q, k (b, dim); neg (b, N, dim)
    def contrastive_loss(self, q, k, neg):
        N = neg.shape[1]
        b = q.shape[0]
        l_pos = th.bmm(q.view(b, 1, -1), k.view(b, -1, 1)) # (b,1,1)
        l_neg = th.bmm(q.view(b, 1, -1), neg.transpose(1,2)) # (b,1,N)
        logits = th.cat([l_pos.view(b, 1), l_neg.view(b, N)], dim=1)
        
        labels = th.zeros(b, dtype=th.long)
        labels = labels.to(q.device)
        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/self.args.infonce_temp, labels)
        return loss
    
    def global_dml_loss(self, train_tasks, main_args, task2global_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    for mb_j in range(main_args.meta_batch_size):
                        z_neg = task2global_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def local_dml_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    bs_j = task2local_encoding[train_tasks[j]].shape[0]
                    for mb_j in range(bs_j):
                        z_neg = task2local_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos]
                        pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                        pos_cnt += 1
                # 这里和global之间先只做positive而不做negative，可以节省非常多训练时间
                # for k in range(len(train_tasks)):
                #     if k == i:
                #         continue
                #     for gl_k in range(main_args.meta_batch_size):
                #         z_neg = task2global_encoding[train_tasks[k]][gl_k]
                #         neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                #         neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def global_cl_loss(self, train_tasks, main_args, task2global_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_ls.append(task2global_encoding[train_tasks[j]])
            z_neg = th.stack(neg_z_ls, dim=0).unsqueeze(0)
            # z_neg = th.cat(neg_z_ls, dim=0)
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    bs, A, B, dim = z_neg.shape
                    indices = th.randint(0, B, (bs, A)).to(z_neg.device)
                    indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    z_neg_sampled = th.gather(z_neg, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
        return cl_loss/total_cnt
    
    def local_cl_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_local_ls = []
            neg_z_global_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_local_ls.append(task2local_encoding[train_tasks[j]])
                neg_z_global_ls.append(task2global_encoding[train_tasks[j]])
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            z_neg_local = th.stack(neg_z_local_ls, dim=0).unsqueeze(0)
            z_neg_global = th.stack(neg_z_global_ls, dim=0).unsqueeze(0)
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    bs, A, B, dim = z_neg_local.shape
                    indices = th.randint(0, B, (bs, A)).to(z_neg_local.device)
                    indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    z_neg_sampled = th.gather(z_neg_local, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos].unsqueeze(0)

                        bs, A, B, dim = z_neg_global.shape
                        indices = th.randint(0, B, (bs, A)).to(z_neg_global.device)
                        indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                        z_neg_sampled = th.gather(z_neg_global, dim=2, index=indices_expanded).reshape(bs, A, dim)

                        cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                        total_cnt += 1

        return cl_loss/total_cnt

    def cuda(self):
        self.mac.cuda()
        self.temporal_encoder.cuda()
        self.local_encoder.cuda()
        self.global_encoder.cuda()
        self.temporal_role_encoder.cuda()
        self.local_role_encoder.cuda()
        self.role_decoder.cuda()
    
    def save_models(self, path):
        # self.mac.save_models(path)
        th.save(self.temporal_role_encoder.state_dict(), "{}/temporal_role_encoder.th".format(path))
        th.save(self.local_role_encoder.state_dict(), "{}/local_role_encoder.th".format(path))
        th.save(self.role_decoder.state_dict(), "{}/role_decoder.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))
    
    def load_task_encoder(self, path):
        # self.mac.load_models(path)
        self.temporal_encoder.load_state_dict(th.load("{}/temporal_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_encoder.load_state_dict(th.load("{}/local_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.global_encoder.load_state_dict(th.load("{}/global_encoder.th".format(path), map_location=lambda storage, loc: storage))
    
    def load_models(self, path):
        # self.mac.load_models(path)
        self.temporal_role_encoder.load_state_dict(th.load("{}/temporal_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_role_encoder.load_state_dict(th.load("{}/local_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.role_decoder.load_state_dict(th.load("{}/role_decoder.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
