THE_INDEX = {
    'dl19': 'msmarco-v1-passage',
    'dl20': 'msmarco-v1-passage',
    'dl21-passage': 'msmarco-v2-passage',
    'dl22-passage': 'msmarco-v2-passage',
    'dl23-passage': 'msmarco-v2-passage',
    'covid': 'beir-v1.0.0-trec-covid.flat',
    'arguana': 'beir-v1.0.0-arguana.flat',
    'touche': 'beir-v1.0.0-webis-touche2020.flat',
    'news': 'beir-v1.0.0-trec-news.flat',
    'scifact': 'beir-v1.0.0-scifact.flat',
    'fiqa': 'beir-v1.0.0-fiqa.flat',
    'hotpotqa': 'beir-v1.0.0-hotpotqa.flat',
    'nq': 'beir-v1.0.0-nq.flat',
    'scidocs': 'beir-v1.0.0-scidocs.flat',
    'nfc': 'beir-v1.0.0-nfcorpus.flat',
    'quora': 'beir-v1.0.0-quora.flat',
    'dbpedia-entity': 'beir-v1.0.0-dbpedia-entity.flat',
    'climate-fever': 'beir-v1.0.0-climate-fever.flat',
    'fever': 'beir-v1.0.0-fever.flat',
    'bioasq': 'beir-v1.0.0-bioasq.flat',
    'robust04': 'beir-v1.0.0-robust04.flat',
    'signal': 'beir-v1.0.0-signal1m.flat',
    'cqadupstack-android': 'beir-v1.0.0-cqadupstack-android.flat',
    'cqadupstack-english': 'beir-v1.0.0-cqadupstack-english.flat',
    'cqadupstack-gaming': 'beir-v1.0.0-cqadupstack-gaming.flat',
    'cqadupstack-gis': 'beir-v1.0.0-cqadupstack-gis.flat',
    'cqadupstack-mathematica': 'beir-v1.0.0-cqadupstack-mathematica.flat',
    'cqadupstack-physics': 'beir-v1.0.0-cqadupstack-physics.flat',
    'cqadupstack-programmers': 'beir-v1.0.0-cqadupstack-programmers.flat',
    'cqadupstack-stats': 'beir-v1.0.0-cqadupstack-stats.flat',
    'cqadupstack-tex': 'beir-v1.0.0-cqadupstack-tex.flat',
    'cqadupstack-unix': 'beir-v1.0.0-cqadupstack-unix.flat',
    'cqadupstack-webmasters': 'beir-v1.0.0-cqadupstack-webmasters.flat',
    'cqadupstack-wordpress': 'beir-v1.0.0-cqadupstack-wordpress.flat',

    'msmarco': 'msmarco-passage.flat',
    'mrtydi-ar': 'mrtydi-v1.1-arabic',
    'mrtydi-bn': 'mrtydi-v1.1-bengali',
    'mrtydi-fi': 'mrtydi-v1.1-finnish',
    'mrtydi-id': 'mrtydi-v1.1-indonesian',
    'mrtydi-ja': 'mrtydi-v1.1-japanese',
    'mrtydi-ko': 'mrtydi-v1.1-korean',
    'mrtydi-ru': 'mrtydi-v1.1-russian',
    'mrtydi-sw': 'mrtydi-v1.1-swahili',
    'mrtydi-te': 'mrtydi-v1.1-telugu',
    'mrtydi-th': 'mrtydi-v1.1-thai',
}

THE_TOPICS = {
    'dl19': 'dl19-passage',
    'dl20': 'dl20-passage',
    'dl21-passage': 'dl21',
    'dl22-passage': 'dl22',
    'dl23-passage': 'dl23',
    'covid': 'beir-v1.0.0-trec-covid-test',
    'arguana': 'beir-v1.0.0-arguana-test',
    'touche': 'beir-v1.0.0-webis-touche2020-test',
    'news': 'beir-v1.0.0-trec-news-test',
    'scifact': 'beir-v1.0.0-scifact-test',
    'fiqa': 'beir-v1.0.0-fiqa-test',
    'dbpedia-entity': 'beir-v1.0.0-dbpedia-entity-test',
    'climate-fever': 'beir-v1.0.0-climate-fever-test',
    'nq': 'beir-v1.0.0-nq-test',
    'hotpotqa': 'beir-v1.0.0-hotpotqa-test',
    'scidocs': 'beir-v1.0.0-scidocs-test',
    'nfc': 'beir-v1.0.0-nfcorpus-test',
    'quora': 'beir-v1.0.0-quora-test',
    'fever': 'beir-v1.0.0-fever-test',
    'bioasq': 'beir-v1.0.0-bioasq-test',
    'robust04': 'beir-v1.0.0-robust04-test',
    'signal': 'beir-v1.0.0-signal1m-test',
    'cqadupstack-android': 'beir-v1.0.0-cqadupstack-android-test',
    'cqadupstack-english': 'beir-v1.0.0-cqadupstack-english-test',
    'cqadupstack-gaming': 'beir-v1.0.0-cqadupstack-gaming-test',
    'cqadupstack-gis': 'beir-v1.0.0-cqadupstack-gis-test',
    'cqadupstack-mathematica': 'beir-v1.0.0-cqadupstack-mathematica-test',
    'cqadupstack-physics': 'beir-v1.0.0-cqadupstack-physics-test',
    'cqadupstack-programmers': 'beir-v1.0.0-cqadupstack-programmers-test',
    'cqadupstack-stats': 'beir-v1.0.0-cqadupstack-stats-test',
    'cqadupstack-tex': 'beir-v1.0.0-cqadupstack-tex-test',
    'cqadupstack-unix': 'beir-v1.0.0-cqadupstack-unix-test',
    'cqadupstack-webmasters': 'beir-v1.0.0-cqadupstack-webmasters-test',
    'cqadupstack-wordpress': 'beir-v1.0.0-cqadupstack-wordpress-test',
    'msmarco': 'msmarco-passage-dev',

    'mrtydi-ar': 'mrtydi-v1.1-arabic-test',
    'mrtydi-bn': 'mrtydi-v1.1-bengali-test',
    'mrtydi-fi': 'mrtydi-v1.1-finnish-test',
    'mrtydi-id': 'mrtydi-v1.1-indonesian-test',
    'mrtydi-ja': 'mrtydi-v1.1-japanese-test',
    'mrtydi-ko': 'mrtydi-v1.1-korean-test',
    'mrtydi-ru': 'mrtydi-v1.1-russian-test',
    'mrtydi-sw': 'mrtydi-v1.1-swahili-test',
    'mrtydi-te': 'mrtydi-v1.1-telugu-test',
    'mrtydi-th': 'mrtydi-v1.1-thai-test',

}
from pyserini.search.lucene import LuceneSearcher
from pyserini.search import get_topics, get_qrels
from tqdm import tqdm
import tempfile
import os
import json
import shutil
from beir_eval import run_rerank_eval


def run_retriever(topics, searcher, qrels=None, k=100, qid=None):
    ranks = []
    if isinstance(topics, str):
        import pdb; pdb.set_trace()
        hits = searcher.search(topics, k=k)
        ranks.append({'query': topics, 'hits': []})
        rank = 0
        for hit in hits:
            rank += 1
            content = json.loads(searcher.doc(hit.docid).raw())
            if 'title' in content:
                content = 'Title: ' + content['title'] + ' ' + 'Content: ' + content['text']
            else:
                content = content['contents']
            content = ' '.join(content.split())
            ranks[-1]['hits'].append({
                'content': content,
                'qid': qid, 'docid': hit.docid, 'rank': rank, 'score': hit.score})
        return ranks[-1]

    for qid in tqdm(topics):
        if qid in qrels:
            query = topics[qid]['title']
            hits = searcher.search(query, k=k)
            rank = 0
            prune_qrels = {x: int(qrels[qid][x]) for x in qrels[qid].keys() if qrels[qid][x] != '0'}
            temp = {'qid': qid, 'q_text': query, 'bm25_results': [], 'qrels': prune_qrels}
            for hit in hits:
                rank += 1
                content = json.loads(searcher.doc(hit.docid).raw())
                if content.get('text') != None:
                    content_text = content['text']
                elif content.get('contents') != None:
                    content_text = content['contents']
                elif content.get('passage') != None:
                    content_text = content['passage']
                else:
                    import pdb; pdb.set_trace()
                temp_p = {'text': content_text, 'title': content.get('title') or '', 'bm25_score': hit.score, 'pid': hit.docid}
                temp['bm25_results'].append(temp_p)
            ranks.append(temp)
    return ranks
    

def main():

#for data in ['dbpedia', 'dl19', 'dl20', 'covid', 'nfc', 'touche', 'dbpedia', 'scifact', 'signal', 'news', 'robust04']:
    #for data in ['android', 'english', 'gaming', 'gis', 'mathematica', 'physics', 'programmers', 'stats',
    #        'tex', 'unix', 'webmasters', 'wordpress']:
    #    data = f"cqadupstack-{data}"
    #for data in ['dl19', 'dl20', 'covid', 'nfc',
    for data in ['dl22-passage', 'dl23-passage']:
        #['nq', 'hotpotqa',
        #'fiqa', 'signal', 'news', 'robust04', 'arguana', 'touche', 'quora',
        #'dbpedia-entity', 'scidocs', 'fever', 'climate-fever', 'scifact', 'bioasq']:
        print('#' * 20)
        print(f'Evaluation on {data}')
        print('#' * 20)
        # Retrieve passages using pyserini BM25.
        #try:
        searcher = LuceneSearcher.from_prebuilt_index(THE_INDEX[data])
        topics = get_topics(THE_TOPICS[data] if data != 'dl20' else 'dl20')
        if 'dl' in data:
            qrels = get_qrels(data)
        else:
            qrels = get_qrels(THE_TOPICS[data])
        rank_results = run_retriever(topics, searcher, qrels, k=100)
        #except:
        #    print(f'Failed to retrieve passages for {data}')
        #    continue
        import jsonlines
        out_path = f'./data/new/{data}.jsonl'
        with jsonlines.open(out_path, 'w') as writer:
            writer.write_all(rank_results)
        print(f"Writing to {out_path} done!")
        # sanity check
        print(f"\nSanity: running evaluation..")
        run_rerank_eval(rank_results, combined=True)
        print(f"Length: {len(rank_results)}")
        # how to concat? cat cqaduptack-*.jsonl > cqadupstack-full.jsonl


if __name__ == '__main__':
    main()
