import argparse
import os
import json
import copy
import time
import numpy as np
import pandas as pd
import pytrec_eval
import logging
import faiss
from llava.eval.beir.faiss_index import FaissIndex
from llava.eval.beir.custom_metrics import mrr, recall_cap, hole, top_k_accuracy
from tqdm import tqdm
from typing import List, Dict, Optional, Sequence, Tuple
from collections import OrderedDict

from llava.constants import *
from llava.data.process import *

import math

logger = logging.getLogger(__name__)

# Refer to the DenseRetrievalFaissSearch and FlatIPFaissSearch classes from the beir.retrieval.search.dense.faiss_search.py, the exact search using faiss search.
class DenseRetrievalFaissSearch:

    def __init__(self, use_gpu: bool = False, **kwargs):

        self.score_functions = ['cos_sim', 'dot']
        self.faiss_index = None
        self.use_gpu = use_gpu
        self.single_gpu = faiss.StandardGpuResources() if use_gpu else None
        self.results = {}
        self.mapping = {}
        self.rev_mapping = {}

    def _index(self, corpus_mapping: Dict):
        # Create mapping ids
        self.mapping = corpus_mapping
        for corpus_ids in self.mapping:
            self.rev_mapping[self.mapping[corpus_ids]] = corpus_ids

        faiss_ids = list(corpus_mapping.values())

        return faiss_ids

    def index(self, corpus_embeddings, corpus_mapping, score_function: str = None, **kwargs):
        faiss_ids = self._index(corpus_mapping)
        base_index = faiss.IndexFlatIP(corpus_embeddings.shape[-1])
        if self.use_gpu:
            logger.info("Moving Faiss Index from CPU to GPU...")
            gpu_base_index = faiss.index_cpu_to_gpu(self.single_gpu, 0, base_index)
            self.faiss_index = FaissIndex.build(passage_ids=faiss_ids, passage_embeddings=corpus_embeddings, index=gpu_base_index)
        else:
            self.faiss_index = FaissIndex.build(passage_ids=faiss_ids, passage_embeddings=corpus_embeddings, index=base_index)

    def search(self,
               corpus_embeddings,
               query_embeddings,
               corpus_mapping,
               top_k: int,
               score_function: str, **kwargs) -> Dict[str, Dict[str, float]]:

        assert score_function in self.score_functions
        normalize_embeddings = True if score_function == "cos_sim" else False

        if normalize_embeddings:
            query_embeddings = query_embeddings.astype(np.float32)
            faiss.normalize_L2(query_embeddings)
            query_embeddings = query_embeddings.astype(np.float16)
            corpus_embeddings = corpus_embeddings.astype(np.float32)
            faiss.normalize_L2(corpus_embeddings)
            corpus_embeddings = corpus_embeddings.astype(np.float16)

        self.index(corpus_embeddings, corpus_mapping, score_function)

        query_ids = list(range(len(query_embeddings)))

        faiss_scores, faiss_doc_ids = self.faiss_index.search(query_embeddings, top_k, **kwargs)

        for idx in query_ids:
            scores = [float(score) for score in faiss_scores[idx]]
            if len(self.rev_mapping) != 0:
                doc_ids = [self.rev_mapping[doc_id] for doc_id in faiss_doc_ids[idx]]
            else:
                doc_ids = [str(doc_id) for doc_id in faiss_doc_ids[idx]]
            self.results[str(query_ids[idx])] = dict(zip(doc_ids, scores))

        return self.results

# EvaluateRetrieval class from beir.retrieval.evaluation.py
class EvaluateRetrieval:

    def __init__(self, retriever, k_values: List[int] = [1,10,20,100], score_function: str = "cos_sim"):
        self.k_values = k_values
        self.top_k = max(k_values)
        self.retriever = retriever
        self.score_function = score_function

    # This function performs document-level retrieval
    def retrieve(self, corpus_embeddings, query_embeddings, corpus_mapping, **kwargs) -> Dict[str, Dict[str, float]]:
        if not self.retriever:
            raise ValueError("Model/Technique has not been provided!")
        return self.retriever.search(corpus_embeddings, query_embeddings, corpus_mapping, self.top_k, self.score_function, **kwargs)

    @staticmethod
    def evaluate(qrels: Dict[str, Dict[str, int]],
                 results: Dict[str, Dict[str, float]],
                 k_values: List[int],
                 ignore_identical_ids: bool = True) -> Tuple[
        Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:

        if ignore_identical_ids:
            logger.info(
                'For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.')
            popped = []
            for qid, rels in results.items():
                for pid in list(rels):
                    if qid == pid:
                        results[qid].pop(pid)
                        popped.append(pid)

        ndcg = {}
        _map = {}
        recall = {}
        precision = {}

        for k in k_values:
            ndcg[f"NDCG@{k}"] = 0.0
            _map[f"MAP@{k}"] = 0.0
            recall[f"Recall@{k}"] = 0.0
            precision[f"P@{k}"] = 0.0

        map_string = "map_cut." + ",".join([str(k) for k in k_values])
        ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
        recall_string = "recall." + ",".join([str(k) for k in k_values])
        precision_string = "P." + ",".join([str(k) for k in k_values])
        evaluator = pytrec_eval.RelevanceEvaluator(qrels, {map_string, ndcg_string, recall_string, precision_string})
        scores = evaluator.evaluate(results)

        for query_id in scores.keys():
            for k in k_values:
                ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
                _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
                recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
                precision[f"P@{k}"] += scores[query_id]["P_" + str(k)]

        for k in k_values:
            ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
            _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
            recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
            precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5)

        for eval in [ndcg, _map, recall, precision]:
            logger.info("\n")
            for k in eval.keys():
                logger.info("{}: {:.4f}".format(k, eval[k]))

        return ndcg, _map, recall, precision

    @staticmethod
    def evaluate_custom(qrels: Dict[str, Dict[str, int]],
                        results: Dict[str, Dict[str, float]],
                        k_values: List[int], metric: str) -> Tuple[Dict[str, float]]:

        if metric.lower() in ["mrr", "mrr@k", "mrr_cut"]:
            return mrr(qrels, results, k_values)

        elif metric.lower() in ["recall_cap", "r_cap", "r_cap@k"]:
            return recall_cap(qrels, results, k_values)

        elif metric.lower() in ["hole", "hole@k"]:
            return hole(qrels, results, k_values)

        elif metric.lower() in ["acc", "top_k_acc", "accuracy", "accuracy@k", "top_k_accuracy"]:
            return top_k_accuracy(qrels, results, k_values)


def retrieve(q_embeds, d_embeds, q_relevant_docs, d_mapping):
    model = DenseRetrievalFaissSearch(use_gpu=True)
    retriever = EvaluateRetrieval(model, score_function="cos_sim")
    start_time = time.time()
    corpus_mapping = OrderedDict({wiki_url: wiki_id[0] for wiki_url, wiki_id in d_mapping.items()})
    results = retriever.retrieve(corpus_embeddings=d_embeds, query_embeddings=q_embeds, corpus_mapping=corpus_mapping)
    end_time = time.time()
    print("Time taken to document retrieve: {:.2f} seconds".format(end_time - start_time))

    #### Evaluate your retrieval using NDCG@k, MAP@K ...
    ndcg, _map, recall, precision = retriever.evaluate(q_relevant_docs, results, retriever.k_values)
    mrr = retriever.evaluate_custom(q_relevant_docs, results, retriever.k_values, metric="mrr")

    scores = {
        **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
        **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
        **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
        **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
        **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
    }

    return scores, results


def rerank(q_embeds, s_embeds, q_relevant_secs, s_mapping, d_results, top_k):
    model = DenseRetrievalFaissSearch(use_gpu=True)
    retriever = EvaluateRetrieval(model, score_function="cos_sim")
    start_time = time.time()
    corpus_mapping = OrderedDict()
    # If doc_ret_results is None, we perform the section retrieval using all the section embeddings.
    if d_results is None:
        accum_id = 0
        for wiki_url in s_mapping:
            global_section_ids = s_mapping[wiki_url]
            for section_idx in global_section_ids:
                local_section_idx = section_idx - global_section_ids[0]  # 0 ~ the number of sections - 1
                corpus_mapping[wiki_url + f'_section_{local_section_idx:02d}'] = accum_id + local_section_idx
            accum_id += len(global_section_ids)
        results = retriever.retrieve(corpus_embeddings=s_embeds, query_embeddings=q_embeds, corpus_mapping=corpus_mapping)

    # If doc_ret_results is given, we perform the section retrieval based on the retrieved documents for each test query.
    else:
        new_s_embeds = []
        corpus_mapping = {}
        seen_corpus_id = set()

        query_corpus_id = {}

        accum_id = 0
        for query_id in d_results:
            query_corpus_id[query_id] = set()
            for (wiki_url, _) in sorted(d_results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
                
                query_corpus_id[query_id].add(wiki_url)
                
                if wiki_url not in seen_corpus_id:  # skip the document that is already recorded.
                    seen_corpus_id.add(wiki_url)

                    global_section_ids = s_mapping[wiki_url]
                    # Generate new section embeddings
                    new_s_embeds.append(s_embeds[np.array(global_section_ids)])
                    
                    # Generate new mapping function
                    for section_idx in global_section_ids:
                        local_section_idx = section_idx - global_section_ids[0] # 0 ~ the number of sections - 1
                        corpus_mapping[wiki_url + f'_section_{local_section_idx:02d}'] = accum_id + local_section_idx
                    accum_id += len(global_section_ids)

        new_s_embeds = np.concatenate(new_s_embeds, axis=0)
        results = retriever.retrieve(corpus_embeddings=new_s_embeds, query_embeddings=q_embeds, corpus_mapping=corpus_mapping)

        # Collect only the results that correspond to the top_k documents for each test query.
        for query_id in results:
            removed_section_list = []
            for wiki_section_id in results[query_id]:
                wiki_url = wiki_section_id[:-11]
                
                if wiki_url not in query_corpus_id[query_id]:
                    removed_section_list.append(wiki_section_id)

            for wiki_section_id in removed_section_list:
                del results[query_id][wiki_section_id]

    end_time = time.time()
    print("Time taken to section retrieve: {:.2f} seconds".format(end_time - start_time))

    #### Evaluate your retrieval using NDCG@k, MAP@K ...
    ndcg, _map, recall, precision = retriever.evaluate(q_relevant_secs, results, retriever.k_values)
    mrr = retriever.evaluate_custom(q_relevant_secs, results, retriever.k_values, metric="mrr")

    scores = {
        **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
        **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
        **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
        **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
        **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
    }

    return scores



def rerank_contrastive(args):

    test_queries = pd.read_csv(args.test_query_path)
    doc_ret_results = None

    # Perform intra-document retrieval
    # Here, we collect top-k document retrieval information and conduct the section retrieval.
    query_sec_embeds = np.load(args.query_sec_embed_path)
    sec_embeds = np.load(args.sec_embed_path)
    # Intra document mapping function that maps the wikipedia_url to the corresponding section embeddings
    # {"wikipedia_url_0": np.array([0,1,2]), "wikipedia_url_1": np.array([3,4,5,6]), ...}
    # The grouping_doc_mapping could be useful later when we want to extract only a few
    grouping_sec_mapping = json.load(open(args.sec_mapping_path, 'r'))

    # Initialize relevant section mapping function (key: query idx, value: Dict(section idx, 1))
    q_relevant_secs = OrderedDict()
    for query in test_queries.itertuples():
        positive_sec = query.evidence_section_id
        doc_id = query.wikipedia_url + f'_section_{positive_sec:02d}'
        q_relevant_secs[str(query.Index)] = {doc_id: 1}

    # q_rels = {q_index_1: {wikipedia_url + section-n: 1}, q_index_2: {wikipedia_url + section-m: 1}, ...}
    # We select this format, since it is hard for the test queries to know the global section indices.
    # result = {q_index_1: {wikipedia_url + section-0: xx,  wikipedia_url + section-1: yy, ...}, ...}
    tick = time.time()
    sec_ret_scores = rerank(q_embeds=query_sec_embeds, s_embeds=sec_embeds,
                                q_relevant_secs=q_relevant_secs, s_mapping=grouping_sec_mapping,
                                d_results=doc_ret_results, top_k=args.section_top_k)
    tock = time.time()
    logger.info(f"Section retrieval takes {tock - tick:.2f} seconds")

    os.makedirs(os.path.dirname(args.sec_ret_save_path), exist_ok=True)
    with open(args.sec_ret_save_path, 'w') as f:
        json.dump(sec_ret_scores, f)

    print("Done!")

# If you are unfamiliar with the retrieval metrics and pytrec_eval module, refer to the link below:
# https://weaviate.io/blog/retrieval-evaluation-metrics

# We refer to the beir module to save the query embedding and iteratively compute the document embedding
# Since our input is much more complicated than the beir default input,
# we modify the structure, but follows the same flow of retrieval
# 1. Unlike beir, we pre-compute the query embeddings and document embeddings. If we have multiple GPUs, we can extract faster.
# 2. Then, using the embeddings and faiss module, we retrieve the top-k similar pairs.
# 3. For the intra-document retrieval, we load the test query csv file to extract ground-truth section id.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--query_sec_embed_path", type=str, default=None)
    parser.add_argument("--sec_embed_path", type=str, default=None)
    parser.add_argument("--sec_mapping_path", type=str, default=None)
    parser.add_argument("--test_query_path", type=str, default=None)
    parser.add_argument("--section_top_k", type=int, default=25)
    parser.add_argument("--ret_result_path", type=str, default=None)
    parser.add_argument("--sec_ret_save_path", type=str, default=None)

    args = parser.parse_args()
