import argparse
import os
import json
import copy
import time
import numpy as np
import pandas as pd
import pytrec_eval
import logging
import faiss
from llava.eval.beir.faiss_index import FaissIndex
from llava.eval.beir.custom_metrics import mrr, recall_cap, hole, top_k_accuracy
from tqdm import tqdm
from typing import List, Dict, Optional, Sequence, Tuple
from collections import OrderedDict

from llava.constants import *
from llava.data.process import *

import math

logger = logging.getLogger(__name__)

# Refer to the DenseRetrievalFaissSearch and FlatIPFaissSearch classes from the beir.retrieval.search.dense.faiss_search.py, the exact search using faiss search.
class DenseRetrievalFaissSearch:

    def __init__(self, use_gpu: bool = False, **kwargs):

        self.score_functions = ['cos_sim', 'dot']
        self.faiss_index = None
        self.use_gpu = use_gpu
        self.single_gpu = faiss.StandardGpuResources() if use_gpu else None
        self.results = {}
        self.mapping = {}
        self.rev_mapping = {}

    def _index(self, corpus_mapping: Dict):
        # Create mapping ids
        self.mapping = corpus_mapping
        for corpus_ids in self.mapping:
            self.rev_mapping[self.mapping[corpus_ids]] = corpus_ids

        faiss_ids = list(corpus_mapping.values())

        return faiss_ids

    def index(self, corpus_embeddings, corpus_mapping, score_function: str = None, **kwargs):
        faiss_ids = self._index(corpus_mapping)
        base_index = faiss.IndexFlatIP(corpus_embeddings.shape[-1])
        if self.use_gpu:
            logger.info("Moving Faiss Index from CPU to GPU...")
            gpu_base_index = faiss.index_cpu_to_gpu(self.single_gpu, 0, base_index)
            self.faiss_index = FaissIndex.build(passage_ids=faiss_ids, passage_embeddings=corpus_embeddings, index=gpu_base_index)
        else:
            self.faiss_index = FaissIndex.build(passage_ids=faiss_ids, passage_embeddings=corpus_embeddings, index=base_index)

    def search(self,
               corpus_embeddings,
               query_embeddings,
               corpus_mapping,
               top_k: int,
               score_function: str, **kwargs) -> Dict[str, Dict[str, float]]:

        assert score_function in self.score_functions
        normalize_embeddings = True if score_function == "cos_sim" else False

        if normalize_embeddings:
            query_embeddings = query_embeddings.astype(np.float32)
            faiss.normalize_L2(query_embeddings)
            query_embeddings = query_embeddings.astype(np.float16)
            corpus_embeddings = corpus_embeddings.astype(np.float32)
            faiss.normalize_L2(corpus_embeddings)
            corpus_embeddings = corpus_embeddings.astype(np.float16)

        self.index(corpus_embeddings, corpus_mapping, score_function)

        query_ids = list(range(len(query_embeddings)))

        faiss_scores, faiss_doc_ids = self.faiss_index.search(query_embeddings, top_k, **kwargs)

        for idx in query_ids:
            scores = [float(score) for score in faiss_scores[idx]]
            if len(self.rev_mapping) != 0:
                doc_ids = [self.rev_mapping[doc_id] for doc_id in faiss_doc_ids[idx]]
            else:
                doc_ids = [str(doc_id) for doc_id in faiss_doc_ids[idx]]
            self.results[str(query_ids[idx])] = dict(zip(doc_ids, scores))

        return self.results

# EvaluateRetrieval class from beir.retrieval.evaluation.py
class EvaluateRetrieval:

    def __init__(self, retriever, k_values: List[int] = [1,10,100], score_function: str = "cos_sim"):
        self.k_values = k_values
        self.top_k = max(k_values)
        self.retriever = retriever
        self.score_function = score_function

    # This function performs document-level retrieval
    def retrieve(self, corpus_embeddings, query_embeddings, corpus_mapping, **kwargs) -> Dict[str, Dict[str, float]]:
        if not self.retriever:
            raise ValueError("Model/Technique has not been provided!")
        return self.retriever.search(corpus_embeddings, query_embeddings, corpus_mapping, self.top_k, self.score_function, **kwargs)

    @staticmethod
    def evaluate(qrels: Dict[str, Dict[str, int]],
                 results: Dict[str, Dict[str, float]],
                 k_values: List[int],
                 ignore_identical_ids: bool = True) -> Tuple[
        Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:

        if ignore_identical_ids:
            logger.info(
                'For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.')
            popped = []
            for qid, rels in results.items():
                for pid in list(rels):
                    if qid == pid:
                        results[qid].pop(pid)
                        popped.append(pid)

        ndcg = {}
        _map = {}
        recall = {}
        precision = {}

        for k in k_values:
            ndcg[f"NDCG@{k}"] = 0.0
            _map[f"MAP@{k}"] = 0.0
            recall[f"Recall@{k}"] = 0.0
            precision[f"P@{k}"] = 0.0

        map_string = "map_cut." + ",".join([str(k) for k in k_values])
        ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
        recall_string = "recall." + ",".join([str(k) for k in k_values])
        precision_string = "P." + ",".join([str(k) for k in k_values])
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
        scores = evaluator.evaluate(results)

        for query_id in scores.keys():
            for k in k_values:
                ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
                _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
                recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
                precision[f"P@{k}"] += scores[query_id]["P_" + str(k)]

        for k in k_values:
            ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
            _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
            recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
            precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5)

        for eval in [ndcg, _map, recall, precision]:
            logger.info("\n")
            for k in eval.keys():
                logger.info("{}: {:.4f}".format(k, eval[k]))

        return ndcg, _map, recall, precision

    @staticmethod
    def evaluate_custom(qrels: Dict[str, Dict[str, int]],
                        results: Dict[str, Dict[str, float]],
                        k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:

        if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
            return mrr(qrels, results, k_values)

        elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
            return recall_cap(qrels, results, k_values)

        elif metric.lower() in ["hole", "hole@k"]:
            return hole(qrels, results, k_values)

        elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
            return top_k_accuracy(qrels, results, k_values)


def retrieve(q_embeds, d_embeds, q_relevant_docs, d_mapping):
    model = DenseRetrievalFaissSearch(use_gpu=True)
    retriever = EvaluateRetrieval(model, score_function="cos_sim")
    start_time = time.time()
    corpus_mapping = OrderedDict({wiki_url: wiki_id[0] for wiki_url, wiki_id in d_mapping.items()})
    results = retriever.retrieve(corpus_embeddings=d_embeds, query_embeddings=q_embeds, corpus_mapping=corpus_mapping)
    end_time = time.time()
    print("Time taken to document retrieve: {:.2f} seconds".format(end_time - start_time))

    #### Evaluate your retrieval using NDCG@k, MAP@K ...
    ndcg, _map, recall, precision = retriever.evaluate(q_relevant_docs, results, retriever.k_values)
    mrr = retriever.evaluate_custom(q_relevant_docs, results, retriever.k_values, metric="mrr")

    scores = {
        **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
        **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
        **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
        **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
        **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
    }

    return scores, results


def retrieve_document(args):
    test_queries = pd.read_csv(args.test_query_path)

    # Perform document retrieval
    query_doc_embeds = np.load(args.query_doc_embed_path)
    doc_embeds = np.load(args.doc_embed_path)
    # Document mapping function that maps the wikipedia_url to the corresponding document embedding
    # {"wikipedia_url_0": np.array([0]), ...},
    grouping_doc_mapping = json.load(open(args.doc_mapping_path, 'r'))

    # Initialize relevant document mapping function
    # (key: query-idx, value: Dict(document-idx, 1)), for the pytrec_eval usage
    q_relevant_docs = OrderedDict()
    for query in test_queries.itertuples():
        doc_id = query.wikipedia_url
        q_relevant_docs[str(query.Index)] = {doc_id: 1}

    tick = time.time()
    doc_ret_scores, doc_ret_results = retrieve(q_embeds=query_doc_embeds, d_embeds=doc_embeds,
                                                q_relevant_docs=q_relevant_docs, d_mapping=grouping_doc_mapping)
    tock = time.time()
    logger.info(f"Document retrieval takes {tock - tick:.2f} seconds")

    os.makedirs(os.path.dirname(args.ret_score_save_path), exist_ok=True)
    with open(args.ret_score_save_path, 'w') as f:
        json.dump(doc_ret_scores, f)

    os.makedirs(os.path.dirname(args.ret_result_save_path), exist_ok=True)
    with open(args.ret_result_save_path, 'w') as f:
        json.dump(doc_ret_results, f)

    print("Done!")


# If you are unfamiliar with the retrieval metrics and pytrec_eval module, refer to the link below:
# https://weaviate.io/blog/retrieval-evaluation-metrics

# We refer to the beir module to save the query embedding and iteratively compute the document embedding
# Since our input is much more complicated than the beir default input,
# we modify the structure, but follows the same flow of retrieval
# 1. Unlike beir, we pre-compute the query embeddings and document embeddings. If we have multiple GPUs, we can extract faster.
# 2. Then, using the embeddings and faiss module, we retrieve the top-k similar pairs.
# 3. For the intra-document retrieval, we load the test query csv file to extract ground-truth section id.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--query_doc_embed_path", type=str, default=None)
    parser.add_argument("--doc_embed_path", type=str, default=None)
    parser.add_argument("--doc_mapping_path", type=str, default=None)
    parser.add_argument("--test_query_path", type=str, default=None)
    parser.add_argument("--ret_score_save_path", type=str, default=None)
    parser.add_argument("--ret_result_save_path", type=str, default=None)
    args = parser.parse_args()

    retrieve_document(args)
