# Adding retrieval-related evaluation steps
import os
import os.path as osp
import random

import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.nn.modules.loss import _Loss
from torchvision.transforms import transforms
import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm
import copy

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

from .mmrl import MMRL
from .mmrl import CustomCLIP as CustomCLIP_

_tokenizer = _Tokenizer()

CUSTOM_TEMPLATES = {
    'OxfordPets': 'a photo of a {}, a type of pet.',
    'OxfordFlowers': 'a photo of a {}, a type of flower.',
    'FGVCAircraft': 'a photo of a {}, a type of aircraft.',
    'DescribableTextures': '{} texture.',
    'EuroSAT': 'a centered satellite photo of {}.',
    'StanfordCars': 'a photo of a {}.',
    'Food101': 'a photo of {}, a type of food.',
    'SUN397': 'a photo of a {}.',
    'Caltech101': 'a photo of a {}.',
    'UCF101': 'a photo of a person doing {}.',
    'ImageNet': 'a photo of a {}.',
    'ImageNetSketch': 'a photo of a {}.',
    'ImageNetV2': 'a photo of a {}.',
    'ImageNetA': 'a photo of a {}.',
    'ImageNetR': 'a photo of a {}.'
}

def load_clip_to_cpu(cfg, model_name="CLIP"):
    backbone_name = cfg.MODEL.BACKBONE.NAME
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")
    design_details = {"model": model_name,
                      "rep_tokens_layers": cfg.TRAINER.MMRL.REP_LAYERS,
                      "n_rep_tokens": cfg.TRAINER.MMRL.N_REP_TOKENS,
                      "beta0": cfg.TRAINER.MMRL.BETA0,
                      "beta1": cfg.TRAINER.MMRL.BETA1,
                      "beta2": cfg.TRAINER.MMRL.BETA2}
    model = clip.build_model_PGMPL(state_dict or model.state_dict(), design_details)

    return model

class TextEncoder_PGMPL(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts, compound_rep_tokens_text, compound_batch_tokens_text):
        n_rep_tokens = compound_rep_tokens_text[0].shape[0]
        n_batch_tokens = compound_batch_tokens_text[0].shape[0]
        x = prompts + self.positional_embedding.type(self.dtype)

        x = x.permute(1, 0, 2)  # NLD -> LND
        # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
        eot_index = tokenized_prompts.argmax(dim=-1)
        combined = [x, compound_rep_tokens_text, 0, None, compound_batch_tokens_text, None]  # third argument is the counter which denotes depth of representation tokens
        outputs = self.transformer(combined)

        x = outputs[0]  # extract the x back from here
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        x = x[torch.arange(x.shape[0]), eot_index + n_rep_tokens + n_batch_tokens] @ self.text_projection

        return x


class TextEncoder_CLIP(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        outputs = self.transformer(x)

        x = outputs
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
        return x


def _get_text_base_features_zero_shot(cfg, classnames, clip_model, text_encoder):
    device = next(text_encoder.parameters()).device

    text_encoder = text_encoder.cuda()
    dataset = cfg.DATASET.NAME
    template = CUSTOM_TEMPLATES[dataset]

    with torch.no_grad():
        tokenized_prompts = []
        for text in tqdm(classnames, desc="Extracting text features"):
            tokens = clip.tokenize(template.format(text.replace('_', ' ')))  # (n_tokens)
            tokens = tokens.to(device)
            tokenized_prompts.append(tokens)
        tokenized_prompts = torch.cat(tokenized_prompts)  # (n_classes, n_tokens)

        embeddings = clip_model.token_embedding(tokenized_prompts).type(
            clip_model.dtype)  # (n_classes, n_tokens, embed_dim)
        outputs = text_encoder(embeddings.cuda(), tokenized_prompts.cuda())

        text_embeddings = outputs

    text_encoder = text_encoder.to(device)
    return text_embeddings


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


class MultiModalRepresentationLearner(nn.Module):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.crossAttnText = CrossAttentionRep(clip_model.ln_final.weight.shape[0], clip_model.ln_final.weight.shape[0], clip_model.visual.width, num_heads=8)
        self.crossAttnImage = CrossAttentionRep(clip_model.visual.width, clip_model.visual.width, clip_model.ln_final.weight.shape[0], num_heads=8)

        n_rep_tokens = cfg.TRAINER.MMRL.N_REP_TOKENS
        self.dtype = clip_model.dtype

        text_dim = clip_model.ln_final.weight.shape[0]
        visual_dim = clip_model.visual.ln_post.weight.shape[0]

        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        rep_dim = cfg.TRAINER.MMRL.REP_DIM

        self.rep_layers_length = len(cfg.TRAINER.MMRL.REP_LAYERS)  # max=12
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        dataset = cfg.DATASET.NAME

        template = CUSTOM_TEMPLATES[dataset]

        tokenized_prompts = []
        for text in classnames:
            tokens = clip.tokenize(template.format(text.replace('_', ' ')))  # (n_tokens)
            tokenized_prompts.append(tokens)
        self.tokenized_prompts = torch.cat(tokenized_prompts)  # (n_classes, n_tokens)

        with torch.no_grad():
            self.prompt_embeddings = clip_model.token_embedding(self.tokenized_prompts).type(self.dtype)  # (n_classes, n_tokens, embed_dim)

        ### *** contribution1 learnable token
        self.compound_prompts_rep_text = nn.ParameterList([nn.Parameter(torch.empty(n_rep_tokens, text_dim))
                                                       for _ in range(self.rep_layers_length)])
        self.compound_prompts_rep_visual = nn.ParameterList([nn.Parameter(torch.empty(n_rep_tokens, visual_dim))
                                                         for _ in range(self.rep_layers_length)])

        self.compound_prompts_batch_text = nn.ParameterList([nn.Parameter(torch.empty(n_rep_tokens, text_dim))
                                                         for _ in range(self.rep_layers_length)])
        self.compound_prompts_batch_visual = nn.ParameterList([nn.Parameter(torch.empty(n_rep_tokens, visual_dim))
                                                         for _ in range(self.rep_layers_length)])
        for i in range(self.rep_layers_length):
            nn.init.normal_(self.compound_prompts_rep_text[i], std=0.02)
            nn.init.normal_(self.compound_prompts_rep_visual[i], std=0.02)
            nn.init.normal_(self.compound_prompts_batch_text[i], std=0.02)
            nn.init.normal_(self.compound_prompts_batch_visual[i], std=0.02)

    def forward(self):
        compound_rep_tokens_visual = []
        compound_rep_tokens_text = []
        compound_batch_tokens_text = []
        compound_batch_tokens_visual = []

        for index in range(self.rep_layers_length):   
            rep_mapped_to_text = self.compound_prompts_rep_text[index]
            rep_mapped_to_visual = self.compound_prompts_rep_visual[index]

            rep_mapped_to_visual_ = self.crossAttnImage(rep_mapped_to_visual, rep_mapped_to_text)  
            rep_mapped_to_text_ = self.crossAttnText(rep_mapped_to_text, rep_mapped_to_visual) 
            rep_mapped_to_visual_ = (rep_mapped_to_visual + rep_mapped_to_visual_) / 2
            rep_mapped_to_text_ = (rep_mapped_to_text + rep_mapped_to_text_) / 2    

            compound_rep_tokens_text.append(rep_mapped_to_text_.type(self.dtype))
            compound_rep_tokens_visual.append(rep_mapped_to_visual_.type(self.dtype))

            batch_mapped_to_text = self.compound_prompts_batch_text[index]
            batch_mapped_to_visual = self.compound_prompts_batch_visual[index]

            batch_mapped_to_visual_ = self.crossAttnImage(batch_mapped_to_visual, batch_mapped_to_text)  
            batch_mapped_to_text_ = self.crossAttnText(batch_mapped_to_text, batch_mapped_to_visual)  
            batch_mapped_to_visual_ = (batch_mapped_to_visual + batch_mapped_to_visual_) / 2
            batch_mapped_to_text_ = (batch_mapped_to_text + batch_mapped_to_text_) / 2  

            compound_batch_tokens_text.append(batch_mapped_to_text_.type(self.dtype))
            compound_batch_tokens_visual.append(batch_mapped_to_visual_.type(self.dtype))

        return compound_rep_tokens_text, compound_rep_tokens_visual, compound_batch_tokens_text, compound_batch_tokens_visual

class ClassConditionedPrototypeMemory(nn.Module):
    def __init__(self, num_classes, num_prototypes_per_class, feature_dim, momentum=0.9):
        super().__init__()
        self.num_classes = num_classes
        self.num_prototypes_per_class = num_prototypes_per_class
        self.feature_dim = feature_dim
        self.momentum = momentum

        self.prototypes = nn.Parameter(torch.randn(num_classes, num_prototypes_per_class, feature_dim))

    def get_prototypes(self, cross_it_features=None, labels=None):
        return self.prototypes[labels]  # [B*K, D]

    def update_prototypes(self, labels, features):
        with torch.no_grad():
            for i, label in enumerate(labels):
                self.prototypes[label] = (
                    self.momentum * self.prototypes[label] +
                    (1 - self.momentum) * features[i][label]
                )

class CrossAttentionRep(nn.Module):
    def __init__(self, dim_out, dim_text, dim_visual, num_heads=8):
        super().__init__()
        self.dim_out = dim_out
        self.num_heads = num_heads  # 8
        self.head_dim = self.dim_out // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Conv1d(dim_text, dim_out, kernel_size=1)
        self.k_proj = nn.Conv1d(dim_visual, dim_out, kernel_size=1)
        self.v_proj = nn.Conv1d(dim_visual, dim_out, kernel_size=1)
        self.out_proj = nn.Conv1d(dim_out, dim_out, kernel_size=1)

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize the weights."""
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std= 0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

    def forward(self, text, image):
        text = text.unsqueeze(0)   
        image = image.unsqueeze(0)

        B1, N1, C1 = text.shape  
        B2, N2, C2 = image.shape 

        q = self.q_proj(text.permute(0, 2, 1)).permute(0, 2, 1).reshape(B1, N1, self.num_heads, self.head_dim)
        k = self.k_proj(image.permute(0, 2, 1)).permute(0, 2, 1).reshape(B2, N2, self.num_heads, self.head_dim)
        v = self.v_proj(image.permute(0, 2, 1)).permute(0, 2, 1).reshape(B2, N2, self.num_heads, self.head_dim)

        attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)

        out = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B1, N1, self.dim_out)
        out = self.out_proj(out.permute(0, 2, 1)).permute(0, 2, 1)
        out = out / out.norm(dim=-1, keepdim=True)

        out = out.squeeze(0) 
        return out

class CrossAttention(nn.Module):
    def __init__(self, dim_text, dim_visual, num_heads=8):
        super().__init__()
        dim_out = dim_text
        self.dim_out = dim_out
        self.num_heads = num_heads  # 8
        self.head_dim = self.dim_out // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Conv1d(dim_text, dim_out, kernel_size=1)
        self.k_proj = nn.Conv1d(dim_visual, dim_out, kernel_size=1)
        self.v_proj = nn.Conv1d(dim_visual, dim_out, kernel_size=1)
        self.out_proj = nn.Conv1d(dim_out, dim_out, kernel_size=1)

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize the weights."""
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std= 0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

    def forward(self, text, image):
        if text.dim() == 2:
            text = text.unsqueeze(0)   
        batch_size = image.size(0)
        if text.size(0) == 1:
            text = text.expand(batch_size, -1, -1) 

        B1, N1, C1 = text.shape  
        B2, N2, C2 = image.shape 

        q = self.q_proj(text.permute(0, 2, 1)).permute(0, 2, 1).reshape(B1, N1, self.num_heads, self.head_dim)
        k = self.k_proj(image.permute(0, 2, 1)).permute(0, 2, 1).reshape(B2, N2, self.num_heads, self.head_dim)
        v = self.v_proj(image.permute(0, 2, 1)).permute(0, 2, 1).reshape(B2, N2, self.num_heads, self.head_dim)

        attn = torch.einsum('bnkc,bmkc->bknm', q, k) * self.scale
        attn = F.softmax(attn, dim=-1)

        out = torch.einsum('bknm,bmkc->bnkc', attn, v).reshape(B1, N1, self.dim_out)
        out = self.out_proj(out.permute(0, 2, 1)).permute(0, 2, 1)
        out = out / out.norm(dim=-1, keepdim=True)

        return out

class CustomCLIP(CustomCLIP_):
    def __init__(self, cfg, classnames, clip_model):
        super().__init__(cfg, classnames, clip_model)
        self.alpha = cfg.TRAINER.MMRL.ALPHA
        self.alpha2 = cfg.TRAINER.MMRL.ALPHA2
        self.classnames = classnames
        self.representation_learner = MultiModalRepresentationLearner(cfg, classnames, clip_model).type(clip_model.dtype)
        self.tokenized_prompts = self.representation_learner.tokenized_prompts
        self.register_buffer("prompt_embeddings", self.representation_learner.prompt_embeddings)
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder_PGMPL(clip_model)
        self.dtype = clip_model.dtype
        self.text_features_for_inference = None
        self.compound_rep_tokens_text_for_inference = None
        self.compound_rep_tokens_visual_for_inference = None
        self.text_dim = clip_model.ln_final.weight.shape[0]

        self.cross_attention = CrossAttention(clip_model.ln_final.weight.shape[0], clip_model.visual.width, num_heads=8)
        self.prototype_memory = ClassConditionedPrototypeMemory(
            num_classes=cfg.TRAINER.MMRL.NUM_CLASSES_TRAIN,
            num_prototypes_per_class=1,
            feature_dim=clip_model.ln_final.weight.shape[0]
        )

    def forward(self, image, labels=None):

        if self.representation_learner.training:
            compound_rep_tokens_text, compound_rep_tokens_visual, compound_batch_tokens_text, compound_batch_tokens_visual = self.representation_learner()
            text_features = self.text_encoder(self.prompt_embeddings, self.tokenized_prompts, compound_rep_tokens_text, compound_batch_tokens_text)
        else:
            if self.text_features_for_inference is None:
                self.compound_rep_tokens_text_for_inference, self.compound_rep_tokens_visual_for_inference, self.compound_batch_tokens_text, self.compound_batch_tokens_visual = self.representation_learner()
                self.text_features_for_inference = self.text_encoder(self.prompt_embeddings, self.tokenized_prompts,
                                                                     self.compound_rep_tokens_text_for_inference, self.compound_batch_tokens_text)

            compound_rep_tokens_text, compound_rep_tokens_visual, compound_batch_tokens_text, compound_batch_tokens_visual = self.compound_rep_tokens_text_for_inference, self.compound_rep_tokens_visual_for_inference, self.compound_batch_tokens_text, self.compound_batch_tokens_visual
            text_features = self.text_features_for_inference

        image_features, image_features_rep, image_features_batch, patch_tokens, patch_tokens_rep, patch_tokens_batch = self.image_encoder([image.type(self.dtype), compound_rep_tokens_visual, compound_batch_tokens_visual])

        cross_it_features = self.cross_attention(text_features, patch_tokens)  
        cross_it_features = cross_it_features / cross_it_features.norm(dim=-1, keepdim=True)
        cross_rt_features = self.cross_attention(text_features, patch_tokens_rep)
        cross_rt_features = cross_rt_features / cross_rt_features.norm(dim=-1, keepdim=True)
        cross_bt_features = self.cross_attention(text_features, patch_tokens_batch)
        cross_bt_features = cross_bt_features / cross_bt_features.norm(dim=-1, keepdim=True)

        if self.training:
            prototypes = self.prototype_memory.get_prototypes(labels=labels)

            features_it = F.normalize(cross_it_features, dim=-1)    
            prototypes = F.normalize(prototypes, dim=-1)         
            similarity_it = torch.einsum('bnd,bkd->bnk', features_it, prototypes) 
            weights_it = F.softmax(similarity_it / 0.07, dim=-1)    
            enhanced_it_features = torch.einsum('bnk,bkd->bnd', weights_it, prototypes) 
            logits_it = 100. * torch.einsum('bd,bnd->bn', image_features, enhanced_it_features)    

            # rt
            features_rt = F.normalize(cross_rt_features, dim=-1)
            similarity_rt = torch.einsum('bnd,bkd->bnk', features_rt, prototypes)
            weights_rt = F.softmax(similarity_rt / 0.07, dim=-1)
            enhanced_rt_features = torch.einsum('bnk,bkd->bnd', weights_rt, prototypes)
            logits_rt = 100. * torch.einsum('bd,bnd->bn', image_features_rep, enhanced_rt_features)

            # bt
            features_bt = F.normalize(cross_bt_features, dim=-1)
            similarity_bt = torch.einsum('bnd,bkd->bnk', features_bt, prototypes)
            weights_bt = F.softmax(similarity_bt / 0.07, dim=-1)
            enhanced_bt_features = torch.einsum('bnk,bkd->bnd', weights_bt, prototypes)
            logits_bt = 100. * torch.einsum('bd,bnd->bn', image_features_batch, enhanced_bt_features)


        alpha = self.alpha
        image_features_ = image_features.clone()
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)     
        image_features_rep = image_features_rep / image_features_rep.norm(dim=-1, keepdim=True)
        image_features_batch = image_features_batch / image_features_batch.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)   

        logits = 100. * image_features @ text_features.t() 
        logits_rep = 100. * image_features_rep @ text_features.t()
        logits_batch = 100. * image_features_batch @ text_features.t()

        if self.training:
            alpha2 = self.alpha2
        else:
            alpha2 = 1  # infer

        if self.training:
            
            logits = alpha2 * logits + (1 - alpha2) * logits_it   
            logits_rep = alpha2 * logits_rep + (1 - alpha2) * logits_rt
            logits_batch = alpha2 * logits_batch + (1 - alpha2) * logits_bt

        if self.training:
            cross_features = torch.cat([cross_it_features, cross_rt_features], dim=0)
            labels = torch.cat([labels, labels], dim=0)
            self.prototype_memory.update_prototypes(labels, cross_features)

        logits_fusion = alpha * logits + (1 - alpha) * logits_rep + (1 - alpha) * logits_batch 
        return logits, logits_rep, logits_batch, logits_fusion, image_features, text_features, self.prototype_memory.prototypes, image_features_, cross_it_features, cross_rt_features

class PGMPL_loss(_Loss):
    def __init__(self, reg_weight=1.0, alpha=0.7):
        super(PGMPL_loss, self).__init__()
        self.reg_weight = reg_weight
        self.alpha = alpha

    def forward(self, logits, logits_rep, logits_batch,
                image_features, text_features,
                image_features_clip, text_features_clip,
                label):
        xe_loss1 = F.cross_entropy(logits, label)
        xe_loss2 = F.cross_entropy(logits_rep, label)
        xe_loss3 = F.cross_entropy(logits_batch, label)

        cossim_reg_img = 1 - torch.mean(F.cosine_similarity(image_features, image_features_clip, dim=1))
        cossim_reg_text = 1 - torch.mean(F.cosine_similarity(text_features, text_features_clip, dim=1))

        return self.alpha * xe_loss1 + (1 - self.alpha) * xe_loss2 + (1 - self.alpha) * xe_loss3 + self.reg_weight * cossim_reg_img + self.reg_weight * cossim_reg_text

from dassl.data.data_manager import DataManager as DataManager_
from dassl.data.data_manager import build_data_loader, build_dataset, build_transform

class DataManager(DataManager_):
    def __init__(self, cfg, custom_tfm_train=None, custom_tfm_test=None, dataset_wrapper=None):
        super().__init__(cfg, custom_tfm_train, custom_tfm_test, dataset_wrapper)
        dataset = build_dataset(cfg)

        if custom_tfm_test is None:
            tfm_test = build_transform(cfg, is_train=False)
        else:
            print("* Using custom transform for testing")
            tfm_test = custom_tfm_test

        test_query_loader = build_data_loader(
            cfg,
            sampler_type=cfg.DATALOADER.TEST.SAMPLER,
            data_source=dataset.test_query,
            batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
            tfm=tfm_test,
            is_train=False,
            dataset_wrapper=dataset_wrapper
        )
        test_given_loader = build_data_loader(
            cfg,
            sampler_type=cfg.DATALOADER.TEST.SAMPLER,
            data_source=dataset.test_given,
            batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
            tfm=tfm_test,
            is_train=False,
            dataset_wrapper=dataset_wrapper
        )
        test_db_loader = build_data_loader(
            cfg,
            sampler_type=cfg.DATALOADER.TEST.SAMPLER,
            data_source=dataset.test_db,
            batch_size=cfg.DATALOADER.TEST.BATCH_SIZE,
            tfm=tfm_test,
            is_train=False,
            dataset_wrapper=dataset_wrapper
        )
        self.test_query_loader = test_query_loader
        self.test_given_loader = test_given_loader
        self.test_db_loader = test_db_loader

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
def visualize_features_tsne(features, labels, classnames,
                            title='t-SNE Visualization', save_path='tsne_vis.png'):
    features = features.cpu().numpy()
    labels = labels.cpu().numpy()
    tsne = TSNE(n_components=2, random_state=42)
    features_2d = tsne.fit_transform(features)

    plt.figure(figsize=(12, 8))

    unique_labels = np.unique(labels)
    colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
    for label, color in zip(unique_labels, colors):
        mask = labels == label
        plt.scatter(features_2d[mask, 0], features_2d[mask, 1], c=[color], label=classnames[label], alpha=0.6)

    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

@TRAINER_REGISTRY.register()
class PGMPL(MMRL):
    ###
    def build_data_loader(self):
        SEED = self.cfg.SEED
        torch.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        np.random.seed(SEED)
        random.seed(SEED)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        """Create essential data-related attributes.

        A re-implementation of this method must create the
        same attributes (self.dm is optional).
        """
        dm = DataManager(self.cfg)

        self.train_loader_x = dm.train_loader_x
        self.train_loader_u = dm.train_loader_u  # optional, can be None
        self.val_loader = dm.val_loader  # optional, can be None
        self.test_loader = dm.test_loader
        self.test_query_loader = dm.test_query_loader
        self.test_given_loader = dm.test_given_loader
        self.test_db_loader = dm.test_db_loader

        self.num_classes = dm.num_classes
        self.num_source_domains = dm.num_source_domains
        self.lab2cname = dm.lab2cname  # dict {label: classname}

        self.dm = dm

    def check_cfg(self, cfg):
        assert cfg.TRAINER.MMRL.PREC in ["fp16", "fp32", "amp"]

    def build_model(self):
        SEED = self.cfg.SEED
        torch.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
        np.random.seed(SEED)
        random.seed(SEED)

        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        cfg = self.cfg
        classnames = self.dm.dataset.classnames
        self.num_classes = len(classnames)

        print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
        clip_model = load_clip_to_cpu(cfg, "PGMPL")
        clip_model_zero_shot = load_clip_to_cpu(cfg)

        if cfg.TRAINER.MMRL.PREC == "fp32" or cfg.TRAINER.MMRL.PREC == "amp":
            # CLIP's default precision is fp16
            clip_model.float()
            clip_model_zero_shot.float()

        self.dtype = clip_model.dtype

        with torch.no_grad():
            self.text_encoder_clip = TextEncoder_CLIP(clip_model_zero_shot)
            text_features_clip = _get_text_base_features_zero_shot(cfg, classnames, clip_model_zero_shot, self.text_encoder_clip)
            self.text_features_clip = text_features_clip / text_features_clip.norm(dim=-1, keepdim=True)
        self.image_encoder_clip = clip_model_zero_shot.visual

        print("Building custom CLIP")
        self.model = CustomCLIP(cfg, classnames, clip_model)

        print("Turning off gradients in both the image and the text encoder")
        names_to_update = ["representation_learner", "image_encoder.proj_rep", "cross_attention", "prototype_memory"]

        for name, param in self.model.named_parameters():
            update = False

            for name_to_update in names_to_update:
                if name_to_update in name:
                    update = True
                    break
            param.requires_grad_(update)

        # Double check
        enabled = set()
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                enabled.add(name)
        print(f"Parameters to be updated: {enabled}")

        if cfg.MODEL.INIT_WEIGHTS:
            load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)

        self.model.to(self.device)

        self.image_encoder_clip.to(self.device)

        reg_weight = cfg.TRAINER.MMRL.REG_WEIGHT
        alpha = cfg.TRAINER.MMRL.ALPHA
        alpha2 = cfg.TRAINER.MMRL.ALPHA2
        self.criterion = PGMPL_loss(reg_weight=reg_weight, alpha=alpha)

        # NOTE: only give representation_learner to the optimizer
        self.optim = build_optimizer(self.model, cfg.OPTIM)
        self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
        self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched)

        self.scaler = GradScaler() if cfg.TRAINER.MMRL.PREC == "amp" else None

        # Note that multi-gpu training could be slow because CLIP's size is
        # big, which slows down the copy operation in DataParallel
        device_count = torch.cuda.device_count()
        if device_count > 1:
            print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
            self.model = nn.DataParallel(self.model)
            self.image_encoder_clip = nn.DataParallel(self.image_encoder_clip)

    def forward_backward(self, batch):
        image, label = self.parse_batch_train(batch)

        model = self.model
        optim = self.optim
        scaler = self.scaler
        prec = self.cfg.TRAINER.MMRL.PREC
        if prec == "amp":
            with autocast():
                with torch.no_grad():
                    image_features_clip = self.image_encoder_clip(image.type(self.dtype))
                    image_features_clip = image_features_clip / image_features_clip.norm(dim=-1, keepdim=True)

                logits, logits_rep, logits_batch, logits_fusion, image_features, text_features, _, _, _, _ = model(image, label)
                text_features = text_features[0:self.num_classes]  # Crop the returned text_features for multi-GPU compatibility

                loss = self.criterion(logits, logits_rep, logits_batch,
                                      image_features, text_features,
                                      image_features_clip, self.text_features_clip,
                                      label)

            optim.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optim)
            scaler.update()
        else:
            with torch.no_grad():
                image_features_clip = self.image_encoder_clip(image.type(self.dtype))
                image_features_clip = image_features_clip / image_features_clip.norm(dim=-1, keepdim=True)

            logits, logits_rep, logits_batch, logits_fusion, image_features, text_features, _, _, _, _ = model(image, label)
            text_features = text_features[0:self.num_classes]  # Crop the returned text_features for multi-GPU compatibility

            loss = self.criterion(logits, logits_rep, logits_batch,
                                  image_features, text_features,
                                  image_features_clip, self.text_features_clip,
                                  label)

            optim.zero_grad()
            loss.backward()
            optim.step()

        output = logits_fusion
        loss_summary = {"loss": loss.item(),
                        'acc': compute_accuracy(output, label)[0].item()}

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary

    def parse_batch_train(self, batch):
        input = batch["img"]
        label = batch["label"]
        input = input.to(self.device)
        label = label.to(self.device)
        return input, label

    
    @torch.no_grad()
    def test(self, split=None):
        self.set_model_mode("eval")
        self.evaluator.reset()

        sub_cls = self.cfg.DATASET.SUBSAMPLE_CLASSES
        dataset = self.cfg.DATASET.NAME
        task = self.cfg.TASK
        n_cls = self.num_classes
        dim = self.model.text_dim
        N_img = 5 # 5 imgs for cluster feature
        topk = self.cfg.TRAINER.MMRL.TOP_K

        if split is None:
            split = self.cfg.TEST.SPLIT
        data_loader = self.val_loader if split == "val" and self.val_loader is not None else self.test_loader
        print(f"Evaluate on the *{split}* set")

        print("\nPrecomputing all features...")

        all_db_features = []
        all_db_labels = []
        db_text_features = None
        db_protos = None

        given_loader = self.test_given_loader
        query_loader = self.test_db_loader 

        print("Building database from 'given' images (one average feature per class)...")
        avg_given_image_feature = torch.zeros(n_cls, dim).to(self.device)
        num_ = torch.zeros(n_cls, 1).to(self.device)
        for batch_idx, batch in enumerate(tqdm(given_loader)):  #
            input, label = self.parse_batch_test(batch)
            _, _, _, _, given_image_feature, _, protos, image_features_, cross_it, cross_rt = self.model(input)
            for i, l in enumerate(label):
                if num_[l] == N_img:
                    break
                avg_given_image_feature[l] += given_image_feature[i]
                num_[l] += 1
        final_avg_given_image_feature = avg_given_image_feature / num_

        print("Performing queries from 'db_loader' against the new database...")

        total_correct = 0
        total_samples = 0
        all_true_labels = []
        all_pred_labels = []

        for batch in tqdm(query_loader, desc="Querying"):
            query_images, query_labels = self.parse_batch_test(batch)
            _, _, _, _, query_features, _, _, _, _, _ = self.model(query_images)
            query_features = F.normalize(query_features, p=2, dim=-1)  # (batch_size, dim)
            similarity = query_features @ final_avg_given_image_feature.t()
            if torch.isnan(similarity).any():
                print("Warning: NaN detected in similarity matrix. Replacing with -inf.")
              
                nan_replace_value = torch.finfo(similarity.dtype).min
                similarity = torch.nan_to_num(similarity, nan=nan_replace_value)
            _, predicted_labels = similarity.max(1)

            total_correct += (predicted_labels == query_labels).sum().item()
            total_samples += query_labels.size(0)
            all_true_labels.extend(query_labels.cpu().numpy())
            all_pred_labels.extend(predicted_labels.cpu().numpy())

        print("\n--- Evaluation Results ---")
        print("=> result")
        accuracy = (total_correct / total_samples) * 100
        print(f"* Overall Classification Accuracy: {accuracy:.2f}%")
        self.write_scalar(f"{split}/classification_accuracy", accuracy, self.epoch)
        return accuracy 

    def load_model(self, directory, epoch=None):
        if not directory:
            print('Note that load_model() is skipped as no pretrained model is given')
            return

        names = self.get_model_names()

        for name in names:
            model_path_prefix = osp.join(directory, name)
            if not osp.exists(model_path_prefix):
                raise FileNotFoundError(
                    'Model not found at "{}"'.format(model_path_prefix)
                )
            for file in os.listdir(model_path_prefix):
                if "model-best.pth" in file:
                    model_path = osp.join(model_path_prefix, file)
                    break
                if "model.pth" in file:
                    model_path = osp.join(model_path_prefix, file)

            if not osp.exists(model_path):
                raise FileNotFoundError(
                    'Model not found at "{}"'.format(model_path)
                )

            checkpoint = load_checkpoint(model_path)
            state_dict = checkpoint["state_dict"]
            epoch = checkpoint["epoch"]
            state_dict = {k: v for k, v in state_dict.items() if "prompt_embeddings" not in k}

            print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
            # set strict=False
            self._models[name].load_state_dict(state_dict, strict=False)