import argparse
import json
import os
from rank1 import rank1
from standardrr import StandardRR
from rank1_no_reason import Rank1_NoReason
from reasonrr import ReasonRR
from pyserini.search.lucene import LuceneSearcher
from index_paths import THE_TOPICS, THE_SPARSE_INDEX, load_queries_qids

from bm25 import BM25
from helpers import write_scores_to_file, evaluate, save_outputs_to_pkl
from prompts import get_prompt, PROMPT_DICT
from datasets import load_dataset

BRIGHT_CORPUS = {'biology', 
                'earth_science', 
                'economics', 
                'psychology',
                'robotics',
                'stackoverflow',
                'sustainable_living',
                'pony',
                'leetcode',
                'aops',
                'theoremqa_theorems',
                'theoremqa_questions'}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluate Rerankers')
    parser.add_argument('--model_path', required=True, help='base model path')
    parser.add_argument('--model_name', required=True, help='model name')
    parser.add_argument('--lora_module', default=None, help='lora path')
    parser.add_argument('--num_gpus', type=int, default=1)
    parser.add_argument('--corpus_name', required=True, help='dataset path')
    parser.add_argument('--bright_run_file', required=False, help='first stage retrieval run file for BRIGHT dataset')
    parser.add_argument('--qrels_path', required=False, default=None, help='Custom Qrels file for non-Pyserini corpora')
    parser.add_argument('--k', type=int, default=100, help='search results for final retrieval')
    parser.add_argument('--save_reasoning_text', action='store_true')
    parser.add_argument('--output_filename', required=True)
    args = parser.parse_args()
    ###############################
    if args.corpus_name not in BRIGHT_CORPUS:
        qids, queries = load_queries_qids(args.corpus_name)
        # Load Lucene searcher
        searcher = LuceneSearcher.from_prebuilt_index(THE_SPARSE_INDEX[args.corpus_name])
        if searcher == None:
            searcher = LuceneSearcher(THE_SPARSE_INDEX[args.corpus_name])
        bm25 = BM25(searcher=searcher, task=args.corpus_name)
        # Run Search
        bm25_outputs = bm25.run_search(qids, queries, k=args.k, return_passage_texts=True)
        bm25_output_filename = os.path.join(os.path.dirname(args.output_filename), f'bm25_{args.corpus_name}') 
    else:
        bright_corpus = load_dataset("xlangai/BRIGHT", 'documents')
        # NOTE: We rerank using the original query following the setup of the Rank1 paper
        bright_queries = load_dataset("xlangai/BRIGHT", 'examples')
        corpus = bright_corpus[args.corpus_name]
        queries = bright_queries[args.corpus_name]

        run_file = args.bright_run_file
        bm25_outputs = {}
        if '.json' in run_file:
            # Load JSON from a file
            with open(run_file, 'r') as file:
                bm25_run = json.load(file)
            bm25_run = {
                        query_id: dict(sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:args.k])
                        for query_id, doc_scores in bm25_run.items()
                        }

        else:
            # We assume it is a TREC file instead
            bm25_run = {}
            with open(run_file, 'r', encoding='utf-8') as f:
                for line in f:
                    query_id, _, doc_id, _, score, _ = line.strip().split()
                    bm25_run.setdefault(query_id, {})[doc_id] = float(score)

            # Now limit to top 100 per query, sorted by score
            bm25_run = {
                query_id: dict(sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:args.k])
                for query_id, doc_scores in bm25_run.items()
            }

        bm25_outputs['qids'] = [i for i in bm25_run for _ in bm25_run[i]]
        id_to_query = {ex['id']: ex['query'] for ex in queries}
        bm25_outputs['queries'] = [id_to_query[qid] for qid in bm25_outputs['qids']]
        bm25_outputs['docids'] = [docid.replace(' ', '_') for i in bm25_run for docid in bm25_run[i]]
        id_to_passage = {ex['id'].replace(' ', '_'): ex['content'] for ex in corpus}
        bm25_outputs['passage_texts'] = [id_to_passage[docid] for docid in bm25_outputs['docids']]    
        bm25_outputs['bm25_scores'] = [bm25_run[i][docid] for i in bm25_run for docid in bm25_run[i]]
        if 'bm25' in args.bright_run_file:
            bm25_output_filename = os.path.join(os.path.dirname(args.output_filename), f'bm25_{args.corpus_name}') 
        elif 'reasonir' in args.bright_run_file:
            bm25_output_filename = os.path.join(os.path.dirname(args.output_filename), f'reasonir_{args.corpus_name}') 
    write_scores_to_file(all_qids=bm25_outputs['qids'],
                            all_docids=bm25_outputs['docids'], 
                            scores=bm25_outputs['bm25_scores'], 
                            output_filename=bm25_output_filename)
    print("Quickly evaluating BM25....", flush=True)
    evaluate(args.corpus_name, bm25_output_filename, qrels_path=args.qrels_path)
    ###############################
    # Load reranker
    prompt = get_prompt(args.corpus_name)
    context_size=32768
    num_gpus=args.num_gpus
    model_name=args.model_name
    if model_name == 'standardrr':
        assert args.lora_module is not None, "Please provide a lora module"
        reranker = StandardRR(base_model_name_or_path=args.model_path,
                              lora_module=args.lora_module, 
                              context_size=context_size, 
                              num_gpus=num_gpus,
                              dataset_prompt=prompt)
    elif model_name == 'rank1_our_impl':
        assert args.lora_module is not None, "Please provide a lora module"
        reranker = ReasonRR(base_model_name_or_path=args.model_path,
                            lora_module=args.lora_module, 
                            context_size=context_size, 
                            num_gpus=num_gpus,
                            dataset_prompt=prompt)

    elif model_name == 'rank1_noreason':
        reranker = Rank1_NoReason(base_model_name_or_path=args.model_path,
                                  context_size=context_size, 
                                  num_gpus=num_gpus,
                                  dataset_prompt=prompt,
                                  )
    else:
        reranker = rank1(model_name_or_path=args.model_path,
                         context_size=context_size, 
                         num_gpus=num_gpus,
                         dataset_prompt=prompt)
    ###############################
    # Run Reranking!    
    print("Now reranking!", flush=True)
    reranker_output_filename = args.output_filename
    if model_name == 'standardrr':
        reranker_scores = reranker.predict(queries=bm25_outputs['queries'], passages=bm25_outputs['passage_texts'])
    else:
        if args.save_reasoning_text:
            texts, reranker_scores = \
                reranker.predict(queries=bm25_outputs['queries'], passages=bm25_outputs['passage_texts'], save_reasoning_text=args.save_reasoning_text)
            save_outputs_to_pkl(bm25_outputs['qids'], bm25_outputs['docids'], texts, output_filename=reranker_output_filename)
        else:
            reranker_scores = reranker.predict(queries=bm25_outputs['queries'], passages=bm25_outputs['passage_texts'])
        
    write_scores_to_file(bm25_outputs['qids'], bm25_outputs['docids'], reranker_scores, reranker_output_filename)
    ###############################
    # Eval!
    evaluate(args.corpus_name, reranker_output_filename, qrels_path=args.qrels_path)