from components.episode_buffer import EpisodeBatch
import os
import torch as th
from torch.optim import RMSprop, Adam, AdamW
import torch.nn.functional as F
import torch.nn as nn
from components.standarize_stream import RunningMeanStd
from modules.encoders.transition_encoder import TransformerTransitionEncoder, TransformerPriorRoleEncoder
from modules.encoders.transition_decoder import TransformerTransitionDecoder, TransformerTransitionRoleDecoder
from modules.encoders.temporal_encoder import TransformerTemporalEncoder, TransformerTemporalRoleEncoder
from modules.encoders.local_encoder import LocalEncoder, LocalRoleEncoder
from modules.encoders.global_encoder import GlobalEncoder
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# from cuml.manifold import TSNE
# import cupy as cp
import numpy as np
import swanlab


class PriorRoleEncoderLearner:
    def __init__(self, mac, logger, args):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.train_tasks = mac.train_tasks
        self.task2decomposer = mac.task2decomposer

        self.temporal_encoder = TransformerTemporalEncoder(self.task2decomposer, args)
        self.local_encoder = LocalEncoder(args)
        self.global_encoder = GlobalEncoder(args)

        self.load_task_encoder(args.encoder_path_ls[args.encoder_id])

        self.temporal_role_encoder = TransformerTemporalRoleEncoder(self.task2decomposer, args)
        self.local_role_encoder = LocalRoleEncoder(args)
        # self.role_decoder = TransformerTransitionRoleDecoder(self.task2decomposer, args)
        self.load_role_encoder(args.role_encoder_path_ls[args.role_encoder_id])

        self.prior_role_encoder = TransformerPriorRoleEncoder(self.task2decomposer, args)

        self.weighted_prior_learning = getattr(args, "weighted_prior_learning", False)


        self.params = self.prior_role_encoder.parameters()

        match self.args.optim_type.lower():
            case "rmsprop":
                self.optimiser = RMSprop(params=self.params, lr=self.args.lr, alpha=self.args.optim_alpha, eps=self.args.optim_eps, weight_decay=self.args.weight_decay)
            case "adam":
                self.optimiser = Adam(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case "adamw":
                self.optimiser = AdamW(params=self.params, lr=self.args.lr, weight_decay=self.args.weight_decay)
            case _:
                raise ValueError("Invalid optimiser type", self.args.optim_type)
        
        self.log_stats_t = -self.args.encoder_learner_log_interval - 1
        self.training_steps = 0

    def train(self, train_tasks, main_args, task2args, task2episode_sample, task2encoding_sample, episode):
        total_loss = 0
        for task in train_tasks:
            with th.no_grad():
                batch = task2encoding_sample[task]
                rewards = batch["reward"][:, :-1]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                obs = batch["obs"][:, :-1]
                next_obs = batch["obs"][:, 1:]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                temporal_encoding = self.temporal_encoder(obs, actions_one_hot, next_obs, rewards, mask, task)
                local_encoding = self.local_encoder(temporal_encoding)

                batch = task2episode_sample[task]
                actions_one_hot = batch["actions_onehot"][:, :-1]
                actions_long = batch["actions"][:, :-1]
                obs = batch["obs"][:, :-1]
                terminated = batch["terminated"][:, :-1].float()
                mask = batch["filled"][:, :-1].float()
                mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
                bs, max_t, n_agents, _ = obs.shape
                temporal_role_encoding = self.temporal_role_encoder(obs, actions_one_hot, mask, task)
                role_encoding = self.local_role_encoder(temporal_role_encoding, local_encoding)
                role_encoding = role_encoding.unsqueeze(1).repeat(1, max_t, 1, 1)

            if main_args.prior_role_use_history:
                mac_out = []
                prior_role_hidden = self.prior_role_encoder.init_hidden().unsqueeze(0).expand(bs, n_agents, -1)
                for t in range(batch.max_seq_length):
                    if t == batch.max_seq_length - 1:
                        prior_role_encoding, prior_role_hidden = self.prior_role_encoder(obs[:,-1,:,:], local_encoding, task, hidden_state=prior_role_hidden)
                    else:
                        prior_role_encoding, prior_role_hidden = self.prior_role_encoder(obs[:,t,:,:], local_encoding, task, hidden_state=prior_role_hidden)
                    mac_out.append(prior_role_encoding)
                # print(prior_role_encoding.shape, mask.shape)
                # assert False
                prior_role_encoding = th.stack(mac_out, dim=1)[:,:-1]
            else:
                prior_role_encoding, _ = self.prior_role_encoder(obs, local_encoding, task)
            
            weight = rewards.unsqueeze(2).repeat(1, 1, n_agents, 1).sum(1).unsqueeze(1).repeat(1, max_t, 1, 1)
            mean_weight = rewards.sum(1).mean()
            weight = self.args.weight_alpha * (weight - mean_weight)
            weight = th.clamp(weight, min=self.args.weight_min, max=self.args.weight_max)
            exp_w = th.exp(weight).detach()

            encoding_error = (prior_role_encoding - role_encoding.detach())
            mask = mask.unsqueeze(2).repeat(1,1,n_agents,1)
            masked_encoding_error = encoding_error * mask
            if self.weighted_prior_learning:
                temp_loss = (exp_w * (masked_encoding_error ** 2)).sum() / mask.sum()
            else:
                temp_loss = (masked_encoding_error ** 2).sum() / mask.sum()
            total_loss += temp_loss
        total_loss = total_loss / len(train_tasks)
        
        self.optimiser.zero_grad()
        total_loss.backward()
        # grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()

        if episode - self.log_stats_t >= self.args.encoder_learner_log_interval:
            self.logger.log_stat(f"loss", total_loss.item(), episode)
            self.log_stats_t = episode
    
    def visualize(self, train_roles, main_args, task2args, role2episode_sample_vis, role2encoding_sample_vis, role2task_sample_vis, episode):
        pass
        
    
    # q, k (b, dim); neg (b, N, dim)
    def contrastive_loss(self, q, k, neg):
        N = neg.shape[1]
        b = q.shape[0]
        l_pos = th.bmm(q.view(b, 1, -1), k.view(b, -1, 1)) # (b,1,1)
        l_neg = th.bmm(q.view(b, 1, -1), neg.transpose(1,2)) # (b,1,N)
        logits = th.cat([l_pos.view(b, 1), l_neg.view(b, N)], dim=1)
        
        labels = th.zeros(b, dtype=th.long)
        labels = labels.to(q.device)
        cross_entropy_loss = nn.CrossEntropyLoss()
        loss = cross_entropy_loss(logits/self.args.infonce_temp, labels)
        return loss
    
    def global_dml_loss(self, train_tasks, main_args, task2global_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    for mb_j in range(main_args.meta_batch_size):
                        z_neg = task2global_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def local_dml_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding, epsilon=1e-3):
        pos_z_loss = 0
        neg_z_loss = 0
        pos_cnt = 0
        neg_cnt = 0

        for i in range(len(train_tasks)):
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i]
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos]
                    pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                    pos_cnt += 1
                for j in range(i+1, len(train_tasks)):
                    bs_j = task2local_encoding[train_tasks[j]].shape[0]
                    for mb_j in range(bs_j):
                        z_neg = task2local_encoding[train_tasks[j]][mb_j]
                        neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                        neg_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos]
                        pos_z_loss += th.sqrt(th.mean((z_q - z_pos)**2)+epsilon)
                        pos_cnt += 1
                # 这里和global之间先只做positive而不做negative，可以节省非常多训练时间
                # for k in range(len(train_tasks)):
                #     if k == i:
                #         continue
                #     for gl_k in range(main_args.meta_batch_size):
                #         z_neg = task2global_encoding[train_tasks[k]][gl_k]
                #         neg_z_loss += 1/(th.mean((z_q - z_neg)**2)+epsilon*100)
                #         neg_cnt += 1
        
        return pos_z_loss/pos_cnt + neg_z_loss/neg_cnt
    
    def global_cl_loss(self, train_tasks, main_args, task2global_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_ls.append(task2global_encoding[train_tasks[j]])
            z_neg = th.stack(neg_z_ls, dim=0).unsqueeze(0)
            # z_neg = th.cat(neg_z_ls, dim=0)
            for mb_i in range(main_args.meta_batch_size):
                z_q = task2global_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, main_args.meta_batch_size):
                    z_pos = task2global_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    bs, A, B, dim = z_neg.shape
                    indices = th.randint(0, B, (bs, A)).to(z_neg.device)
                    indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    z_neg_sampled = th.gather(z_neg, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
        return cl_loss/total_cnt
    
    def local_cl_loss(self, train_tasks, main_args, task2global_encoding, task2local_encoding):
        total_cnt = 0
        cl_loss = 0
        for i in range(len(train_tasks)):
            neg_z_local_ls = []
            neg_z_global_ls = []
            for j in range(len(train_tasks)):
                if j==i:
                    continue
                neg_z_local_ls.append(task2local_encoding[train_tasks[j]])
                neg_z_global_ls.append(task2global_encoding[train_tasks[j]])
            bs_i = task2local_encoding[train_tasks[i]].shape[0]
            z_neg_local = th.stack(neg_z_local_ls, dim=0).unsqueeze(0)
            z_neg_global = th.stack(neg_z_global_ls, dim=0).unsqueeze(0)
            for mb_i in range(bs_i):
                z_q = task2local_encoding[train_tasks[i]][mb_i].unsqueeze(0)
                for mb_pos in range(mb_i+1, bs_i):
                    z_pos = task2local_encoding[train_tasks[i]][mb_pos].unsqueeze(0)

                    bs, A, B, dim = z_neg_local.shape
                    indices = th.randint(0, B, (bs, A)).to(z_neg_local.device)
                    indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                    z_neg_sampled = th.gather(z_neg_local, dim=2, index=indices_expanded).reshape(bs, A, dim)

                    cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                    total_cnt += 1
                if not self.wo_global:
                    for gl_pos in range(main_args.meta_batch_size):
                        z_pos = task2global_encoding[train_tasks[i]][gl_pos].unsqueeze(0)

                        bs, A, B, dim = z_neg_global.shape
                        indices = th.randint(0, B, (bs, A)).to(z_neg_global.device)
                        indices_expanded = indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 1, dim)
                        z_neg_sampled = th.gather(z_neg_global, dim=2, index=indices_expanded).reshape(bs, A, dim)

                        cl_loss += self.contrastive_loss(z_q, z_pos, z_neg_sampled)
                        total_cnt += 1

        return cl_loss/total_cnt

    def cuda(self):
        self.mac.cuda()
        self.temporal_encoder.cuda()
        self.local_encoder.cuda()
        self.global_encoder.cuda()
        self.temporal_role_encoder.cuda()
        self.local_role_encoder.cuda()
        self.prior_role_encoder.cuda()
    
    def save_models(self, path):
        # self.mac.save_models(path)
        th.save(self.prior_role_encoder.state_dict(), "{}/prior_role_encoder.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))
    
    def load_task_encoder(self, path):
        # self.mac.load_models(path)
        self.temporal_encoder.load_state_dict(th.load("{}/temporal_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_encoder.load_state_dict(th.load("{}/local_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.global_encoder.load_state_dict(th.load("{}/global_encoder.th".format(path), map_location=lambda storage, loc: storage))
    
    def load_role_encoder(self, path):
        # self.mac.load_models(path)
        self.temporal_role_encoder.load_state_dict(th.load("{}/temporal_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.local_role_encoder.load_state_dict(th.load("{}/local_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
    
    def load_models(self, path):
        # self.mac.load_models(path)
        self.prior_role_encoder.load_state_dict(th.load("{}/prior_role_encoder.th".format(path), map_location=lambda storage, loc: storage))
        self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
