import argparse
import os
import sys

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from PIL import Image

from easydict import EasyDict as edict
from models.get_model import load_model

# Codes imported from https://github.com/salesforce/ALBEF/blob/main/Retrieval.py
from dataset import create_dataset_no_norm, create_sampler, create_loader
# from optim import create_optimizer
from constants import images_normalize

from models.clip_model import clip

import utils.utils as utils
from utils.utils_attack import get_attacker
from utils.utils_eval import eval_pipeline, analysis_each_query
from utils.utils_visualization import vis_retrieval_results

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()


def main(args, config):
    utils.init_distributed_mode(args)    
    
    device = torch.device(args.device)

    ########################
    ###### set seed ########
    ########################
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    ########################
    ###### load data #######
    ########################
    train_dataset, val_dataset, test_dataset, train_dataset_for_eval = create_dataset_no_norm('re', config, get_train_eval=True)
    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]
    # subset of train set
    # random_idx = np.random.choice(len(train_dataset_for_eval), 1000, replace=False)
    # samplers.append(torch.utils.data.sampler.SubsetRandomSampler(random_idx))
    samplers.append(None)
    train_loader, val_loader, test_loader, train_subset_loader = create_loader(
        [train_dataset, val_dataset, test_dataset, train_dataset_for_eval],
        samplers,
        batch_size=[config['batch_size_train']]+[config['batch_size_test']]*3,
        num_workers=[4]*4,
        is_trains=[True, False, False, False], 
        collate_fns=[None]*4
    )  
       
    ########################
    ###### load model ######
    ########################
    print("Loading model")
    model, ref_model, tokenizer = load_model(
        config, args.model, None, args.text_encoder, device=device,
        train_config=train_config
    )

    # load
    if args.eval_ckpt_path is None:
        print("No checkpoint path is provided. Use the pre-trained model.")
    else:
        ckpt_path = args.eval_ckpt_path
        print("Loading checkpoint:", ckpt_path)
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])

    model = model.to(device)   
    ref_model = ref_model.to(device)

    model_without_ddp = model

    ########################
    ###### Evaluation ######
    ########################
    start_time = time.time()
    scores_dict = {
        "val": {}, "test": {}, "train_subset": {},
    }
    eval_results_dict = {
        "val": {}, "test": {}, "train_subset": {}
    }
    
    #### Clean scenario ####
    # val
    if args.eval_val_set:
        print("Start evaluation on val set")
        eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
            args, model_without_ddp, val_loader, tokenizer, device, config
        )
        print(eval_result)
        eval_results_dict["val"]["Clean"] = eval_result
        scores_dict["val"]["Clean"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

    # # test
    # print("Start evaluation on test set")
    # eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, _, _ = eval_pipeline(
    #     args, model_without_ddp, test_loader, tokenizer, device, config
    # )
    # print(eval_result)
    # eval_results_dict["test"]["Clean"] = eval_result
    # scores_dict["test"]["Clean"] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

    #### Attack scenario ####
    # for surrogate_model_name in ["ALBEF", "TCL"]:
    for surrogate_model_name in SURROGATE_LIST:
        print(f"Surrogate model: {surrogate_model_name}")

        ###### Surrogate model #####
        if surrogate_model_name == "self":
            surrogate_model = model_without_ddp
        elif surrogate_model_name == "CLIP_ViT-B-16":
            surrogate_model, _, _ = load_model(
                config, "CLIP_ViT-B-16", None, args.text_encoder, device=device,
                train_config=train_config
            )
        elif surrogate_model_name == "ALBEF":
            albef_ckpt = "/data/ckpts/ALBEF/flickr30k.pth"
            surrogate_model, _, _ = load_model(  
                config, "ALBEF", albef_ckpt, args.text_encoder, device=device,
                train_config=train_config
            )
        elif surrogate_model_name == "TCL":
            tcl_ckpt = "/data/ckpts/TCL/checkpoint_flickr_finetune.pth"
            surrogate_model, _, _ = load_model(
                config, "TCL", tcl_ckpt, args.text_encoder, device=device,
                train_config=train_config
            )
        else:
            raise ValueError(f"Invalid surrogate model name: {surrogate_model_name}")
        surrogate_model.to(device)
        surrogate_model.eval()
        
        # for each attack
        for att in ATTACK_EVAL_LIST:
            print(f"-- Attack: {att}")
            ###### get attacker ########
            if att == "CLEAN":
                eval_attacker = None
                cls = False
                attack_fused_emb = False
            else:
                if surrogate_model_name in ["ALBEF", "TCL"]:
                    cls = False
                    attack_fused_emb = True
                else:
                    cls = False
                    attack_fused_emb = False
                eval_attacker = get_attacker(
                    args, train_config, att, 
                    surrogate_model, ref_model, tokenizer,
                    cls=cls, attack_fused_emb=attack_fused_emb,
                    eps=args.epsilon,
                    steps=args.num_iters,
                    step_size=args.step_size,
                )

            
            ######### eval #########
            # train subset
            if args.eval_train_subset:
                print("--- Start evaluation on train subset ---")
                eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, adv_images_list, adv_texts_list = eval_pipeline(
                    args, model_without_ddp, train_subset_loader, tokenizer, device, config, 
                    attacker=eval_attacker, attack_name=att,
                    attack_fused_emb=attack_fused_emb
                )
                # key = f"{surrogate_model_name}_{att}_train_subset"
                # eval_results_dict[key] = eval_result
                eval_results_dict["train_subset"].setdefault(att, {})[surrogate_model_name] = eval_result
                print(eval_result)

                # vis analysis
                if args.is_analysis and args.eval_train_subset_paired:
                    vis_dir = os.path.join(args.output_dir, f"train_subset_paired_{surrogate_model_name}_{att}")
                    os.makedirs(vis_dir, exist_ok=True)
                    analysis_each_query(score_matrix_i2t, score_matrix_t2i, 
                                        txt2img=np.arange(len(adv_texts_list)), 
                                        img2txt=[[i] for i in range(len(adv_images_list))],
                                        texts=adv_texts_list,
                                        VIS_DIR=vis_dir)
        
            # test loader
            if att in ["Sup1-1", "Sup5-5"]:
                continue
            print("--- Start evaluation on test set ---")
            eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, adv_images_list, adv_texts_list = eval_pipeline(
                args, model_without_ddp, test_loader, tokenizer, device, config, 
                attacker=eval_attacker, attack_name=att,
                attack_fused_emb=attack_fused_emb
            )
            print(eval_result)
            # key = f"{surrogate_model_name}_{att}"
            # eval_results_dict[key] = eval_result
            # scores_dict[key] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}
            eval_results_dict["test"].setdefault(att, {})[surrogate_model_name] = eval_result
            scores_dict["test"].setdefault(att, {})[surrogate_model_name] = {"i2t": score_matrix_i2t, "t2i": score_matrix_t2i}

            # visualization
            vis_dir = os.path.join(args.output_dir, f"{surrogate_model_name}_{att}")
            os.makedirs(vis_dir, exist_ok=True)
            vis_retrieval_results(
                adv_images_list,
                adv_texts_list,
                score_matrix_i2t,
                score_matrix_t2i,
                vis_dir,
                txt2img=test_dataset.txt2img,
                img2txt=test_dataset.img2txt,
                show_n=10,
                top_k=5,
            )

            # vis analysis
            if args.is_analysis:
                analysis_each_query(score_matrix_i2t, score_matrix_t2i, 
                                    test_loader.dataset.txt2img, 
                                    test_loader.dataset.img2txt,
                                    test_loader.dataset.text,
                                    vis_dir)


    print("Evaluation finished")

    ##################
    ###### log #######
    ##################
    log_stats = eval_results_dict
    log_stats.update({"output_dir": args.output_dir})
    with open(os.path.join(args.output_dir, f"log.json"), "w") as f:
        json.dump(log_stats, f, indent=4)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str)) 

    # save the last log_stats
    with open(os.path.join(args.output_dir, "eval_results.json"), "w") as f:
        json.dump(log_stats, f, indent=4)

    # save scores
    with open(os.path.join(args.output_dir, "scores_dict.npy"), "wb") as f:
        np.save(f, scores_dict)

    print("===> Saved results to:", args.output_dir)

    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr_train_clip_b128.yaml")
    parser.add_argument("--seed", default=42, type=int)

    parser.add_argument("--model", default="CLIP_ViT-B-16_PT", type=str) # model architecture
    parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)
    
    # misc config
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')    
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=False, type=bool)
    
    # evaluation
    parser.add_argument('--eval_ckpt_path', default=None, type=str)
    parser.add_argument('--output_dir', default="../eval_results", type=str)
    parser.add_argument('--model_train_name', default=None, type=str)
    parser.add_argument('--attack_fused_emb', default=False, action='store_true')
    parser.add_argument('--cls', default=False, action='store_true')

    # attack config
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--step_size", default=0.5, type=float)
    parser.add_argument('--num_iters', default=10, type=int)

    # what to eval
    parser.add_argument('--eval_black_box', default=False, action='store_true')
    parser.add_argument('--eval_val_set', default=False, action='store_true')
    parser.add_argument('--eval_train_subset', default=False, action='store_true')
    parser.add_argument('--eval_train_subset_paired', default=False, action='store_true') # caps_k=1
    parser.add_argument('--eval_train_n', default=1000, type=int)
    parser.add_argument('--eval_attack_list', default=[
        "CLEAN", "SupPGD", "UnsupPGD", "BERT", "SGA"
        ], nargs="+", type=str)
    parser.add_argument('--eval_batch_size', default=None, type=int)
    parser.add_argument('--is_analysis', default=False, action='store_true')

    args = parser.parse_args()

    assert args.model not in ["ALBEF", "ALBEF_PT", "TCL", "TCL_PT"]

    # ATTACK_EVAL_LIST = ["SGA"]
    ATTACK_EVAL_LIST = args.eval_attack_list
    # ATTACK_EVAL_LIST = ["PGD", "BERT", "Co-Attack", "SGA"]
    if args.eval_black_box:
        SURROGATE_LIST = ["self", "CLIP_ViT-B-16", "ALBEF", "TCL"]
    else:
        SURROGATE_LIST = ["self"]
    print("Attack list:", ATTACK_EVAL_LIST)
    print("Surrogate list:", SURROGATE_LIST)

    ##############################
    ##### load train config ######
    ##############################
    if args.eval_ckpt_path is None:
        ckpt_dir = "/data/train_results/flickr30k/CLIP_ViT-B-16/_bs64sgd_lr0.0001_ep5_wd0.0001_warmup0.3/MMA-iters1-step2.0-scale0-img-unsup0.0-sup1.0/2024-03-24_09-28-19/"
        print("===> Evaluate pre-trained model.")
    else:
        ckpt_dir = os.path.dirname(args.eval_ckpt_path)
    # config_path  = os.path.join(ckpt_dir, "config.json")
    # config = json.load(open(config_path, "r"))

    config_path = args.config
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    config = edict(config)

    config["batch_size_test"] = 32
    if args.eval_batch_size is not None:
        config["batch_size_test"] = args.eval_batch_size
    # for ALBEF, TCL
    config["distill"] = False
    config["queue_size"] = 65536
    config["momentum"] = 0.995
    config["vision_width"] = 768
    config["embed_dim"] = 256
    config["temp"] = 0.07
    config["k_test"] = 128

    if args.eval_train_n is not None:
        config["eval_train_n"] = args.eval_train_n

    if args.eval_train_subset_paired:
        config["caps_k"] = 1

    try:
        train_config_path  = os.path.join(ckpt_dir, "train_config.json")
        train_config = json.load(open(train_config_path, "r"))
    except:
        train_config_path  = os.path.join(ckpt_dir, "prompt_config.json")
        train_config = json.load(open(train_config_path, "r"))
    train_config = edict(train_config)
   
    ############################
    ## create output directory
    ############################
    try:
        dataset_name = config["dataset_name"]
    except:
        dataset_name = "flickr30k"    

    args.output_dir = os.path.join(args.output_dir, dataset_name, args.model, args.model_train_name)
    os.makedirs(args.output_dir, exist_ok=True)
    print("Output directory:", args.output_dir)

    # log print
    sys.stdout = utils.Tee(sys.stdout, open(os.path.join(args.output_dir, "out.txt"), "w"))

    ############################
    ## save args
    ############################
    with open(os.path.join(args.output_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, indent=4)
    # save config
    with open(os.path.join(args.output_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    with open(os.path.join(args.output_dir, "train_config.json"), "w") as f:
        json.dump(train_config, f, indent=4)
    
    main(args, config)
