
import os
import utils
import copy
import time
import ruamel.yaml as yaml
import numpy as np
import argparse
import random
import time
import datetime
import json
from pathlib import Path

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from transformers import BertForMaskedLM
from torchvision import transforms
from PIL import Image

from models.model_retrieval import ALBEF, ALBEF1, ALBEF_i, ALBEF_t
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
from models import clip

from SEAttacker import Attacker, ImageAttacker, TextAttacker
from dataset import paired_dataset

def retrieval_eval(model_inf, model_infi, model_inft, ref_model, t_model_inf, t_model_infi, t_model_inft, t_ref_model, t_test_transform, data_loader, tokenizer, t_tokenizer, device, args, config):
    model_inf.float()
    model_infi.float()
    model_inft.float()
    model_inf.eval()
    model_infi.eval()
    model_inft.eval()
    ref_model.eval()
    t_model_inf.float()
    t_model_infi.float()
    t_model_inft.float()
    t_model_inf.eval()
    t_model_infi.eval()
    t_model_inft.eval()
    t_ref_model.eval() 

    print('Initialzing...')

    images_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    img_attacker = ImageAttacker(images_normalize, eps=args.eps/255, steps=args.num_steps, step_size=0.5/255) #2

    max_length = 30 if args.source_model in ['ALBEF', 'TCL'] else 77 
    max_length_t = 30 if args.target_model in ['ALBEF', 'TCL'] else 77 
    txt_attacker = TextAttacker(ref_model, tokenizer, cls=False, max_length=max_length, number_perturbation=1,
                                topk=10, threshold_pred_score=0.3)
    attacker = Attacker(model_inf, model_infi, model_inft, img_attacker, txt_attacker)
    
    print('Prepare memory')
    num_text = len(data_loader.dataset.text)
    num_image = len(data_loader.dataset.ann)

    s_feat_dict = {}
    if args.source_model in ['ALBEF', 'TCL']:
        s_feat_dict['s_image_feats'] = torch.zeros(num_image, config['embed_dim'])
        s_feat_dict['s_image_embeds'] = torch.zeros(num_image, 577, 768)
        s_feat_dict['s_text_feats'] = torch.zeros(num_text, config['embed_dim'])
        s_feat_dict['s_text_embeds'] = torch.zeros(num_text, 30, 768)
        s_feat_dict['s_text_atts'] = torch.zeros(num_text, 30).long()
    else:
        s_feat_dict['s_image_feats'] = torch.zeros(num_image, model_inf.visual.output_dim)
        s_feat_dict['s_text_feats'] = torch.zeros(num_text, model_inf.visual.output_dim)

    t_feat_dict = {}
    if args.target_model in ['ALBEF', 'TCL']:
        t_feat_dict['t_image_feats'] = torch.zeros(num_image, config['embed_dim'])
        t_feat_dict['t_image_embeds'] = torch.zeros(num_image, 577, 768)
        t_feat_dict['t_text_feats'] = torch.zeros(num_text, config['embed_dim'])
        t_feat_dict['t_text_embeds'] = torch.zeros(num_text, 30, 768)
        t_feat_dict['t_text_atts'] = torch.zeros(num_text, 30).long()
    else:
        t_feat_dict['t_image_feats'] = torch.zeros(num_image, t_model_inf.module.visual.output_dim)
        t_feat_dict['t_text_feats'] = torch.zeros(num_text, t_model_inf.module.visual.output_dim)

    if args.scales is not None:
        scales = [float(itm) for itm in args.scales.split(',')]
        print(scales)
    else:
        scales = None

    print('Forward')
    for batch_idx, (images, texts_group, images_ids, text_ids_groups) in enumerate(data_loader):
        print(f'--------------------> batch:{batch_idx}/{len(data_loader)}')
        texts_ids = []
        txt2img = []
        texts = []
        for i in range(len(texts_group)):
            texts += texts_group[i]
            texts_ids += text_ids_groups[i]
            txt2img += [i]*len(text_ids_groups[i])

        images = images.to(device)

        adv_images, adv_texts = attacker.attack(images, texts, txt2img, device=device, max_lemgth=max_length, scales=scales) 

  
                                                                                        
        with torch.no_grad():
        
            s_adv_images_norm = images_normalize(adv_images)
            adv_texts_input = tokenizer(adv_texts, padding='max_length', truncation=True, max_length=max_length, 
                                            return_tensors="pt").to(device)
            
            t_adv_img_list = []
            for itm in adv_images:
                t_adv_img_list.append(t_test_transform(itm))
            t_adv_imgs = torch.stack(t_adv_img_list, 0).to(device)            
            
            t_adv_images_norm = images_normalize(t_adv_imgs)
            
            
            if args.source_model in ['ALBEF', 'TCL']:
                            
                s_output_img = model_infi(s_adv_images_norm)
                s_output_txt = model_inft(adv_texts_input.input_ids, adv_texts_input.attention_mask)

                s_feat_dict['s_image_feats'][images_ids] = s_output_img['image_feat'].cpu().detach()
                s_feat_dict['s_image_embeds'][images_ids] = s_output_img['image_embed'].cpu().detach()
                s_feat_dict['s_text_feats'][texts_ids] = s_output_txt['text_feat'].cpu().detach()
                s_feat_dict['s_text_embeds'][texts_ids] = s_output_txt['text_embed'].cpu().detach()
                s_feat_dict['s_text_atts'][texts_ids] = adv_texts_input.attention_mask.cpu().detach()
            else:
                output = model_inf(s_adv_images_norm, adv_texts_input.input_ids)
                s_feat_dict['s_image_feats'][images_ids] = output['image_feat'].cpu().float().detach()
                s_feat_dict['s_text_feats'][texts_ids] = output['text_feat'].cpu().float().detach()
            
            
            adv_texts_input_t = tokenizer(adv_texts, padding='max_length', truncation=True, max_length=max_length_t, return_tensors="pt").to(device)   
            if args.target_model in ['ALBEF', 'TCL']:
                         
                t_output_img = t_model_infi(t_adv_images_norm)
                t_output_txt = t_model_inft(adv_texts_input_t.input_ids, adv_texts_input_t.attention_mask)
                t_feat_dict['t_image_feats'][images_ids] = t_output_img['image_feat'].cpu().detach()
                t_feat_dict['t_image_embeds'][images_ids] = t_output_img['image_embed'].cpu().detach()
                t_feat_dict['t_text_feats'][texts_ids] = t_output_txt['text_feat'].cpu().detach()
                t_feat_dict['t_text_embeds'][texts_ids] = t_output_txt['text_embed'].cpu().detach()
                t_feat_dict['t_text_atts'][texts_ids] = adv_texts_input_t.attention_mask.cpu().detach()
            else:
                output = t_model_inf(t_adv_images_norm, adv_texts_input_t.input_ids)
                t_feat_dict['t_image_feats'][images_ids] = output['image_feat'].cpu().float().detach()
                t_feat_dict['t_text_feats'][texts_ids] = output['text_feat'].cpu().float().detach()
               
    s_score_matrix_i2t = None
    s_score_matrix_t2i = None
    if args.source_model in ['ALBEF', 'TCL']:
        s_score_matrix_i2t, s_score_matrix_t2i = retrieval_score(model_inft, s_feat_dict['s_image_feats'], s_feat_dict['s_image_embeds'], s_feat_dict['s_text_feats'],
                                                        s_feat_dict['s_text_embeds'], s_feat_dict['s_text_atts'], num_image, num_text, device=device)
        s_score_matrix_i2t = s_score_matrix_i2t.cpu().numpy()
        s_score_matrix_t2i = s_score_matrix_t2i.cpu().numpy()
    else:
        s_sims_matrix = s_feat_dict['s_image_feats'] @ s_feat_dict['s_text_feats'].t()
        s_score_matrix_i2t = s_sims_matrix.cpu().numpy()
        s_score_matrix_t2i = s_sims_matrix.t().cpu().numpy()

    t_score_matrix_i2t = None
    t_score_matrix_t2i = None
    if args.target_model in ['ALBEF', 'TCL']:
        t_score_matrix_i2t, t_score_matrix_t2i = retrieval_score(t_model_inft, t_feat_dict['t_image_feats'], t_feat_dict['t_image_embeds'], t_feat_dict['t_text_feats'],
                                                        t_feat_dict['t_text_embeds'], t_feat_dict['t_text_atts'], num_image, num_text, device=device)
        t_score_matrix_i2t, t_score_matrix_t2i = t_score_matrix_i2t.cpu().numpy(), t_score_matrix_t2i.cpu().numpy()  
    else:
        t_sims_matrix = t_feat_dict['t_image_feats'] @ t_feat_dict['t_text_feats'].t()
        t_score_matrix_i2t = t_sims_matrix.cpu().numpy()
        t_score_matrix_t2i = t_sims_matrix.t().cpu().numpy()
    
    return s_score_matrix_i2t, s_score_matrix_t2i, t_score_matrix_i2t, t_score_matrix_t2i

@torch.no_grad()
def retrieval_score(model, image_feats, image_embeds, text_feats, text_embeds, text_atts, num_image, num_text, device=None):
    if device is None:
        device = image_embeds.device

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Evaluation Direction Similarity With Bert Attack:'

    sims_matrix = image_feats @ text_feats.t()
    score_matrix_i2t = torch.full((num_image, num_text), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)

        encoder_output = image_embeds[i].repeat(config['k_test'], 1, 1).to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
        output = model.module.text_encoder(encoder_embeds=text_embeds[topk_idx].to(device),
                                    attention_mask=text_atts[topk_idx].to(device),
                                    encoder_hidden_states=encoder_output,
                                    encoder_attention_mask=encoder_att,
                                    return_dict=True,
                                    mode='fusion'
                                    )
        score = model.module.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_i2t[i, topk_idx] = score

    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((num_text, num_image), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
        encoder_output = image_embeds[topk_idx].to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
        output = model.module.text_encoder(encoder_embeds=text_embeds[i].repeat(config['k_test'], 1, 1).to(device),
                                    attention_mask=text_atts[i].repeat(config['k_test'], 1).to(device),
                                    encoder_hidden_states=encoder_output,
                                    encoder_attention_mask=encoder_att,
                                    return_dict=True,
                                    mode='fusion'
                                    )
        score = model.module.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_t2i[i, topk_idx] = score

    return score_matrix_i2t, score_matrix_t2i


@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, img2txt, txt2img, model_name, image_encoder):
    # Images->Text
    ranks = np.zeros(scores_i2t.shape[0])
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)


    after_attack_tr1 = np.where(ranks < 1)[0]
    after_attack_tr5 = np.where(ranks < 5)[0]
    after_attack_tr10 = np.where(ranks < 10)[0]
    
    original_rank_index_path = args.original_rank_index_path

    if model_name == 'CLIP': 
        if image_encoder in ['RN50', 'RN101']:
            #print(args.target_image_encoder)
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr10_rank_index.npy')
        else:    
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr10_rank_index.npy')
          
    elif model_name == 'ALBEF' or model_name == 'TCL':
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_tr10_rank_index.npy')


    asr_tr1 = round(100.0 * len(np.setdiff1d(origin_tr1, after_attack_tr1)) / len(origin_tr1), 2) 
    asr_tr5 = round(100.0 * len(np.setdiff1d(origin_tr5, after_attack_tr5)) / len(origin_tr5), 2)
    asr_tr10 = round(100.0 * len(np.setdiff1d(origin_tr10, after_attack_tr10)) / len(origin_tr10), 2)

    # Text->Images
    ranks = np.zeros(scores_t2i.shape[0])
    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]


    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    after_attack_ir1 = np.where(ranks < 1)[0]
    after_attack_ir5 = np.where(ranks < 5)[0]
    after_attack_ir10 = np.where(ranks < 10)[0]

    if model_name == 'CLIP': 
        if image_encoder in ['RN50', 'RN101']:
            origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir1_rank_index.npy')
            origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir5_rank_index.npy')
            origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir10_rank_index.npy')
        else:    
            origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir1_rank_index.npy')
            origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir5_rank_index.npy')
            origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir10_rank_index.npy')
          
    elif model_name == 'ALBEF' or model_name == 'TCL':
        origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_ir1_rank_index.npy')
        origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_ir5_rank_index.npy')
        origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_ir10_rank_index.npy') 


    asr_ir1 = round(100.0 * len(np.setdiff1d(origin_ir1, after_attack_ir1)) / len(origin_ir1), 2) 
    asr_ir5 = round(100.0 * len(np.setdiff1d(origin_ir5, after_attack_ir5)) / len(origin_ir5), 2)
    asr_ir10 = round(100.0 * len(np.setdiff1d(origin_ir10, after_attack_ir10)) / len(origin_ir10), 2)


    eval_result = {'txt_r1_ASR (txt_r1)': f'{asr_tr1}({tr1})',
                   'txt_r5_ASR (txt_r5)': f'{asr_tr5}({tr5})',
                   'txt_r10_ASR (txt_r10)': f'{asr_tr10}({tr10})',
                   'img_r1_ASR (img_r1)': f'{asr_ir1}({ir1})',
                   'img_r5_ASR (img_r5)': f'{asr_ir5}({ir5})',
                   'img_r10_ASR (img_r10)': f'{asr_ir10}({ir10})'}
    return eval_result


def load_model(model_name, model_ckpt, image_encoder, text_encoder, device):
    tokenizer = BertTokenizer.from_pretrained(text_encoder)
    ref_model = BertForMaskedLM.from_pretrained(text_encoder)    
    if model_name in ['ALBEF', 'TCL']:
        model_inf = ALBEF1(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        model_infi = ALBEF_i(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        model_inft = ALBEF_t(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        checkpoint = torch.load(model_ckpt, map_location='cpu')
    ### load checkpoint
    else:
        model_inf, model_infi, model_inft, preprocess, _, _ = clip.load1(image_encoder, device=device)
        model_inf.set_tokenizer(tokenizer)
        model_infi.set_tokenizer(tokenizer)
        model_inft.set_tokenizer(tokenizer)
        return model_inf, model_infi, model_inft, ref_model, tokenizer
    
    try:
        state_dict = checkpoint['model']
    except:
        state_dict = checkpoint

    if model_name == 'TCL':
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model_inf.visual_encoder)         
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
        m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model_inf.visual_encoder_m)   
        state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 

    for key in list(state_dict.keys()):
        if 'bert' in key:
            encoder_key = key.replace('bert.', '')
            state_dict[encoder_key] = state_dict[key]
            del state_dict[key]
    model_inf.load_state_dict(state_dict, strict=False)
    model_infi.load_state_dict(state_dict, strict=False)
    model_inft.load_state_dict(state_dict, strict=False)
    
    return model_inf, model_infi, model_inft, ref_model, tokenizer


def eval_asr(model_inf, model_infi, model_inft, ref_model, tokenizer, t_model_inf, t_model_infi, t_model_inft, t_ref_model, t_tokenizer, t_test_transform, data_loader, device, args, config):
    print("Start eval")
    start_time = time.time()

    score_i2t, score_t2i, t_score_i2t, t_score_t2i= retrieval_eval(model_inf, model_infi, model_inft, ref_model, t_model_inf, t_model_infi, t_model_inft, t_ref_model, t_test_transform,
                                                                   data_loader, tokenizer, t_tokenizer, device, args, config)
                                                                   
    t_result = itm_eval(t_score_i2t, t_score_t2i, data_loader.dataset.img2txt, data_loader.dataset.txt2img, args.target_model, args.target_image_encoder)
    print('Performance on {}{}: \n {}'.format(args.target_model,args.target_image_encoder, t_result))
    
    result = itm_eval(score_i2t, score_t2i, data_loader.dataset.img2txt, data_loader.dataset.txt2img, args.source_model, args.source_image_encoder)
    print('Performance on {}: \n {}'.format(args.source_model, result))
    

    log_stats_asr = {**{f'test_{k}': v for k, v in t_result.items()},
                 'eval type': args.adv, 'eps': args.eps, 'iters':args.num_steps}
    log_stats = {**{f'test_{k}': v for k, v in result.items()},
                 'eval type': args.adv, 'eps': args.eps, 'iters':args.num_steps}
    with open(os.path.join(args.output_dir, str(args.adv) + "_log_CLIP.txt"), "a+") as f:
        f.write(json.dumps(log_stats) + "\n")
        f.write(json.dumps(log_stats_asr) + "\n")
    
    torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Evaluate time {}'.format(total_time_str))

def main(args, config):
    torch.cuda.set_device(args.cuda_id)
    device = torch.device('cuda')

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True
    
    if args.dataset == 'flickr30':
        if args.source_model == 'ALBEF' or args.source_model == 'CLIP':
            args.source_ckpt = 'XXX/Datasets/ALBEF-main/checkpoint/flickr30k.pth'        
        elif args.source_model == 'TCL':
            args.source_ckpt = 'XXX/Datasets/TCL-main/checkpoint/checkpoint_retrieval_flickr_finetune.pth'
        
        if args.target_model == 'ALBEF' or args.target_model == 'CLIP':
            args.target_ckpt = 'XXX/Datasets/ALBEF-main/checkpoint/flickr30k.pth'        
        elif args.target_model == 'TCL':
            args.target_ckpt = 'XXX/Datasets/TCL-main/checkpoint/checkpoint_retrieval_flickr_finetune.pth'
            
        args.original_rank_index_path = 'XXX/noise/cross_modal_attack/VLA-main/std_eval_idx/flickr30k/'
        args.target_dataset_root = 'XXX/Datasets/flickr30k_images/'
        args.test_file= 'XXX/Datasets/data/flickr30k_test.json'

    elif args.dataset == 'mscoco':
        if args.source_model == 'ALBEF' or args.source_model == 'CLIP':
            args.source_ckpt = 'XXX/Datasets/ALBEF-main/checkpoint/mscoco.pth'        
        elif args.source_model == 'TCL':
            args.source_ckpt = 'XXX/Datasets/TCL-main/checkpoint/checkpoint_retrieval_coco_finetune.pth'
        
        if args.target_model == 'ALBEF' or args.target_model == 'CLIP':
            args.target_ckpt = 'XXX/Datasets/ALBEF-main/checkpoint/mscoco.pth'       
        elif args.target_model == 'TCL':
            args.target_ckpt = 'XXX/Datasets/TCL-main/checkpoint/checkpoint_retrieval_coco_finetune.pth'
            
        args.original_rank_index_path = 'XXX/noise/cross_modal_attack/VLA-main/std_eval_idx/mscoco'
        args.target_dataset_root = 'XXX/Datasets/MSCOCO/'
        args.test_file= 'XXX/Datasets/data/coco_test.json' 

    print("Creating Source Model")
    model_inf, model_infi, model_inft, ref_model, tokenizer = load_model(args.source_model, args.source_ckpt, args.source_image_encoder, args.source_text_encoder, device)
    

    print("Creating Target Model")
    t_model_inf, t_model_infi, t_model_inft, t_ref_model, t_tokenizer = load_model(args.target_model, args.target_ckpt, args.target_image_encoder, args.target_text_encoder, device)
   
   
    model_infi=nn.DataParallel(model_infi.to(device))
    model_inft=nn.DataParallel(model_inft.to(device))
    ref_model=nn.DataParallel(ref_model.to(device))
    t_model_inf=nn.DataParallel(t_model_inf.to(device))
   
    t_model_infi=nn.DataParallel(t_model_infi.to(device))
    t_model_inft=nn.DataParallel(t_model_inft.to(device))
    t_ref_model=nn.DataParallel(t_ref_model.to(device))
   
    print("Creating dataset")
    s_test_transform = None
    if args.source_model in ['ALBEF', 'TCL']:
        s_test_transform = transforms.Compose([
            transforms.Resize((384, 384), interpolation=Image.BICUBIC),
            transforms.ToTensor(),        
        ])
    else:
        n_px = model_inf.visual.input_resolution
        s_test_transform = transforms.Compose([
            transforms.Resize(n_px, interpolation=Image.BICUBIC),
            transforms.CenterCrop(n_px),
            transforms.ToTensor(),       
        ])

    t_test_transforms = []
    if args.target_model in ['ALBEF', 'TCL']:
        t_test_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((384, 384), interpolation=Image.BICUBIC),
            transforms.ToTensor(),  
        ])
    else:
        t_n_px = t_model_inf.module.visual.input_resolution
        t_test_transform = transforms.Compose([
            transforms.Resize(t_n_px, interpolation=Image.BICUBIC),
            transforms.CenterCrop(t_n_px),
        ])
    
    test_dataset = paired_dataset(args.test_file, s_test_transform, args.target_dataset_root)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4, collate_fn=test_dataset.collate_fn)

    eval_asr(model_inf, model_infi, model_inft, ref_model, tokenizer, t_model_inf, t_model_infi, t_model_inft, t_ref_model, t_tokenizer, t_test_transform, test_loader, device, args, config)

    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/Retrieval_flickr.yaml')
    parser.add_argument('--seed', default=42, type=int, choices=['42, 43, 44, 45, 46'])
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--cuda_id', default=0, type=int)
    
    parser.add_argument('--method', default='SEA')
    parser.add_argument('--adv', default=13, type=int,
                        help='0=clean, 1=adv text, 2=adv image, 3=adv text and adv image,11')
    parser.add_argument('--dataset', default='flickr30', choices=['mscoco', 'flickr30'])
    parser.add_argument('--eps', default=2, type=int)

    parser.add_argument('--source_model', default='CLIP', choices=['ALBEF', 'TCL', 'CLIP', 'BLIP'])                    
    parser.add_argument('--source_image_encoder', default='ViT-B/16', type=str, choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
    parser.add_argument('--source_text_encoder', default='bert-base-uncased', type=str)
    parser.add_argument('--source_ckpt', default='', type=str)      

    parser.add_argument('--target_model', default='TCL', choices=['ALBEF', 'CLIP', 'TCL', 'BLIP'])
    parser.add_argument('--target_image_encoder', default='ViT-B/16', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
    parser.add_argument('--target_text_encoder', default='bert-base-uncased', type=str)
 
    parser.add_argument('--original_rank_index_path', default='./std_eval_idx/flickr30k/')  
    parser.add_argument('--scales', type=str, default='0.5,0.75,1.25,1.5')
    args = parser.parse_args()
    
    args.num_steps = 10

    args.output_dir = os.path.join('output', args.method, str(args.source_model) +'_'+ str(args.source_image_encoder), str(args.target_model) +'_'+ str(args.target_image_encoder), str(args.dataset) , str(args.eps) +'_'+ str(args.batch_size))
    
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml = yaml.YAML(typ="safe", pure=True)
    config = yaml.load(open(args.config, 'r'))

    main(args, config)    
