"""
Imported and modified from:
 https://github.com/Zoky-2020/SGA/blob/main/eval_albef2tcl_flickr.py
 https://github.com/adversarial-for-goodness/Co-Attack/blob/main/RetrievalEval.py
 https://github.com/adversarial-for-goodness/Co-Attack/blob/main/RetrievalFusionEval.py
"""

import argparse
import os
import glob
import sys

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
from PIL import Image
from easydict import EasyDict as edict

import umap
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as pat
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

import torch
from  torch.cuda.amp import autocast

import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from transformers import BertForMaskedLM
from torchvision import transforms
from PIL import Image

from models.get_model import load_model

import utils.utils as utils

# SGA
sys.path.append("SGA")
from attacker import (
    SGAttacker as SGAttacker,
    ImageAttacker as SGA_ImageAttacker,
    TextAttacker as SGA_TextAttacker,
)

# ACMMM2022: https://github.com/adversarial-for-goodness/Co-Attack/tree/main
# attack image_feat and text_feat, which are normalized features.
from co_attack_modified import ImageAttacker as PGD_ImageAttacker
from co_attack_modified import BertAttack, BertAttackFusion
from co_attack_modified import MultiModalAttacker as CoAttacker
# attack image_embed and text_embed
# sys.path.append("Co-Attack")
# from attack.imageAttack import ImageAttacker as PGD_ImageAttacker
# from attack.bert_attack import BertAttack, BertAttackFusion
# from attack.multimodalAttack import MultiModalAttacker as CoAttacker

from dataset_simple import paired_dataset, get_test_transform, get_test_data


img2test_txt = {
    0: "a photo of a man",
    1: "a photo of a dog",
    2: "a photo of martial arts",
    3: "a photo taken in winter",
    4: "a photo of a rooftop",
}

def get_feats(
    args,
    model,
    ref_model,
    data_loader,
    tokenizer,
    device,
    config,
):
    max_n = args.max_n

    # test
    model.float()
    model.eval()
    ref_model.eval()
   
    print("Computing features for evaluation adv...")

    images_normalize = transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    )
    
    if args.attack == "SGA":    
        img_attacker = SGA_ImageAttacker(
            images_normalize, eps=args.epsilon / 255, steps=10, step_size=0.5 / 255
        )
        txt_attacker = SGA_TextAttacker(
            ref_model,
            tokenizer,
            cls=False,
            max_length=30,
            number_perturbation=1,
            topk=10,
            threshold_pred_score=0.3,
        )
        attacker = SGAttacker(model, img_attacker, txt_attacker)
    elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack", "Clean"]:
        if args.attack == "BERT":
            adv_mode = 1
        elif args.attack == "PGD":
            adv_mode = 2
        elif args.attack == "Sep-Attack":
            adv_mode = 3
        elif args.attack == "Co-Attack":
            adv_mode = 4
        elif args.attack == "Clean":
            adv_mode = 0
        else:
            raise ValueError(f"Invalid attack mode: {args.attack}")
        image_attacker = PGD_ImageAttacker(
            args.epsilon / 255.0,
            preprocess=images_normalize,
            bounding=(0, 1),
            cls=args.cls,
        )
        if args.attack_fused_emb:
            text_attacker = BertAttackFusion(ref_model, tokenizer, cls=args.cls)
        else:
            text_attacker = BertAttack(ref_model, tokenizer, cls=args.cls)
        multi_attacker = CoAttacker(
            model, image_attacker, text_attacker, tokenizer, cls=args.cls
        )

        print("Prepare memory")
        num_text = len(data_loader.dataset.text)
        num_image = len(data_loader.dataset.ann)

        if args.source_model in FUSED_MODELS:
            s_image_feats = torch.zeros(num_image, config["embed_dim"])
            s_text_feats = torch.zeros(num_text, config["embed_dim"])

            s_image_sigma = torch.zeros(num_image, config["embed_dim"])
            s_text_sigma = torch.zeros(num_text, config["embed_dim"])
        else:
            s_image_feats = torch.zeros(num_image, model.visual.output_dim)
            s_text_feats = torch.zeros(num_text, model.visual.output_dim)

            s_image_sigma = torch.zeros(num_image, model.visual.output_dim)
            s_text_sigma = torch.zeros(num_text, model.visual.output_dim)

        # for visualization
        adv_images_list = None
        adv_texts_list = []

        if args.attack == "SGA":
            if args.scales is not None:
                scales = [float(itm) for itm in args.scales.split(",")]
                print(scales)
            else:
                scales = None

        print("Forward")
        n = 0
        n_texts = 0
        for batch_idx, (images, texts_group, images_ids, text_ids_groups) in enumerate(
            data_loader
        ):
            if n > max_n:
                print("Stop iterations at n:", n)
                break
            print(f"--------------------> batch:{batch_idx}/{len(data_loader)}")
            texts_ids = []
            txt2img = []
            texts = []
            for i in range(len(texts_group)):
                texts += texts_group[i]
                texts_ids += text_ids_groups[i]
                txt2img += [i] * len(text_ids_groups[i])

            images = images.to(device)

            # print(images.shape, len(texts), texts[0])

            if args.attack == "SGA":
                adv_images, adv_texts = attacker.attack(
                    images, texts, txt2img, device=device, max_lemgth=30, scales=scales
                )
            elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack"]:
                if args.attack_fused_emb:
                    # Note: CoAttack for fused embeddings requires paired input (image, text).
                    B = len(images)
                    repeated_images = torch.stack([im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0)
                    max_length = 1e3 if args.source_model in FUSED_MODELS else 77
                    adv_images, adv_texts = multi_attacker.run(
                        repeated_images,
                        texts,
                        adv=adv_mode,
                        num_iters=config["num_iters"],
                        alpha=args.alpha,
                    )
                    indice = np.cumsum(img2txt_n) - 1
                    adv_images = adv_images[indice]
                    # print(len(adv_images), len(adv_texts))
                else:
                    # Note: CoAttack for fused embeddings requires paired input (image, text).
                    # print("Run Co-Attack for separated embeddings.")
                    B = len(images)
                    img2txt_n = [txt2img.count(i) for i in range(B)]
                    repeated_images = torch.stack([im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0)
                    max_length = 1e3 if args.source_model in FUSED_MODELS else 77
                    adv_images, adv_texts = multi_attacker.run_before_fusion(
                        repeated_images,
                        texts,
                        adv=adv_mode,
                        num_iters=config["num_iters"],
                        alpha=args.alpha,
                        max_length=max_length,
                    )
                    indice = np.cumsum(img2txt_n) - 1
                    adv_images = adv_images[indice]
                    
            elif args.attack == "Clean":
                adv_images = images
                adv_texts = texts
            else:
                raise ValueError(f"Invalid attack mode: {args.attack}")

            np_adv_images = adv_images.cpu().numpy().transpose(0, 2, 3, 1)
            if adv_images_list is None:
                adv_images_list = np_adv_images
            else:
                adv_images_list = np.concatenate(
                    [adv_images_list, np_adv_images], axis=0
                )
            adv_texts_list += adv_texts

            with torch.no_grad():
                adv_images_norm = images_normalize(adv_images)
                adv_texts_input = tokenizer(
                    adv_texts,
                    padding="max_length",
                    truncation=True,
                    max_length=30,
                    return_tensors="pt",
                ).to(device)

                if "PDE" in args.source_model:
                    output = model.inference(adv_images_norm, adv_texts, return_pde=True)
                    s_image_feats[images_ids] = output["img_mu"].cpu().float().detach()
                    s_text_feats[texts_ids] = output["txt_mu"].cpu().float().detach()

                    s_image_sigma[images_ids] = output["img_sigma"].cpu().float().detach()
                    s_text_sigma[texts_ids] = output["txt_sigma"].cpu().float().detach()
                else:
                    output = model.inference(adv_images_norm, adv_texts)
                    s_image_feats[images_ids] = output["image_feat"].cpu().float().detach()
                    s_text_feats[texts_ids] = output["text_feat"].cpu().float().detach()

            n += len(images)
            n_texts += len(texts)

        # if not all the data is used
        s_image_feats = s_image_feats[:n]
        s_text_feats = s_text_feats[:n_texts]
        s_image_sigma = s_image_sigma[:n]
        s_text_sigma = s_text_sigma[:n_texts]

    print(s_image_feats.shape, s_text_feats.shape)

    return adv_images_list, adv_texts_list, s_image_feats, s_text_feats, s_image_sigma, s_text_sigma


def get_feats_txt(args, model, config, tokenizer, img2test_txt):
    print("Computing features for evaluation adv...")

    # test
    model.float()
    model.eval()

    print("Prepare memory")
    num_text = len(img2test_txt)
    if args.source_model in FUSED_MODELS:
        s_text_feats = torch.zeros(num_text, config["embed_dim"])
        s_text_sigma = torch.zeros(num_text, config["embed_dim"])
    else:
        s_text_feats = torch.zeros(num_text, model.visual.output_dim)
        s_text_sigma = torch.zeros(num_text, model.visual.output_dim)

    print("Forward")
    for i, txt in img2test_txt.items():
        print(f"--------------------> txt:{i}/{len(img2test_txt)}: {txt}")
        text_input = tokenizer(
            txt,
            padding="max_length",
            truncation=True,
            max_length=30,
            return_tensors="pt",
        )
        with torch.no_grad():
            if "PDE" in args.source_model:
                output = model.inference_text(text_input, return_pde=True)
                s_text_feats[i] = output["mu"].cpu().float().detach()
                s_text_sigma[i] = output["sigma"].cpu().float().detach()
            else:
                output = model.inference_text(text_input)
                s_text_feats[i] = output["text_feat"].cpu().float().detach()
    print(s_text_feats.shape)
    return s_text_feats, s_text_sigma


def vis_feats(
    adv_images_list,
    adv_texts_list,
    s_image_feats,
    s_text_feats,
    s_image_sigma,
    s_text_sigma,
    img2txt,
    show_n=5,
    additional_txt_dict=None,
):
    """
    Visualize features
    """

    # normalize features
    s_image_sigma = s_image_sigma / s_image_feats.norm(dim=1, keepdim=True)
    s_image_feats = s_image_feats / s_image_feats.norm(dim=1, keepdim=True)

    s_text_sigma = s_text_sigma / s_text_feats.norm(dim=1, keepdim=True)
    s_text_feats = s_text_feats / s_text_feats.norm(dim=1, keepdim=True)

    if additional_txt_dict is not None:
        additional_txt_dict["sigma"] = additional_txt_dict["sigma"] / additional_txt_dict["feats"].norm(dim=1, keepdim=True)
        additional_txt_dict["feats"] = additional_txt_dict["feats"] / additional_txt_dict["feats"].norm(dim=1, keepdim=True)
    
    # sigma size
    s_image_sigma_size = torch.tensor(s_image_sigma).norm(dim=1, keepdim=True)
    s_text_sigma_size = torch.tensor(s_text_sigma).norm(dim=1, keepdim=True)
    additional_txt_dict["sigma_size"] = torch.tensor(additional_txt_dict["sigma"]).norm(dim=1, keepdim=True)
    print(s_text_sigma_size)
    print(s_text_sigma_size.shape)

    # zero sigma for test. (for visualization)
    s_image_sigma = torch.zeros_like(s_image_sigma)
    s_text_sigma = torch.zeros_like(s_text_sigma)
    additional_txt_dict["sigma"] = torch.zeros_like(additional_txt_dict["sigma"])

    # print distance matrix
    matrix = s_image_feats.cpu().numpy() @ s_text_feats.cpu().numpy().T
    matrix_rank_per_img = np.argsort(-matrix, axis=1)
    print("Distance matrix: ")
    print(matrix_rank_per_img)

    start_time = time.time()

    #####################
    ###### UMAP #########
    #####################
    ## concatenate features
    # X = np.concatenate(
    #     [s_image_feats.cpu().numpy(), s_text_feats.cpu().numpy()], axis=0
    # )
    # reducer = umap.UMAP(
    #     n_neighbors=3, 
    #     min_dist=0.1, 
    #     n_components=2, 
    #     metric="euclidean", 
    #     random_state=0
    # )
    # reducer.fit(X)
    # img_embed = reducer.transform(s_image_feats.cpu().numpy())
    # txt_embed = reducer.transform(s_text_feats.cpu().numpy())
    # # img_sigma = reducer.transform(s_image_sigma.cpu().numpy())
    # # txt_sigma = reducer.transform(s_text_sigma.cpu().numpy())
    # img_sigma = np.zeros_like(img_embed)
    # txt_sigma = np.zeros_like(txt_embed)
    # if additional_txt_dict is not None:
    #     add_txt_feats = additional_txt_dict["feats"]
    #     add_txt_sigma = additional_txt_dict["sigma"]
    #     add_txt_embed = reducer.transform(add_txt_feats)
    #     # add_txt_sigma = reducer.transform(add_txt_sigma)
    #     add_txt_sigma = np.zeros_like(add_txt_embed)

    #####################
    ###### TSNE #########
    #####################
    s_image_feats_plus_sigma = s_image_feats + s_image_sigma
    s_image_feats_minus_sigma = s_image_feats - s_image_sigma
    s_text_feats_plus_sigma = s_text_feats + s_text_sigma
    s_text_feats_minus_sigma = s_text_feats - s_text_sigma
    add_s_text_feats_plus_sigma = additional_txt_dict["feats"] + additional_txt_dict["sigma"]
    add_s_text_feats_minus_sigma = additional_txt_dict["feats"] - additional_txt_dict["sigma"]
    X = np.concatenate(
        [
            s_image_feats.cpu().numpy(),
            s_image_feats_plus_sigma.cpu().numpy(),
            s_image_feats_minus_sigma.cpu().numpy(),
            s_text_feats.cpu().numpy(),
            s_text_feats_plus_sigma.cpu().numpy(),
            s_text_feats_minus_sigma.cpu().numpy(),
            additional_txt_dict["feats"],
            add_s_text_feats_plus_sigma,
            add_s_text_feats_minus_sigma,
        ],
        axis=0,
    )
    print(X.shape)

    reducer = TSNE(n_components=2, random_state = 0, perplexity = 30, n_iter = 2000)
    X = reducer.fit_transform(X)

    img_embed = X[:len(s_image_feats)]
    img_sigma_plus = X[len(s_image_feats):len(s_image_feats)*2]
    img_sigma_minus = X[len(s_image_feats)*2:len(s_image_feats)*3]
    img_sigma = np.maximum(np.abs(img_sigma_plus - img_embed), np.abs(img_embed - img_sigma_minus))
    
    txt_embed = X[len(s_image_feats)*3:len(s_image_feats)*3+len(s_text_feats)]
    txt_sigma_plus = X[len(s_image_feats)*3+len(s_text_feats):len(s_image_feats)*3+len(s_text_feats)*2]
    txt_sigma_minus = X[len(s_image_feats)*3+len(s_text_feats)*2:len(s_image_feats)*3+len(s_text_feats)*3]
    txt_sigma = np.maximum(np.abs(txt_sigma_plus - txt_embed), np.abs(txt_embed - txt_sigma_minus))

    add_txt_embed = X[len(s_image_feats)*3+len(s_text_feats)*3:len(s_image_feats)*3+len(s_text_feats)*3+len(additional_txt_dict["feats"])]
    add_txt_sigma_plus = X[len(s_image_feats)*3+len(s_text_feats)*3+len(additional_txt_dict["feats"]):len(s_image_feats)*3+len(s_text_feats)*3+len(additional_txt_dict["feats"])*2]
    add_txt_sigma_minus = X[len(s_image_feats)*3+len(s_text_feats)*3+len(additional_txt_dict["feats"])*2:len(s_image_feats)*3+len(s_text_feats)*3+len(additional_txt_dict["feats"])*3]
    add_txt_sigma = np.maximum(np.abs(add_txt_sigma_plus - add_txt_embed), np.abs(add_txt_embed - add_txt_sigma_minus))

    interval = time.time() - start_time
    print(f"Compression time: {interval:.2f}s")


    def imscatter(x, y, image, ax=None, zoom=1):
        if ax is None:
            ax = plt.gca()
        try:
            image = Image.fromarray(image)
        except TypeError:
            # Likely already an array...
            pass
        im = OffsetImage(image, zoom=zoom)
        x, y = np.atleast_1d(x, y)
        artists = []
        for x0, y0 in zip(x, y):
            ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
            artists.append(ax.add_artist(ab))
        ax.update_datalim(np.column_stack([x, y]))
        ax.autoscale()
        return artists

    # plot: draw ellipse based on mean and sigma
    color_list = cm.tab10(np.linspace(0, 1, 10))
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    for i in range(show_n):
        adv_img = adv_images_list[i]
        inds = img2txt[i]
        # show only first
        inds = [inds[0]]
        print(inds)
        adv_txts = [adv_texts_list[j] for j in inds]
        adv_img_embed = img_embed[i]
        adv_img_sigma = img_sigma[i]
        adv_txt_embed = txt_embed[inds]
        adv_txt_sigma = txt_sigma[inds]
        txt_sigma_size = torch.tensor(adv_txt_sigma).norm(dim=1, keepdim=True)

        c = color_list[i % 10]

        E = pat.Ellipse(
            xy=adv_img_embed,
            width=adv_img_sigma[0],
            height=adv_img_sigma[1],
            angle=0,
            facecolor=c,
            edgecolor="r",
            alpha=0.2,
        )
        ax.add_patch(E)
        imscatter(adv_img_embed[0], adv_img_embed[1], adv_img, ax=ax, zoom=0.2)
        plt.scatter(adv_img_embed[0], adv_img_embed[1], s=5, color=c)
        

        for j in range(len(adv_txt_embed)):
            E = pat.Ellipse(
                xy=adv_txt_embed[j],
                width=adv_txt_sigma[j][0] * 2,
                height=adv_txt_sigma[j][1] * 2,
                angle=0,
                facecolor=c,
                edgecolor="b",
                alpha=0.2,
            )
            ax.add_patch(E)
            plt.text(
                adv_txt_embed[j][0],
                adv_txt_embed[j][1],
                adv_txts[j],
                fontsize=8,
                color="black",
            )
            plt.scatter(adv_txt_embed[j][0], adv_txt_embed[j][1], s=5, color=c)

            print(adv_txts[j], s_text_sigma_size[inds[j]])

    if additional_txt_dict is not None:
        add_txts = additional_txt_dict["txt"]

        for i in range(len(add_txts)):
            print(add_txts[i], additional_txt_dict["sigma_size"][i])

        for i in range(len(add_txts)):
            # xy: top of the circle
            # print(add_txt_embed[i], additional_txt_dict["sigma_size"][i])

            E = pat.Ellipse(
                xy=add_txt_embed[i],
                width=add_txt_sigma[i][0] * 2,
                height=add_txt_sigma[i][1] * 2,
                angle=0,
                facecolor='none',
                edgecolor="black",
                alpha=0.5,
            )
            ax.add_patch(E)
            x = add_txt_embed[i][0]
            y = add_txt_embed[i][1] + add_txt_sigma[i][1]
            plt.text(
                x, y, add_txts[i], fontsize=8, color="black"
            )
            plt.scatter(add_txt_embed[i][0], add_txt_embed[i][1], s=5, color="black")

    # save
    save_fig_path = os.path.join(VIS_DIR, "vis_feature.png")
    plt.savefig(save_fig_path)
    plt.close()
    print(f"Save vis feature to {save_fig_path}")


def main(args, config):
    device = torch.device("cuda")

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    print("Creating Source Model")
    model, ref_model, tokenizer = load_model(
        config, args.source_model, args.source_ckpt, args.source_text_encoder, device, prompt_config=prompt_config
    )
    
    #### Dataset ####
    print("Creating dataset")
    test_transform = get_test_transform(config, args.source_model, model)
    test_loader = get_test_data(args, config, test_transform)

    # get features
    adv_images_list, adv_texts_list, s_image_feats, s_text_feats, s_image_sigma, s_text_sigma = get_feats(
        args, model, ref_model, test_loader, tokenizer, device, config
    )
    _txt_feats, _txt_sigma = get_feats_txt(args, model, config, tokenizer, img2test_txt)
    additional_txt_dict = {
        "txt": list(img2test_txt.values()),
        "feats": _txt_feats,
        "sigma": _txt_sigma,
    }

    # visualize features
    vis_feats(
        adv_images_list,
        adv_texts_list,
        s_image_feats,
        s_text_feats,
        s_image_sigma,
        s_text_sigma,
        test_loader.dataset.img2txt,
        show_n=5,
        additional_txt_dict=additional_txt_dict,
    )



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr.yaml")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--batch_size", default=8, type=int)

    parser.add_argument("--source_model", default="ALBEF", type=str) # model architecture
    parser.add_argument("--source_model_name", default=None, type=str) # id for the model
    parser.add_argument("--source_text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--source_ckpt", default=None, type=str)

    # parser.add_argument("--target_model", default=None, type=str) # model architecture
    # parser.add_argument("--target_model_name", default=None, type=str) # id for the model
    # parser.add_argument("--target_text_encoder", default="bert-base-uncased", type=str)
    # parser.add_argument("--target_ckpt", default=None, type=str)

    # parser.add_argument("--original_rank_index_path", default=None, type=str)

    # SGA attack
    parser.add_argument("--scales", type=str, default="0.5,0.75,1.25,1.5")

    # evaluation config
    parser.add_argument("--cls", default=False, action="store_true")
    parser.add_argument("--is_force_recompute_feats", default=False, action="store_true")
    parser.add_argument("--max_n", default=10000, type=int)
    parser.add_argument("--attack_fused_emb", default=False, action="store_true")
    parser.add_argument(
        "--attack",
        default="SGA",
        type=str,
        choices=["SGA", "Co-Attack", "Sep-Attack", "PGD", "BERT", "Clean"],
    )
    parser.add_argument("--eval_save_dir", default="../eval_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)

    parser.add_argument("--prompt_config", default=None, type=str)

    args = parser.parse_args()

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

    FUSED_MODELS = [
        "ALBEF",
        "ALBEF_PT",
        "TCL",
        "TCL_PT",
    ]

    # modify save dir
    if args.max_n < 200:
        print(
            "Warning: max_n is less than 200, which may not be enough for evaluation."
        )
        args.eval_save_dir = os.path.join(args.eval_save_dir, f"max_n_{args.max_n}")

    if args.source_model in FUSED_MODELS:
        print("ALBEF or TCL model, set attack_fused_emb == True, cls == False. (Multi-full)")
        args.attack_fused_emb = True
        args.cls = False
    else:
        print("CLIP model: (Uni-cls)")
        args.attack_fused_emb = False
        args.cls = False

    
    # prompt config
    prompt_config = None
    vision_ctx = 0
    if "PT" in args.source_model:
        if args.prompt_config is not None:
            with open(args.prompt_config, "r") as f:
                prompt_config = yaml.load(f, Loader=yaml.Loader)
            vision_ctx = prompt_config["vision_ctx"]
            language_ctx = prompt_config["language_ctx"]
            print("load prompt config: ", args.prompt_config)
            prompt_config = edict(prompt_config)
        else:
            try:
                source_ckpt_dir = os.path.dirname(args.source_ckpt)
                prompt_config_path = glob.glob(os.path.join(source_ckpt_dir, "prompt_config.json"))[0]
                with open(prompt_config_path, "r") as f:
                    prompt_config = json.load(f)
            except:
                prompt_config_path = "/data/robust_crossmodal-retrieval/configs/prompt_tuning/base5.yaml"
                with open(prompt_config_path, "r") as f:
                    prompt_tuning_config = yaml.load(f, Loader=yaml.Loader)
                    prompt_config = edict(prompt_tuning_config)["PROMPT"]
            vision_ctx = prompt_config["NUM_TOKENS"]
            language_ctx = 0

    ATTACK = args.attack
    # for attack in ["Clean", ATTACK]:
    for attack in ["Clean"]:
        args.attack = attack
        print(f"Start eval for {args.source_model_name} with {args.attack}")

        # dir to save the results
        LOG_BASE_DIR = os.path.join(args.eval_save_dir, args.source_model_name, args.attack)
        # SAVE_FEAT_DIR = os.path.join(LOG_BASE_DIR, "features")
        # SAVE_IMG_DIR = os.path.join(LOG_BASE_DIR, "imgs")
        VIS_DIR = os.path.join(LOG_BASE_DIR, "vis")

        # os.makedirs(SAVE_FEAT_DIR, exist_ok=True)
        os.makedirs(LOG_BASE_DIR, exist_ok=True)
        # os.makedirs(SAVE_IMG_DIR, exist_ok=True)
        os.makedirs(VIS_DIR, exist_ok=True)

        main(args, config)
