import argparse
import os
import json
import copy
import time
import numpy as np
import pandas as pd
import pytrec_eval
from llava.eval.beir.custom_metrics import mrr, recall_cap, hole, top_k_accuracy

from typing import List, Dict, Tuple
from collections import OrderedDict
from llava.constants import *
from llava.data.process import *
import logging

logger = logging.getLogger(__name__)

def evaluate(qrels: Dict[str, Dict[str, int]],
                results: Dict[str, Dict[str, float]],
                k_values: List[int],
                ignore_identical_ids: bool = True) -> Tuple[
    Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:

    if ignore_identical_ids:
        logger.info(
            'For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.')
        popped = []
        for qid, rels in results.items():
            for pid in list(rels):
                if qid == pid:
                    results[qid].pop(pid)
                    popped.append(pid)

    ndcg = {}
    _map = {}
    recall = {}
    precision = {}

    for k in k_values:
        ndcg[f"NDCG@{k}"] = 0.0
        _map[f"MAP@{k}"] = 0.0
        recall[f"Recall@{k}"] = 0.0
        precision[f"P@{k}"] = 0.0

    map_string = "map_cut." + ",".join([str(k) for k in k_values])
    ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
    recall_string = "recall." + ",".join([str(k) for k in k_values])
    precision_string = "P." + ",".join([str(k) for k in k_values])
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
    scores = evaluator.evaluate(results)

    for query_id in scores.keys():
        for k in k_values:
            ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
            _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
            recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
            precision[f"P@{k}"] += scores[query_id]["P_" + str(k)]

    for k in k_values:
        ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
        _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
        recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
        precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5)

    for eval in [ndcg, _map, recall, precision]:
        logger.info("\n")
        for k in eval.keys():
            logger.info("{}: {:.4f}".format(k, eval[k]))

    return ndcg, _map, recall, precision

def evaluate_custom(qrels: Dict[str, Dict[str, int]],
                    results: Dict[str, Dict[str, float]],
                    k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:

    if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
        return mrr(qrels, results, k_values)

    elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
        return recall_cap(qrels, results, k_values)

    elif metric.lower() in ["hole", "hole@k"]:
        return hole(qrels, results, k_values)

    elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
        return top_k_accuracy(qrels, results, k_values)

def rerank(args):

    test_queries = pd.read_csv(args.query_path)

    # Perform reranking
    # Initialize relevant section mapping function (key: query idx, value: Dict(section idx, 1))
    q_relevant_secs = OrderedDict()
    for query in test_queries.itertuples():
        positive_sec = query.evidence_section_id
        sec_id = query.wikipedia_url + f'_section_{positive_sec:02d}'
        q_relevant_secs[str(query.Index)] = {sec_id: 1}
    # q_rels = {q_index_1: {wikipedia_url + section-n: 1}, q_index_2: {wikipedia_url + section-m: 1}, ...}
    # We select this format, since it is hard for the test queries to know the global section indices.
    # result = {q_index_1: {wikipedia_url + section-0: xx,  wikipedia_url + section-1: yy, ...}, ...}

    # Here, we collect top-k document retrieval information and conduct the reranking.
    rerank_result = json.load(open(args.rerank_result_path, 'r'))

    k_values = [1,10,20]

    #### Evaluate your retrieval using NDCG@k, MAP@K ...
    ndcg, _map, recall, precision = evaluate(q_relevant_secs, rerank_result, k_values)
    mrr = evaluate_custom(q_relevant_secs, rerank_result, k_values, metric="mrr")

    scores = {
        **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
        **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
        **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
        **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
        **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
    }

    os.makedirs(os.path.dirname(args.rerank_score_save_path), exist_ok=True)
    with open(args.rerank_score_save_path, 'w') as f:
        json.dump(scores, f)

    print("Done!")


# If you are unfamiliar with the retrieval metrics and pytrec_eval module, refer to the link below:
# https://weaviate.io/blog/retrieval-evaluation-metrics

# We refer to the beir module to save the query embedding and iteratively compute the document embedding
# Since our input is much more complicated than the beir default input,
# we modify the structure, but follows the same flow of retrieval
# 1. Unlike beir, we pre-compute the query embeddings and document embeddings. If we have multiple GPUs, we can extract faster.
# 2. Then, using the embeddings and faiss module, we retrieve the top-k similar pairs.
# 3. For the intra-document retrieval, we load the test query csv file to extract ground-truth section id.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--query_path", type=str, default=None)
    parser.add_argument("--rerank_result_path", type=str, default=None)
    parser.add_argument("--rerank_score_save_path", type=str, default=None)
    args = parser.parse_args()

    rerank(args)
