
import argparse
import os

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler

from transformers import BertForMaskedLM
from torchvision import transforms
from PIL import Image

from models.model_retrieval import ALBEF
from models.vit import interpolate_pos_embed
from models.tokenization_bert import BertTokenizer
from models import clip
from models.discriminator import domain_classifier

import utils

from attacker import SGAttacker, ImageAttacker, TextAttacker
from sklearn.cluster import KMeans
from dataset import paired_dataset, pair_dataset3


def retrieval_eval(model_inf, model_infi, model_inft, ref_model, t_model, t_ref_model, t_test_transform, data_loader, test_dataset1, tokenizer, t_tokenizer, device, config):
    # test
    model_inf.float()
    model_infi.float()
    model_inft.float()
    model_inf.eval()
    model_infi.eval()
    model_inft.eval()
    ref_model.eval()
    t_model.float()
    t_model.eval()
    t_ref_model.eval()    

    print('Computing features for evaluation adv...')

    images_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    img_attacker = ImageAttacker(images_normalize, eps=args.eps/255, steps=args.num_steps, step_size=0.5/255)  #10

    txt_attacker = TextAttacker(ref_model, tokenizer, device, cls=False, max_length=77, number_perturbation=1,
                                topk=10, threshold_pred_score=0.3)
    attacker = SGAttacker(model_inf, model_infi, model_inft, img_attacker, txt_attacker)

    print('Prepare memory')
    num_text = len(data_loader.dataset.text)
    num_image = len(data_loader.dataset.ann)

    s_image_feats = torch.zeros(num_image, model_inf.module.visual.output_dim)
    s_text_feats = torch.zeros(num_text, model_inf.module.visual.output_dim)


    t_image_feats = torch.zeros(num_image, t_model.module.visual.output_dim)
    t_text_feats = torch.zeros(num_text, t_model.module.visual.output_dim)

    # Example usage
    image_feats = torch.zeros(num_image, model_inf.module.visual.output_dim)
    text_feats = torch.zeros(num_image, model_inf.module.visual.output_dim)

    if args.scales is not None:
        scales = [float(itm) for itm in args.scales.split(',')]
        print(scales)
    else:
        scales = None
    
    for batch_idx, (images, texts_group, images_ids, text_ids_groups) in enumerate(data_loader):
        print(f'--------------------> batch:{batch_idx}/{len(data_loader)}')
        texts_ids = []
        texts_ids1 = []
        txt2img = []
        texts = []
        texts1 = []
        for i in range(len(texts_group)):
            texts += texts_group[i]
            texts_ids += text_ids_groups[i]
            txt2img += [i]*len(text_ids_groups[i])
        images = images.to(device)                                                         
        
        
        try:
            _, _, gen_image, _, _ = next(test_dataset_iter)
        except:
            test_dataset_iter = iter(test_dataset1)
            _, _, gen_image, _, _ = next(test_dataset_iter)
        
        g_images = gen_image#[3]
           
        elif args.adv == 2:
            adv_images, adv_texts = attacker.attack7(images, texts, txt2img, args.scales_num, args, device=device,
                                                max_lemgth=30, scales=scales)
        elif args.adv == 4:
            adv_images, adv_texts = attacker.attack8(images, texts, txt2img, args.scales_num, args, device=device,
                                                max_lemgth=30, scales=scales)
             
        adv_texts = txt_attacker.tokenizer(adv_texts, padding='max_length', truncation=True, max_length=77, return_tensors="pt")#.to(device)
        adv_texts.input_ids = adv_texts.input_ids.to(device)
        
            
        with torch.no_grad():

            adv_images_norm = images_normalize(adv_images)
            output = model_inf(adv_images_norm, adv_texts.input_ids)
            s_image_feats[images_ids] = output['image_feat'].cpu().float().detach()
            s_text_feats[texts_ids] = output['text_feat'].cpu().float().detach()

            t_adv_img_list = []
            for itm in adv_images:
                t_adv_img_list.append(t_test_transform(itm))
            t_adv_imgs = torch.stack(t_adv_img_list, 0).to(device)            
            t_adv_images_norm = images_normalize(t_adv_imgs)
            output = t_model(t_adv_images_norm, adv_texts.input_ids)
            
            t_image_feats[images_ids] = output['image_feat'].cpu().float().detach()
            t_text_feats[texts_ids] = output['text_feat'].cpu().float().detach()

    s_sims_matrix = s_image_feats @ s_text_feats.t()           
    t_sims_matrix = t_image_feats @ t_text_feats.t()

    return s_sims_matrix.cpu().numpy(), s_sims_matrix.t().cpu().numpy(), \
        t_sims_matrix.cpu().numpy(), t_sims_matrix.t().cpu().numpy(),
        

@torch.no_grad()
def retrieval_score(model, image_feats, image_embeds, text_feats, text_embeds, text_atts, num_image, num_text, device=None):
    if device is None:
        device = image_embeds.device

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Evaluation Direction Similarity With Bert Attack:'

    sims_matrix = image_feats @ text_feats.t()
    score_matrix_i2t = torch.full((num_image, num_text), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)

        encoder_output = image_embeds[i].repeat(config['k_test'], 1, 1).to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
        output = model.text_encoder(encoder_embeds=text_embeds[topk_idx].to(device),
                                    attention_mask=text_atts[topk_idx].to(device),
                                    encoder_hidden_states=encoder_output,
                                    encoder_attention_mask=encoder_att,
                                    return_dict=True,
                                    mode='fusion'
                                    )
        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_i2t[i, topk_idx] = score

    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((num_text, num_image), -100.0).to(device)

    for i, sims in enumerate(metric_logger.log_every(sims_matrix, 50, header)):
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
        encoder_output = image_embeds[topk_idx].to(device)
        encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device)
        output = model.text_encoder(encoder_embeds=text_embeds[i].repeat(config['k_test'], 1, 1).to(device),
                                    attention_mask=text_atts[i].repeat(config['k_test'], 1).to(device),
                                    encoder_hidden_states=encoder_output,
                                    encoder_attention_mask=encoder_att,
                                    return_dict=True,
                                    mode='fusion'
                                    )
        score = model.itm_head(output.last_hidden_state[:, 0, :])[:, 1]
        score_matrix_t2i[i, topk_idx] = score

    return score_matrix_i2t, score_matrix_t2i



@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, img2txt, txt2img, model_name, image_encoder):
    # Images->Text
    ranks = np.zeros(scores_i2t.shape[0])
    for index, score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)


    after_attack_tr1 = np.where(ranks < 1)[0]
    after_attack_tr5 = np.where(ranks < 5)[0]
    after_attack_tr10 = np.where(ranks < 10)[0]
    
    original_rank_index_path = args.original_rank_index_path
    if model_name == 'CLIP': 
        if image_encoder in ['RN50', 'RN101']:
            print(model_name)
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_CNN_tr10_rank_index.npy')
        else:    
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_ViT_tr10_rank_index.npy')
          
    elif model_name == 'ALBEF' or model_name == 'TCL':
            origin_tr1 = np.load(f'{original_rank_index_path}/{model_name}_tr1_rank_index.npy')
            origin_tr5 = np.load(f'{original_rank_index_path}/{model_name}_tr5_rank_index.npy')
            origin_tr10 = np.load(f'{original_rank_index_path}/{model_name}_tr10_rank_index.npy')
    
    asr_tr1 = round(100.0 * len(np.setdiff1d(origin_tr1, after_attack_tr1)) / len(origin_tr1), 2)
    asr_tr5 = round(100.0 * len(np.setdiff1d(origin_tr5, after_attack_tr5)) / len(origin_tr5), 2)
    asr_tr10 = round(100.0 * len(np.setdiff1d(origin_tr10, after_attack_tr10)) / len(origin_tr10), 2)



    # Text->Images
    ranks = np.zeros(scores_t2i.shape[0])
    for index, score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]


    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)

    after_attack_ir1 = np.where(ranks < 1)[0]
    after_attack_ir5 = np.where(ranks < 5)[0]
    after_attack_ir10 = np.where(ranks < 10)[0]

    if model_name == 'CLIP': 
        if image_encoder in ['RN50', 'RN101']:
            origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir1_rank_index.npy')
            origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir5_rank_index.npy')
            origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_CNN_ir10_rank_index.npy')
        else:    
            origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir1_rank_index.npy')
            origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir5_rank_index.npy')
            origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_ViT_ir10_rank_index.npy')
          
    elif model_name == 'ALBEF' or model_name == 'TCL':
        origin_ir1 = np.load(f'{original_rank_index_path}/{model_name}_ir1_rank_index.npy')
        origin_ir5 = np.load(f'{original_rank_index_path}/{model_name}_ir5_rank_index.npy')
        origin_ir10 = np.load(f'{original_rank_index_path}/{model_name}_ir10_rank_index.npy') 

    asr_ir1 = round(100.0 * len(np.setdiff1d(origin_ir1, after_attack_ir1)) / len(origin_ir1), 2) 
    asr_ir5 = round(100.0 * len(np.setdiff1d(origin_ir5, after_attack_ir5)) / len(origin_ir5), 2)
    asr_ir10 = round(100.0 * len(np.setdiff1d(origin_ir10, after_attack_ir10)) / len(origin_ir10), 2)


    eval_result = {'txt_r1_ASR (txt_r1)': f'{asr_tr1}({tr1})',
                   'txt_r5_ASR (txt_r5)': f'{asr_tr5}({tr5})',
                   'txt_r10_ASR (txt_r10)': f'{asr_tr10}({tr10})',
                   'img_r1_ASR (img_r1)': f'{asr_ir1}({ir1})',
                   'img_r5_ASR (img_r5)': f'{asr_ir5}({ir5})',
                   'img_r10_ASR (img_r10)': f'{asr_ir10}({ir10})'}
    return eval_result

def load_model1(model_name, model_ckpt, text_encoder, device):
    tokenizer = BertTokenizer.from_pretrained(text_encoder)
    ref_model = BertForMaskedLM.from_pretrained(text_encoder)    
    if model_name in ['ALBEF', 'TCL']:
        model = ALBEF(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        checkpoint = torch.load(model_ckpt, map_location='cpu')
    ### load checkpoint
    else:
        model_inf, model_infi, model_inft, preprocess, _, _ = clip.load1(model_name, device=device)
        model_inf.set_tokenizer(tokenizer)
        model_infi.set_tokenizer(tokenizer)
        model_inft.set_tokenizer(tokenizer)
        return model_inf, model_infi, model_inft, ref_model, tokenizer
    
    try:
        state_dict = checkpoint['model']
    except:
        state_dict = checkpoint

    if model_name == 'TCL':
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
        m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
        state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 

    for key in list(state_dict.keys()):
        if 'bert' in key:
            encoder_key = key.replace('bert.', '')
            state_dict[encoder_key] = state_dict[key]
            del state_dict[key]
    model_inf.load_state_dict(state_dict, strict=False)
    model_infi.load_state_dict(state_dict, strict=False)
    model_inft.load_state_dict(state_dict, strict=False)
    
    return model_inf, model_infi, model_inft, ref_model, tokenizer


def load_model2(model_name, model_ckpt, text_encoder, device):
    tokenizer = BertTokenizer.from_pretrained(text_encoder)
    ref_model = BertForMaskedLM.from_pretrained(text_encoder)    
    if model_name in ['ALBEF', 'TCL']:
        model = ALBEF(config=config, text_encoder=text_encoder, tokenizer=tokenizer)
        checkpoint = torch.load(model_ckpt, map_location='cpu')
    ### load checkpoint
    else:
        model, preprocess = clip.load2(model_name, device=device)
        model.set_tokenizer(tokenizer)
        return model, ref_model, tokenizer
    
    try:
        state_dict = checkpoint['model']
    except:
        state_dict = checkpoint

    if model_name == 'TCL':
        pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)         
        state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
        m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)   
        state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped 

    for key in list(state_dict.keys()):
        if 'bert' in key:
            encoder_key = key.replace('bert.', '')
            state_dict[encoder_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict, strict=False)
    
    return model, ref_model, tokenizer

def eval_asr(model_inf, model_infi, model_inft, ref_model, tokenizer, t_model, t_ref_model, t_tokenizer, t_test_transform, data_loader, test_dataset1, device, args, config):
    print("Start eval")
    start_time = time.time()
    
    score_i2t, score_t2i, t_score_i2t, t_score_t2i= retrieval_eval(model_inf, model_infi, model_inft, ref_model, t_model, t_ref_model, t_test_transform,
                                                                   data_loader, test_dataset1, tokenizer, t_tokenizer, device, config)
    

    t_result = itm_eval(t_score_i2t, t_score_t2i, data_loader.dataset.img2txt, data_loader.dataset.txt2img, args.target_model, args.target_image_encoder)
    print('Performance on target {}: \n {}'.format(args.target_model, t_result))

    result = itm_eval(score_i2t, score_t2i, data_loader.dataset.img2txt, data_loader.dataset.txt2img, args.source_model, args.source_image_encoder)

    print('Performance on source {}: \n {}'.format(args.source_model, result))
    

    torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Evaluate time {}'.format(total_time_str))    

    log_stats = {**{f'test_{k}': v for k, v in t_result.items()},
                 'eval type': args.adv, 'eps': args.eps, 'eps1': args.eps1, 'num_steps': args.num_steps, 'iters':config['num_iters'], 'scales_num': args.scales_num}
    log_stats_asr = {**{f'test_{k}': v for k, v in result.items()},
                 'eval type': args.adv, 'eps': args.eps, 'eps1': args.eps1, 'num_steps': args.num_steps, 'iters':config['num_iters'], 'scales_num': args.scales_num}
    with open(os.path.join(args.output_dir, str(args.adv) + "_log_CLIP.txt"), "a+") as f:
        f.write(json.dumps(log_stats) + "\n")
        f.write(json.dumps(log_stats_asr) + "\n")



def main(args, config):
    
    device_count = torch.cuda.device_count()
    if device_count >1:
        device = torch.device("cuda:0")
        print(f"Using GPU {device} - {torch.cuda.get_device_name(device)}")
    elif device_count ==1:
        device = torch.device("cuda:0")
    else:
        
        print("CUDA is not available. No GPU devices found.")

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True
    

    if args.dataset == 'flickr30':
        if args.target_model == 'ALBEF' or args.target_model == 'CLIP':
            args.target_ckpt = 'xxx/ALBEF-main/checkpoint/flickr30k.pth'
        
        elif args.target_model == 'TCL':
            args.target_ckpt = 'xxx/TCL-main/checkpoint/checkpoint_retrieval_flickr_finetune.pth'
        args.original_rank_index_path = 'xxx/Co-Attack-main/std_eval_idx/flickr30k'
        args.target_dataset_root = 'xxx/Datasets/flickr30k_images/'
        args.test_file= 'xxx/Datasets/data/flickr30k_test.json'

    elif args.dataset == 'mscoco':
        if args.target_model == 'ALBEF' or args.target_model == 'CLIP':
            args.target_ckpt = 'xxx/ALBEF-main/checkpoint/mscoco.pth'
        
        elif args.target_model == 'TCL':
            args.target_ckpt = 'xxx/TCL-main/checkpoint/checkpoint_retrieval_coco_finetune.pth'
        args.original_rank_index_path = 'xxx/Co-Attack-main/std_eval_idx/mscoco'
        args.target_dataset_root = 'xxx/Datasets/MSCOCO/'
        args.test_file= 'xxx/data/coco_test.json' 
    
    print("Creating Source Model")
    model_inf, model_infi, model_inft, ref_model, tokenizer = load_model1(args.source_image_encoder, args.source_ckpt, args.source_text_encoder, device)
    t_model, t_ref_model, t_tokenizer = load_model2(args.target_image_encoder, args.target_ckpt, args.target_text_encoder, device)
    
    #### Dataset ####
    print("Creating dataset")
    n_px = model_inf.visual.input_resolution
    s_test_transform = transforms.Compose([
        transforms.Resize(n_px, interpolation=Image.BICUBIC),
        transforms.CenterCrop(n_px),
        transforms.ToTensor(),       
    ])

    t_n_px = t_model.visual.input_resolution
    t_test_transform = transforms.Compose([
        transforms.Resize(t_n_px, interpolation=Image.BICUBIC),
        transforms.CenterCrop(t_n_px)     
    ])
    
    test_dataset = paired_dataset(args.test_file, s_test_transform, args.target_dataset_root)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             num_workers=4, collate_fn=test_dataset.collate_fn)

    test_dataset1 = pair_dataset3(args.diffusion_data_path)
    test_dataset1 = DataLoader(test_dataset1, batch_size=args.batch_size, shuffle = True, num_workers=4)
    
    model_inf=nn.DataParallel(model_inf)
    model_infi=nn.DataParallel(model_infi)
    model_inft=nn.DataParallel(model_inft)
    ref_model=nn.DataParallel(ref_model.to(device))
    t_model=nn.DataParallel(t_model)
    t_ref_model=nn.DataParallel(t_ref_model.to(device))

    eval_asr(model_inf, model_infi, model_inft, ref_model, tokenizer, t_model, t_ref_model, t_tokenizer, t_test_transform, test_loader, test_dataset1, device, args, config)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/Retrieval_flickr.yaml')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--batch_size', default=4, type=int) #8
    parser.add_argument('--adv', default=9, type=int,
                        help='0=clean, 1=adv text, 2=adv image, 3=adv text and adv image,')
    parser.add_argument('--dataset', default='flickr30', choices=['mscoco', 'flickr30'])
    parser.add_argument('--eps', default=4, type=int)
    
    parser.add_argument('--source_model', default='CLIP')                    
    parser.add_argument('--source_image_encoder', default='ViT-B/32', type=str, choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
    parser.add_argument('--source_text_encoder', default='bert-base-uncased', type=str)
    parser.add_argument('--source_ckpt', default='', type=str)      

    parser.add_argument('--target_model', default='CLIP', choices=['ALBEF', 'CLIP', 'TCL', 'BLIP'])
    parser.add_argument('--target_image_encoder', default='RN50', choices=['ViT-L/14', 'ViT-B/16', 'ViT-B/32', 'RN50', 'RN101'])
    parser.add_argument('--target_text_encoder', default='bert-base-uncased', type=str)  
    
    parser.add_argument('--scales_num', type=int, default=5) 
    parser.add_argument('--scales', type=str, default='1.25, 1.5, 1.75, 2', choices=['1.25, 1.5, 1.75, 2'])
    parser.add_argument('--scales1', type=str, default='0.5, 0.6, 0.7, 0.8', choices=['0.5,0.75, 1.25, 1.5'])
    parser.add_argument('--eps1', default=0.004, type=float, choices=['0.004', '1.0'])
    parser.add_argument('--constraint', default='Linf', choices=['Linf', 'L2'], type=str)
    parser.add_argument('--diffusion_data_path', default='xxx/Co-Attack-main/data/flickr30/0.032_diffusion_data_v3')

    args = parser.parse_args()
    args.num_steps = 60   

    args.output_dir = os.path.join('output', str(args.source_model) +'_'+ str(args.source_image_encoder), str(args.target_model) +'_'+ str(args.target_image_encoder), str(args.dataset) , str(args.eps) +'_'+ str(args.scales_num) +'_'+ str(args.batch_size))
    
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml = yaml.YAML(typ="safe", pure=True)
    config = yaml.load(open(args.config, 'r'))

    main(args, config)


