"""
Imported and modified from:
 https://github.com/Zoky-2020/SGA/blob/main/eval_albef2tcl_flickr.py
 https://github.com/adversarial-for-goodness/Co-Attack/blob/main/RetrievalEval.py
 https://github.com/adversarial-for-goodness/Co-Attack/blob/main/RetrievalFusionEval.py
"""

import argparse
import os
import sys

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from easydict import EasyDict as edict

import torch
from  torch.cuda.amp import autocast

import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from transformers import BertForMaskedLM
from torchvision import transforms
from PIL import Image

from models.ALBEF.model_retrieval import ALBEF
from models.TCL.model_retrieval import ALBEF as TCL
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
from models import clip
from models.get_model import load_model

import utils.utils as utils

# SGA
sys.path.append("SGA")
from attacker import (
    SGAttacker as SGAttacker,
    ImageAttacker as SGA_ImageAttacker,
    TextAttacker as SGA_TextAttacker,
)

# ACMMM2022: https://github.com/adversarial-for-goodness/Co-Attack/tree/main
# attack image_feat and text_feat, which are normalized features.
from co_attack_modified import ImageAttacker as PGD_ImageAttacker
from co_attack_modified import BertAttack, BertAttackFusion
from co_attack_modified import MultiModalAttacker as CoAttacker
# attack image_embed and text_embed
# sys.path.append("Co-Attack")
# from attack.imageAttack import ImageAttacker as PGD_ImageAttacker
# from attack.bert_attack import BertAttack, BertAttackFusion
# from attack.multimodalAttack import MultiModalAttacker as CoAttacker

from dataset_simple import paired_dataset

def retrieval_eval(
    args,
    model,
    ref_model,
    t_model,
    t_ref_model,
    t_test_transform,
    data_loader,
    tokenizer,
    t_tokenizer,
    device,
    config,
):
    max_n = args.max_n

    p = os.path.join(SAVE_FEAT_DIR, "s_image_feats.pth")
    if os.path.exists(p) and not args.is_force_recompute_feats:
        print(f"Features already exist in {SAVE_FEAT_DIR}, loading...")
        s_image_feats = torch.load(p)
        s_image_embeds = torch.load(os.path.join(SAVE_FEAT_DIR, "s_image_embeds.pth"))
        s_text_feats = torch.load(os.path.join(SAVE_FEAT_DIR, "s_text_feats.pth"))
        s_text_embeds = torch.load(os.path.join(SAVE_FEAT_DIR, "s_text_embeds.pth"))
        s_text_atts = torch.load(os.path.join(SAVE_FEAT_DIR, "s_text_atts.pth"))

        if args.target_model is not None:
            t_image_feats = torch.load(
                os.path.join(SAVE_FEAT_DIR_T, "t_image_feats.pth")
            )
            t_image_embeds = torch.load(
                os.path.join(SAVE_FEAT_DIR_T, "t_image_embeds.pth")
            )
            t_text_feats = torch.load(os.path.join(SAVE_FEAT_DIR_T, "t_text_feats.pth"))
            t_text_embeds = torch.load(
                os.path.join(SAVE_FEAT_DIR_T, "t_text_embeds.pth")
            )
            t_text_atts = torch.load(os.path.join(SAVE_FEAT_DIR_T, "t_text_atts.pth"))

        # load imgs
        adv_images_list = np.load(os.path.join(SAVE_IMG_DIR, "adv_images_list.npy"))
        with open(os.path.join(SAVE_IMG_DIR, "adv_texts_list.json"), "r") as f:
            adv_texts_list = json.load(f)

    else:
        print(f"Features do not exist in {SAVE_FEAT_DIR}.")

        # test
        model.float()
        model.eval()
        ref_model.eval()
        t_model.float()
        t_model.eval()
        t_ref_model.eval()

        print("Computing features for evaluation adv...")

        images_normalize = transforms.Normalize(
            (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
        )
        # def denorm(x):
        #     mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1, 3, 1, 1)
        #     std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1, 3, 1, 1)
        #     return x * std.to(x.device) + mean.to(x.device)

        if args.attack == "SGA":
            img_attacker = SGA_ImageAttacker(
                images_normalize, eps=args.epsilon / 255, steps=10, step_size=0.5 / 255
            )
            txt_attacker = SGA_TextAttacker(
                ref_model,
                tokenizer,
                cls=False,
                max_length=30,
                number_perturbation=1,
                topk=10,
                threshold_pred_score=0.3,
            )
            attacker = SGAttacker(model, img_attacker, txt_attacker)
        elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack", "Clean"]:
            if args.attack == "BERT":
                adv_mode = 1
            elif args.attack == "PGD":
                adv_mode = 2
            elif args.attack == "Sep-Attack":
                adv_mode = 3
            elif args.attack == "Co-Attack":
                adv_mode = 4
            elif args.attack == "Clean":
                adv_mode = 0
            else:
                raise ValueError(f"Invalid attack mode: {args.attack}")
            image_attacker = PGD_ImageAttacker(
                args.epsilon / 255.0,
                preprocess=images_normalize,
                bounding=(0, 1),
                cls=args.cls,
            )
            if args.attack_fused_emb:
                text_attacker = BertAttackFusion(ref_model, tokenizer, cls=args.cls)
            else:
                text_attacker = BertAttack(ref_model, tokenizer, cls=args.cls)
            multi_attacker = CoAttacker(
                model, image_attacker, text_attacker, tokenizer, cls=args.cls
            )

        print("Prepare memory")
        num_text = len(data_loader.dataset.text)
        num_image = len(data_loader.dataset.ann)

        if args.source_model in FUSED_MODELS:
            s_image_feats = torch.zeros(num_image, config["embed_dim"])
            
            s_embed_dim = 577 if "PT" not in args.source_model else 577 + vision_ctx
            s_image_embeds = torch.zeros(num_image, s_embed_dim, 768)
            s_text_feats = torch.zeros(num_text, config["embed_dim"])
            
            s_text_embed_dim = 30 
            s_text_embeds = torch.zeros(num_text, s_text_embed_dim, 768)
            s_text_atts = torch.zeros(num_text, s_text_embed_dim).long()
        else:
            s_image_feats = torch.zeros(num_image, model.visual.output_dim)
            s_text_feats = torch.zeros(num_text, model.visual.output_dim)
            s_embed_dim = 577 if "PT" not in args.source_model else 577 + vision_ctx
            s_image_embeds = torch.zeros(num_image, s_embed_dim, 768)

            s_text_embed_dim = 30 
            s_text_embeds = torch.zeros(num_text, s_text_embed_dim, 768)
            s_text_atts = torch.zeros(num_text, s_text_embed_dim).long()

        if args.target_model is not None:
            if args.target_model in FUSED_MODELS:
                t_image_feats = torch.zeros(num_image, config["embed_dim"])
                t_image_embeds = torch.zeros(num_image, 577, 768)
                t_text_feats = torch.zeros(num_text, config["embed_dim"])
                t_text_embeds = torch.zeros(num_text, 30, 768)
                t_text_atts = torch.zeros(num_text, 30).long()
            else:
                t_image_feats = torch.zeros(num_image, t_model.visual.output_dim)
                t_text_feats = torch.zeros(num_text, t_model.visual.output_dim)
                t_image_embeds = torch.zeros(num_image, 577, 768)
                t_text_embeds = torch.zeros(num_text, 30, 768)
                t_text_atts = torch.zeros(num_text, 30).long()

        # for visualization
        adv_images_list = None
        adv_texts_list = []

        if args.attack == "SGA":
            if args.scales is not None:
                scales = [float(itm) for itm in args.scales.split(",")]
                print(scales)
            else:
                scales = None

        print("Forward")
        n = 0
        n_texts = 0
        for batch_idx, (images, texts_group, images_ids, text_ids_groups) in enumerate(
            data_loader
        ):
            if n > max_n:
                print("Stop iterations at n:", n)
                break
            print(f"--------------------> batch:{batch_idx}/{len(data_loader)}")
            texts_ids = []
            txt2img = []
            texts = []
            for i in range(len(texts_group)):
                texts += texts_group[i]
                texts_ids += text_ids_groups[i]
                txt2img += [i] * len(text_ids_groups[i])

            images = images.to(device)

            # print(images.shape, len(texts), texts[0])

            if args.attack == "SGA":
                adv_images, adv_texts = attacker.attack(
                    images, texts, txt2img, device=device, max_lemgth=30, scales=scales
                )
            elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack"]:
                if args.attack_fused_emb:
                    # Note: CoAttack for fused embeddings requires paired input (image, text).
                    B = len(images)
                    repeated_images = torch.stack([im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0)
                    max_length = 1e3 if args.source_model in FUSED_MODELS else 77
                    adv_images, adv_texts = multi_attacker.run(
                        repeated_images,
                        texts,
                        adv=adv_mode,
                        num_iters=config["num_iters"],
                        alpha=args.alpha,
                    )
                    indice = np.cumsum(img2txt_n) - 1
                    adv_images = adv_images[indice]
                    # print(len(adv_images), len(adv_texts))
                else:
                    # Note: CoAttack for fused embeddings requires paired input (image, text).
                    # print("Run Co-Attack for separated embeddings.")
                    B = len(images)
                    img2txt_n = [txt2img.count(i) for i in range(B)]
                    repeated_images = torch.stack([im for i, (im, n) in enumerate(zip(images, img2txt_n)) for _ in range(n)], dim=0)
                    max_length = 1e3 if args.source_model in FUSED_MODELS else 77
                    adv_images, adv_texts = multi_attacker.run_before_fusion(
                        repeated_images,
                        texts,
                        adv=adv_mode,
                        num_iters=config["num_iters"],
                        alpha=args.alpha,
                        max_length=max_length,
                    )
                    indice = np.cumsum(img2txt_n) - 1
                    adv_images = adv_images[indice]
                    
            elif args.attack == "Clean":
                adv_images = images
                adv_texts = texts
            else:
                raise ValueError(f"Invalid attack mode: {args.attack}")

            np_adv_images = adv_images.cpu().numpy().transpose(0, 2, 3, 1)
            if adv_images_list is None:
                adv_images_list = np_adv_images
            else:
                adv_images_list = np.concatenate(
                    [adv_images_list, np_adv_images], axis=0
                )
            adv_texts_list += adv_texts

            with torch.no_grad():
                adv_images_norm = images_normalize(adv_images)
                adv_texts_input = tokenizer(
                    adv_texts,
                    padding="max_length",
                    truncation=True,
                    max_length=30,
                    return_tensors="pt",
                ).to(device)
                if args.source_model in FUSED_MODELS:
                    s_output_img = model.inference_image(adv_images_norm)
                    s_output_txt = model.inference_text(adv_texts_input)
                    s_image_feats[images_ids] = s_output_img["image_feat"].cpu().detach()
                    s_image_embeds[images_ids] = s_output_img["image_embed"].cpu().detach()
                    s_text_feats[texts_ids] = s_output_txt["text_feat"].cpu().detach()
                    s_text_embeds[texts_ids] = s_output_txt["text_embed"].cpu().detach()
                    s_text_atts[texts_ids] = adv_texts_input.attention_mask.cpu().detach()
                else:
                    output = model.inference(adv_images_norm, adv_texts)
                    s_image_feats[images_ids] = output["image_feat"].cpu().float().detach()
                    s_text_feats[texts_ids] = output["text_feat"].cpu().float().detach()


                if args.target_model is not None:
                    t_adv_img_list = []
                    for itm in adv_images:
                        t_adv_img_list.append(t_test_transform(itm))
                    t_adv_imgs = torch.stack(t_adv_img_list, 0).to(device)

                    t_adv_images_norm = images_normalize(t_adv_imgs)
                    adv_texts_input = tokenizer(
                        adv_texts,
                        padding="max_length",
                        truncation=True,
                        max_length=30,
                        return_tensors="pt",
                    ).to(device)
                    if args.target_model in FUSED_MODELS:
                        t_output_img = t_model.inference_image(t_adv_images_norm)
                        t_output_txt = t_model.inference_text(adv_texts_input)
                        t_image_feats[images_ids] = (
                            t_output_img["image_feat"].cpu().detach()
                        )
                        t_image_embeds[images_ids] = (
                            t_output_img["image_embed"].cpu().detach()
                        )
                        t_text_feats[texts_ids] = t_output_txt["text_feat"].cpu().detach()
                        t_text_embeds[texts_ids] = t_output_txt["text_embed"].cpu().detach()
                        t_text_atts[texts_ids] = (
                            adv_texts_input.attention_mask.cpu().detach()
                        )
                    else:
                        t_output = t_model.inference(t_adv_images_norm, adv_texts)
                        t_image_feats[images_ids] = t_output["image_feat"].cpu().float().detach()
                        t_text_feats[texts_ids] = t_output["text_feat"].cpu().float().detach()

            n += len(images)
            n_texts += len(texts)

        # if not all the data is used
        s_image_feats = s_image_feats[:n]
        s_image_embeds = s_image_embeds[:n]
        s_text_feats = s_text_feats[:n_texts]
        s_text_embeds = s_text_embeds[:n_texts]
        s_text_atts = s_text_atts[:n_texts]

        if args.target_model is not None:
            t_image_feats = t_image_feats[:n]
            t_image_embeds = t_image_embeds[:n]
            t_text_feats = t_text_feats[:n_texts]
            t_text_embeds = t_text_embeds[:n_texts]
            t_text_atts = t_text_atts[:n_texts]

        # save the features
        torch.save(s_image_feats, os.path.join(SAVE_FEAT_DIR, "s_image_feats.pth"))
        torch.save(s_image_embeds, os.path.join(SAVE_FEAT_DIR, "s_image_embeds.pth"))
        torch.save(s_text_feats, os.path.join(SAVE_FEAT_DIR, "s_text_feats.pth"))
        torch.save(s_text_embeds, os.path.join(SAVE_FEAT_DIR, "s_text_embeds.pth"))
        torch.save(s_text_atts, os.path.join(SAVE_FEAT_DIR, "s_text_atts.pth"))

        if args.target_model is not None:
            torch.save(
                t_image_feats, os.path.join(SAVE_FEAT_DIR_T, "t_image_feats.pth")
            )
            torch.save(
                t_image_embeds, os.path.join(SAVE_FEAT_DIR_T, "t_image_embeds.pth")
            )
            torch.save(t_text_feats, os.path.join(SAVE_FEAT_DIR_T, "t_text_feats.pth"))
            torch.save(
                t_text_embeds, os.path.join(SAVE_FEAT_DIR_T, "t_text_embeds.pth")
            )
            torch.save(t_text_atts, os.path.join(SAVE_FEAT_DIR_T, "t_text_atts.pth"))

        # save imgs
        SAVE_IMG_DIR = os.path.join(LOG_BASE_DIR, "imgs")
        os.makedirs(SAVE_IMG_DIR, exist_ok=True)
        np.save(os.path.join(SAVE_IMG_DIR, "adv_images_list.npy"), adv_images_list)
        with open(os.path.join(SAVE_IMG_DIR, "adv_texts_list.json"), "w") as f:
            json.dump(adv_texts_list, f)

    print(s_image_feats.shape, s_text_feats.shape)

    # get matching score
    if args.source_model in FUSED_MODELS:
        print(f"Get matching score with fusion... ({args.source_model_name})")
        s_score_matrix_i2t, s_score_matrix_t2i = retrieval_score(
            model,
            s_image_feats,
            s_image_embeds,
            s_text_feats,
            s_text_embeds,
            s_text_atts,
            num_image,
            num_text,
            device=device,
        )
    else:
        print(f"Get matching score without fusion... ({args.source_model_name})")
        s_sims_matrix = s_image_feats @ s_text_feats.t()
        s_score_matrix_i2t = s_sims_matrix
        s_score_matrix_t2i = s_sims_matrix.t()
    s_score_matrix_i2t = s_score_matrix_i2t.cpu().numpy()
    s_score_matrix_t2i = s_score_matrix_t2i.cpu().numpy()

    # target model
    if args.target_model is not None:
        if args.target_model in FUSED_MODELS:
            print(f"Get matching score with fusion... ({args.target_model_name})")
            t_score_matrix_i2t, t_score_matrix_t2i = retrieval_score(
                t_model,
                t_image_feats,
                t_image_embeds,
                t_text_feats,
                t_text_embeds,
                t_text_atts,
                num_image,
                num_text,
                device=device,
            )
        else:
            print(f"Get matching score without fusion... ({args.target_model_name})")
            t_sims_matrix = t_image_feats @ t_text_feats.t()
            t_score_matrix_i2t = t_sims_matrix
            t_score_matrix_t2i = t_sims_matrix.t()
        t_score_matrix_i2t = t_score_matrix_i2t.cpu().numpy()
        t_score_matrix_t2i = t_score_matrix_t2i.cpu().numpy()
    else:
        t_score_matrix_i2t = None
        t_score_matrix_t2i = None

    return (
        adv_images_list,
        adv_texts_list,
        s_score_matrix_i2t,
        s_score_matrix_t2i,
        t_score_matrix_i2t,
        t_score_matrix_t2i
    )


@torch.no_grad()
def retrieval_score(
    model,
    image_feats,
    image_embeds,
    text_feats,
    text_embeds,
    text_atts,
    num_image,
    num_text,
    device=None,
):
    """
    This is only used for fused models, such as ALBEF or TCL.
    It is used to get the matching score with fusion for fused models, such as ALBEF or TCL.
    For aligned models, such as CLIP, this function is not used.
    """
    if device is None:
        device = image_embeds.device

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Evaluation Direction Similarity With Attack:"

    sims_matrix = image_feats @ text_feats.t()
    score_matrix_i2t = torch.full((num_image, num_text), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config["k_test"], dim=0)

        encoder_output = image_embeds[i].repeat(config["k_test"], 1, 1).to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
            device
        )
        output = model.text_encoder(
            encoder_embeds=text_embeds[topk_idx].to(device),
            attention_mask=text_atts[topk_idx].to(device),
            encoder_hidden_states=encoder_output,
            encoder_attention_mask=encoder_att,
            return_dict=True,
            mode="fusion",
        )
        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_i2t[i, topk_idx] = score

    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((num_text, num_image), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config["k_test"], dim=0)
        encoder_output = image_embeds[topk_idx].to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(
            device
        )
        output = model.text_encoder(
            encoder_embeds=text_embeds[i].repeat(config["k_test"], 1, 1).to(device),
            attention_mask=text_atts[i].repeat(config["k_test"], 1).to(device),
            encoder_hidden_states=encoder_output,
            encoder_attention_mask=encoder_att,
            return_dict=True,
            mode="fusion",
        )
        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_t2i[i, topk_idx] = score

    return score_matrix_i2t, score_matrix_t2i


@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, img2txt, txt2img, model_name):
    """
    Evaluate the R@K of the model, based on similarity matrix.

    Imported from:
    https://github.com/adversarial-for-goodness/Co-Attack/blob/96f2c0ebf743f600860ed7aa78c14b38c8883521/RetrievalEval.py#L136-L179
    """
    ########################################################
    # Images->Text
    ranks = np.zeros(scores_i2t.shape[0])
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            if i not in inds:
                continue
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    # Save the rank index
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_tr1_rank_index.npy"), np.where(ranks < 1)[0])
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_tr5_rank_index.npy"), np.where(ranks < 5)[0])
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_tr10_rank_index.npy"), np.where(ranks < 10)[0])

    ########################################################
    # Text->Images
    ranks = np.zeros(scores_t2i.shape[0])

    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        if txt2img[index] not in inds:
            ranks[index] = len(inds)
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    # Save the rank index
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_ir1_rank_index.npy"), np.where(ranks < 1)[0])
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_ir5_rank_index.npy"), np.where(ranks < 5)[0])
    np.save(os.path.join(RANK_INDEX_DIR, f"{model_name}_ir10_rank_index.npy"), np.where(ranks < 10)[0])

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    eval_result = {
        "txt_r1": tr1,
        "txt_r5": tr5,
        "txt_r10": tr10,
        "txt_r_mean": tr_mean,
        "img_r1": ir1,
        "img_r5": ir5,
        "img_r10": ir10,
        "img_r_mean": ir_mean,
        "r_mean": r_mean,
    }
    return eval_result


@torch.no_grad()
def itm_eval_asr(scores_i2t, scores_t2i, img2txt, txt2img, model_name):
    """
    Evaluate the ASR of the model.
    To calculate ASR, the original rank index should be saved in the original_rank_index_path.

    Imported from:
    https://github.com/Zoky-2020/SGA/blob/f49e6926cd13148e49bdbc2c3e537660e096230c/eval_albef2clip-vit_flickr.py#L163-L227
    """
    # Images->Text
    ranks = np.zeros(scores_i2t.shape[0])
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            if i not in inds:
                continue
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    after_attack_tr1 = np.where(ranks < 1)[0]
    after_attack_tr5 = np.where(ranks < 5)[0]
    after_attack_tr10 = np.where(ranks < 10)[0]

    # load the original rank index
    origin_tr1 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_tr1_rank_index.npy"))
    origin_tr5 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_tr5_rank_index.npy"))
    origin_tr10 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_tr10_rank_index.npy"))

    asr_tr1 = round(
        100.0 * len(np.setdiff1d(origin_tr1, after_attack_tr1)) / len(origin_tr1), 2
    )
    asr_tr5 = round(
        100.0 * len(np.setdiff1d(origin_tr5, after_attack_tr5)) / len(origin_tr5), 2
    )
    asr_tr10 = round(
        100.0 * len(np.setdiff1d(origin_tr10, after_attack_tr10)) / len(origin_tr10), 2
    )

    # Text->Images
    ranks = np.zeros(scores_t2i.shape[0])
    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        if txt2img[index] not in inds:
            ranks[index] = len(inds)
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    after_attack_ir1 = np.where(ranks < 1)[0]
    after_attack_ir5 = np.where(ranks < 5)[0]
    after_attack_ir10 = np.where(ranks < 10)[0]

    # load the original rank index
    origin_ir1 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_ir1_rank_index.npy"))
    origin_ir5 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_ir5_rank_index.npy"))
    origin_ir10 = np.load(os.path.join(CLEAN_RANK_INDEX_DIR, f"{model_name}_ir10_rank_index.npy"))

    asr_ir1 = round(
        100.0 * len(np.setdiff1d(origin_ir1, after_attack_ir1)) / len(origin_ir1), 2
    )
    asr_ir5 = round(
        100.0 * len(np.setdiff1d(origin_ir5, after_attack_ir5)) / len(origin_ir5), 2
    )
    asr_ir10 = round(
        100.0 * len(np.setdiff1d(origin_ir10, after_attack_ir10)) / len(origin_ir10), 2
    )

    eval_result = {
        "txt_r1_ASR (txt_r1)": f"{asr_tr1}({tr1})",
        "txt_r5_ASR (txt_r5)": f"{asr_tr5}({tr5})",
        "txt_r10_ASR (txt_r10)": f"{asr_tr10}({tr10})",
        "img_r1_ASR (img_r1)": f"{asr_ir1}({ir1})",
        "img_r5_ASR (img_r5)": f"{asr_ir5}({ir5})",
        "img_r10_ASR (img_r10)": f"{asr_ir10}({ir10})",
    }
    return eval_result


def vis_retrieval_results(
    adv_images_list,
    adv_texts_list,
    scores_i2t,
    scores_t2i,
    img2txt,
    txt2img,
    show_n=5,
    top_k=5,
):
    """
    Visualize the retrieval results.
    """

    # Images->Text
    img_idx2txt = {}
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # top N results
        txt_list = [adv_texts_list[i] for i in inds[:top_k]]
        # worst N results
        inds_worst = np.argsort(score)
        txt_list_worst = [adv_texts_list[i] for i in inds_worst[:top_k]]

        s = ""
        for i, txt in enumerate(txt_list):
            s += f"Top {i+1}: {txt}\n"
        s += "\n"
        for i, txt in enumerate(txt_list_worst):
            s += f"Worst {i+1}: {txt}\n"

        img_idx2txt[index] = s

        if len(img_idx2txt) >= show_n:
            break

    # Text->Images
    txt2img_idx_dict = {}
    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        # top N results
        img_idx_list = [i for i in inds[:top_k]]
        # worst N results
        inds_worst = np.argsort(score)
        img_idx_list_worst = [i for i in inds_worst[:top_k]]

        txt = adv_texts_list[index]
        img_idx_dict = {}
        for i, img_idx in enumerate(img_idx_list):
            img_idx_dict[f"Top {i+1}"] = img_idx
        for i, img_idx in enumerate(img_idx_list_worst):
            img_idx_dict[f"Worst {i+1}"] = img_idx

        txt2img_idx_dict[txt] = img_idx_dict

        if len(txt2img_idx_dict) >= show_n:
            break

    # show Images->Text
    # figure with
    #   show_n rows, 2 columns
    #   left Image, right text
    plt.subplots(show_n, 2, figsize=(20, 20))
    for i, (img_idx, txt) in enumerate(img_idx2txt.items()):
        numpy_image = (adv_images_list[img_idx] * 255).astype(np.uint8)
        resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
        resized_img = resized_img.astype(np.uint8)
        plt.subplot(show_n, 2, 2 * i + 1)
        plt.imshow(resized_img)
        plt.axis("off")
        plt.subplot(show_n, 2, 2 * i + 2)
        plt.text(0.5, 0.5, txt, ha="center", va="center", wrap=True)
        plt.axis("off")
    # save fig
    save_fig_path = os.path.join(VIS_DIR, "img2txt.png")
    plt.savefig(save_fig_path, bbox_inches="tight")
    plt.show()

    # show Text->Images
    # figure with
    #   len(txt2img_idx_dict) rows, 1 + top_k*2 columns
    #   left text, right top_k*2 images
    rows = len(txt2img_idx_dict)
    cols = 1 + top_k * 2
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))

    for i, (txt, img_idx_dict) in enumerate(txt2img_idx_dict.items()):
        axes[i, 0].text(0, 1.1, txt, ha="left", va="top", wrap=True)
        axes[i, 0].axis("off")
        for j, (title, img_idx) in enumerate(img_idx_dict.items()):
            # print(title, img_idx, j)
            numpy_image = (adv_images_list[img_idx] * 255).astype(np.uint8)
            resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
            resized_img = resized_img.astype(np.uint8)
            axes[i, j + 1].imshow(resized_img)
            axes[i, j + 1].axis("off")
            # put text above the image
            axes[i, j + 1].text(
                0.5, 1.2, title, ha="center", va="center", wrap=True
            )
            axes[i, j].axis("off")
    save_fig_path = os.path.join(VIS_DIR, "txt2img.png")
    plt.savefig(save_fig_path, bbox_inches="tight")
    plt.show()


# def load_model(config, model_name, model_ckpt, text_encoder, device):
#     tokenizer = BertTokenizer.from_pretrained(text_encoder)
#     ref_model = BertForMaskedLM.from_pretrained(text_encoder)
#     if model_name in FUSED_MODELS:
#         model = ALBEF(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
#         checkpoint = torch.load(model_ckpt, map_location="cpu")
#     ### load checkpoint
#     else:
#         print("CLIP model")
#         model, preprocess = clip.load(model_name, device=device)
#         model.set_tokenizer(tokenizer)
#         return model, ref_model, tokenizer
    
#     try:
#         state_dict = checkpoint["model"]
#     except:
#         state_dict = checkpoint

#     if model_name == "TCL":
#         pos_embed_reshaped = interpolate_pos_embed(
#             state_dict["visual_encoder.pos_embed"], model.visual_encoder
#         )
#         state_dict["visual_encoder.pos_embed"] = pos_embed_reshaped
#         m_pos_embed_reshaped = interpolate_pos_embed(
#             state_dict["visual_encoder_m.pos_embed"], model.visual_encoder_m
#         )
#         state_dict["visual_encoder_m.pos_embed"] = m_pos_embed_reshaped

#     for key in list(state_dict.keys()):
#         if "bert" in key:
#             encoder_key = key.replace("bert.", "")
#             state_dict[encoder_key] = state_dict[key]
#             del state_dict[key]
#     model.load_state_dict(state_dict, strict=False)

#     if "_PT" in model_name:
#         model.wrap_vision_encoder_with_prompter(prompt_config)
#         model.to(device)
#         model.eval()

#     return model, ref_model, tokenizer


def eval_asr(
    model,
    ref_model,
    tokenizer,
    t_model,
    t_ref_model,
    t_tokenizer,
    t_test_transform,
    data_loader,
    device,
    args,
    config,
):
    model = model.to(device)
    ref_model = ref_model.to(device)

    t_model = t_model.to(device)
    t_ref_model = t_ref_model.to(device)

    print("Start eval")
    start_time = time.time()

    adv_images_list, adv_texts_list, score_i2t, score_t2i, t_score_i2t, t_score_t2i = (
        retrieval_eval(
            args,
            model,
            ref_model,
            t_model,
            t_ref_model,
            t_test_transform,
            data_loader,
            tokenizer,
            t_tokenizer,
            device,
            config,
        )
    )

    # vis
    vis_retrieval_results(
        adv_images_list,
        adv_texts_list,
        score_i2t,
        score_t2i,
        data_loader.dataset.img2txt,
        data_loader.dataset.txt2img,
        show_n=5,
        top_k=5,
    )

    # eval
    t_result = {}
    t_result_asr = {}
    if args.target_model is not None:
        t_result = itm_eval(
            t_score_i2t,
            t_score_t2i,
            data_loader.dataset.img2txt,
            data_loader.dataset.txt2img,
            args.target_model_name,
        )
        t_result_asr = {}
        t_result_asr = itm_eval_asr(
            t_score_i2t,
            t_score_t2i,
            data_loader.dataset.img2txt,
            data_loader.dataset.txt2img,
            args.target_model_name,
        )
        print("Performance on {}: \n {}".format(args.target_model_name, t_result))

    result = itm_eval(
        score_i2t,
        score_t2i,
        data_loader.dataset.img2txt,
        data_loader.dataset.txt2img,
        args.source_model_name,
    )
    result_asr = {}
    result_asr = itm_eval_asr(
        score_i2t,
        score_t2i,
        data_loader.dataset.img2txt,
        data_loader.dataset.txt2img,
        args.source_model_name,
    )
    print("Performance on {}: \n {}".format(args.source_model_name, result))

    torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print("Evaluate time {}".format(total_time_str))

    return result, t_result, result_asr, t_result_asr


def get_data(args, config, model):
    if args.source_model in FUSED_MODELS:
        s_test_transform = transforms.Compose(
            [
                transforms.Resize(
                    (config["image_res"], config["image_res"]), interpolation=Image.BICUBIC
                ),
                transforms.ToTensor(),
            ]
        )
    else:
        n_px = model.visual.input_resolution
        s_test_transform = transforms.Compose(
            [
                transforms.Resize(n_px, interpolation=Image.BICUBIC),
                transforms.CenterCrop(n_px),
                transforms.ToTensor(),
            ]
        )

    if args.target_model is not None:
        if args.target_model in FUSED_MODELS:
            t_test_transform = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize(
                        (config["image_res"], config["image_res"]), interpolation=Image.BICUBIC
                    ),
                    transforms.ToTensor(),
                ]
            )
        else:
            t_n_px = t_model.visual.input_resolution
            t_test_transform = transforms.Compose(
                [
                    transforms.Resize(t_n_px, interpolation=Image.BICUBIC),
                    transforms.CenterCrop(t_n_px),
                ]
            )
    else:
        t_test_transform = s_test_transform

    test_dataset = paired_dataset(
        config["test_file"], s_test_transform, config["image_root"]
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        num_workers=4,
        collate_fn=test_dataset.collate_fn,
    )
    return test_loader, s_test_transform, t_test_transform


def main(args, config):
    device = torch.device("cuda")

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    print("Creating Source Model")
    model, ref_model, tokenizer = load_model(
        config, args.source_model, args.source_ckpt, args.source_text_encoder, device, prompt_config=prompt_config
    )
    if args.target_model is not None:
        print("Creating Target Model")
        t_model, t_ref_model, t_tokenizer = load_model(
            config, args.target_model, args.target_ckpt, args.target_text_encoder, device, prompt_config=prompt_config
        )
    else:
        t_model, t_ref_model, t_tokenizer = model, ref_model, tokenizer

    #### Dataset ####
    print("Creating dataset")
    test_loader, s_test_transform, t_test_transform = get_data(args, config, model)

    result, t_result, result_asr, t_result_asr = eval_asr(
        model,
        ref_model,
        tokenizer,
        t_model,
        t_ref_model,
        t_tokenizer,
        t_test_transform,
        test_loader,
        device,
        args,
        config,
    )

    # save result
    result_path = os.path.join(LOG_BASE_DIR, "result.json")
    with open(result_path, "w") as f:
        json.dump(
            {
                "whitebox": result,
                f"transfer_{args.target_model_name}": t_result,
                "whitebox_ASR": result_asr,
                f"transfer_{args.target_model_name}_ASR": t_result_asr,
            },
            f,
            indent=4,
        )

    # save args
    args_path = os.path.join(LOG_BASE_DIR, "args.json")
    with open(args_path, "w") as f:
        json.dump(vars(args), f, indent=4)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr.yaml")
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument("--batch_size", default=8, type=int)

    parser.add_argument("--source_model", default="ALBEF", type=str) # model architecture
    parser.add_argument("--source_model_name", default=None, type=str) # id for the model
    parser.add_argument("--source_text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--source_ckpt", default=None, type=str)

    parser.add_argument("--target_model", default=None, type=str) # model architecture
    parser.add_argument("--target_model_name", default=None, type=str) # id for the model
    parser.add_argument("--target_text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--target_ckpt", default=None, type=str)

    # parser.add_argument("--original_rank_index_path", default=None, type=str)

    # SGA attack
    parser.add_argument("--scales", type=str, default="0.5,0.75,1.25,1.5")

    # evaluation config
    parser.add_argument("--cls", default=False, action="store_true")
    parser.add_argument("--is_force_recompute_feats", default=False, action="store_true")
    parser.add_argument("--max_n", default=10000, type=int)
    parser.add_argument("--attack_fused_emb", default=False, action="store_true")
    parser.add_argument(
        "--attack",
        default="SGA",
        type=str,
        choices=["SGA", "Co-Attack", "Sep-Attack", "PGD", "BERT", "Clean"],
    )
    parser.add_argument("--eval_save_dir", default="../eval_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)

    parser.add_argument("--prompt_config", default=None, type=str)

    args = parser.parse_args()

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

    FUSED_MODELS = [
        "ALBEF",
        "ALBEF_PT",
        "TCL",
        "TCL_PT",
    ]

    # modify save dir
    if args.max_n < 200:
        print(
            "Warning: max_n is less than 200, which may not be enough for evaluation."
        )
        args.eval_save_dir = os.path.join(args.eval_save_dir, f"max_n_{args.max_n}")

    if args.source_model in FUSED_MODELS:
        print("ALBEF or TCL model, set attack_fused_emb == True, cls == False. (Multi-full)")
        args.attack_fused_emb = True
        args.cls = False
    else:
        print("CLIP model: (Uni-cls)")
        args.attack_fused_emb = False
        args.cls = False

    
    # prompt config
    prompt_config = None
    vision_ctx = 0
    if "PT" in args.source_model:
        if args.prompt_config is not None:
            with open(args.prompt_config, "r") as f:
                prompt_config = yaml.load(f, Loader=yaml.Loader)
            vision_ctx = prompt_config["vision_ctx"]
            language_ctx = prompt_config["language_ctx"]
            print("load prompt config: ", args.prompt_config)
            prompt_config = edict(prompt_config)
        else:
            try:
                source_ckpt_dir = os.path.dirname(args.source_ckpt)
                prompt_config_path = glob.glob(os.path.join(source_ckpt_dir, "prompt_config.json"))[0]
                with open(prompt_config_path, "r") as f:
                    prompt_config = json.load(f)
            except:
                prompt_config_path = "/data/robust_crossmodal-retrieval/configs/prompt_tuning/base5.yaml"
                with open(prompt_config_path, "r") as f:
                    prompt_tuning_config = yaml.load(f, Loader=yaml.Loader)
                    prompt_config = edict(prompt_tuning_config)["PROMPT"]
            vision_ctx = prompt_config["NUM_TOKENS"]
            language_ctx = 0

    ATTACK = args.attack
    for attack in ["Clean", ATTACK]:
        args.attack = attack
        print(f"Start eval for {args.source_model_name} with {args.attack}")

        # dir to save the results
        LOG_BASE_DIR = os.path.join(args.eval_save_dir, args.source_model_name, args.attack)
        SAVE_FEAT_DIR = os.path.join(LOG_BASE_DIR, "features")
        SAVE_IMG_DIR = os.path.join(LOG_BASE_DIR, "imgs")
        VIS_DIR = os.path.join(LOG_BASE_DIR, "vis")
        RANK_INDEX_DIR = os.path.join(LOG_BASE_DIR, "rank_index")
        CLEAN_RANK_INDEX_DIR = os.path.join(args.eval_save_dir, args.source_model_name, "Clean", "rank_index")
        
        if attack == "Clean" and os.path.exists(os.path.join(CLEAN_RANK_INDEX_DIR, f"{args.source_model_name}_tr1_rank_index.npy")):
            print("Original rank index exists, use it.")
            continue

        os.makedirs(SAVE_FEAT_DIR, exist_ok=True)
        os.makedirs(LOG_BASE_DIR, exist_ok=True)
        os.makedirs(SAVE_IMG_DIR, exist_ok=True)
        os.makedirs(VIS_DIR, exist_ok=True)
        os.makedirs(RANK_INDEX_DIR, exist_ok=True)
        os.makedirs(CLEAN_RANK_INDEX_DIR, exist_ok=True)

        if args.target_model is not None:
            SAVE_FEAT_DIR_T = os.path.join(SAVE_FEAT_DIR, args.target_model_name)
            os.makedirs(SAVE_FEAT_DIR_T, exist_ok=True)

        main(args, config)
