"""
Analyze augmentation effect.
- Alignment: is the augmented image aligned with the text description? (vice versa)
- Distribution gap: how far is the augmented image from the original image? (same for text)
- Diversity: does the augmentation yield diverse images? (same for text)
"""

import argparse
import os
import sys
import gc

import yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
import pprint
import copy
from tqdm import tqdm

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data.dataset import Subset
from PIL import Image

import faiss
import scipy.stats as stats

from easydict import EasyDict as edict
from models.get_model import load_model

# Codes imported from https://github.com/salesforce/ALBEF/blob/main/Retrieval.py
from dataset import create_dataset_no_norm, create_sampler, create_loader
from dataset.caption_dataset import re_train_dataset_subset
from dataset.randaugment import RandomAugment
# from scheduler import create_scheduler_each_step

# from optim import create_optimizer
from constants import images_normalize

from models.clip_model import clip

import utils.utils as utils
from utils.utils_attack import get_attacker, attack_batch_train
from utils.utils_eval import eval_pipeline
from utils.utils_optimizer import get_trainable_params, get_optimizer
from utils.utils_visualization import vis_img_txt_pairs
from utils.utils_loss import BCEwithProj

from attacks.MMA import eda

from utils import FARE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion_kl = nn.KLDivLoss(reduction="sum").to(device)


def check_empty_text(text):
    # check empty text. debug
    new_text = []
    for t in text:
        if len(t) == 0:
            t = "a photo"
            print("Empty text: replaced with 'a photo'")
        new_text.append(t)
    return new_text


def eda_text(text, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=1):
    new_text = []
    for t in text:
        try:
            t = eda(t, alpha_sr=alpha_sr, alpha_ri=alpha_ri, alpha_rs=alpha_rs, p_rd=p_rd, num_aug=num_aug)[
                0
            ]
        except:
            print("Error in EDA")
            print("t:", t)
            t = "a photo"
        new_text.append(t)
    return new_text


def feat_extract_test(model, data_loader, device, config, is_norm=False, eda_alpha=0.0, max_n=1000):
    images = []
    texts = []

    model.eval()
    feat_dict = {}
    n = 0
    for batch_idx, data in tqdm(enumerate(data_loader)):
        image, text, idx = data

        images.append(image)

        image = image.to(device, non_blocking=True)

        # normalize before forward
        image = images_normalize(image)

        if eda_alpha != 0.0:
            text = eda_text(
                text, alpha_sr=eda_alpha, alpha_ri=eda_alpha, alpha_rs=eda_alpha, p_rd=eda_alpha, num_aug=1
            )
        texts.extend(text)

        text_input_ids = clip.tokenize(text, truncate=True).to(device)

        # get features
        with torch.no_grad():
            image_feat = model.encode_image(image).cpu().detach()
            text_feat = model.encode_text(text_input_ids).cpu().detach()
        feat_dict.setdefault("image", []).append(image_feat)
        feat_dict.setdefault("text", []).append(text_feat)

        n += len(image)
        if n >= max_n:
            break

    for k, v in feat_dict.items():
        feat_dict[k] = torch.cat(v, dim=0)

    if is_norm:
        for k, v in feat_dict.items():
            feat_dict[k] = v / v.norm(dim=-1, keepdim=True)

    images = torch.cat(images, dim=0)

    print("Extracted features: ", feat_dict["image"].size(), feat_dict["text"].size())

    return feat_dict, images, texts


def main(args, config, test_config_dict):
    # test_config_dict = {
    #    name: [image/text, test_config]

    # utils.init_distributed_mode(args)

    device = torch.device(args.device)

    ########################
    ###### set seed ########
    ########################
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    #################################
    ###### load original data #######
    #################################
    test_size = 1000
    test_indices = list(range(test_size))
    # print("Reference indices:", reference_indices[:10])
    print("Test indices:", test_indices[:10])

    test_transform = transforms.Compose(
        [
            transforms.Resize((config["image_res"], config["image_res"]), interpolation=Image.BICUBIC),
            transforms.ToTensor(),
        ]
    )
    aug_n = 2
    aug_m = 5
    aug_scale = (0.7, 1.0)
    train_transform_2 = transforms.Compose(
        [
            transforms.RandomResizedCrop(config["image_res"], scale=aug_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                aug_n,
                aug_m,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Equalize",
                    "Brightness",
                    "Sharpness",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            # normalize,
        ]
    )

    # strong augmentation
    aug_n = 2
    aug_m = 7
    aug_scale = (0.5, 1.0)
    train_transform_5 = transforms.Compose(
        [
            transforms.RandomResizedCrop(config["image_res"], scale=aug_scale, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(),
            RandomAugment(
                aug_n,
                aug_m,
                isPIL=True,
                augs=[
                    "Identity",
                    "AutoContrast",
                    "Equalize",
                    "Brightness",
                    "Sharpness",
                    "ShearX",
                    "ShearY",
                    "TranslateX",
                    "TranslateY",
                    "Rotate",
                ],
            ),
            transforms.ToTensor(),
            # normalize,
        ]
    )

    reference_dataset = re_train_dataset_subset(
        [config["train_file"][0]],
        test_transform,
        config["image_root"],
        indices=test_indices,
        caps_k=1,
    )

    test_dataset_dict = {}
    test_dataset_dict.update({
        key: re_train_dataset_subset(
            [test_config[1]["train_file"][0]],
            test_transform,
            config["image_root"],
            indices=test_indices,
            caps_k=1,
        )
        for key, test_config in test_config_dict.items()
    })
    # random augmentation
    test_dataset_dict["Basic(RandAug)"] = re_train_dataset_subset(
        [config["train_file"][0]],
        train_transform_2,
        config["image_root"],
        indices=test_indices,
        caps_k=1,
    )
    test_config_dict["Basic(RandAug)"] = ["image", config]
    # test_dataset_dict["randaug-strong"] = re_train_dataset_subset(
    #     [config["train_file"][0]],
    #     train_transform_5,
    #     config["image_root"],
    #     indices=test_indices,
    #     caps_k=1,
    # )
    # test_config_dict["randaug-strong"] = ["image", config]


    bs = 256
    reference_loader = DataLoader(reference_dataset, batch_size=bs, pin_memory=True, shuffle=False)
    test_loader_dict = {
        key: DataLoader(test_dataset, batch_size=bs, pin_memory=True, shuffle=False)
        for key, test_dataset in test_dataset_dict.items()
    }

    ########################
    ###### load model ######
    ########################
    print("Loading model")
    model, _, _ = load_model(
        config, args.model, args.ckpt, args.text_encoder, device=device, train_config=train_config
    )
    model.eval()

    model = model.to(device)

    ##############################
    ###### extract features ######
    ##############################
    print("Extracting features")
    if os.path.exists(os.path.join(args.output_dir, "reference_feat_dict.pth")):
    # if False:
        # load features
        reference_feat_dict = torch.load(os.path.join(args.output_dir, "reference_feat_dict.pth"))
        test_feat_dict_dict = torch.load(os.path.join(args.output_dir, "test_feat_dict_dict.pth"))
        reference_images = torch.load(os.path.join(args.output_dir, "reference_images.pth"))
        reference_texts = torch.load(os.path.join(args.output_dir, "reference_texts.pth"))
        test_images_dict = torch.load(os.path.join(args.output_dir, "test_images_dict.pth"))
        test_texts_dict = torch.load(os.path.join(args.output_dir, "test_texts_dict.pth"))

        print("Loaded features")

        Key = "Basic(EDA)"
        test_config_dict[Key] = ["text", config]
    

    else:
        reference_feat_dict, reference_images, reference_texts = feat_extract_test(model, reference_loader, device, config, max_n=test_size)
        print(reference_feat_dict["image"][0][:10])

        test_feat_dict_dict = {}
        test_images_dict = {}
        test_texts_dict = {}
        ##### eda 0.1
        # _features, _images, _texts = feat_extract_test(
        #     model, reference_loader, device, config, eda_alpha=0.1, max_n=test_size
        # )
        # test_feat_dict_dict["EDA0.1"] = _features
        # test_images_dict["EDA0.1"] = _images
        # test_texts_dict["EDA0.1"] = _texts
        # test_config_dict["EDA0.1"] = ["text", config]
        
        ##### eda 0.3
        Key = "Basic(EDA)"
        _features, _images, _texts = feat_extract_test(
            model, reference_loader, device, config, eda_alpha=0.3, max_n=test_size
        )
        test_feat_dict_dict[Key] = _features
        test_images_dict[Key] = _images
        test_texts_dict[Key] = _texts
        test_config_dict[Key] = ["text", config]
        ##### Others
        for key, test_loader in test_loader_dict.items():
            print(key)
            _features, _images, _texts = feat_extract_test(model, test_loader, device, config, max_n=test_size)
            test_feat_dict_dict[key] = _features
            test_images_dict[key] = _images
            test_texts_dict[key] = _texts
            print(_features["image"][0][:10])

        # save features
        torch.save(reference_feat_dict, os.path.join(args.output_dir, "reference_feat_dict.pth"))
        torch.save(test_feat_dict_dict, os.path.join(args.output_dir, "test_feat_dict_dict.pth"))
        torch.save(reference_images, os.path.join(args.output_dir, "reference_images.pth"))
        torch.save(reference_texts, os.path.join(args.output_dir, "reference_texts.pth"))
        torch.save(test_images_dict, os.path.join(args.output_dir, "test_images_dict.pth"))
        torch.save(test_texts_dict, os.path.join(args.output_dir, "test_texts_dict.pth"))


    # color palatte
    orig_color = "black"
    colors = plt.cm.get_cmap("tab10", 10)
    
    test_config_image_list = [key for key, test_config in test_config_dict.items() if test_config[0] == "image"]
    test_config_text_list = [key for key, test_config in test_config_dict.items() if test_config[0] == "text"]
    colors_dict = {key: colors(i) for i, key in enumerate(test_config_image_list)}
    colors_dict.update({key: colors(i+len(test_config_image_list)) for i, key in enumerate(test_config_text_list)})
    colors_dict["Orig."] = orig_color


    ###################################
    ###### distribution gap: KNN ######
    ###################################
    # print("KNN")
    # for k in [1, 50]:
    #     for approx in ["cosine", "l2"]: #, "standard"]:
            
    #         knn_dict = {}
    #         if approx == "cosine":
    #             # normalize
    #             img_feats = F.normalize(reference_feat_dict["image"], dim=-1).numpy()
    #             txt_feats = F.normalize(reference_feat_dict["text"], dim=-1).numpy()
    #         elif approx == "standard":
    #             img_feats = (reference_feat_dict["image"] - reference_feat_dict["image"].mean(dim=0)) / reference_feat_dict["image"].std(dim=0)
    #             txt_feats = (reference_feat_dict["text"] - reference_feat_dict["text"].mean(dim=0)) / reference_feat_dict["text"].std(dim=0)
    #             img_feats = img_feats.numpy()
    #             txt_feats = txt_feats.numpy()
    #         elif approx == "l2":
    #             img_feats = reference_feat_dict["image"].numpy()
    #             txt_feats = reference_feat_dict["text"].numpy()
    #         else:
    #             raise ValueError("Invalid approx")
    #         knn_dict["image"] = faiss.IndexFlatL2(img_feats.shape[-1])
    #         knn_dict["image"].add(img_feats)
    #         knn_dict["text"] = faiss.IndexFlatL2(txt_feats.shape[-1])
    #         knn_dict["text"].add(txt_feats)

    #         _min = 1
    #         _max = 0
    #         scores_ood_dict = {}
    #         for key in test_config_dict:
    #             print(key)
    #             aug_modality = test_config_dict[key][0]

    #             # scores_ood_dict = {}
    #             # for key, test_feat_dict in test_feat_dict_dict.items():
    #             test_feat_dict = test_feat_dict_dict[key]
    #             if approx == "cosine":
    #                 test_feat = F.normalize(test_feat_dict[aug_modality], dim=-1).numpy()
    #             elif approx == "standard":
    #                 test_feat = (test_feat_dict[aug_modality] - test_feat_dict[aug_modality].mean(dim=0)) / test_feat_dict[aug_modality].std(dim=0)
    #                 test_feat = test_feat.numpy()
    #             elif approx == "l2":
    #                 test_feat = test_feat_dict[aug_modality].numpy()
    #             else:
    #                 raise ValueError("Invalid approx")
    #             D, I = knn_dict[aug_modality].search(test_feat, k)
    #             scores_ood = D[:, -1]
    #             scores_ood_dict.setdefault(aug_modality, {})[key] = scores_ood
    #             _min = min(_min, scores_ood.min())
    #             _max = max(_max, scores_ood.max())

    #         # plot histogram
    #         for m in ["image", "text"]:
    #             for key, scores_ood in scores_ood_dict[m].items():
    #                 plt.hist(scores_ood, bins=50, alpha=0.5, label=key, histtype=u'step')
    #             plt.legend()
    #             plt.title(f"{m} KNN")
    #             plt.savefig(os.path.join(args.output_dir, f"{m}_KNN={k}_{approx}.png"))
    #             plt.close()

    #             # density plot
    #             for key, scores_ood in scores_ood_dict[m].items():
    #                 density = stats.gaussian_kde(scores_ood)
    #                 x = np.linspace(_min, _max, 1000)
    #                 plt.plot(x, density(x), label=key)
    #             plt.legend()
    #             plt.title(f"{m} KNN")
    #             plt.savefig(os.path.join(args.output_dir, f"{m}_KNN={k}_{approx}_density.png"))
    #             plt.close()

    ##################################################
    ###### distribution gap: pair-wise distance ######
    ##################################################
    print("Pair-wise distance")
    _min_cos_sim = 1
    _max_cos_sim = 0
    _min_l2_dist = 100
    _max_l2_dist = -100
    image_pairwise_cos_sim_dict = {}
    image_pairwise_l2_dist_dict = {}
    text_pairwise_cos_sim_dict = {}
    text_pairwise_l2_dist_dict = {}
    for key in test_config_dict:
        aug_modality = test_config_dict[key][0]
        test_feat_dict = test_feat_dict_dict[key]

        cos_sim = 1 - F.cosine_similarity(reference_feat_dict[aug_modality], test_feat_dict[aug_modality], dim=-1)
        l2_dist = F.pairwise_distance(reference_feat_dict[aug_modality], test_feat_dict[aug_modality], p=2)
        if aug_modality == "image":
            image_pairwise_cos_sim_dict[key] = cos_sim
            image_pairwise_l2_dist_dict[key] = l2_dist
        else:
            text_pairwise_cos_sim_dict[key] = cos_sim
            text_pairwise_l2_dist_dict[key] = l2_dist
        _min_cos_sim = min(_min_cos_sim, cos_sim.min())
        _max_cos_sim = max(_max_cos_sim, cos_sim.max())
        _min_l2_dist = min(_min_l2_dist, l2_dist.min())
        _max_l2_dist = max(_max_l2_dist, l2_dist.max())

        print("cossim:", key, cos_sim.min(), cos_sim.max())
        print("l2dist:", key, l2_dist.min(), l2_dist.max())
    _max_l2_dist = 11

    # plot histogram: cos dist
    for m, cos_sim_dict in zip(["image", "text"], [image_pairwise_cos_sim_dict, text_pairwise_cos_sim_dict]):
        # for key, cos_sim in cos_sim_dict.items():
        #     plt.hist(cos_sim, bins=50, alpha=0.5, label=key, histtype=u'step')
        # plt.legend()
        # plt.xlabel("Cosine distance")
        # plt.ylabel("Number of samples")
        # plt.title("Cosine distance from orig. sample's representation")
        # plt.savefig(os.path.join(args.output_dir, f"{m}_pairwise_cos_dist.png"))
        # plt.close()

        # density plot
        plt.figure(figsize=(6, 3))
        for key, cos_sim in cos_sim_dict.items():
            density = stats.gaussian_kde(cos_sim)
            x = np.linspace(_min_cos_sim, _max_cos_sim, 1000)
            y = density(x) * len(cos_sim)
            plt.plot(x, y, label=key)
        plt.legend()
        plt.xlim(_min_cos_sim, _max_cos_sim)
        plt.xlabel("Cosine distance")
        plt.ylabel("Num. samples")
        # plt.ylabel("Density")
        plt.title("Cosine distance from orig. sample's representation")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{m}_pairwise_cos_dist_density.png"))
        plt.close()

    # plot histogram: l2 dist
    for m, l2_dist_dict in zip(["image", "text"], [image_pairwise_l2_dist_dict, text_pairwise_l2_dist_dict]):
        # for key, l2_dist in l2_dist_dict.items():
        #     plt.hist(l2_dist, bins=50, alpha=0.5, label=key, histtype=u'step')
        # plt.legend()
        # plt.xlabel("L2 distance")
        # plt.ylabel("Number of samples")
        # plt.title("L2 distance from orig. sample's representation")
        # plt.savefig(os.path.join(args.output_dir, f"{m}_pairwise_l2_dist.png"))
        # plt.close()

        # density plot
        plt.figure(figsize=(6, 3)) 
        for key, l2_dist in l2_dist_dict.items():
            density = stats.gaussian_kde(l2_dist)
            x = np.linspace(_min_l2_dist, _max_l2_dist, 1000)
            y = density(x) * len(l2_dist)
            plt.plot(x, y, label=key, color=colors_dict[key])
            print(key, l2_dist.min(), l2_dist.max())
        plt.xlim(0, _max_l2_dist)
        plt.legend()
        plt.xlabel("L2 distance")
        plt.ylabel("Num. samples")
        # plt.ylabel("Density")
        plt.title("L2 distance from orig. sample's representation")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{m}_pairwise_l2_dist_density.png"))
        plt.close()


    # save statistics
    data = {}
    for m, cos_sim_dict in zip(["image", "text"], [image_pairwise_cos_sim_dict, text_pairwise_cos_sim_dict]):
        for key, cos_sim in cos_sim_dict.items():
            data.setdefault(m, {})[key] = {
                "mean": cos_sim.mean().item(),
                "std": cos_sim.std().item(),
                "50th": np.median(cos_sim.numpy()),
                "25th": np.percentile(cos_sim.numpy(), 25),
                "75th": np.percentile(cos_sim.numpy(), 75),
            }
            # float
            for k, v in data[m][key].items():
                data[m][key][k] = float(v)
            data[m][key]["all"] = [float(x) for x in cos_sim.numpy().tolist()]
    with open(os.path.join(args.output_dir, "pairwise_cos_sim_stats.json"), "w") as f:
        json.dump(data, f, indent=4)

    data = {}
    for m, l2_dist_dict in zip(["image", "text"], [image_pairwise_l2_dist_dict, text_pairwise_l2_dist_dict]):
        for key, l2_dist in l2_dist_dict.items():
            data.setdefault(m, {})[key] = {
                "mean": l2_dist.mean().item(),
                "std": l2_dist.std().item(),
                "50th": np.median(l2_dist.numpy()),
                "25th": np.percentile(l2_dist.numpy(), 25),
                "75th": np.percentile(l2_dist.numpy(), 75),
            }
            # float
            for k, v in data[m][key].items():
                data[m][key][k] = float(v)
            data[m][key]["all"] = [float(x) for x in l2_dist.numpy().tolist()]
    with open(os.path.join(args.output_dir, "pairwise_l2_dist_stats.json"), "w") as f:
        json.dump(data, f, indent=4)


    ##########################
    ###### alignment #########
    ##########################
    print("Alignment orig vs aug")
    orig_img_feat = reference_feat_dict["image"]
    orig_txt_feat = reference_feat_dict["text"]
    cos_sim_orig = F.cosine_similarity(orig_img_feat, orig_txt_feat, dim=-1)

    _min = 1
    _max = 0
    cos_sim_test_dict = {}
    for key, test_feat_dict in test_feat_dict_dict.items():
        aug_modality = test_config_dict[key][0]

        aug_img_feat = test_feat_dict["image"]
        aug_txt_feat = test_feat_dict["text"]
        if aug_modality == "image":
            cos_sim_test = F.cosine_similarity(aug_img_feat, orig_txt_feat, dim=-1)
        else:
            cos_sim_test = F.cosine_similarity(orig_img_feat, aug_txt_feat, dim=-1)
        cos_sim_test_dict.setdefault(aug_modality, {})[key] = cos_sim_test
        _min = min(_min, cos_sim_test.min())
        _max = max(_max, cos_sim_test.max())


    for aug_modality, this_cos_sim_test_dict in cos_sim_test_dict.items():
        # hist
        # plt.hist(cos_sim_orig, bins=50, alpha=0.5, label="Orig.", histtype=u'step')
        # for key, cos_sim_test in cos_sim_test_dict.items():
        #     plt.hist(cos_sim_test, bins=50, alpha=0.5, label=key, histtype=u'step')
        # plt.legend()
        # plt.title(f"Alignment score of image-text pairs")
        # plt.xlabel("Cosine similarity")
        # plt.ylabel("Number of samples")
        # plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_alignment.png"))
        # plt.close()

        # density plot
        plt.figure(figsize=(6, 3)) 
        density = stats.gaussian_kde(cos_sim_orig)
        x = np.linspace(_min, _max, 1000)
        y = density(x) * len(cos_sim_orig)
        plt.plot(x, y, label="Orig.", color=orig_color, linestyle="--")
        print("Orig.", cos_sim_orig.min(), cos_sim_orig.max())
        for key, cos_sim_test in this_cos_sim_test_dict.items():
            density = stats.gaussian_kde(cos_sim_test)
            y = density(x) * len(cos_sim_test)
            plt.plot(x, y, label=key, color=colors_dict[key])
            print(key, cos_sim_test.min(), cos_sim_test.max())
        plt.legend()
        plt.title("Alignment score of image-text pairs")
        plt.xlabel("Cosine similarity")
        plt.ylabel("Num. samples")
        # plt.ylabel("Density")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_alignment_density.png"))
        plt.close()

    # save stats
    data = {}
    for aug_modality, this_cos_sim_test_dict in cos_sim_test_dict.items():
        for key, cos_sim_test in this_cos_sim_test_dict.items():
            data.setdefault(aug_modality, {})[key] = {
                "mean": cos_sim_test.mean().item(),
                "std": cos_sim_test.std().item(),
                "50th": np.median(cos_sim_test.numpy()),
                "25th": np.percentile(cos_sim_test.numpy(), 25),
                "75th": np.percentile(cos_sim_test.numpy(), 75),
            }
            # float
            for k, v in data[aug_modality][key].items():
                data[aug_modality][key][k] = float(v)
            data[aug_modality][key]["all"] = [float(x) for x in cos_sim_test.numpy().tolist()]
    with open(os.path.join(args.output_dir, "alignment_stats.json"), "w") as f:
        json.dump(data, f, indent=4)

    
    ###############################################################
    ###### 3d density map of alignment score and diff score #######
    ###############################################################
    print("3D visualization")
    # for each augmentation, visualize orig vs aug
    for key, test_feat_dict in test_feat_dict_dict.items():
        aug_modality = test_config_dict[key][0]
        paired_modality = "text" if aug_modality == "image" else "image"
        
        print(key, aug_modality)
        cos_sim_alignment = F.cosine_similarity(reference_feat_dict[paired_modality], test_feat_dict[aug_modality], dim=-1)
        cos_sim_diff = F.cosine_similarity(reference_feat_dict[aug_modality], test_feat_dict[aug_modality], dim=-1)

        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        hist, xedges, yedges = np.histogram2d(cos_sim_alignment.numpy(), cos_sim_diff.numpy(), bins=50, range=[[0, 1], [0, 1]])

        xpos, ypos = np.meshgrid(xedges[:-1] + 0.25, yedges[:-1] + 0.25, indexing="ij")
        xpos = xpos.ravel()
        ypos = ypos.ravel()
        zpos = 0

        dx = dy = 0.5 * np.ones_like(zpos)
        dz = hist.ravel()

        ax.bar3d(xpos, ypos, zpos, dx, dy, dz, zsort='average')

        ax.set_xlabel('Alignment')
        ax.set_ylabel('Diff')
        ax.set_zlabel('Frequency')

        plt.title(f"3D {aug_modality} {key}")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_3D_{key}.png"))
        plt.close()


    ##########################
    ###### visualization #########
    #########################

    # Standardize features
    # for k, v in reference_feat_dict.items():
    #     reference_feat_dict[k] = (v - v.mean(dim=0)) / v.std(dim=0)
    # for k, v in test_feat_dict_dict.items():
    #     for kk, vv in v.items():
    #         test_feat_dict_dict[k][kk] = (vv - vv.mean(dim=0)) / vv.std(dim=0)

    # l2 normalize
    for k, v in reference_feat_dict.items():
        reference_feat_dict[k] = F.normalize(v, dim=-1)
    for k, v in test_feat_dict_dict.items():
        for kk, vv in v.items():
            test_feat_dict_dict[k][kk] = F.normalize(vv, dim=-1)

    # TSNE visualization
    show_N = min(100, test_size)
    indices = np.random.choice(test_size, show_N, replace=False)

    print("== TSNE ==")
    from sklearn.manifold import TSNE

    # for each augmentation, visualize orig vs aug
    # tsne is trained on orig and aug
    for key, test_feat_dict in test_feat_dict_dict.items():
        aug_modality = test_config_dict[key][0]
        paired_modality = "text" if aug_modality == "image" else "image"
        
        print(key, aug_modality)
        tsne = TSNE(n_components=2, random_state=0)
        orig_feat = reference_feat_dict[aug_modality]
        orig_paired_feat = reference_feat_dict[paired_modality]
        aug_feat = test_feat_dict[aug_modality]
        feat_all = torch.cat([orig_feat, orig_paired_feat, aug_feat], dim=0).numpy()
        tsne_feats = tsne.fit_transform(feat_all)
        tsne_orig = tsne_feats[:test_size]
        tsne_paired = tsne_feats[test_size:2*test_size]
        tsne_aug = tsne_feats[2*test_size:]

        print(test_size)
        print(tsne_feats.shape)
        print(tsne_orig.shape)
        print(tsne_paired.shape)
        print(tsne_aug.shape)

        #########
        # before after visualization
        plt.scatter(tsne_orig[indices, 0], tsne_orig[indices, 1], label=f"orig. {aug_modality}")
        plt.scatter(tsne_aug[indices, 0], tsne_aug[indices, 1], label=f"aug. {aug_modality}")

        # draw arrows for each data point
        for i in indices:
            # orig -> aug
            plt.arrow(
                tsne_orig[i, 0],
                tsne_orig[i, 1],
                tsne_aug[i, 0] - tsne_orig[i, 0],
                tsne_aug[i, 1] - tsne_orig[i, 1],
                head_width=0.1,
                head_length=0.1,
                fc="k",
                ec="k",
                linewidth=0.5
            )

        plt.legend()
        plt.title("Visualization of datapoints")
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_TSNE_{key}_1.png"))
        plt.close()

        #########
        # triplet visualization
        plt.scatter(tsne_orig[indices, 0], tsne_orig[indices, 1], label=f"orig. {aug_modality}")
        plt.scatter(tsne_aug[indices, 0], tsne_aug[indices, 1], label=f"aug. {aug_modality}")
        plt.scatter(tsne_paired[indices, 0], tsne_paired[indices, 1], label=f"paired {paired_modality}")

        # draw arrows for each data point
        for i in indices:
            # orig -> aug
            plt.arrow(
                tsne_orig[i, 0],
                tsne_orig[i, 1],
                tsne_aug[i, 0] - tsne_orig[i, 0],
                tsne_aug[i, 1] - tsne_orig[i, 1],
                head_width=0.1,
                head_length=0.1,
                fc="k",
                ec="k",
                linewidth=0.5
            )
            # orig -- pair (dotted line)
            plt.plot([tsne_orig[i, 0], tsne_paired[i, 0]], [tsne_orig[i, 1], tsne_paired[i, 1]], linestyle="--", color="k", linewidth=0.2)
            # pair -- aug (dotted line)
            plt.plot([tsne_paired[i, 0], tsne_aug[i, 0]], [tsne_paired[i, 1], tsne_aug[i, 1]], linestyle="--", color="k", linewidth=0.2)

        plt.legend()
        plt.title(f"Visualization of datapoints")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_TSNE_{key}.png"))
        plt.close()

    # UMAP visualization
    print("== UMAP ==")
    import umap

    # for each augmentation, visualize orig vs aug
    # tsne is trained on orig and aug
    for key, test_feat_dict in test_feat_dict_dict.items():
        aug_modality = test_config_dict[key][0]
        paired_modality = "text" if aug_modality == "image" else "image"
        
        print(key, aug_modality)
        umap_model = umap.UMAP(random_state=0)
        orig_feat = reference_feat_dict[aug_modality]
        orig_paired_feat = reference_feat_dict[paired_modality]
        aug_feat = test_feat_dict[aug_modality]
        feat_all = torch.cat([orig_feat, orig_paired_feat, aug_feat], dim=0).numpy()
        umap_orig_aug = umap_model.fit_transform(feat_all)
        umap_orig = umap_orig_aug[:test_size]
        umap_paired = umap_orig_aug[test_size:2*test_size]
        umap_aug = umap_orig_aug[2*test_size:]

        #########
        # before after visualization
        plt.scatter(umap_orig[indices, 0], umap_orig[indices, 1], label=f"orig. {aug_modality}")
        plt.scatter(umap_aug[indices, 0], umap_aug[indices, 1], label=f"aug. {aug_modality}")

        # draw arrows for each data point
        for i in indices:
            # orig -> aug
            plt.arrow(
                umap_orig[i, 0],
                umap_orig[i, 1],
                umap_aug[i, 0] - umap_orig[i, 0],
                umap_aug[i, 1] - umap_orig[i, 1],
                head_width=0.1,
                head_length=0.1,
                fc="k",
                ec="k",
                linewidth=0.5
            )

        plt.legend()
        plt.title("Visualization of datapoints")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_UMAP_{key}_1.png"))
        plt.close()

        #########
        # triplet visualization
        plt.scatter(umap_orig[indices, 0], umap_orig[indices, 1], label=f"orig. {aug_modality}")
        plt.scatter(umap_aug[indices, 0], umap_aug[indices, 1], label=f"aug. {aug_modality}")
        plt.scatter(umap_paired[indices, 0], umap_paired[indices, 1], label=f"paired {paired_modality}")
        

        # draw arrows for each data point
        for i in indices:
            # orig -> aug
            plt.arrow(
                umap_orig[i, 0],
                umap_orig[i, 1],
                umap_aug[i, 0] - umap_orig[i, 0],
                umap_aug[i, 1] - umap_orig[i, 1],
                head_width=0.1,
                head_length=0.1,
                fc="k",
                ec="k",
                linewidth=0.5
            )
            # orig -- pair (dotted line)
            plt.plot([umap_orig[i, 0], umap_paired[i, 0]], [umap_orig[i, 1], umap_paired[i, 1]], linestyle="--", color="k", linewidth=0.2)
            # pair -- aug (dotted line)
            plt.plot([umap_paired[i, 0], umap_aug[i, 0]], [umap_paired[i, 1], umap_aug[i, 1]], linestyle="--", color="k", linewidth=0.2)

        plt.legend()
        plt.title("Visualization of datapoints")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, f"{aug_modality}_UMAP_{key}.png"))
        plt.close()


    # Visualize Caps and SD at the same time
    orig_img_feat = reference_feat_dict["image"]
    orig_txt_feat = reference_feat_dict["text"]
    aug_img_feat = test_feat_dict_dict["SD"]["image"]
    aug_txt_feat = test_feat_dict_dict["Caps"]["text"]

    tsne = TSNE(n_components=2, random_state=0)
    feat_all = torch.cat([orig_img_feat, orig_txt_feat, aug_img_feat, aug_txt_feat], dim=0).numpy()
    tsne_feats = tsne.fit_transform(feat_all)
    tsne_orig_img = tsne_feats[:test_size]
    tsne_orig_txt = tsne_feats[test_size:2*test_size]
    tsne_aug_img = tsne_feats[2*test_size:3*test_size]
    tsne_aug_txt = tsne_feats[3*test_size:]

    plt.scatter(tsne_orig_img[indices, 0], tsne_orig_img[indices, 1], label="orig_img")
    plt.scatter(tsne_orig_txt[indices, 0], tsne_orig_txt[indices, 1], label="orig_txt")
    plt.scatter(tsne_aug_img[indices, 0], tsne_aug_img[indices, 1], label="aug_img")
    plt.scatter(tsne_aug_txt[indices, 0], tsne_aug_txt[indices, 1], label="aug_txt")

    # draw arrows for each data point
    for i in indices:
        # orig_img -> aug_img
        plt.arrow(
            tsne_orig_img[i, 0],
            tsne_orig_img[i, 1],
            tsne_aug_img[i, 0] - tsne_orig_img[i, 0],
            tsne_aug_img[i, 1] - tsne_orig_img[i, 1],
            head_width=0.1,
            head_length=0.1,
            fc="k",
            ec="k",
            linewidth=0.5
        )
        # orig_txt -> aug_txt
        plt.arrow(
            tsne_orig_txt[i, 0],
            tsne_orig_txt[i, 1],
            tsne_aug_txt[i, 0] - tsne_orig_txt[i, 0],
            tsne_aug_txt[i, 1] - tsne_orig_txt[i, 1],
            head_width=0.1,
            head_length=0.1,
            fc="k",
            ec="k",
            linewidth=0.5
        )

    # orig_img - orig_txt
    for i in indices:
        plt.plot([tsne_orig_img[i, 0], tsne_orig_txt[i, 0]], [tsne_orig_img[i, 1], tsne_orig_txt[i, 1]], linestyle="--", color="k", linewidth=0.2)

    plt.legend()
    plt.title(f"TSNE Caps + SD")
    plt.savefig(os.path.join(args.output_dir, f"TSNE_Caps_SD.png"))
    plt.close()

    # UMAP visualization
    umap_model = umap.UMAP(random_state=0)
    feat_all = torch.cat([orig_img_feat, orig_txt_feat, aug_img_feat, aug_txt_feat], dim=0).numpy()
    umap_orig_aug = umap_model.fit_transform(feat_all)
    umap_orig_img = umap_orig_aug[:test_size]
    umap_orig_txt = umap_orig_aug[test_size:2*test_size]
    umap_aug_img = umap_orig_aug[2*test_size:3*test_size]
    umap_aug_txt = umap_orig_aug[3*test_size:]

    plt.scatter(umap_orig_img[indices, 0], umap_orig_img[indices, 1], label="orig_img")
    plt.scatter(umap_orig_txt[indices, 0], umap_orig_txt[indices, 1], label="orig_txt")
    plt.scatter(umap_aug_img[indices, 0], umap_aug_img[indices, 1], label="aug_img")
    plt.scatter(umap_aug_txt[indices, 0], umap_aug_txt[indices, 1], label="aug_txt")

    # draw arrows for each data point
    for i in indices:
        # orig_img -> aug_img
        plt.arrow(
            umap_orig_img[i, 0],
            umap_orig_img[i, 1],
            umap_aug_img[i, 0] - umap_orig_img[i, 0],
            umap_aug_img[i, 1] - umap_orig_img[i, 1],
            head_width=0.1,
            head_length=0.1,
            fc="k",
            ec="k",
            linewidth=0.5
        )
        # orig_txt -> aug_txt
        plt.arrow(
            umap_orig_txt[i, 0],
            umap_orig_txt[i, 1],
            umap_aug_txt[i, 0] - umap_orig_txt[i, 0],
            umap_aug_txt[i, 1] - umap_orig_txt[i, 1],
            head_width=0.1,
            head_length=0.1,
            fc="k",
            ec="k",
            linewidth=0.5
        )

    # orig_img - orig_txt
    for i in indices:
        plt.plot([umap_orig_img[i, 0], umap_orig_txt[i, 0]], [umap_orig_img[i, 1], umap_orig_txt[i, 1]], linestyle="--", color="k", linewidth=0.2)

    plt.legend()
    plt.title(f"UMAP Caps + SD")
    plt.savefig(os.path.join(args.output_dir, f"UMAP_Caps_SD.png"))
    plt.close()


    ###################################################
    #### Visualize top N, worst N aligned pairs #######
    ###################################################
    N = 20
    print("Visualize top N, worst N aligned pairs")
    for key, test_feat_dict in test_feat_dict_dict.items():
        aug_modality = test_config_dict[key][0]

        cos_sim = F.cosine_similarity(reference_feat_dict[aug_modality], test_feat_dict[aug_modality], dim=-1)
        topk = torch.topk(cos_sim, N, largest=True)
        worstk = torch.topk(cos_sim, N, largest=False)
        # reversed_worstk = torch.topk(cos_sim, N, largest=True)
        indices = torch.cat([topk.indices, worstk.indices], dim=0)
        
        if aug_modality == "image":
            # for each row, show original image and augmented image, text description + distance
            # for each column, show top N, worst N
            # number of rows = N * 2
            fig_size_x = 5
            fig_size_y = N * 5
            plt.subplots(N * 2, 3, figsize=(fig_size_x, fig_size_y))
            for i, idx in enumerate(indices):
                # original image
                numpy_image = (reference_images[idx].numpy() * 255).astype(np.uint8).transpose(1, 2, 0)
                resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
                resized_img = resized_img.astype(np.uint8)
                plt.subplot(N * 2, 3, 3 * i + 1)
                plt.imshow(resized_img)
                plt.axis("off")
                # augmented image
                numpy_image = (test_images_dict[key][idx].numpy() * 255).astype(np.uint8).transpose(1, 2, 0)
                resized_img = np.array(Image.fromarray(numpy_image).resize((224, 224)))
                resized_img = resized_img.astype(np.uint8)
                plt.subplot(N * 2, 3, 3 * i + 2)
                plt.imshow(resized_img)
                plt.axis("off")
                # text/distance at 3rd column
                plt.subplot(N * 2, 3, 3 * i + 3)
                txt = reference_texts[idx]
                this_cos_sim = cos_sim[idx].item()
                plt.text(0.5, 0.5, f"cos-sim: {this_cos_sim:.2f}\n{txt}", ha="center", va="center", wrap=True)
                # print top or worst k
                if i < N:
                    plt.text(0.5, 0.8, f"Top {i+1}", ha="center", va="center", wrap=True)
                else:
                    plt.text(0.5, 0.8, f"Worst {i+1-N}", ha="center", va="center", wrap=True)
                plt.axis("off")
            
            # save fig
            save_fig_path = os.path.join(args.output_dir, f"{aug_modality}_top_worst_{key}.png")
            plt.savefig(save_fig_path, bbox_inches="tight")

        elif aug_modality == "text":
            # save to json folder
            # key: top-{k}
            # value: {"orig_text": str, "aug_text": str, "cos_sim": float}
            top5_worst5_dict = {}
            for i, idx in enumerate(indices[:N]):
                top5_worst5_dict[f"top-{i+1}"] = {
                    "orig_text": reference_texts[idx],
                    "aug_text": test_texts_dict[key][idx],
                    "cos_sim": cos_sim[idx].item()
                }
            for i, idx in enumerate(indices[-N:]):
                k = i+1-N
                top5_worst5_dict[f"worst-{k}"] = {
                    "orig_text": reference_texts[idx],
                    "aug_text": test_texts_dict[key][idx],
                    "cos_sim": cos_sim[idx].item()
                }
            save_json_path = os.path.join(args.output_dir, f"{aug_modality}_top_worst_{key}.json")
            with open(save_json_path, "w") as f:
                json.dump(top5_worst5_dict, f, indent=4)




    print("Done")



def t2bool(t):
    if t.lower() == "true":
        return True
    elif t.lower() == "false":
        return False
    else:
        raise ValueError("Invalid value")


def t2fl(t):
    """text to float list"""
    return [float(x) for x in t.split(",")]


def t2il(t):
    """text to int list"""
    return [int(x) for x in t.split(",")]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
    parser.add_argument("--test_config", nargs="+", default="./configs/Retrieval_flickr_train.yaml")
    parser.add_argument("--test_config_name", nargs="+", default="flickr")
    parser.add_argument("--test_config_modal", nargs="+", default="image")

    parser.add_argument("--seed", default=42, type=int)

    parser.add_argument("--model", default="CLIP_ViT-B-16_PT", type=str)  # model architecture
    parser.add_argument("--model_name", default="CLIP_ViT-B-16_PT", type=str)  # id for the model
    parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--ckpt", default=None, type=str)

    # training config
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--world_size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--gpu", default=0, nargs="+", type=int, help="GPU id to use.")
    parser.add_argument("--dist_url", default="env://", help="url used to set up distributed training")
    parser.add_argument("--distributed", default=False, type=bool)
    parser.add_argument("--multi_gpu", default=False, type=bool)

    # adversarial training config
    parser.add_argument(
        "--attack",
        default=None,
        type=str,
        choices=[
            "SGA",
            "Co-Attack",
            "Sep-Attack",
            "PGD",
            "BERT",
            "Clean",
            "FSGA",
            "PDE-MMA",
            "SupPGD",
            "UnsupPGD",
            "MMA",
        ],
    )
    parser.add_argument("--is_rand_mask", default=False, action="store_true")
    parser.add_argument("--attack_fused_emb", default=False, type=bool)
    parser.add_argument("--cls", default=False, type=bool)
    parser.add_argument("--output_dir", default="../train_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)  # for Co-Attack
    parser.add_argument("--num_iters", default=10, type=int)
    parser.add_argument("--step_size", default=0.5, type=float)

    # dataset
    parser.add_argument(
        "--caps_k", default=None, type=int
    )  # how many captions per image is used for training.
    parser.add_argument("--aug_n", default=2, type=int)
    parser.add_argument("--aug_m", default=7, type=int)
    parser.add_argument("--aug_scale", default=0.5, type=float)
    parser.add_argument("--is_eda", default=False, type=t2bool, help="easy data augmentation")

    # FSGA config
    parser.add_argument("--scale_ver", default=0, type=int)
    parser.add_argument("--txt_att_k", default=0, type=int)
    parser.add_argument("--txt_attack", default=None, type=str, choices=["rand", "bert"])
    parser.add_argument("--img_attack_loss", default=False, type=str)

    # MMA config
    parser.add_argument("--is_use_gt_caps", default=False, type=t2bool)
    parser.add_argument(
        "--txt_sup_k", default=5, type=int
    )  # if > 1, use augmented texts for text-supervised image attack
    parser.add_argument("--alpha_sr", default=0.1, type=float)
    parser.add_argument("--alpha_ri", default=0.1, type=float)
    parser.add_argument("--alpha_rs", default=0.1, type=float)
    parser.add_argument("--p_rd", default=0.1, type=float)
    parser.add_argument("--alpha_unsup", default=0.0, type=float)
    parser.add_argument("--alpha_sup", default=1.0, type=float)
    parser.add_argument("--is_txt_aug", default=False, type=t2bool)
    parser.add_argument("--txt_aug", default="sr", type=str)
    parser.add_argument("--curric_eps", default=None, type=t2fl)
    parser.add_argument("--curric_iter", default=None, type=t2il)

    # train config
    parser.add_argument("--train_config", default=None, type=str)
    parser.add_argument("--fix_gau", default=False, action="store_true")
    parser.add_argument("--label_smoothing", default=0.0, type=float)
    parser.add_argument("--is_aug_txt", default=False, type=t2bool)
    parser.add_argument("--aug_alpha", default=0.3, type=float)
    parser.add_argument("--no_attack_warmup_epoch", default=0, type=int)

    # pre-trained model guided adversarial training: https://github.com/serendipity1122/Pre-trained-Model-Guided-Fine-Tuning-for-Zero-Shot-Adversarial-Robustness/
    parser.add_argument("--is_pretrained_model_guided", default=False, action="store_true")

    # other
    parser.add_argument("--skip_eval", default=False, action="store_true")
    parser.add_argument("--eval_every_epoch", default=False, action="store_true")
    parser.add_argument("--resume_ckpt", default=None, type=str)
    parser.add_argument("--is_resume_opt", default=None, type=str)

    # evaluation
    parser.add_argument("--evaluate", action="store_true")
    parser.add_argument("--eval_ckpt_path", default=None, type=str)
    parser.add_argument("--eval_only_clean", default=False, action="store_true")
    parser.add_argument("--eval_train_subset", default=True, action="store_true")

    # overwrite config for grid search
    parser.add_argument("--lr", default=None, type=float)
    parser.add_argument("--pde_mul_lr", default=None, type=float)
    parser.add_argument("--epochs", default=None, type=int)
    parser.add_argument("--total_steps", default=None, type=int)
    parser.add_argument("--batch_size", default=None, type=int)
    parser.add_argument("--lr_scheduler", default=None, type=str)

    parser.add_argument("--mark", default=None, type=str)

    args = parser.parse_args()

    assert args.model not in ["ALBEF", "ALBEF_PT", "TCL", "TCL_PT"]

    ATTACK_EVAL_LIST = ["SupPGD", "UnsupPGD", "BERT", "Co-Attack", "SGA"]
    if args.eval_only_clean:
        ATTACK_EVAL_LIST = []
    # ATTACK_EVAL_LIST = ["PGD"]
    # ATTACK_EVAL_LIST = ["SGA"]

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
    config = edict(config)

    # test_config = yaml.load(open(args.test_config, "r"), Loader=yaml.Loader)
    # test_config = edict(test_config)

    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.Loader)
    train_config = edict(train_config)

    attack_config = train_config["attack"]

    # create output directory
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # log print
    sys.stdout = utils.Tee(sys.stdout, open(os.path.join(args.output_dir, "out.txt"), "w"))

    ############################
    ## save args
    ############################
    with open(os.path.join(args.output_dir, "args.json"), "w") as f:
        json.dump(vars(args), f, indent=4)
    # save config
    with open(os.path.join(args.output_dir, "config.json"), "w") as f:
        json.dump(config, f, indent=4)
    if args.train_config is not None:
        with open(os.path.join(args.output_dir, "train_config.json"), "w") as f:
            json.dump(train_config, f, indent=4)

    loss_img = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).to(device)
    loss_txt = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).to(device)

    test_config_dict = {
        test_config_name: [test_config_modal, edict(yaml.load(open(test_config, "r"), Loader=yaml.Loader))]
        for test_config_name, test_config_modal, test_config in zip(
            args.test_config_name, args.test_config_modal, args.test_config
        )
    }
    main(args, config, test_config_dict)
