# For Elastic Search issue refer to https://github.com/beir-cellar/beir/issues/4
# Can also refer to https://colab.research.google.com/drive/1HfutiEhHMJLXiWGT8pcipxT5L2TpYEdt?usp=sharing#scrollTo=nqotyXuIBPt6
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.lexical import BM25Search as BM25
from beir.reranking.models import CrossEncoder
from beir.reranking import Rerank
import torch
import sys
from inference import Inferencer
import pathlib, os
import logging
import random
from nltk.tokenize import sent_tokenize
import numpy as np

#### Just some code to print debug information to stdout
def beir_benchmarking(align_func):
    print("Running beir benchmarking")
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    #### /print debug information to stdout

    # rename dataset and index_name to evaluate all datasets
    all_dataset = ['msmarco', 'trec-covid', 'nfcorpus', 'bioasq', 'nq', 'hotpotqa', 'fiqa', 'signal1m', 'trec-news', 'robust04', 'arguana', 'webis-touche2020', 'cqadupstack', 'quora', 'dbpedia-entity', 'scidocs', 'fever', 'climate-fever', 'scifact']
    dataset_idx = -1

    #### Download trec-covid.zip dataset and unzip the dataset
    dataset = all_dataset[dataset_idx]
    url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
    out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
    data_path = util.download_and_unzip(url, out_dir)

    #### Provide the data path where trec-covid has been downloaded and unzipped to the data loader
    # data folder would contain these files: 
    # (1) trec-covid/corpus.jsonl  (format: jsonlines)
    # (2) trec-covid/queries.jsonl (format: jsonlines)
    # (3) trec-covid/qrels/test.tsv (format: tsv ("\t"))

    corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")

    #########################################
    #### (1) RETRIEVE Top-100 docs using BM25
    #########################################

    #### Provide parameters for Elasticsearch
    hostname = "localhost" #localhost
    index_name = all_dataset[dataset_idx]
    initialize = True # False

    model = BM25(index_name=index_name, hostname=hostname, initialize=initialize)
    retriever = EvaluateRetrieval(model)

    #### Retrieve dense results (format of results is identical to qrels)
    results = retriever.retrieve(corpus, queries)

    ################################################
    #### (2) RERANK Top-100 docs using Cross-Encoder
    ################################################

    #### Reranking using Cross-Encoder models #####
    #### https://www.sbert.net/docs/pretrained_cross-encoders.html
    # cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-electra-base')

    #### Or use MiniLM, TinyBERT etc. CE models (https://www.sbert.net/docs/pretrained-models/ce-msmarco.html)
    # cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
    # cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')

    class AlignModelForBEIR:
        def __init__(me, **kwargs):
            pass

        # Write your own score function, which takes in query-document text pairs and returns the similarity scores
        def predict(me, sentences, **kwags):
            query = [each[0] for each in sentences]
            document = [each[1] for each in sentences]
            assert len(query) == len(document)
            new_sent_inorder = []
            new_sent_reverse = []
            for q,d in zip(query, document):
                document_sent = sent_tokenize(d)
                query_sent = [q]*len(document_sent)
                tri_labels = align_func(query_sent, document_sent)[2].numpy()
                max_e = np.argmax(tri_labels[:,0])
                max_n = np.argmax(tri_labels[:,2])
                new_sent_inorder.append(document_sent[max_e]+document_sent[max_n])
                new_sent_reverse.append(document_sent[max_n]+document_sent[max_e])
            inorder_output = align_func(new_sent_inorder, query)[0].tolist()
            reverse_output = align_func(new_sent_reverse, query)[0].tolist()
            output = [(i+r)/2 for i,r in zip(inorder_output, reverse_output)]
            return output
            # output = align_func(document, query)[0].tolist() # Switch to regression label
            # return output

    reranker = Rerank(AlignModelForBEIR(), batch_size=128)

    # Rerank top-100 results using the reranker provided
    rerank_results = reranker.rerank(corpus, queries, results, top_k=100)

    #### Evaluate your retrieval using NDCG@k, MAP@K ...
    ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, rerank_results, retriever.k_values)

    #### Print top-k documents retrieved ####
    top_k = 10

    query_id, ranking_scores = random.choice(list(rerank_results.items()))
    scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True)
    logging.info("Query : %s\n" % queries[query_id])

    for rank in range(top_k):
        doc_id = scores_sorted[rank][0]
        # Format: Rank x: ID [Title] Body
        logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))

if __name__ == "__main__":
    model_name = "bert-base-uncased_no_mlm_xsum_cnndm_mnli_squad_paws_paws_qqp_vitaminc_race_anli_r1_anli_r2_anli_r3_snli_wikihow_msmarco_paws_unlabeled_wiki103_qqp_stsb_sick_ctc_500000_16x4x4_final"
    infer = Inferencer(ckpt_path=f"checkpoints/bert-base-uncased/{model_name}.ckpt",
                        model='bert-base-uncased', batch_size=32, device='cuda:7')
    beir_benchmarking(infer.inference)