"""
Evaluation codes for robust cross-modal retrieval.
"""
import time
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from matplotlib import pyplot as plt
import json

from constants import images_normalize
import utils.utils as utils
from utils.utils_attack import attack_batch_eval
from utils.utils_text_metrics import get_sentence_length, get_pos_counts, get_sentence_inv_freq_sum


def eval_pipeline(
        args, model, data_loader, tokenizer, device, config, attacker=None, attack_name=None,
        num_iters=10,
        **kwargs,
    ):
    """
    Evaluate the model on the given dataset.
    """
    if attacker is not None:
        print(f"Evaluating... {num_iters=}")
    score_matrix_i2t, score_matrix_t2i, feat_dict, adv_images_list, adv_texts_list = get_score_metrix(
        args,
        model, data_loader, tokenizer, device, config, attacker, attack_name,
        attack_fused_emb=kwargs.get("attack_fused_emb", False),
        num_iters=num_iters,
    )
    eval_result = itm_eval(score_matrix_i2t, score_matrix_t2i, data_loader.dataset.txt2img, data_loader.dataset.img2txt)
    eval_result_2 = itm_acc_calculator(feat_dict, data_loader)
    for k in eval_result_2: 
        eval_result[k].update(eval_result_2[k])

    return eval_result, score_matrix_i2t, score_matrix_t2i, feat_dict, adv_images_list, adv_texts_list


def get_score_metrix(
        args,
        model, data_loader, tokenizer, device, config, 
        attacker=None, 
        attack_name=None,
        attack_fused_emb=False,
        num_iters=10,
    ):
    """
    for CLIP
    1. compute image and text embeddings
    2. compute similarity matrix

    Returns: 
    - score_matrix_i2t: image to text similarity matrix
    - score_matrix_t2i: text to image similarity matrix
    - feat_dict: dictionary of image and text embeddings
    """
    # test
    model.float()
    # model.train()
    model.eval() 
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Evaluation:'    
    
    print('Computing features for evaluation...')
    start_time = time.time()  

    # for visualization
    adv_images_list = None
    adv_texts_list = []

    text_feats = []
    image_feats = []
    texts = data_loader.dataset.text
    img2txt = data_loader.dataset.img2txt
    for i, (image, img_id) in enumerate(metric_logger.log_every(data_loader, 50, header)):
        this_images = image.to(device,non_blocking=True)
        this_texts = []
        txt2img = []
        for j in range(image.shape[0]):
            _id = img_id[j]
            text_id_group = img2txt[_id.item()]
            _text = [texts[_t_id] for _t_id in text_id_group]
            this_texts.extend(_text)
            txt2img.extend([j]*len(_text))

        if attacker is not None:
            this_images, this_texts = attack_batch_eval(
                args, attack_name, attacker,
                this_images, this_texts, txt2img, device, 
                attack_fused_emb=attack_fused_emb,
                num_iters=num_iters
            )
            this_images.detach_()

        # for visualization
        np_adv_images = this_images.cpu().numpy().transpose(0, 2, 3, 1)
        if adv_images_list is None:
            adv_images_list = np_adv_images
        else:
            adv_images_list = np.concatenate([adv_images_list, np_adv_images], axis=0)
        adv_texts_list += this_texts

        # normalize image
        this_images = images_normalize(this_images)

        # forward
        this_text_input = tokenizer(
            this_texts,
            padding="max_length",
            truncation=True,
            max_length=30,
            return_tensors="pt",
        ).to(device)

        text_feat = model.inference_text(this_text_input)["text_feat"].cpu().detach()
        image_feat = model.inference_image(this_images)["image_feat"].cpu().detach()
        text_feats.append(text_feat)
        image_feats.append(image_feat)

    text_feats = torch.cat(text_feats, dim=0)
    image_feats = torch.cat(image_feats, dim=0)
    print("text_feats:", text_feats.shape)
    print("image_feats:", image_feats.shape)

    sims_matrix = image_feats @ text_feats.t()
    score_matrix_i2t = sims_matrix 
    score_matrix_t2i = sims_matrix.t()

    feat_dict = {
        "text_feats": text_feats,
        "image_feats": image_feats,
    }

    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy(), feat_dict, adv_images_list, adv_texts_list


@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
    """
    Returns ITM metrics.
    """
    
    #Images->Text 
    ranks = np.zeros(scores_i2t.shape[0])
    for index,score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  
    #Text->Images 
    ranks = np.zeros(scores_t2i.shape[0])
    
    for index,score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)        

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    # eval_result =  {'txt_r1': tr1,
    #                 'txt_r5': tr5,
    #                 'txt_r10': tr10,
    #                 'txt_r_mean': tr_mean,
    #                 'img_r1': ir1,
    #                 'img_r5': ir5,
    #                 'img_r10': ir10,
    #                 'img_r_mean': ir_mean,
    #                 'r_mean': r_mean}
    eval_result = {
        "i2t": {
            "R@1": tr1,
            "R@5": tr5,
            "R@10": tr10,
            "mean": tr_mean,
        },
        "t2i": {
            "R@1": ir1,
            "R@5": ir5,
            "R@10": ir10,
            "mean": ir_mean,
        },
    }
    return eval_result


def itm_acc_calculator(
    feat_dict,
    data_loader,
    ):
    """
    Return metrics for image-text matching for both I2T and T2I,
    including
        "AMI",  
        "NMI",
        "mean_average_precision",
        "mean_average_precision_at_r"
        "mean_reciprocal_rank",
        "precision_at_1", # R@1
        "r_precision"

    Returns dict:
        {
            "i2t": {
                "AMI": 0.0,
                "NMI": 0.0,
                "mean_average_precision": 0.0,
                "mean_average_precision_at_r": 0.0,
                "mean_reciprocal_rank": 0.0,
                "precision_at_1": 0.0,
                "r_precision": 0.0,
            },
            "T2I": {
                "AMI": 0.0,
                "NMI": 0.0,
                "mean_average_precision": 0.0,
                "mean_average_precision_at_r": 0.0,
                "mean_reciprocal_rank": 0.0,
                "precision_at_1": 0.0,
                "r_precision": 0.0,
            },
        }
    """
    calculator = AccuracyCalculator(k="max_bin_count")

    acc_dict_all = {}
    image_feats = feat_dict["image_feats"]
    text_feats = feat_dict["text_feats"]
    img2txt = data_loader.dataset.img2txt
    i2t_query_labels = np.arange(len(img2txt))
    i2t_reference_labels = np.array([i for i in range(len(img2txt)) for _ in img2txt[i]])
    t2i_query_labels = i2t_reference_labels
    t2i_reference_labels = i2t_query_labels
    # i2t, t2i
    for key, query_embeddings, reference_embeddings, query_labels, reference_labels in [
        ["i2t", image_feats, text_feats, i2t_query_labels, i2t_reference_labels],
        ["t2i", text_feats, image_feats, t2i_query_labels, t2i_reference_labels],
    ]:  
        print("query_embeddings:", query_embeddings.shape)
        print("reference_embeddings:", reference_embeddings.shape)
        print("query_labels:", query_labels.shape)
        print("reference_labels:", reference_labels.shape)
        acc_dict = calculator.get_accuracy(
            query_embeddings,
            query_labels,
            reference=reference_embeddings,
            reference_labels=reference_labels,
            include=("mean_average_precision_at_r",),
        )
        acc_dict_all[key] = acc_dict
    
    return acc_dict_all


def itm_average_precision_list(feat_dict, data_loader,):
    """
    Compute average precision for image-text matching.
    """
    print("Computing average precision...")
    image_feats = feat_dict["image_feats"]
    text_feats = feat_dict["text_feats"]
    img2txt = data_loader.dataset.img2txt
    i2t_query_labels = np.arange(len(img2txt))
    i2t_reference_labels = np.array([i for i in range(len(img2txt)) for _ in img2txt[i]])
    t2i_query_labels = i2t_reference_labels
    t2i_reference_labels = i2t_query_labels

    ap_list_dict = {}
    for key, query_embeddings, reference_embeddings, query_labels, reference_labels in [
        ["i2t", image_feats, text_feats, i2t_query_labels, i2t_reference_labels],
        ["t2i", text_feats, image_feats, t2i_query_labels, t2i_reference_labels],
    ]:
        calculator = AccuracyCalculator(k="max_bin_count")
        ap_list = []
        for i in range(len(query_embeddings)):
            _query_embeddings = query_embeddings[i][None, :]
            _query_labels = query_labels[i][None]
            acc_dict = calculator.get_accuracy(
                _query_embeddings,
                _query_labels,
                reference=reference_embeddings,
                reference_labels=reference_labels,
                include=("mean_average_precision_at_r",),
            )
            ap_list.append(acc_dict["mean_average_precision_at_r"])
        ap_list_dict[key] = np.array(ap_list)
    return ap_list_dict


def itm_rank_list(score_matrix_i2t, score_matrix_t2i, txt2img, img2txt):
    """
    Compute rank list for image-text matching.
    """
    #Images->Text 
    ranks_i2t = []
    i2t_top1_txt_idx = []
    for index,score in enumerate(score_matrix_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        best_txt_idx = inds[0]
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
                best_txt_idx = i
        ranks_i2t.append(rank)
        i2t_top1_txt_idx.append(best_txt_idx)

    #Text->Images 
    ranks_t2i = []
    for index,score in enumerate(score_matrix_t2i):
        inds = np.argsort(score)[::-1]
        ranks_t2i.append(np.where(inds == txt2img[index])[0][0])

    return ranks_i2t, ranks_t2i, i2t_top1_txt_idx


def analysis_each_query(score_matrix_i2t, score_matrix_t2i, txt2img, img2txt, texts, VIS_DIR):
    print("Analysis each query...")
    # txt2img = dataloader.dataset.txt2img
    # img2txt = dataloader.dataset.img2txt
    ranks_i2t, ranks_t2i, i2t_top1_txt_idx = itm_rank_list(score_matrix_i2t, score_matrix_t2i, txt2img, img2txt)

    # texts
    # texts = dataloader.dataset.text

    # t2i
    text_metrics = {}
    for i in range(len(texts)):
        text = texts[i]
        text_length = get_sentence_length(text)
        inv_freq_sum = get_sentence_inv_freq_sum(text)
        pos_counts = get_pos_counts(text)

        text_metrics.setdefault("text_length", []).append(text_length)
        text_metrics.setdefault("inv_freq_sum", []).append(inv_freq_sum)
        text_metrics.setdefault("NOUN", []).append(pos_counts.get("NOUN", 0))
        text_metrics.setdefault("VERB", []).append(pos_counts.get("VERB", 0))
        text_metrics.setdefault("ADJ", []).append(pos_counts.get("ADJ", 0))
    # plot
    for k in text_metrics:
        metric_values = text_metrics[k]

        plt.figure(figsize=(12, 8))
        plt.scatter(
            metric_values, ranks_t2i, label=k
        )
        plt.xlabel(k)
        plt.ylabel("Retrieved Rank")
        
        plt.savefig(f"{VIS_DIR}/T2I_{k}_Rank.png")
        plt.close()

    # i2t
    for k in text_metrics:
        metric_values = text_metrics[k]
        values = [metric_values[i] for i in i2t_top1_txt_idx]

        plt.figure(figsize=(12, 8))
        plt.scatter(
            values, ranks_i2t, label=k
        )
        plt.xlabel(k)
        plt.ylabel("Retrieved Rank")
        
        plt.savefig(f"{VIS_DIR}/I2T_{k}_Rank.png")
        plt.close()

    # calc correlation between text metrics and ranks
    corr_dict = {}
    for k in text_metrics:
        metric_values = text_metrics[k]
        metric_values_per_image = [metric_values[i] for i in i2t_top1_txt_idx]
        print(f"{k}:")
        print("t2i:", np.corrcoef(metric_values, ranks_t2i)[0, 1])
        print("i2t:", np.corrcoef(metric_values_per_image, ranks_i2t)[0, 1])
        corr_dict[k] = {
            "t2i": np.corrcoef(metric_values, ranks_t2i)[0, 1],
            "i2t": np.corrcoef(metric_values_per_image, ranks_i2t)[0, 1],
        }
    with open(f"{VIS_DIR}/corr_dict.json", "w") as f:
        json.dump(corr_dict, f, indent=4)



if __name__ == "__main__":
    query_embeddings = torch.tensor([
        [0, 0, 0, 0],
        [1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
    ]).float()
    reference_embeddings = torch.tensor([
        [0, 0, 0, 0],
        [1, 1, 1, 1],
        [2, 2, 2, 2],
        [3, 3, 3, 3],
    ]).float()
    query_labels = torch.tensor([0, 1, 2, 3])
    reference_labels = torch.tensor([0, 1, 3, 2])

    k = "max_bin_count"
    calculator = AccuracyCalculator(k=k)
    acc_dict = calculator.get_accuracy(
        query_embeddings,
        query_labels,
        reference=reference_embeddings,
        reference_labels=reference_labels,
    )
    print(acc_dict)

    
    calculator = AccuracyCalculator(k="max_bin_count")
    ap_list = []
    for i in range(len(query_embeddings)):
        _query_embeddings = query_embeddings[i][None, :]
        _query_labels = query_labels[i][None]
        acc_dict = calculator.get_accuracy(
            _query_embeddings,
            _query_labels,
            reference=reference_embeddings,
            reference_labels=reference_labels,
            include=("mean_average_precision_at_r",),
        )
        ap_list.append(acc_dict["mean_average_precision_at_r"])
    print(ap_list)
