import bm25s
import Stemmer
from tqdm import tqdm
import json
import numpy as np
from rank_bm25 import BM25Okapi
from fire import Fire
def calculate_bm25_scores(train_embs,dev_embs):
    if(not isinstance(dev_embs, list)):
        dev_embs = [dev_embs]
    corpus = train_embs
    tokenized_corpus = [doc.split(" ") for doc in corpus]
    bm25 = BM25Okapi(tokenized_corpus)
    scores = []
    for query in tqdm(dev_embs):
        tokenized_query = query.split(" ")
        doc_scores = bm25.get_scores(tokenized_query)
        scores.append(doc_scores)
    if len(dev_embs) == 1:
        return np.array(scores[0])
    else:
        return np.vstack(scores)
def calculate_bm25_scores_bm25s(train_embs,dev_embs):
    
    corpus = train_embs
    stemmer = Stemmer.Stemmer("english")

    corpus_tokens = bm25s.tokenize(corpus, stopwords="en",stemmer=stemmer)

    retriever = bm25s.BM25()
    retriever.index(corpus_tokens)
    scores = np.array([])
    for query in tqdm(dev_embs):
        query_tokens = bm25s.tokenize(query,stemmer=stemmer)

        results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=len(train_embs))
        id_to_index = {doc_id: index for index, doc_id in enumerate(train_embs)}

        original_order_score = [0] * len(train_embs)
        for i in range(results.shape[1]):
            doc_id = results[0, i]
            score = scores[0, i]
            original_index = id_to_index[doc_id]
            original_order_score[original_index] = score

        scores = np.append(scores, original_order_score)
        
        scores = scores.reshape(-1, len(train_embs))
        scores = scores[0]
    return scores
def main(train_dataset, test_dataset):
    if(train_dataset == "gsm8k"):
        with open(f"data/{train_dataset}/{train_dataset}_train.json", 'r', encoding='utf-8') as file:
            ds_train = json.load(file)
    else:
        with open(f"data/{train_dataset}/{train_dataset}_train.jsonl", 'r', encoding='utf-8') as file:
            ds_train = [json.loads(line) for line in file]
    if(test_dataset == "gsm8k"):
        with open(f"data/{test_dataset}/{test_dataset}_test.json", 'r', encoding='utf-8') as file:
            ds_test = json.load(file)
    else:
        with open(f"data/{test_dataset}/{test_dataset}_test.jsonl", 'r', encoding='utf-8') as file:
            ds_test = [json.loads(line) for line in file]

    if isinstance(ds_train, list) and all(isinstance(item, dict) for item in ds_train) and isinstance(ds_test, list) and all(isinstance(item, dict) for item in ds_test):
        if( "gsm8k" in train_dataset or "prm800k" in train_dataset):
            train_embs = [example["question"] for example in ds_train]
            dev_embs = [example["question"] for example in ds_test]
        elif("squad" in train_dataset or "sciq" in train_dataset):
            train_embs = [f"Support: {example['support']}\nQuestion: {example['question']}" for example in ds_train]
            dev_embs = [f"Support: {example['support']}\nQuestion: {example['question']}" for example in ds_test]
    
    dev_embs = dev_embs[:10]
    for dev_emb in dev_embs:
        dev_emb = [dev_emb]
        scores = calculate_bm25_scores_bm25s(train_embs=train_embs,dev_embs=dev_emb)
    
        print(scores)
        
if __name__ == "__main__":
    Fire(main)