import json
from pathlib import Path
from typing import Annotated, cast

import huggingface_hub
import nltk
import torch
import typer
from datasets import Dataset, load_dataset
from dotenv import load_dotenv
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi

from custom_colbert.trainer.retrieval_evaluator import CustomEvaluator

nltk.download("stopwords")

OUTPUT_MEASURE_EMB_SIZE = Path("outputs/measure_emb_size/")
DEBUG_SAVE_TENSOR = False
if DEBUG_SAVE_TENSOR:
    print("DEBUG_SAVE_TENSOR is set to `True`: the first embeddings will be saved on disk and the script will stop.")

load_dotenv(override=True)


def log_metrics(metrics, dataset_name, log_file="metrics.json"):
    if Path(log_file).exists():
        loaded_metrics = json.load(open(log_file))
        loaded_metrics[dataset_name] = metrics
        with open(log_file, "w") as f:
            json.dump(loaded_metrics, f)
    else:
        with open(log_file, "w") as f:
            json.dump({dataset_name: metrics}, f)


def preprocess_text(documents):
    """
    Basic preprocessing of the text data in french : remove stopwords, punctuation, lowercase all the words

    return the tokenized list of words
    """
    stop_words = set(stopwords.words("english"))
    tokenized_list = [
        [
            word.lower()
            for word in word_tokenize(sentence, language="french")
            if word.isalnum() and word.lower() not in stop_words
        ]
        for sentence in documents.values()
    ]

    return tokenized_list


def recall_at_k_bm25(
    bm25: BM25Okapi, corpus: dict, questions: dict, passage2filename: dict, query2filename: dict, k: int
):
    recall = 0

    # experiment without preprocessing
    # tokenized_questions = [doc.split(" ") for doc in questions.values()]

    # experiment with preprocessing
    tokenized_questions = preprocess_text(questions)
    not_retrieved = []
    corresponding_indices_retrieved = []

    if DEBUG_SAVE_TENSOR:
        savedir_dense = OUTPUT_MEASURE_EMB_SIZE / "bm25" / "dense"
        savedir_dense.mkdir(parents=True, exist_ok=True)

        savedir_sparse = OUTPUT_MEASURE_EMB_SIZE / "bm25" / "sparse"
        savedir_sparse.mkdir(parents=True, exist_ok=True)

        for i, question in enumerate(tokenized_questions):
            scores = bm25.get_scores(question)

            if i == 0:
                # NOTE: all dense tensors have the same size, so we only need to save the first one
                scores_dense = torch.tensor(scores).to(torch.float32)
                savepath_dense = savedir_dense / f"doc_embedding_{i}.pt"
                torch.save(scores_dense, savepath_dense)
                print(f"Dense document embedding saved to `{savepath_dense}`")

            scores_sparse = torch.tensor(scores).to_sparse().coalesce()
            savepath_sparse = savedir_sparse / f"doc_embedding_{i}.pt"
            torch.save(scores_sparse, savepath_sparse)
            print(f"Sparse document embedding saved to `{savepath_sparse}`")

        exit()

    for i, question in enumerate(tokenized_questions):
        scores = bm25.get_scores(question)

        # get the top k indices
        top_indices = scores.argsort()[-k:]

        # get the corresponding indices in tabfquad dataset
        corresponding_indices = [list(corpus.keys())[idx] for idx in top_indices]
        ground_truth = list(questions.items())[i][0]
        ground_truth_filename = query2filename[questions[ground_truth]]
        retrieved_filenames = [
            passage2filename[corpus[corresponding_index]] for corresponding_index in corresponding_indices
        ]

        # print(f"Ground truth: {ground_truth_filename}")
        # print(f"Retrieved: {retrieved_filenames}")

        if ground_truth_filename in retrieved_filenames:
            recall += 1
        else:
            not_retrieved.append(ground_truth)
            corresponding_indices_retrieved.append(corresponding_indices)

    return recall / len(questions), not_retrieved, corresponding_indices_retrieved


def evaluate_bm25(dataset_name: str, model_name: str, text_only: bool = False):

    ds = cast(Dataset, load_dataset(dataset_name, split="test"))
    evaluator = CustomEvaluator(is_multi_vector=False)

    print(f"Dataset length: {len(ds)}")
    if "chunk_type" in ds.column_names and text_only:
        ds = ds.filter(lambda x: x["chunk_type"] == "text")
        print(f"Text only dataset length: {len(ds)}")
    if text_only:
        text_str = "text"
    else:
        text_str = ""
    evaluator = CustomEvaluator(is_multi_vector=False)
    queries = list(set(ds["query"]))

    # remove None queries
    if None in queries:
        queries.remove(None)

    passages = ds["text_description"]
    query2filename = {query: filename for query, filename in zip(ds["query"], ds["image_filename"])}

    passages2filename = {
        passage: image_filename for passage, image_filename in zip(ds["text_description"], ds["image_filename"])
    }

    corpus = {idx: passage for idx, passage in enumerate(passages)}
    questions = {idx: query for idx, query in enumerate(queries)}

    tokenized_corpus = preprocess_text(corpus)
    tokenized_questions = preprocess_text(questions)
    bm25 = BM25Okapi(tokenized_corpus)

    relevant_docs = {}
    results = {}
    for i, question in enumerate(tokenized_questions):
        # Get the scores for each passage in the corpus
        scores = bm25.get_scores(question)
        top_indices = scores.argsort()

        # Get the ground truth filename for the current question
        ground_truth = list(questions.items())[i][0]
        ground_truth_filename = query2filename[questions[ground_truth]]
        relevant_docs[questions[ground_truth]] = {ground_truth_filename: 1}

        # Store the top-k results for the current question
        for idx in top_indices:
            passage = list(corpus.keys())[idx]
            filename = passages2filename[corpus[passage]]
            score_passage = float(scores[idx])

            if questions[ground_truth] in results:
                results[questions[ground_truth]][filename] = max(results[questions[ground_truth]].get(filename, 0), score_passage)
            else:
                results[questions[ground_truth]] = {filename: score_passage}

    assert len(results) == len(tokenized_questions), f"Results length is {len(results)} instead of {len(tokenized_questions)}"
    
    metrics = evaluator.compute_metrics(relevant_docs, results)

    log_metrics(metrics, dataset_name, log_file=f"data/evaluation_results/metrics/metrics_{model_name}_ocrvis.json")


def main(
    model_name: Annotated[str, typer.Option(help="model name to use for evaluation")],
    collection_name: Annotated[str, typer.Option(help="collection name to use for evaluation")] = "",
    dataset_name: Annotated[str, typer.Option(help="dataset on hugging face to evaluate")] = "",
    text_only: Annotated[bool, typer.Option(help="If True, only text chunks will be used for evaluation")] = False,
    # ocr_only: Annotated[bool, typer.Option(help="If True, only ocr chunks will be used for evaluation")] = False,
):
    # print parameters
    print(f"Model name: {model_name}")
    print(f"Dataset name: {dataset_name}")
    print(f"Text only: {text_only}")
    # print(f"OCR only: {ocr_only}")

    if dataset_name != "":
        evaluate_bm25(dataset_name, model_name, text_only=text_only)
    elif collection_name != "":
        collection = huggingface_hub.get_collection(collection_name)
        datasets = collection.items
        for dataset in datasets:
            print(f"\n---------------------------\nEvaluating {dataset.item_id}")
            evaluate_bm25(dataset.item_id, model_name, text_only=text_only)


if __name__ == "__main__":
    typer.run(main)
