import sys
sys.path.append("/data/robust_crossmodal-retrieval")
print(sys.path)

from models.ALBEF.model_retrieval import ALBEF
# from models.ALBEF.model_retrieval_pt import ALBEF as ALBEF_PT
from models.TCL.model_retrieval import ALBEF as TCL
# from models.TCL.model_retrieval_pt import ALBEF as TCL_PT
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
from models.clip_model import clip
import torch
import torch.nn as nn
from transformers import BertForMaskedLM

from torch.nn import functional as F

from models.pcmepp_clip import load_clip_with_pcmepp
# from models.pcmepp.encoder import get_pcmepp_model
from pcmepp.criterions.pcmepp import ClosedFormSampledDistanceLoss

def load_model(config, model_name, model_ckpt, text_encoder, device, train_config=None):
    tokenizer = BertTokenizer.from_pretrained(text_encoder)
    ref_model = BertForMaskedLM.from_pretrained(text_encoder)
    if model_name == "ALBEF":
        model = ALBEF(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        checkpoint = torch.load(model_ckpt, map_location="cpu")
    elif model_name == "TCL":
        model = TCL(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        checkpoint = torch.load(model_ckpt, map_location="cpu")
    # elif model_name == "ALBEF_PT":
    #     model = ALBEF_PT(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
    #     checkpoint = torch.load(model_ckpt, map_location="cpu")
    # elif model_name == "TCL_PT":
    #     model = TCL_PT(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
    #     checkpoint = torch.load(model_ckpt, map_location="cpu")
    ### load checkpoint
    elif model_name == "CLIP_ViT-B-16_PT":
        model = load_clip_ivlp("ViT-B/16", train_config, device)
        model.set_tokenizer(tokenizer)

        # Important! default is half precision. This affects attack performance.
        model.float()

        if model_ckpt:
            checkpoint = torch.load(model_ckpt, map_location="cpu")
            try:
                checkpoint = checkpoint["model"]
            except:
                pass
            for key in list(checkpoint.keys()):
                if "image_encoder" in key:
                    checkpoint[key.replace("image_encoder", "visual")] = checkpoint[key]
                    del checkpoint[key]
            model.load_state_dict(checkpoint)
            print("loaded: ", model_ckpt)
        return model, ref_model, tokenizer

    elif model_name == "CLIP_ViT-B-16_PT_PDE":
        model = load_clip_ivlp_with_pde("ViT-B/16", train_config, device)  
        model.set_tokenizer(tokenizer)

        # Important! default is half precision. This affects attack performance.
        model.float()

        if model_ckpt:
            checkpoint = torch.load(model_ckpt, map_location="cpu")
            try:
                checkpoint = checkpoint["model"]
            except:
                pass
            model.load_state_dict(checkpoint)
            print("loaded: ", model_ckpt)
        return model, ref_model, tokenizer

    elif model_name == "CLIP_ViT-B-16":
        model, preprocess = clip.load("ViT-B/16", device=device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    elif model_name == "CLIP_ViT-B-16_PDE":
        model = load_clip_with_pde("ViT-B/16", train_config, device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    elif model_name == "CLIP_ViT-B-16_PDE_CSD":
        model = load_clip_with_pde_csd("ViT-B/16", train_config, device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    elif model_name == "CLIP_ViT-B-16_PCMEPP":
        model = load_clip_with_pcmepp("ViT-B/16", train_config["pcmepp"], device)
        # model = get_pcmepp_model(train_config["pcmepp"], device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    else:
        print("CLIP model")
        model, preprocess = clip.load(model_name, device=device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    
    try:
        state_dict = checkpoint["model"]
    except:
        state_dict = checkpoint

    if "ALBEF" in model_name:
        try:
            print("loading pretrained model")
            # reshape positional embedding to accomodate for image resolution change
            pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
            state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
            m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
            state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 
            
            for key in list(state_dict.keys()):
                if 'bert' in key:
                    encoder_key = key.replace('bert.','')         
                    state_dict[encoder_key] = state_dict[key] 
                    del state_dict[key]                
            msg = model.load_state_dict(state_dict,strict=False)  
            
            print('load checkpoint from %s'%model_ckpt)
            print(msg)  
            return model, ref_model, tokenizer
        except:
            try:
                print("loading fine-tuned model")

                for key in list(state_dict.keys()):
                    if 'bert' in key:
                        encoder_key = key.replace('bert.','')         
                        state_dict[encoder_key] = state_dict[key] 
                        del state_dict[key]        
                msg = model.load_state_dict(state_dict,strict=False)  
                
                print('load checkpoint from %s'%model_ckpt)
                print(msg)  
                return model, ref_model, tokenizer
            except:
                model.wrap_vision_encoder_with_prompter(train_config)
                print("loading fine-tuned model")

                for key in list(state_dict.keys()):
                    if 'bert' in key:
                        encoder_key = key.replace('bert.','')         
                        state_dict[encoder_key] = state_dict[key] 
                        del state_dict[key]                
                msg = model.load_state_dict(state_dict,strict=False)  
                
                print('load checkpoint from %s'%model_ckpt)
                print(msg)  
                return model, ref_model, tokenizer

    if "TCL" in model_name:
        print("loading pretrained model")
        pos_embed_reshaped = interpolate_pos_embed(
            state_dict["visual_encoder.pos_embed"], model.visual_encoder
        )
        state_dict["visual_encoder.pos_embed"] = pos_embed_reshaped
        m_pos_embed_reshaped = interpolate_pos_embed(
            state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m
        )
        state_dict["visual_encoder_m.pos_embed"] = m_pos_embed_reshaped

        for key in list(state_dict.keys()):
            if "bert" in key:
                encoder_key = key.replace("bert.", "")
                state_dict[encoder_key] = state_dict[key]
                del state_dict[key]
        msg = model.load_state_dict(state_dict,strict=False)  
            
        print('load checkpoint from %s'%model_ckpt)
        print(msg)  

    return model, ref_model, tokenizer


# Clip with prompt tuning
# https://github.com/muzairkhattak/multimodal-prompt-learning/blob/69bce21ae8eda80ad6187534b2dce09cf6c59e17/trainers/independentVL.py#L21
def load_clip_ivlp(backbone_name, train_config, device):
    """
    Independent Vision-Language Prompting (IVLP) model.
    """
    assert backbone_name in ["ViT-B/16"]

    sys.path.append("multimodal-prompt-learning")
    from clip import clip
    from clip.model import build_model

    # load clip tp cpu
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    # clip build model.
    # this version of clip will have learnable prompts appended to each attention block
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/clip/model.py#L479
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/69bce21ae8eda80ad6187534b2dce09cf6c59e17/clip/model.py#L654
    # this model returns (logits_per_image, logits_per_text)
    design_details = {
        "trainer": 'IVLP',
        "vision_depth": train_config["vision_depth"], # cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION,
        "language_depth": train_config["language_depth"], # cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT,
        "vision_ctx": train_config["vision_ctx"], # cfg.TRAINER.IVLP.N_CTX_VISION,
        "language_ctx": train_config["language_ctx"], # cfg.TRAINER.IVLP.N_CTX_TEXT,
    }
    model = clip.build_model(state_dict or model.state_dict(), design_details)
    print("Model loaded")
    print("vision_depth: ", train_config["vision_depth"])
    print("language_depth: ", train_config["language_depth"])
    print("vision_ctx: ", train_config["vision_ctx"])
    print("language_ctx: ", train_config["language_ctx"])

    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/clip/clip.py#L71

    # Here, we redign "CustomCLIP", 
    # which is originally designed for classification task in: 
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/trainers/independentVL.py
    class CustomCLIP(nn.Module):
        def __init__(self, cfg, clip_model):
            super().__init__()
            # n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT
            # ctx_init = cfg.TRAINER.IVLP.CTX_INIT
            n_ctx = cfg["language_ctx"]
            # dtype = clip_model.dtype
            self.dtype = torch.float32 # modified from default dtype
            ctx_dim = clip_model.ln_final.weight.shape[0]
            vis_dim = clip_model.visual.output_dim
            clip_imsize = clip_model.visual.input_resolution
            cfg_imsize = cfg.INPUT.SIZE[0]

            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            self.text_VPT_input = nn.Parameter(ctx_vectors)
            
            self.visual = clip_model.visual

            self.transformer = clip_model.transformer
            self.ln_final = clip_model.ln_final
            self.text_projection = clip_model.text_projection
            self.token_embedding = clip_model.token_embedding
            self.positional_embedding = clip_model.positional_embedding

            self.logit_scale = clip_model.logit_scale

            self.disable_prompt = False
            if "disable_prompt" in cfg:
                self.disable_prompt = cfg["disable_prompt"]


        def encode_image(self, image, return_full=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            Returns: image_embed
            """
            # x = self.visual.inference(image.type(self.dtype))
            # if return_full:
            #     return x
            # x = x[:, 0, :]
            x = self.visual(image.type(self.dtype))
            if return_full:
                raise NotImplementedError
            return x


        def encode_text(self, prompt, tokenized_text, return_full=False):
            """
            Modified to add prompts.
            
            prompt: [n_ctx, dim], which is learnable prompt
            tokenized_text: tokenized text (input ids)
            """

            # [batch_size, 77] -> [batch_size, 77, d_model]
            x = self.token_embedding(tokenized_text).type(self.dtype)  

            if not self.disable_prompt:
                _ctx_dim = 77 + prompt.shape[0]
                _x = torch.zeros(x.shape[0], _ctx_dim, x.shape[2], dtype=x.dtype, device=x.device)
                prefix = x[:, :1, :]
                suffix = x[:, 1:, :]
                _x[:, :1, :] = prefix
                _x[:, 1:1+prompt.shape[0], :] = prompt
                _x[:, 1+prompt.shape[0]:, :] = suffix
                x = _x[:, :77, :]

            x = x + self.positional_embedding.type(self.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD
            x = self.ln_final(x).type(self.dtype)

            if return_full:
                return x
            # x.shape = [batch_size, n_ctx, transformer.width]
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            x = x[torch.arange(x.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection

            return x

        def forward(self, image, tokenized_text):
            """
            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: tokenized text (input ids)

            Append learnable text prompt to the input text,
            then pass through the model.
            """
            logit_scale = self.logit_scale.exp()

            text_prompt = self.text_VPT_input
            text_features = self.encode_text(text_prompt, tokenized_text)
            image_features = self.encode_image(image.type(self.dtype))

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()
            
            return logits_per_image, logits_per_text

        def inference_text(self, text_input, return_full=False):
            """
            Inference: output features

            text_input: output of tokenizer(text)
                {"input_ids": tokenized text, "attention_mask": attention mask}
            """
            text = []
            for input_ids in text_input.input_ids:
                t = (
                    self.tokenizer.decode(input_ids)
                    .replace("[PAD]", "")
                    .replace("[CLS]", "")
                    .replace("[SEP]", "")
                    .strip()
                )
                text.append(t)
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_prompt = self.text_VPT_input
            if return_full:
                return self.encode_text(text_prompt, text_input, return_full=True)
            text_embed = self.encode_text(text_prompt, text_input)
            text_features = text_embed / text_embed.norm(dim=-1, keepdim=True)
            return {"text_embed": text_embed, "text_feat":  text_features}

        def inference_image(self, image, return_full=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            """
            if return_full:
                return self.encode_image(image.type(self.dtype), return_full=True)
            image_embed = self.encode_image(image.type(self.dtype))
            image_features = image_embed / image_embed.norm(dim=-1, keepdim=True)
            return {"image_embed": image_embed, "image_feat": image_features}

        def inference(self, image, text):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: raw text
            """
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_features = self.encode_text(self.text_VPT_input, text_input)
            image_features = self.encode_image(image.type(self.dtype))
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            return {"image_feat": image_features, "text_feat": text_features}

        def set_tokenizer(self, tokenizer):
            self.tokenizer = tokenizer


    model = CustomCLIP(train_config, model).to(device)

    # Training setup, following: 
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/69bce21ae8eda80ad6187534b2dce09cf6c59e17/trainers/independentVL.py#L183
    # Only train prompts
    for name, param in model.named_parameters():
        # Make sure that VPT prompts are updated
        if "VPT" in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

    # Double check
    enabled = set()
    for name, param in model.named_parameters():
        if param.requires_grad:
            enabled.add(name)
    # print(f"Parameters to be updated: {sorted(enabled)}")

    return model


def load_clip_ivlp_with_pde(backbone_name, train_config, device):
    """
    Independent Vision-Language Prompting (IVLP) model.
    """
    assert backbone_name in ["ViT-B/16"]

    sys.path.append("multimodal-prompt-learning")
    from clip import clip
    from clip.model import build_model

    # load clip tp cpu
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url)
    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    # clip build model.
    # this version of clip will have learnable prompts appended to each attention block
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/clip/model.py#L479
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/69bce21ae8eda80ad6187534b2dce09cf6c59e17/clip/model.py#L654
    # this model returns (logits_per_image, logits_per_text)
    design_details = {
        "trainer": 'IVLP',
        "vision_depth": train_config["vision_depth"], # cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION,
        "language_depth": train_config["language_depth"], # cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT,
        "vision_ctx": train_config["vision_ctx"], # cfg.TRAINER.IVLP.N_CTX_VISION,
        "language_ctx": train_config["language_ctx"], # cfg.TRAINER.IVLP.N_CTX_TEXT,
    }
    model = clip.build_model(state_dict or model.state_dict(), design_details)
    print("Model loaded")
    print("vision_depth: ", train_config["vision_depth"])
    print("language_depth: ", train_config["language_depth"])
    print("vision_ctx: ", train_config["vision_ctx"])
    print("language_ctx: ", train_config["language_ctx"])

    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/clip/clip.py#L71

    # Here, we redign "CustomCLIP", 
    # which is originally designed for classification task in: 
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/main/trainers/independentVL.py

    # We add PDE (Probability Distribution Embedding) to the model
    # https://github.com/IIGROUP/MAP/blob/main/map/modules/map_module.py
    # DisTrans: outputs mean and logvar of the distribution
    from models.PDE import DisTrans

    def Wasserstein2(mu1, sigma1, mu2, sigma2):
        bs1 = mu1.shape[0]
        bs2 = mu2.shape[0]
        mu1 = torch.stack([mu1]*bs2, dim=1)
        sigma1 = torch.stack([sigma1]*bs2, dim=1)
        mu2 = torch.stack([mu2]*bs1, dim=0)
        sigma2 = torch.stack([sigma2]*bs1, dim=0)
        p1 = torch.sum(torch.pow(mu1 - mu2, 2), dim=-1)
        p2 = torch.sum(torch.pow(sigma1 - sigma2, 2), dim=-1)
        return p1+p2, p1

    def init_weights(module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()


    class CustomCLIP(nn.Module):
        def __init__(self, cfg, clip_model):
            super().__init__()
            # n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT
            # ctx_init = cfg.TRAINER.IVLP.CTX_INIT
            n_ctx = cfg["language_ctx"]
            # dtype = clip_model.dtype
            self.dtype = torch.float32 # modified from default dtype
            ctx_dim = clip_model.ln_final.weight.shape[0]
            vis_dim = clip_model.visual.output_dim
            clip_imsize = clip_model.visual.input_resolution
            cfg_imsize = cfg.INPUT.SIZE[0]

            # random initialization
            ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=self.dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            self.text_VPT_input = nn.Parameter(ctx_vectors)
            
            self.visual = clip_model.visual

            self.transformer = clip_model.transformer
            self.ln_final = clip_model.ln_final
            self.text_projection = clip_model.text_projection
            self.token_embedding = clip_model.token_embedding
            self.positional_embedding = clip_model.positional_embedding

            self.logit_scale = clip_model.logit_scale

            self.disable_prompt = False
            if "disable_prompt" in cfg:
                self.disable_prompt = cfg["disable_prompt"]

            # PDE
            self.img_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.txt_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.img_gau_encoder.apply(init_weights)
            self.txt_gau_encoder.apply(init_weights)

            pde_config = cfg["pde"]
            self.negative_scale = pde_config["negative_scale"]
            self.shift = pde_config["shift"]

        def encode_image(self, image):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            Returns: image_embed # [batch_size, n_ctx, 512]
            """
            x = image.type(self.dtype)

            ViT = self.visual
            x = ViT.conv1(x)  # shape = [*, width, grid, grid]
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            x = torch.cat(
                [ViT.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
                x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            x = x + ViT.positional_embedding.to(x.dtype)

            # After positional embeddings, we will attach prompts with the model, remember only those
            # are trainable parameters here in whole image encoder.
            if ViT.VPT_shallow:
                visual_ctx = ViT.VPT.expand(x.shape[0], -1, -1).half()
                x = torch.cat([x, visual_ctx], dim=1)
            else:
                assert ViT.prompt_till_layer_visual == 0

            # Normal code as before
            x = ViT.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = ViT.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = ViT.ln_post(x) # [batch_size, n_ctx, 512]

            if ViT.proj is not None:
                x = x @ ViT.proj

            assert len(x.shape) == 3 # [batch_size, token_num, dim]
            return x

        def encode_text(self, prompt, tokenized_text):
            """
            Modified to add prompts.
            
            prompt: [n_ctx, dim], which is learnable prompt
            tokenized_text: tokenized text (input ids)
            """
            # [batch_size, 77] -> [batch_size, 77, d_model]
            x = self.token_embedding(tokenized_text).type(self.dtype)  

            if not self.disable_prompt:
                _ctx_dim = 77 + prompt.shape[0]
                _x = torch.zeros(x.shape[0], _ctx_dim, x.shape[2], dtype=x.dtype, device=x.device)
                prefix = x[:, :1, :]
                suffix = x[:, 1:, :]
                _x[:, :1, :] = prefix
                _x[:, 1:1+prompt.shape[0], :] = prompt
                _x[:, 1+prompt.shape[0]:, :] = suffix
                x = _x[:, :77, :]

            x = x + self.positional_embedding.type(self.dtype)
            x = x.permute(1, 0, 2)  # NLD -> LND
            x = self.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD
            x = self.ln_final(x).type(self.dtype)

            # x.shape = [batch_size, n_ctx, transformer.width]
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            # x = x[torch.arange(x.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            assert len(x.shape) == 3 # [batch_size, token_num, dim]
            return x

        def forward(self, image, tokenized_text):
            """
            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: tokenized text (input ids)

            Append learnable text prompt to the input text,
            then pass through the model.
            """
            logit_scale = self.logit_scale.exp()

            text_features = self.encode_text(self.text_VPT_input, tokenized_text)
            image_features = self.encode_image(image.type(self.dtype))
            assert len(image_features.shape) == 3 # [batch_size, token_num, dim]
            assert len(text_features.shape) == 3 # [batch_size, token_num, dim]

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            img_sigma = torch.exp(img_logsigma)
            txt_sigma = torch.exp(txt_logsigma)

            # get the CLS token
            img_mu = img_mu[:, 0, :]
            img_sigma = img_sigma[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection

            # Note: features are not normalized before calculating distance
            W2_distance, mu_distance = Wasserstein2(img_mu, img_sigma, txt_mu, txt_sigma)

            logits_per_image = (-self.negative_scale * W2_distance + self.shift) * logit_scale
            logits_per_text = logits_per_image.t()

            # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            # text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            # logits_per_image = logit_scale * image_features @ text_features.t()
            # logits_per_text = logits_per_image.t()
            
            return logits_per_image, logits_per_text

        def inference_text(self, text_input, return_pde=False):
            """
            Inference: output features

            text_input: output of tokenizer(text)
                {"input_ids": tokenized text, "attention_mask": attention mask}
            """
            text = []
            for input_ids in text_input.input_ids:
                t = (
                    self.tokenizer.decode(input_ids)
                    .replace("[PAD]", "")
                    .replace("[CLS]", "")
                    .replace("[SEP]", "")
                    .strip()
                )
                text.append(t)
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_embed = self.encode_text(self.text_VPT_input, text_input)

            # PDE
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_embed, mask=None)
            # use mean
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            text_embed = txt_mu 
            if return_pde:
                txt_sigma = torch.exp(txt_logsigma)
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"mu": text_embed, "sigma": txt_sigma}

            text_features = text_embed / text_embed.norm(dim=-1, keepdim=True)
            assert len(text_embed.shape) == 2, (text_embed.shape) # [batch_size, dim]
            assert len(text_features.shape) == 2, (text_features.shape) # [batch_size, dim]
            return {"text_embed": text_embed, "text_feat":  text_features}

        def inference_image(self, image, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            """
            image_embed = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_embed, mask=None)
            # use mean
            image_embed = img_mu[:, 0, :]
            if return_pde:
                img_sigma = torch.exp(img_logsigma)
                img_sigma = img_sigma[:, 0, :]
                return {"mu": image_embed, "sigma": img_sigma}

            image_features = image_embed / image_embed.norm(dim=-1, keepdim=True)
            assert len(image_features.shape) == 2
            assert len(image_embed.shape) == 2
            return {"image_embed": image_embed, "image_feat": image_features}

        def inference(self, image, text, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: raw text
            """
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_features = self.encode_text(self.text_VPT_input, text_input)
            image_features = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            img_sigma = torch.exp(img_logsigma)
            txt_sigma = torch.exp(txt_logsigma)
            # print(img_mu.shape, txt_mu.shape, text_input.shape, text_input.argmax(dim=-1).shape)
            img_mu = img_mu[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            # print(img_mu.shape, txt_mu.shape)
            # use mean
            image_features = img_mu
            text_features = txt_mu

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            if return_pde:
                # return normalized mu and sigma
                img_sigma = img_sigma[:, 0, :]
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"img_mu": img_mu, "txt_mu": txt_mu, "img_sigma": img_sigma, "txt_sigma": txt_sigma}
            
            return {"image_feat": image_features, "text_feat": text_features}

        def set_tokenizer(self, tokenizer):
            self.tokenizer = tokenizer


    model = CustomCLIP(train_config, model).to(device)

    # Training setup, following: 
    # https://github.com/muzairkhattak/multimodal-prompt-learning/blob/69bce21ae8eda80ad6187534b2dce09cf6c59e17/trainers/independentVL.py#L183
    # Only train prompts
    for name, param in model.named_parameters():
        # Make sure that VPT prompts are updated. Also, PDE parameters are updated.
        if "VPT" in name or "gau_encoder" in name:
            param.requires_grad_(True)
        else:
            param.requires_grad_(False)

    # Double check
    enabled = set()
    for name, param in model.named_parameters():
        if param.requires_grad:
            enabled.add(name)
    # print(f"Parameters to be updated: {sorted(enabled)}")

    return model



def load_clip_with_pde(backbone_name, config, device):
    """
    Load CLIP model with PDE (Probability Distribution Embedding).
    """
    assert backbone_name in ["ViT-B/16"]

    model, preprocess = clip.load("ViT-B/16", device=device)

    # We add PDE (Probability Distribution Embedding) to the model
    # https://github.com/IIGROUP/MAP/blob/main/map/modules/map_module.py
    # DisTrans: outputs mean and logvar of the distribution
    from models.PDE import DisTrans

    def Wasserstein2(mu1, sigma1, mu2, sigma2):
        bs1 = mu1.shape[0]
        bs2 = mu2.shape[0]
        mu1 = torch.stack([mu1]*bs2, dim=1)
        sigma1 = torch.stack([sigma1]*bs2, dim=1)
        mu2 = torch.stack([mu2]*bs1, dim=0)
        sigma2 = torch.stack([sigma2]*bs1, dim=0)
        p1 = torch.sum(torch.pow(mu1 - mu2, 2), dim=-1)
        p2 = torch.sum(torch.pow(sigma1 - sigma2, 2), dim=-1)
        return p1+p2, p1

    def init_weights(module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    class CustomCLIP(nn.Module):
        def __init__(self, cfg, clip_model):
            super().__init__()
            self.dtype = torch.float32 # modified from default dtype
            clip_imsize = clip_model.visual.input_resolution
            cfg_imsize = cfg.INPUT.SIZE[0]

            self.visual = clip_model.visual
            self.transformer = clip_model.transformer
            self.ln_final = clip_model.ln_final
            self.text_projection = clip_model.text_projection
            self.token_embedding = clip_model.token_embedding
            self.positional_embedding = clip_model.positional_embedding

            self.logit_scale = clip_model.logit_scale

            if "is_visualize" in cfg:
                if cfg["is_visualize"]:
                    self.proj_img = nn.Linear(512, 2)
                    self.proj_txt = nn.Linear(512, 2)
                    self.img_gau_encoder = DisTrans(512, 8) # original: 768, 12
                    self.txt_gau_encoder = DisTrans(512, 8) # original: 768, 12
                    self.img_gau_encoder.apply(init_weights)
                    self.txt_gau_encoder.apply(init_weights)

            # PDE
            self.img_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.txt_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.img_gau_encoder.apply(init_weights)
            self.txt_gau_encoder.apply(init_weights)

            pde_config = cfg["pde"]
            self.negative_scale = pde_config["negative_scale"]
            self.shift = pde_config["shift"]

        def encode_image(self, image):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            Returns: image_embed
            """
            x = image.type(self.dtype)

            ViT = self.visual
            x = ViT.conv1(x)
            x = x.reshape(x.shape[0], x.shape[1], -1)
            x = x.permute(0, 2, 1)
            x = torch.cat(
                [ViT.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
                x], dim=1)
            x = x + ViT.positional_embedding.to(x.dtype)
            x = ViT.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = ViT.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = ViT.ln_post(x)

            if ViT.proj is not None:
                x = x @ ViT.proj

            assert len(x.shape) == 3 # [batch_size, token_num, dim]
            return x

        def encode_text(self, tokenized_text):
            """
            tokenized_text: tokenized text (input ids)
            """
            x = self.token_embedding(tokenized_text).type(self.dtype)
            x = x + self.positional_embedding.type(self.dtype)
            x = x.permute(1, 0, 2)
            x = self.transformer(x)
            x = x.permute(1, 0, 2)
            x = self.ln_final(x).type(self.dtype)

            # x.shape = [batch_size, n_ctx, transformer.width]
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            # x = x[torch.arange(x.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            assert len(x.shape) == 3
            return x

        def forward(self, image, tokenized_text):
            """
            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: tokenized text (input ids)

            Append learnable text prompt to the input text,
            then pass through the model.
            """
            logit_scale = self.logit_scale.exp()

            text_features = self.encode_text(tokenized_text)
            image_features = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            img_sigma = torch.exp(img_logsigma)
            txt_sigma = torch.exp(txt_logsigma)
            img_mu = img_mu[:, 0, :]
            img_sigma = img_sigma[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection

            # Note: features are not normalized before calculating distance
            W2_distance, mu_distance = Wasserstein2(img_mu, img_sigma, txt_mu, txt_sigma)

            logits_per_image = (-self.negative_scale * W2_distance + self.shift) * logit_scale
            logits_per_text = logits_per_image.t()

            return logits_per_image, logits_per_text

        def inference_text(self, text_input, return_pde=False):
            """
            Inference: output features

            text_input: output of tokenizer(text)
                {"input_ids": tokenized text, "attention_mask": attention mask}
            """
            text = []
            for input_ids in text_input.input_ids:
                t = (
                    self.tokenizer.decode(input_ids)
                    .replace("[PAD]", "")
                    .replace("[CLS]", "")
                    .replace("[SEP]", "")
                    .strip()
                )
                text.append(t)
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_embed = self.encode_text(text_input)

            # PDE
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_embed, mask=None)
            # use mean
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            text_embed = txt_mu 
            if return_pde:
                txt_sigma = torch.exp(txt_logsigma)
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"mu": text_embed, "sigma": txt_sigma}

            text_features = text_embed / text_embed.norm(dim=-1, keepdim=True)
            return {"text_embed": text_embed, "text_feat":  text_features}

        def inference_image(self, image, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            """
            image_embed = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_embed, mask=None)
            # use mean
            image_embed = img_mu[:, 0, :]
            if return_pde:
                img_sigma = torch.exp(img_logsigma)
                img_sigma = img_sigma[:, 0, :]
                return {"mu": image_embed, "sigma": img_sigma}

            image_features = image_embed / image_embed.norm(dim=-1, keepdim=True)
            return {"image_embed": image_embed, "image_feat": image_features}

        def inference(self, image, text, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: raw text
            """
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_features = self.encode_text(text_input)
            image_features = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            # use mean
            img_mu = img_mu[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            image_features = img_mu
            text_features = txt_mu

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            if return_pde:
                # return normalized mu and sigma
                img_sigma = torch.exp(img_logsigma)
                txt_sigma = torch.exp(txt_logsigma)
                img_sigma = img_sigma[:, 0, :]
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"img_mu": img_mu, "txt_mu": txt_mu, "img_sigma": img_sigma, "txt_sigma": txt_sigma}
            
            return {"image_feat": image_features, "text_feat": text_features}
        
        def freeze_backbone(self):
            for param in self.visual.parameters():
                param.requires_grad = False
            for param in self.transformer.parameters():
                param.requires_grad = False
            for param in self.ln_final.parameters():
                param.requires_grad = False
            self.text_projection.requires_grad_(False)
            for param in self.token_embedding.parameters():
                param.requires_grad_(False)
            self.positional_embedding.requires_grad_(False)

        def unfreeze_backbone(self):
            for param in self.visual.parameters():
                param.requires_grad = True
            for param in self.transformer.parameters():
                param.requires_grad = True
            for param in self.ln_final.parameters():
                param.requires_grad = True
            self.text_projection.requires_grad_(True)
            for param in self.token_embedding.parameters():
                param.requires_grad_(True)
            self.positional_embedding.requires_grad_(True)


        def set_tokenizer(self, tokenizer):
            self.tokenizer = tokenizer

    model = CustomCLIP(config, model).to(device)
    model.float()

    return model


def load_clip_with_pde_csd(backbone_name, config, device):
    """
    Load CLIP model with PDE (Probability Distribution Embedding).
    """
    assert backbone_name in ["ViT-B/16"]

    model, preprocess = clip.load("ViT-B/16", device=device)

    # We add PDE (Probability Distribution Embedding) to the model
    # https://github.com/IIGROUP/MAP/blob/main/map/modules/map_module.py
    # DisTrans: outputs mean and logvar of the distribution
    from models.PDE import DisTrans

    def init_weights(module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    class CustomCLIP(nn.Module):
        def __init__(self, cfg, clip_model):
            super().__init__()
            self.dtype = torch.float32 # modified from default dtype
            clip_imsize = clip_model.visual.input_resolution
            cfg_imsize = cfg.INPUT.SIZE[0]

            self.visual = clip_model.visual
            self.transformer = clip_model.transformer
            self.ln_final = clip_model.ln_final
            self.text_projection = clip_model.text_projection
            self.token_embedding = clip_model.token_embedding
            self.positional_embedding = clip_model.positional_embedding

            self.logit_scale = clip_model.logit_scale

            if "is_visualize" in cfg:
                if cfg["is_visualize"]:
                    self.proj_img = nn.Linear(512, 2)
                    self.proj_txt = nn.Linear(512, 2)
                    self.img_gau_encoder = DisTrans(512, 8) # original: 768, 12
                    self.txt_gau_encoder = DisTrans(512, 8) # original: 768, 12
                    self.img_gau_encoder.apply(init_weights)
                    self.txt_gau_encoder.apply(init_weights)

            # PDE
            self.img_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.txt_gau_encoder = DisTrans(512, 8) # original: 768, 12
            self.img_gau_encoder.apply(init_weights)
            self.txt_gau_encoder.apply(init_weights)

            # CSD loss
            self.criterion = ClosedFormSampledDistanceLoss(
                init_shift=config["pcmepp"]["criterion"]["init_shift"],
                init_negative_scale=config["pcmepp"]["criterion"]["init_negative_scale"],
                prob_distance=config["pcmepp"]["criterion"]["prob_distance"],
            )
            self.image_std_ln = nn.LayerNorm(512)
            self.text_std_ln = nn.LayerNorm(512)

        def encode_image(self, image):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            Returns: image_embed
            """
            x = image.type(self.dtype)

            ViT = self.visual
            x = ViT.conv1(x)
            x = x.reshape(x.shape[0], x.shape[1], -1)
            x = x.permute(0, 2, 1)
            x = torch.cat(
                [ViT.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
                x], dim=1)
            x = x + ViT.positional_embedding.to(x.dtype)
            x = ViT.ln_pre(x)

            x = x.permute(1, 0, 2)  # NLD -> LND
            x = ViT.transformer(x)
            x = x.permute(1, 0, 2)  # LND -> NLD

            x = ViT.ln_post(x)

            if ViT.proj is not None:
                x = x @ ViT.proj

            assert len(x.shape) == 3 # [batch_size, token_num, dim]
            return x

        def encode_text(self, tokenized_text):
            """
            tokenized_text: tokenized text (input ids)
            """
            x = self.token_embedding(tokenized_text).type(self.dtype)
            x = x + self.positional_embedding.type(self.dtype)
            x = x.permute(1, 0, 2)
            x = self.transformer(x)
            x = x.permute(1, 0, 2)
            x = self.ln_final(x).type(self.dtype)

            # x.shape = [batch_size, n_ctx, transformer.width]
            # take features from the eot embedding (eot_token is the highest number in each sequence)
            # x = x[torch.arange(x.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            assert len(x.shape) == 3
            return x

        def forward(self, image, tokenized_text):
            """
            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: tokenized text (input ids)

            Append learnable text prompt to the input text,
            then pass through the model.
            """
            logit_scale = self.logit_scale.exp()

            text_features = self.encode_text(tokenized_text)
            image_features = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            # img_sigma = torch.exp(img_logsigma)
            # txt_sigma = torch.exp(txt_logsigma)
            img_sigma = img_logsigma
            txt_sigma = txt_logsigma
            img_mu = img_mu[:, 0, :]
            img_sigma = img_sigma[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection
            txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), tokenized_text.argmax(dim=-1)] @ self.text_projection

            # img_sigma = img_sigma / img_mu.norm(dim=-1, keepdim=True)
            # txt_sigma = txt_sigma / txt_mu.norm(dim=-1, keepdim=True)
            img_mu = img_mu / img_mu.norm(dim=-1, keepdim=True)
            txt_mu = txt_mu / txt_mu.norm(dim=-1, keepdim=True)
            img_emb = {"mean": img_mu , "std": img_sigma}
            txt_emb = {"mean": txt_mu, "std": txt_sigma}
            if torch.rand(1) < 0.01:
                print((img_mu @ txt_mu.t())[:5])
            loss, loss_dict = self.criterion(img_emb, txt_emb)
            return loss

        def inference_text(self, text_input, return_pde=False):
            """
            Inference: output features

            text_input: output of tokenizer(text)
                {"input_ids": tokenized text, "attention_mask": attention mask}
            """
            text = []
            for input_ids in text_input.input_ids:
                t = (
                    self.tokenizer.decode(input_ids)
                    .replace("[PAD]", "")
                    .replace("[CLS]", "")
                    .replace("[SEP]", "")
                    .strip()
                )
                text.append(t)
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_embed = self.encode_text(text_input)

            # PDE
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_embed, mask=None)
            # use mean
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            text_embed = txt_mu 
            if return_pde:
                # txt_sigma = torch.exp(txt_logsigma)
                txt_sigma = txt_logsigma
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"mu": text_embed, "sigma": txt_sigma}

            text_features = text_embed / text_embed.norm(dim=-1, keepdim=True)
            return {"text_embed": text_embed, "text_feat":  text_features}

        def inference_image(self, image, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            """
            image_embed = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_embed, mask=None)
            # use mean
            image_embed = img_mu[:, 0, :]
            if return_pde:
                # img_sigma = torch.exp(img_logsigma)
                img_sigma = img_logsigma
                img_sigma = img_sigma[:, 0, :]
                return {"mu": image_embed, "sigma": img_sigma}

            image_features = image_embed / image_embed.norm(dim=-1, keepdim=True)
            return {"image_embed": image_embed, "image_feat": image_features}

        def inference(self, image, text, return_pde=False):
            """
            Inference: output features

            image: a tensor of shape [batch_size, 3, imsize, imsize]
            text: raw text
            """
            text_input = clip.tokenize(text, 77, True).to(self.logit_scale.device)
            text_features = self.encode_text(text_input)
            image_features = self.encode_image(image.type(self.dtype))

            # PDE
            img_mu, img_logsigma, _ = self.img_gau_encoder(image_features, mask=None)
            txt_mu, txt_logsigma, _ = self.txt_gau_encoder(text_features, mask=None)
            # use mean
            img_mu = img_mu[:, 0, :]
            txt_mu = txt_mu[torch.arange(txt_mu.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
            image_features = img_mu
            text_features = txt_mu

            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            if return_pde:
                # return normalized mu and sigma
                # img_sigma = torch.exp(img_logsigma)
                # txt_sigma = torch.exp(txt_logsigma)
                img_sigma = img_logsigma
                txt_sigma = txt_logsigma
                img_sigma = img_sigma[:, 0, :]
                txt_sigma = txt_sigma[torch.arange(txt_sigma.shape[0]), text_input.argmax(dim=-1)] @ self.text_projection
                return {"img_mu": img_mu, "txt_mu": txt_mu, "img_sigma": img_sigma, "txt_sigma": txt_sigma}
            
            return {"image_feat": image_features, "text_feat": text_features}

        def freeze_backbone(self):
            for param in self.visual.parameters():
                param.requires_grad_(False)
            for param in self.transformer.parameters():
                param.requires_grad_(False)
            for param in self.ln_final.parameters():
                param.requires_grad_(False)
            self.text_projection.requires_grad_(False)
            for param in self.token_embedding.parameters():
                param.requires_grad_(False)
            self.positional_embedding.requires_grad_(False)
        
        def unfreeze_backbone(self):
            for param in self.visual.parameters():
                param.requires_grad_(True)
            for param in self.transformer.parameters():
                param.requires_grad_(True)
            for param in self.ln_final.parameters():
                param.requires_grad_(True)
            self.text_projection.requires_grad_(True)
            for param in self.token_embedding.parameters():
                param.requires_grad_(True)
            self.positional_embedding.requires_grad_(True)

        def set_tokenizer(self, tokenizer):
            self.tokenizer = tokenizer

    model = CustomCLIP(config, model).to(device)
    model.float()

    return model




if __name__ == "__main__":
    import yaml
    from easydict import EasyDict as edict

    config_path = "/data/robust_crossmodal-retrieval/configs/Retrieval_flickr_train_clip_b64.yaml"
    config = yaml.load(open(config_path, "r"), Loader=yaml.Loader)
    config = edict(config)

    train_config_path = "/data/robust_crossmodal-retrieval/configs/train/full/clip_pcmepp.yaml"
    with open(train_config_path) as f:
        train_config = yaml.load(f, Loader=yaml.FullLoader)
    train_config = edict(train_config)
    
    # model_name = "CLIP_ViT-B-16_PT"
    model_name = "CLIP_ViT-B-16_PCMEPP"
    model_ckpt = ""
    text_encoder = "bert-base-uncased"
    device = "cuda"

    model, ref_model, tokenizer = load_model(config, model_name, model_ckpt, text_encoder, device, train_config)
    print(model)
    # print(ref_model)
    # print(tokenizer)
    model.eval()
    
    from utils_attack import set_mode_for_attack
    set_mode_for_attack(model)

    imgs = torch.randn(2, 3, 224, 224).to(device)
    adv_imgs = imgs.clone().detach()
    adv_imgs.requires_grad = True

    text = ["a photo of a cat", "a photo of a dog"]
    text_input = clip.tokenize(text, 77, True).to(device)
    print(text_input)

    loss = model(adv_imgs, text_input)
    loss.backward()

    grad = adv_imgs.grad 
    grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True)           
    
    step_size = 0.5
    eps = 1
    perturbation = step_size * grad.sign()
    adv_imgs = adv_imgs.detach() + perturbation
    adv_imgs = torch.min(torch.max(adv_imgs, imgs - eps), imgs + eps)
    adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)