import bm25s
import Stemmer
from tqdm import tqdm
import json
import numpy as np
from rank_bm25 import BM25Okapi
from fire import Fire
def calculate_bm25_scores(train_embs,test_embs):
    if(not isinstance(test_embs, list)):
        test_embs = [test_embs]
    corpus = train_embs
    tokenized_corpus = [doc.split(" ") for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
    scores = []
    for query in test_embs:
        tokenized_query = query.split(" ")
        doc_scores = bm25.get_scores(tokenized_query)
        scores.append(doc_scores)
    if len(test_embs) == 1:
        return np.array(scores[0])
    else:
        return np.vstack(scores)
def calculate_bm25_scores_bm25s(train_embs,test_embs):
    
    corpus = train_embs
    stemmer = Stemmer.Stemmer("english")

    corpus_tokens = bm25s.tokenize(corpus, stopwords="en",stemmer=stemmer)

    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    length = len(test_embs)
    all_scores = np.zeros((length,len(train_embs)))
    
    for idx,query in tqdm(enumerate(test_embs),desc=f"Calculating BM25 scores"):
        query = [query]
        scores = np.array([])
        query_tokens = bm25s.tokenize(query,stemmer=stemmer)

        results, doc_scores = retriever.retrieve(query_tokens, corpus=corpus, k=len(train_embs))

        score_dict = {}
        for i in range(results.shape[1]):
            doc_id = results[0, i]
            score = doc_scores[0, i]
            score_dict[doc_id] = score

        original_order_score = []
        for doc in train_embs:
            score = score_dict.get(doc, 0.0)  
            original_order_score.append(score)

        scores = np.append(scores, original_order_score)
        scores = scores.reshape(-1, len(train_embs))
        scores = scores[0]
        all_scores[idx] = scores
    return all_scores
def main(dataset):
    print(f"Starting to calculate BM25 scores for {dataset}...")
    if(dataset != "gsm8k"):
        with open(f"./data/{dataset}/{dataset}_train.jsonl", 'r', encoding='utf-8') as file:
            ds_train = [json.loads(line) for line in file]
        with open(f"./data/{dataset}/{dataset}_test.jsonl", 'r', encoding='utf-8') as file:
            ds_test = [json.loads(line) for line in file]
    else:
        with open(f"./data/{dataset}/{dataset}_train.json", 'r', encoding='utf-8') as file:
            ds_train = json.load(file)
        with open(f"./data/{dataset}/{dataset}_test.json", 'r', encoding='utf-8') as file:
            ds_test = json.load(file)


    if isinstance(ds_train, list) and all(isinstance(item, dict) for item in ds_train) and isinstance(ds_test, list) and all(isinstance(item, dict) for item in ds_test):
        if("sciq" in dataset or "squad" in dataset):
            train_embs = [f"Support: {data['support']}\nQuestion: {data['question']}" for data in ds_train]
            test_embs = [f"Support: {data['support']}\nQuestion: {data['question']}" for data in ds_test]
        else:
            train_embs = [data["question"] for data in ds_train]
            test_embs = [data["question"] for data in ds_test]

    print("Start calculating test scores")
    test_scores = calculate_bm25_scores_bm25s(train_embs=train_embs,test_embs=test_embs)

    print("Start calculating train scores")
    train_scores = calculate_bm25_scores_bm25s(train_embs=train_embs,test_embs=train_embs)


    print(train_scores[:10,:10])
    print(test_scores[:10,:10])
    np.save(f'./data/{dataset}/{dataset}_train_bm25_scores.npy', train_scores)
    np.save(f'./data/{dataset}/{dataset}_test_bm25_scores.npy', test_scores)

    print(f"BM25 scores for {dataset} have been saved.")
if __name__ == "__main__":
    Fire(main)