import argparse
import os
import sys

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from scipy.stats import pearsonr

from PIL import Image

import utils

from easydict import EasyDict as edict
from models.get_model import load_model

# Codes imported from https://github.com/salesforce/ALBEF/blob/main/Retrieval.py
from dataset import create_dataset_no_norm, create_sampler, create_loader
# from optim import create_optimizer
from constants import images_normalize

from models.clip_model import clip

from utils_optimizer import get_trainable_params, get_optimizer
from utils_attack import get_attacker, attack_batch_train
from utils_eval import eval_pipeline, analysis_each_query
from utils_visualization import vis_retrieval_results
from utils_text_metrics import get_sentence_length, get_pos_counts, get_sentence_inv_freq_sum


loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()



def compare_attack(model, data_loader, optimizer, epoch, device, scheduler, attacker_dict):
    # train
    model.float() # float32
    model.train()  
    
    loss_list = []
    metric_logger = utils.MetricLogger(delimiter="  ")
    # metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    # metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50

    acc_per_image_list = [[], []]
    acc_per_text_list = [[], []]
    is_set_attack_stronger_list = []
    
    probs_gt_list = [[], []]
    text_metrics_all = {}
    
    for i, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        if len(data) == 4:
            image, text, idx, gt_caps_list = data
        else:
            image, text, idx = data
        orig_image = image.to(device,non_blocking=True)   
        orig_text = text
        idx = idx.to(device,non_blocking=True)  

        # text metrics
        text_metrics = {
            "sentence_length": [get_sentence_length(t) for t in orig_text],
            "sentence_inv_freq_sum": [get_sentence_inv_freq_sum(t) for t in orig_text],
        }
        for t in orig_text:
            pos_counts = get_pos_counts(t)
            for k in ["ADJ", "NOUN", "VERB"]:
                v = pos_counts.get(k, 0)
                text_metrics[k] = text_metrics.get(k, []) + [v]

        for k in text_metrics:
            text_metrics_all[k] = text_metrics_all.get(k, []) + text_metrics[k]

        # shape
        print("orig_image.shape:", orig_image.shape)
        print("orig_text.shape:", len(orig_text))

        # attack
        probs_gt = []
        for att_idx, ((scale_ver, txt_sup_k), attacker) in enumerate(attacker_dict.items()):
            print("scale_ver:", scale_ver, "txt_sup_k:", txt_sup_k)
            txt2img = np.arange(len(image)).tolist()
            image, text = attack_batch_train(
                args, args.attack, attacker, 
                orig_image, orig_text, txt2img, device, 
                return_pair=True, 
                gt_caps_list=gt_caps_list if args.is_use_gt_caps else None
                )
            image.detach_()

            # print(orig_image - image)

            # normalize image
            image = images_normalize(image)

            optimizer.zero_grad()

            with torch.no_grad():
                # forward
                text_input_ids = clip.tokenize(text).to(device)

                outs = model(image, text_input_ids)
                logits_per_image, logits_per_text = outs
                # print(logits_per_image, logits_per_text)
                gt = torch.arange(len(image),dtype=torch.long,device=device) # ground truth label (1: 1)
                loss = (loss_img(logits_per_image,gt) + loss_txt(logits_per_text,gt))/2
                acc_per_image = (logits_per_image.argmax(dim=-1) == gt).float().mean()
                acc_per_text = (logits_per_text.argmax(dim=-1) == gt).float().mean()
                print("loss:", loss.item(), "acc_per_image:", acc_per_image.item(), "acc_per_text:", acc_per_text.item())

                logits_per_image = logits_per_image.detach()
                logits_per_text = logits_per_text.detach()
                # probs = torch.softmax(logits_per_image, dim=-1)
                # probs_diagonal = probs[torch.arange(len(image)), gt]
                probs = torch.softmax(logits_per_text, dim=-1)
                probs_diagonal = probs[torch.arange(len(image)), gt]
                probs_diagonal = probs_diagonal.cpu().numpy()
                print("probs:", [round(x, 2) for x in probs_diagonal])
                probs_gt.append(probs_diagonal)

                acc_per_image_list[att_idx].append(acc_per_image.item())
                acc_per_text_list[att_idx].append(acc_per_text.item())

        # compare probs_diagonal, which is stronger attack?
        is_set_attack_stronger = (probs_gt[0] > probs_gt[1])
        # print("is_set_attack_stronger:", is_set_attack_stronger)
        # print("is_set_attack_stronger:", is_set_attack_stronger.mean())
        is_set_attack_stronger_list.append(is_set_attack_stronger.mean())

        probs_gt_list[0].extend(probs_gt[0])
        probs_gt_list[1].extend(probs_gt[1])

        # if loss is nan
        if torch.isnan(loss):
            print("loss is nan")
            exit()

        if i == 20:
            break

    # correlation 
    corr_is_set_strong_k = {}
    gap_set_stronger = np.array(probs_gt_list[0]) - np.array(probs_gt_list[1])
    for k in text_metrics:
        x = gap_set_stronger
        y = text_metrics_all[k]
        r, p_value = pearsonr(x, y)
        # print("k:", k, "r:", r)
        corr_is_set_strong_k[k] = r

    corr_probs_k = [{}, {}]
    for idx in range(2):
        for k in text_metrics:
            x = probs_gt_list[idx]
            y = text_metrics_all[k]
            r, p_value = pearsonr(x, y)
            # print("k:", k, "r:", r)
            corr_probs_k[idx][k] = r


    # print results
    print("acc_per_image_list:", [[round(x, 2) for x in l] for l in acc_per_image_list])
    print("acc_per_text_list:", [[round(x, 2) for x in l] for l in acc_per_text_list])
    print("is_set_attack_stronger_list:", [round(x, 2) for x in is_set_attack_stronger_list])
    corr_is_set_strong_k = {k: round(v, 2) for k, v in corr_is_set_strong_k.items()}
    for k in text_metrics:
        print("k:", k, "corr_is_set_strong_k:", corr_is_set_strong_k[k])
    
    for i in range(2):
        print("--- attack idx:", i)
        corr_probs_k[i] = {k: round(v, 2) for k, v in corr_probs_k[i].items()}
        for k in text_metrics:
            print("k:", k, "corr_probs_k:", corr_probs_k[i][k])


def main(args, config):
    utils.init_distributed_mode(args)    
    
    device = torch.device(args.device)

    ########################
    ###### set seed ########
    ########################
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    ########################
    ###### load data #######
    ########################
    train_dataset, val_dataset, test_dataset, train_dataset_for_eval = create_dataset_no_norm('re', config, get_train_eval=True)
    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]
    # subset of train set
    # random_idx = np.random.choice(len(train_dataset_for_eval), 1000, replace=False)
    # samplers.append(torch.utils.data.sampler.SubsetRandomSampler(random_idx))
    samplers.append(None)
    train_loader, val_loader, test_loader, train_subset_loader = create_loader(
        [train_dataset, val_dataset, test_dataset, train_dataset_for_eval],
        samplers,
        batch_size=[config['batch_size_train']]+[config['batch_size_test']]*3,
        num_workers=[4]*4,
        is_trains=[True, False, False, False], 
        collate_fns=[None]*4
    )  
       
    ########################
    ###### load model ######
    ########################
    print("Loading model")
    model, ref_model, tokenizer = load_model(
        config, args.model, None, args.text_encoder, device=device,
        train_config=train_config
    )

    # load
    if args.eval_ckpt_path is None:
        print("No checkpoint path is provided. Use the pre-trained model.")
    else:
        ckpt_path = args.eval_ckpt_path
        print("Loading checkpoint:", ckpt_path)
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        model.load_state_dict(checkpoint["model"])

    model = model.to(device)   
    ref_model = ref_model.to(device)

    model_without_ddp = model

    ########################
    ###### optimizer #######
    ########################
    parameters = get_trainable_params(model, train_config)
    opt_config = utils.AttrDict(train_config['optimizer'])
    optimizer = get_optimizer(parameters, opt_config)


    ########################
    ###### attacker ########
    ########################
    attack_config["MMA"]["is_use_gt_caps"] = False
    attack_config["MMA"]["alpha_sr"] = args.alpha_sr
    attack_config["MMA"]["alpha_ri"] = args.alpha_ri
    attack_config["MMA"]["alpha_rs"] = args.alpha_rs
    attack_config["MMA"]["p_rd"] = args.p_rd
    attack_config["MMA"]["alpha_unsup"] = 0
    attack_config["MMA"]["alpha_sup"] = 1
    attack_config["MMA"]["is_txt_aug"] = False
    attack_config["MMA"]["txt_aug"] = "rand"

    attacker_dict = {}
    attack_name = "MMA"
    for scale_ver, txt_sup_k in [
        [0, 1],
        [2, 5],
    ]:
        attack_config["MMA"]["txt_sup_k"] = txt_sup_k
        attack_config["MMA"]["scale_ver"] = scale_ver
        # args.scale_ver = scale_ver
        attacker = get_attacker(
            args, train_config, attack_name, 
            model_without_ddp, ref_model, tokenizer,
            attack_config=attack_config,
            eps=args.epsilon, steps=args.num_iters, step_size=args.step_size,
        )
        attacker_dict[(scale_ver, txt_sup_k)] = attacker

    ########################
    ###### compare #######
        ########################
    compare_attack(model, train_loader, optimizer, 0, device, None, attacker_dict)


def t2bool(t):
    if t.lower() == "true":
        return True
    elif t.lower() == "false":
        return False
    else:
        raise ValueError("Invalid value")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
    parser.add_argument("--seed", default=42, type=int)

    parser.add_argument("--model", default="CLIP_ViT-B-16_PT", type=str) # model architecture
    parser.add_argument("--model_name", default="CLIP_ViT-B-16_PT", type=str) # id for the model
    parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--ckpt", default=None, type=str)

    # training config
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')    
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=False, type=bool)
    
    # adversarial training config
    parser.add_argument(
        "--attack",
        default=None,
        type=str,
        choices=[
            "SGA", "Co-Attack", "Sep-Attack", "PGD", "BERT", "Clean",
             "FSGA", "PDE-MMA", "SupPGD", "UnsupPGD", "MMA"],
    )
    parser.add_argument("--is_rand_mask", default=False, action="store_true")
    parser.add_argument("--attack_fused_emb", default=False, type=bool)
    parser.add_argument("--cls", default=False, type=bool)
    parser.add_argument("--output_dir", default="../train_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)
    parser.add_argument("--num_iters", default=10, type=int)
    parser.add_argument("--step_size", default=0.5, type=float)

    # dataset
    parser.add_argument("--caps_k", default=None, type=int) # how many captions per image is used for training.

    # FSGA config
    parser.add_argument("--scale_ver", default=0, type=int)
    parser.add_argument("--txt_att_k", default=0, type=int)
    parser.add_argument("--txt_attack", default=None, type=str, choices=["rand", "bert"])
    parser.add_argument("--img_attack_loss", default=False, type=str)

    # MMA config
    parser.add_argument("--is_use_gt_caps", default=False, type=t2bool)
    parser.add_argument("--txt_sup_k", default=5, type=int) # if > 1, use augmented texts for text-supervised image attack
    parser.add_argument("--alpha_sr", default=0.1, type=float)
    parser.add_argument("--alpha_ri", default=0.1, type=float)
    parser.add_argument("--alpha_rs", default=0.1, type=float)
    parser.add_argument("--p_rd", default=0.1, type=float)
    parser.add_argument("--alpha_unsup", default=0.0, type=float)
    parser.add_argument("--alpha_sup", default=1.0, type=float)
    parser.add_argument("--is_txt_aug", default=True, type=t2bool)
    parser.add_argument("--txt_aug", default="sr", type=str)

    # train config
    parser.add_argument("--train_config", default=None, type=str)

    # overwrite config for grid search
    parser.add_argument("--epochs", default=None, type=int)
    parser.add_argument("--batch_size", default=None, type=str)

    parser.add_argument("--eval_ckpt_path", default=None, type=str)

    args = parser.parse_args()

    assert args.model not in ["ALBEF", "ALBEF_PT", "TCL", "TCL_PT"]

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
    config = edict(config)
    
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.Loader)
    train_config = edict(train_config)

    attack_config = train_config["attack"]

    ########################################
    ### overwrite config for grid search ###
    ########################################
    # if args.lr is not None:
    #     train_config["optimizer"]["lr"] = args.lr
    #     print("Overwrite lr:", args.lr)
    # if args.pde_mul_lr is not None:
    #     train_config["pde_mul_lr"] = args.pde_mul_lr
    #     print("Overwrite pde_mul_lr:", args.pde_mul_lr)
    if args.epochs is not None:
        train_config["schedular"]["epochs"] = args.epochs
        print("Overwrite epochs:", args.epochs)
    if args.batch_size is not None:
        config["batch_size_train"] = args.batch_size
        print("Overwrite batch_size:", args.batch_size)

    if args.is_use_gt_caps:
        config["is_return_set_data"] = True
    if args.caps_k is not None:
        config["caps_k"] = args.caps_k

    # log print
    sys.stdout = utils.Tee(sys.stdout, open(os.path.join(args.output_dir, "out.txt"), "w"))

    main(args, config)
