from components.episode_buffer import EpisodeBatch
import os
import torch as th
from torch.optim import RMSprop, Adam, AdamW
import torch.nn.functional as F
import torch.nn as nn
from components.standarize_stream import RunningMeanStd
from modules.encoders.transition_encoder import TransformerTransitionEncoder
from modules.encoders.transition_decoder import TransformerTransitionDecoder
from modules.encoders.temporal_encoder import TransformerTemporalEncoder
from modules.encoders.local_encoder import LocalEncoder
from modules.encoders.global_encoder import GlobalEncoder
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# from cuml.manifold import TSNE
# import cupy as cp
import numpy as np
import swanlab


class EncoderLearner:
    def __init__(self, mac, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.train_tasks = mac.train_tasks
        self.task2decomposer = mac.task2decomposer

        self.use_club = args.use_club
        self.wo_global = args.wo_global

        self.temporal_encoder = TransformerTemporalEncoder(self.task2decomposer, args)
        self.local_encoder = LocalEncoder(args)
        self.global_encoder = GlobalEncoder(args)
        self.decoder = TransformerTransitionDecoder(self.task2decomposer, args)
        self.params = list(self.temporal_encoder.parameters()) + list(self.local_encoder.parameters()) + list(self.global_encoder.parameters()) + list(self.decoder.parameters())

        if self.use_club:
            self.club_temporal_encoder = TransformerTemporalEncoder(self.task2decomposer, args, is_club=True)
            self.club_local_encoder = LocalEncoder(args, is_club=True)
            self.club_global_encoder = GlobalEncoder(args, is_club=True)
            self.club_params = list(self.club_temporal_encoder.parameters()) + list(self.club_local_encoder.parameters()) + list(self.club_global_encoder.parameters())

        match self.args.optim_type.lower():
            case "rmsprop":
                self.optimiser = RMSprop(params=self.params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
                if self.use_club:
                    self.club_optimiser = RMSprop(params=self.club_params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
            case "adam":
                self.optimiser = Adam(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
                if self.use_club:
                    self.club_optimiser = Adam(params=self.club_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case "adamw":
                self.optimiser = AdamW(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
                if self.use_club:
                    self.club_optimiser = AdamW(params=self.club_params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case _:
                raise ValueError("Invalid optimiser type", self.args.optim_type)
        
        self.log_stats_t = -self.args.encoder_learner_log_interval - 1
        self.training_steps = 0

    def train(self, train_tasks, main_args, task2args, task2batch, episode):

        task2local_encoding, task2global_encoding = {}, {}
        task2obs, task2actions, task2next_obs, task2rewards, task2mask = {}, {}, {}, {}, {}
        task2obs_o, task2actions_o, task2next_obs_o, task2rewards_o, task2mask_o = {}, {}, {}, {}, {}
        task2t = {}
        club_model_loss = 0
        for task in train_tasks:
            batch = task2batch[task]
            rewards = batch["reward"][:, :-1]
            actions_one_hot = batch["actions_onehot"][:, :-1]
            obs = batch["obs"][:, :-1]
            next_obs = batch["obs"][:, 1:]
            terminated = batch["terminated"][:, :-1].float()
            mask = batch["filled"][:, :-1].float()
            mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
            task2obs_o[task] = obs
            task2actions_o[task] = actions_one_hot
            task2next_obs_o[task] = next_obs
            task2rewards_o[task] = rewards
            task2mask_o[task] = mask
            temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
            local_encoding = self.local_encoder(temporal_encoding)
            global_encoding = self.global_encoder(temporal_encoding)
            bs, n_agents, _ = local_encoding.shape
            local_encoding = local_encoding.reshape(bs*n_agents, -1)
            task2local_encoding[task] = local_encoding
            task2global_encoding[task] = global_encoding
            bs, t, n_agents, _ = obs.shape
            obs = obs.reshape(bs*n_agents*t,-1)
            actions_one_hot = actions_one_hot.reshape(bs*n_agents*t,-1)
            next_obs = next_obs.reshape(bs*n_agents*t,-1)
            rewards = rewards.unsqueeze(2).repeat(1,1,n_agents,1).reshape(bs * t * n_agents, -1)
            mask = mask.unsqueeze(2).repeat(1,1,n_agents,1).reshape(bs * t * n_agents, -1)
            task2obs[task] = obs
            task2actions[task] = actions_one_hot
            task2next_obs[task] = next_obs
            task2rewards[task] = rewards
            task2mask[task] = mask
            task2t[task] = t
            if self.use_club:
                club_temporal_encoding = self.club_temporal_encoder(task2obs_o[task], task2actions_o[task], task2next_obs_o[task], task2rewards_o[task], task2mask_o[task], task)
                club_local_mean, club_local_var = self.club_local_encoder(club_temporal_encoding)
                club_global_mean, club_global_var = self.club_global_encoder(club_temporal_encoding)
                club_local_mean = club_local_mean.reshape(bs*n_agents, -1)
                club_local_var = club_local_var.reshape(bs*n_agents, -1)
                club_model_local_loss = ((local_encoding.detach()-club_local_mean)**2/(2*club_local_var)+th.log(th.sqrt(club_local_var))).mean()
                club_model_global_loss = ((global_encoding.detach()-club_global_mean)**2/(2*club_global_var)+th.log(th.sqrt(club_global_var))).mean()
                club_model_loss = club_model_loss + club_model_local_loss + club_model_global_loss

        if self.use_club:
            club_model_loss = club_model_loss / len(train_tasks)
            self.club_optimiser.zero_grad()
            club_model_loss.backward()
            self.club_optimiser.step()
        
        total_loss = 0

        if not self.wo_global:
            if main_args.cl_loss == "dml":
                global_cl_loss = self.global_dml_loss(train_tasks, main_args, task2global_encoding)
            elif main_args.cl_loss == "cl" or main_args.cl_loss == "global_guided_only_cl":
                global_cl_loss = self.global_cl_loss(train_tasks, main_args, task2global_encoding)
            else:
                assert False, f"Not implemented cl loss {main_args.cl_loss}"
            total_loss += global_cl_loss
            
            if main_args.use_decoder_loss and episode >= getattr(main_args, "encoder_local_start_episode", 0):
                global_decoder_loss = 0
                global_decoder_own_loss, global_decoder_enemy_loss, global_decoder_ally_loss, global_decoder_reward_loss = 0, 0, 0, 0
                for task in train_tasks:
                    n_agents = task2args[task].n_agents
                    obs = task2obs[task]
                    actions = task2actions[task]
                    next_obs = task2next_obs[task]
                    rewards = task2rewards[task]
                    mask = task2mask[task]
                    t = task2t[task]
                    global_encoding = task2global_encoding[task].unsqueeze(1).unsqueeze(1).repeat(1,t,n_agents,1).reshape(bs * t * n_agents,-1)
                    decoder_loss, own_loss, enemy_loss, ally_loss, reward_loss = self.decoder.get_decoding_loss(obs, actions, global_encoding, task, next_obs, rewards, mask)
                    global_decoder_loss += decoder_loss
                    global_decoder_own_loss += own_loss
                    global_decoder_enemy_loss += enemy_loss
                    global_decoder_ally_loss += ally_loss
                    global_decoder_reward_loss += reward_loss
                global_decoder_loss /= len(train_tasks)
                global_decoder_own_loss /= len(train_tasks)
                global_decoder_enemy_loss /= len(train_tasks)
                global_decoder_ally_loss /= len(train_tasks)
                global_decoder_reward_loss /= len(train_tasks)
                total_loss += global_decoder_loss
            
            if self.use_club:
                global_club_loss = 0
                for task in train_tasks:
                    n_agents = task2args[task].n_agents
                    obs = task2obs_o[task]
                    actions = task2actions_o[task]
                    next_obs = task2next_obs_o[task]
                    rewards = task2rewards_o[task]
                    mask = task2mask_o[task]
                    t = task2t[task]
                    club_temporal_encoding = self.club_temporal_encoder(obs, actions, next_obs, rewards, mask, task)
                    club_global_mean, club_global_var = self.club_global_encoder(club_temporal_encoding)
                    club_global_mean = club_global_mean.detach()
                    club_global_var = club_global_var.detach()
                    global_encoding = task2global_encoding[task]
                    positive = - ((global_encoding-club_global_mean)**2/club_global_var).mean()
                    bs = club_global_mean.shape[0]
                    club_global_mean_expand = club_global_mean.unsqueeze(1).expand(-1,bs,-1).reshape(bs**2,-1)
                    club_global_var_expand = club_global_var.unsqueeze(1).expand(-1,bs,-1).reshape(bs**2,-1)
                    global_encoding_expand = global_encoding.repeat(bs, 1)
                    negative = - ((global_encoding_expand - club_global_mean_expand)**2/club_global_var_expand).mean()
                    global_club_loss += main_args.club_loss_weight * (positive - negative)
                global_club_loss /= len(train_tasks)
                total_loss += global_club_loss

        
        if not main_args.global_only and (episode >= getattr(main_args, "encoder_local_start_episode", 0) or self.wo_global):
            # 分离global encoding
            for task in train_tasks:
                task2global_encoding[task] = task2global_encoding[task].detach()
            
            if main_args.use_mg2l:
                local_cl_loss = self.mg2l_local_cl_loss(train_tasks, main_args, task2global_encoding, task2local_encoding)
            elif main_args.cl_loss == "dml":
                local_cl_loss = self.local_dml_loss(train_tasks, main_args, task2global_encoding, task2local_encoding)
            elif main_args.cl_loss == "cl":
                local_cl_loss = self.local_cl_loss(train_tasks, main_args, task2global_encoding, task2local_encoding)
            elif main_args.cl_loss == "global_guided_only_cl":
                local_cl_loss = self.local_cl_loss_global_guided_only(train_tasks, main_args, task2global_encoding, task2local_encoding)
            else:
                assert False, f"Not implemented cl loss {main_args.cl_loss}"
            
            total_loss += local_cl_loss

            if main_args.use_decoder_loss:
                local_decoder_loss = 0
                local_decoder_own_loss, local_decoder_enemy_loss, local_decoder_ally_loss, local_decoder_reward_loss = 0, 0, 0, 0
                for task in train_tasks:
                    n_agents = task2args[task].n_agents
                    obs = task2obs[task]
                    actions = task2actions[task]
                    next_obs = task2next_obs[task]
                    rewards = task2rewards[task]
                    mask = task2mask[task]
                    t = task2t[task]
                    bs = int(task2local_encoding[task].shape[0] / n_agents)
                    local_encoding = task2local_encoding[task].reshape(bs,n_agents,-1).unsqueeze(1).repeat(1,t,1,1).reshape(bs * t * n_agents,-1)
                    decoder_loss, own_loss, enemy_loss, ally_loss, reward_loss = self.decoder.get_decoding_loss(obs, actions, local_encoding, task, next_obs, rewards, mask)
                    local_decoder_loss += decoder_loss
                    local_decoder_own_loss += own_loss
                    local_decoder_enemy_loss += enemy_loss
                    local_decoder_ally_loss += ally_loss
                    local_decoder_reward_loss += reward_loss
                local_decoder_loss /= len(train_tasks)
                local_decoder_own_loss /= len(train_tasks)
                local_decoder_enemy_loss /= len(train_tasks)
                local_decoder_ally_loss /= len(train_tasks)
                local_decoder_reward_loss /= len(train_tasks)
                total_loss += local_decoder_loss
            
            if self.use_club:
                local_club_loss = 0
                for task in train_tasks:
                    n_agents = task2args[task].n_agents
                    obs = task2obs_o[task]
                    actions = task2actions_o[task]
                    next_obs = task2next_obs_o[task]
                    rewards = task2rewards_o[task]
                    mask = task2mask_o[task]
                    t = task2t[task]
                    club_temporal_encoding = self.club_temporal_encoder(obs, actions, next_obs, rewards, mask, task)
                    club_local_mean, club_local_var = self.club_local_encoder(club_temporal_encoding)
                    club_local_mean = club_local_mean.detach()
                    club_local_var = club_local_var.detach()
                    local_encoding = task2local_encoding[task]
                    bs, n_agents, _ = club_local_mean.shape
                    club_local_mean = club_local_mean.reshape(bs*n_agents,-1)
                    club_local_var = club_local_var.reshape(bs*n_agents,-1)
                    positive = - ((local_encoding-club_local_mean)**2/club_local_var).mean()
                    bs = club_local_mean.shape[0]
                    club_local_mean_expand = club_local_mean.unsqueeze(1).expand(-1,bs,-1).reshape(bs**2,-1)
                    club_local_var_expand = club_local_var.unsqueeze(1).expand(-1,bs,-1).reshape(bs**2,-1)
                    local_encoding_expand = local_encoding.repeat(bs, 1)
                    negative = - ((local_encoding_expand - club_local_mean_expand)**2/club_local_var_expand).mean()
                    local_club_loss += main_args.club_loss_weight * (positive - negative)
                local_club_loss /= len(train_tasks)
                total_loss += local_club_loss
        
        self.optimiser.zero_grad()
        total_loss.backward()
        # grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if episode - self.log_stats_t >= self.args.encoder_learner_log_interval:
            self.logger.log_stat(f"loss", total_loss.item(), episode)
            if not self.wo_global:
                self.logger.log_stat(f"global_cl_loss", global_cl_loss.item(), episode)
            if not main_args.global_only and (episode >= getattr(main_args, "encoder_local_start_episode", 0) or self.wo_global):
                self.logger.log_stat(f"local_cl_loss", local_cl_loss.item(), episode)
            # self.logger.log_stat(f"grad_norm", grad_norm.item(), episode)
            if main_args.use_decoder_loss and episode >= getattr(main_args, "encoder_local_start_episode", 0):
                if not self.wo_global:
                    self.logger.log_stat(f"global_decoder_loss", global_decoder_loss.item(), episode)
                    self.logger.log_stat(f"global_decoder_reward_loss", global_decoder_reward_loss.item(), episode)
                    self.logger.log_stat(f"global_decoder_own_loss", global_decoder_own_loss.item(), episode)
                    self.logger.log_stat(f"global_decoder_enemy_loss", global_decoder_enemy_loss.item(), episode)
                    self.logger.log_stat(f"global_decoder_ally_loss", global_decoder_ally_loss.item(), episode)
                if not main_args.global_only:
                    self.logger.log_stat(f"local_decoder_loss", local_decoder_loss.item(), episode)
                    self.logger.log_stat(f"local_decoder_reward_loss", local_decoder_reward_loss.item(), episode)
                    self.logger.log_stat(f"local_decoder_own_loss", local_decoder_own_loss.item(), episode)
                    self.logger.log_stat(f"local_decoder_enemy_loss", local_decoder_enemy_loss.item(), episode)
                    self.logger.log_stat(f"local_decoder_ally_loss", local_decoder_ally_loss.item(), episode)
            if self.use_club:
                self.logger.log_stat(f"club_model_loss", club_model_loss.item(), episode)
                if not self.wo_global:
                    self.logger.log_stat(f"global_club_loss", global_club_loss.item(), episode)
                if not main_args.global_only:
                    self.logger.log_stat(f"local_club_loss", local_club_loss.item(), episode)
            self.log_stats_t = episode
    
    def visualize(self, train_tasks, main_args, task2args, task2batch, episode):
        self.logger.console_logger.info("Starting encoder visualization")
        save_path = os.path.join(main_args.vis_save_dir, str(episode))
        os.makedirs(save_path, exist_ok=True)
        global_x = []
        global_y = []
        local_x = []
        local_y = []
        with th.no_grad():
            for id, task in enumerate(train_tasks):
                batch = task2batch[task]
                rewards = batch["reward"][:, :-1]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                obs = batch["obs"][:, :-1]
                next_obs = batch["obs"][:, 1:]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                local_encoding = self.local_encoder(temporal_encoding)
                global_encoding = self.global_encoder(temporal_encoding).cpu().detach().numpy()
                bs, n_agents, _ = local_encoding.shape
                local_encoding = local_encoding.reshape(bs*n_agents, -1).cpu().detach().numpy()
                for i in range(main_args.vis_batch_size):
                    global_x.append(global_encoding[i])
                    global_y.append(id)
                for j in range(main_args.vis_batch_size*n_agents):
                    local_x.append(local_encoding[j])
                    local_y.append(id)
        
        # tsne = TSNE(n_components=2, init='pca', random_state=0)
        # total_x = global_x + local_x
        # len_global = len(global_x)
        # X_tsne = tsne.fit_transform(np.asarray(total_x))
        # x_min, x_max = np.min(X_tsne, 0), np.max(X_tsne, 0)
        # global_X_tsne = X_tsne[:len_global]
        # local_X_tsne = X_tsne[len_global:]
        # global_data = (global_X_tsne - x_min) / (x_max - x_min)
        # local_data = (local_X_tsne - x_min) / (x_max - x_min)

        global_tsne = TSNE(n_components=2, init='pca', random_state=0)
        local_tsne = TSNE(n_components=2, init='pca', random_state=0)
        global_X_tsne = global_tsne.fit_transform(np.asarray(global_x))
        local_X_tsne = local_tsne.fit_transform(np.asarray(local_x))
        global_x_min, global_x_max = np.min(global_X_tsne, 0), np.max(global_X_tsne, 0)
        local_x_min, local_x_max = np.min(local_X_tsne, 0), np.max(local_X_tsne, 0)
        global_data = (global_X_tsne - global_x_min) / (global_x_max - global_x_min)
        local_data = (local_X_tsne - local_x_min) / (local_x_max - local_x_min)

        colors = plt.cm.rainbow(np.linspace(0,1,len(train_tasks)))

        plt.clf()
        for i in range(global_data.shape[0]):
            plt.scatter(global_data[i, 0], global_data[i, 1],
                    color=colors[global_y[i]])
        plt.xticks([])
        plt.yticks([])
        if main_args.use_swanlab:
            swanlab.log({f"global_vis": swanlab.Image(plt)}, step=episode)
        plt.savefig(save_path+"/global_vis.png")

        plt.clf()
        for i in range(local_data.shape[0]):
            plt.scatter(local_data[i, 0], local_data[i, 1],
                    color=colors[local_y[i]])
        plt.xticks([])
        plt.yticks([])
        if main_args.use_swanlab:
            swanlab.log({f"local_vis": swanlab.Image(plt)}, step=episode)
        plt.savefig(save_path+"/local_vis.png")

        plt.clf()
        for i in range(global_data.shape[0]):
            plt.scatter(global_data[i, 0], global_data[i, 1],
                    color=colors[global_y[i]])
        for i in range(local_data.shape[0]):
            plt.scatter(local_data[i, 0], local_data[i, 1],
                    color=colors[local_y[i]])
        plt.xticks([])
        plt.yticks([])
        if main_args.use_swanlab:
            swanlab.log({f"total_vis": swanlab.Image(plt)}, step=episode)
        plt.savefig(save_path+"/total_vis.png")
        self.logger.console_logger.info("Finishing encoder visualization")
    
    def get_visualize_local_data(self, train_tasks, main_args, task2args, task2batch, episode, data_save_path):
        self.logger.console_logger.info("Starting encoder visualization")
        save_path = os.path.join(main_args.vis_save_dir, str(episode))
        os.makedirs(save_path, exist_ok=True)
        local_x = []
        local_y = []
        with th.no_grad():
            for id, task in enumerate(train_tasks):
                batch = task2batch[task]
                rewards = batch["reward"][:, :-1]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                obs = batch["obs"][:, :-1]
                next_obs = batch["obs"][:, 1:]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                local_encoding = self.local_encoder(temporal_encoding)
                bs, n_agents, _ = local_encoding.shape
                local_encoding = local_encoding.reshape(bs*n_agents, -1).cpu().detach().numpy()
                for j in range(main_args.vis_batch_size*n_agents):
                    local_x.append(local_encoding[j])
                    local_y.append(id)

        local_tsne = TSNE(n_components=2, init='pca', random_state=0)
        local_X_tsne = local_tsne.fit_transform(np.asarray(local_x))
        local_x_min, local_x_max = np.min(local_X_tsne, 0), np.max(local_X_tsne, 0)
        local_data = (local_X_tsne - local_x_min) / (local_x_max - local_x_min)
        np.savez(data_save_path, local_data=local_data, local_y=local_y)
        self.logger.console_logger.info("Finishing encoder visualization")
        
    
    # q, k (b, dim); neg (b, N, dim)
    def contrastive_loss(self, q, k, neg):
        N = neg.shape[1]
        b = q.shape[0]
        l_pos = th.bmm(q.view(b, 1, -1), k.view(b, -1, 1)) # (b,1,1)
        l_neg = th.bmm(q.view(b, 1, -1), neg.transpose(1,2)) # (b,1,N)
        logits = th.cat([l_pos.view(b, 1), l_neg.view(b, N)], dim=1)
        
        labels = th.zeros(b, dtype=th.long)
        labels = labels.to(q.device)
        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/self.args.infonce_temp, labels)
        return loss
    
    def global_dml_loss(self, train_tasks, main_args, task2global_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    for mb_j in range(main_args.meta_batch_size):
                        z_neg = task2global_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def local_dml_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    bs_j = task2local_encoding[train_tasks[j]].shape[0]
                    for mb_j in range(bs_j):
                        z_neg = task2local_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos]
                        pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                        pos_cnt += 1
                # 这里和global之间先只做positive而不做negative，可以节省非常多训练时间
                # for k in range(len(train_tasks)):
                #     if k == i:
                #         continue
                #     for gl_k in range(main_args.meta_batch_size):
                #         z_neg = task2global_encoding[train_tasks[k]][gl_k]
                #         neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                #         neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def global_cl_loss(self, train_tasks, main_args, task2global_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_ls.append(task2global_encoding[train_tasks[j]])
            z_neg = th.stack(neg_z_ls, dim=0).unsqueeze(0)
            # z_neg = th.cat(neg_z_ls, dim=0)
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    bs, A, B, dim = z_neg.shape
                    indices = th.randint(0, B, (bs, A)).to(z_neg.device)
                    indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    z_neg_sampled = th.gather(z_neg, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
        return cl_loss/total_cnt
    
    def local_cl_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_local_ls = []
            neg_z_global_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_local_ls.append(task2local_encoding[train_tasks[j]])
                neg_z_global_ls.append(task2global_encoding[train_tasks[j]])
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            # z_neg_local = th.stack(neg_z_local_ls, dim=0).unsqueeze(0)
            z_neg_global = th.stack(neg_z_global_ls, dim=0).unsqueeze(0)
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    z_neg_local_tmp = []
                    for new_j in range(len(neg_z_local_ls)):
                        tmp_bs = neg_z_local_ls[new_j].shape[0]
                        indices = th.randint(0,tmp_bs,(1,))
                        new_res = neg_z_local_ls[new_j][indices, :][0]
                        z_neg_local_tmp.append(new_res)
                    z_neg_sampled = th.stack(z_neg_local_tmp, dim=0).unsqueeze(0)
                    # print(z_neg_sampled.shape)
                    # assert False

                    # bs, A, B, dim = z_neg_local.shape
                    # indices = th.randint(0, B, (bs, A)).to(z_neg_local.device)
                    # indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    # z_neg_sampled = th.gather(z_neg_local, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos].unsqueeze(0)

                        bs, A, B, dim = z_neg_global.shape
                        indices = th.randint(0, B, (bs, A)).to(z_neg_global.device)
                        indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                        z_neg_sampled = th.gather(z_neg_global, dim=2, index=indices_expanded).reshape(bs, A, dim)

                        cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                        total_cnt += 1

        return cl_loss/total_cnt
    
    def local_cl_loss_global_guided_only(self, train_tasks, main_args, task2global_encoding, task2local_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_global_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_global_ls.append(task2global_encoding[train_tasks[j]])
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            # z_neg_local = th.stack(neg_z_local_ls, dim=0).unsqueeze(0)
            z_neg_global = th.stack(neg_z_global_ls, dim=0).unsqueeze(0)
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos].unsqueeze(0)

                        bs, A, B, dim = z_neg_global.shape
                        indices = th.randint(0, B, (bs, A)).to(z_neg_global.device)
                        indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                        z_neg_sampled = th.gather(z_neg_global, dim=2, index=indices_expanded).reshape(bs, A, dim)

                        cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                        total_cnt += 1

        return cl_loss/total_cnt

    def mg2l_local_cl_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding):
        total_cnt = 0
        cl_loss = 0

        for task in train_tasks:
            local_z = task2local_encoding[task]
            global_z =  task2global_encoding[task]
            bs = global_z.shape[0]
            n_agents = int(local_z.shape[0]/bs)
            global_z = global_z.unsqueeze(1).repeat(1,n_agents,1).reshape(bs*n_agents,-1)
            task2global_encoding[task] = global_z.detach().clone()

        for i in range(len(train_tasks)):
            # neg_z_local_ls = []
            # neg_z_global_ls = []
            neg_z_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                local_z = task2local_encoding[train_tasks[j]]
                global_z =  task2global_encoding[train_tasks[j]]
                neg_z = th.cat([global_z, local_z], dim=-1)
                neg_z_ls.append(neg_z)
                # neg_z_local_ls.append(task2local_encoding[train_tasks[j]])
                # neg_z_global_ls.append(task2global_encoding[train_tasks[j]])
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            # z_neg = th.stack(neg_z_ls, dim=0).unsqueeze(0)
            # z_neg_global = th.stack(neg_z_global_ls, dim=0).unsqueeze(0)
            for mb_i in range(bs_i):
                z_q_local = task2local_encoding[train_tasks[i]][mb_i]
                z_q_global = task2global_encoding[train_tasks[i]][mb_i]
                z_q = th.cat([z_q_global, z_q_local], dim=-1).unsqueeze(0)
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos_local = task2local_encoding[train_tasks[i]][mb_pos]
                    z_pos_global = task2global_encoding[train_tasks[i]][mb_pos]
                    z_pos = th.cat([z_pos_global, z_pos_local], dim=-1).unsqueeze(0)

                    # bs, A, B, dim = z_neg.shape
                    # indices = th.randint(0, B, (bs, A)).to(z_neg.device)
                    # indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    # z_neg_sampled = th.gather(z_neg, dim=2, index=indices_expanded).reshape(bs, A, dim)
                    z_neg_tmp = []
                    for new_j in range(len(neg_z_ls)):
                        tmp_bs = neg_z_ls[new_j].shape[0]
                        indices = th.randint(0,tmp_bs,(1,))
                        new_res = neg_z_ls[new_j][indices, :][0]
                        z_neg_tmp.append(new_res)
                    z_neg_sampled = th.stack(z_neg_tmp, dim=0).unsqueeze(0)

                    # bs, A, B, dim = z_neg_global.shape
                    # indices = th.randint(0, B, (bs, A)).to(z_neg_global.device)
                    # indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    # z_neg_global_sampled = th.gather(z_neg_global, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1

        return cl_loss/total_cnt

    def cuda(self):
        self.mac.cuda()
        self.temporal_encoder.cuda()
        self.local_encoder.cuda()
        self.global_encoder.cuda()
        self.decoder.cuda()
        if self.use_club:
            self.club_temporal_encoder.cuda()
            self.club_local_encoder.cuda()
            self.club_global_encoder.cuda()
    
    def save_models(self, path):
        # self.mac.save_models(path)
        th.save(self.temporal_encoder.state_dict(), "{}/temporal_encoder.th".format(path))
        th.save(self.local_encoder.state_dict(), "{}/local_encoder.th".format(path))
        th.save(self.global_encoder.state_dict(), "{}/global_encoder.th".format(path))
        if self.args.use_decoder_loss:
            th.save(self.decoder.state_dict(), "{}/decoder.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))
        if self.use_club:
            th.save(self.club_temporal_encoder.state_dict(), "{}/club_temporal_encoder.th".format(path))
            th.save(self.club_local_encoder.state_dict(), "{}/club_local_encoder.th".format(path))
            th.save(self.club_global_encoder.state_dict(), "{}/club_global_encoder.th".format(path))
            th.save(self.club_optimiser.state_dict(), "{}/club_opt.th".format(path))
    
    def load_models(self, path):
        # self.mac.load_models(path)
        self.temporal_encoder.load_state_dict(th.load("{}/temporal_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_encoder.load_state_dict(th.load("{}/local_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.global_encoder.load_state_dict(th.load("{}/global_encoder.th".format(path), map_location=lambda storage, loc: storage))
        if self.args.use_decoder_loss:
            self.decoder.load_state_dict(th.load("{}/decoder.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
        if self.use_club:
            self.club_temporal_encoder.load_state_dict(th.load("{}/club_temporal_encoder.th".format(path), map_location=lambda storage, loc: storage))
            self.club_local_encoder.load_state_dict(th.load("{}/club_local_encoder.th".format(path), map_location=lambda storage, loc: storage))
            self.club_global_encoder.load_state_dict(th.load("{}/club_global_encoder.th".format(path), map_location=lambda storage, loc: storage))
            self.club_optimiser.load_state_dict(th.load("{}/club_opt.th".format(path), map_location=lambda storage, loc: storage))
