"""
Methods to generate eval datasets for vision retriever benchmarks  
"""

import hashlib
import json
import os
from typing import Annotated

import numpy as np
import typer
from datasets import Dataset, DatasetDict, Features, Image, Value, load_dataset
from dotenv import load_dotenv
from pdf2image import convert_from_path
from PIL import Image as PILImage
from tqdm import tqdm

load_dotenv()


def add_metadata_column(dataset, column_name, value):

    def add_source(example):
        example[column_name] = value
        return example

    return dataset.map(add_source)


# DocVQA eval
def load_docvqa_eval() -> DatasetDict:

    dataset_doc_eval = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
    dataset_doc_eval = dataset_doc_eval.rename_column("question", "query")
    dataset_doc_eval = dataset_doc_eval.rename_column("answers", "answer")
    dataset_doc_eval = dataset_doc_eval.rename_column("ucsf_document_page_no", "page")
    dataset_doc_eval = dataset_doc_eval.rename_column("ucsf_document_id", "image_filename")
    dataset_doc_eval = add_metadata_column(dataset_doc_eval, "source", "docvqa")

    dataset_doc_test = load_dataset("lmms-lab/DocVQA", "DocVQA", split="test")
    dataset_doc_test = dataset_doc_test.rename_column("question", "query")
    dataset_doc_test = dataset_doc_test.rename_column("answers", "answer")
    dataset_doc_test = dataset_doc_test.rename_column("ucsf_document_page_no", "page")
    dataset_doc_test = dataset_doc_test.rename_column("ucsf_document_id", "image_filename")
    dataset_doc_test = add_metadata_column(dataset_doc_test, "source", "docvqa")

    ds_dict = DatasetDict({"eval": dataset_doc_eval, "test": dataset_doc_test})
    return ds_dict


# DocVQA train
def load_docvqa_train() -> DatasetDict:
    docvqa_ds = load_dataset("HuggingFaceM4/the_cauldron", "docvqa")["train"]

    dataset_dict = {}
    features = Features(
        {
            "image": Image(),
            "image_filename": Value("string"),
            "query": Value("string"),
            "answer": Value("string"),
            "source": Value("string"),
        }
    )
    # generate a hash of the image to avoid duplicates

    def gen():
        for example in docvqa_ds:
            # hash the image
            image_hash = hashlib.sha256(np.array(example["images"][0]).flatten()).hexdigest()
            for text in example["texts"]:
                query = text["user"]
                answer = text["assistant"]
                source = text["source"]

                yield {
                    "image": example["images"][0],
                    "image_filename": image_hash,
                    "query": query,
                    "answer": answer,
                    "source": source,
                }

    dataset_dict["train"] = Dataset.from_generator(gen, features=features)
    return DatasetDict(dataset_dict)


# InfographicVQA eval
def load_infovqa_eval() -> DatasetDict:

    dataset_info_eval = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="validation")
    dataset_info_eval = dataset_info_eval.rename_column("answers", "answer")
    dataset_info_eval = dataset_info_eval.rename_column("question", "query")
    dataset_info_eval = dataset_info_eval.rename_column("image_url", "image_filename")
    dataset_info_eval = add_metadata_column(dataset_info_eval, "source", "infovqa")

    dataset_info_test = load_dataset("lmms-lab/DocVQA", "InfographicVQA", split="test")
    dataset_info_test = dataset_info_test.rename_column("answers", "answer")
    dataset_info_test = dataset_info_test.rename_column("question", "query")
    dataset_info_test = dataset_info_test.rename_column("image_url", "image_filename")
    dataset_info_test = add_metadata_column(dataset_info_test, "source", "infovqa")

    ds_dict = DatasetDict({"eval": dataset_info_eval, "test": dataset_info_test})

    return ds_dict


# InfoVQA train
def load_infovqa_train() -> DatasetDict:
    infovqa_ds = load_dataset("HuggingFaceM4/the_cauldron", "infographic_vqa")["train"]

    dataset_dict = {}
    features = Features(
        {
            "image": Image(),
            "image_filename": Value("string"),
            "query": Value("string"),
            "answer": Value("string"),
            "source": Value("string"),
        }
    )

    def gen():
        for example in infovqa_ds:
            hash_image = hashlib.sha256(np.array(example["images"][0]).flatten()).hexdigest()

            for text in example["texts"]:
                query = text["user"]
                answer = text["assistant"]
                source = text["source"]

                yield {
                    "image": example["images"][0],
                    "image_filename": hash_image,
                    "query": query,
                    "answer": answer,
                    "source": source,
                }

    dataset_dict["train"] = Dataset.from_generator(gen, features=features)

    return DatasetDict(dataset_dict)


# TabFquad eval
def load_tabfquad_retrieving() -> DatasetDict:
    dataset = load_dataset("manu/tabfquad_retrieving", split="test")
    dataset = add_metadata_column(dataset, "source", "tabfquad")
    ds_dict = DatasetDict({"test": dataset})
    return ds_dict


# TATDQA eval
def load_tatdqa() -> DatasetDict:
    # Define the path to the dataset folder
    data_dir = "data/downloaded_datasets/tatdqa"

    # Create a Dataset from the dictionary
    features = Features(
        {
            "query": Value("string"),
            "image_filename": Value("string"),
            "image": Image(),
            "answer": Value("string"),
            "answer_type": Value("string"),
            "page": Value("string"),
            "model": Value("string"),
            "prompt": Value("string"),
            "source": Value("string"),
        }
    )

    dataset_dict = {}

    for split in ["train", "dev", "test"]:
        split_file = os.path.join(data_dir, f"tatdqa_dataset_{split}.json")

        def gen():
            with open(split_file, "r") as f:
                entries = json.load(f)
                for entry in entries:
                    doc_id = entry["doc"]["uid"]
                    for question in entry["questions"]:
                        question_text = question["question"]
                        pdf_path = os.path.join(data_dir, split, f"{doc_id}.pdf")

                        if os.path.exists(pdf_path):
                            # print(question)
                            yield {
                                "query": question_text,
                                "image_filename": pdf_path,
                                "image": PILImage.fromarray(np.array(convert_from_path(pdf_path)[0]), "RGB"),
                                "source": "tatdqa",
                                "answer": question["answer"] if "answer" in question else "",
                                "answer_type": question["answer_type"] if "answer_type" in question else "",
                                "page": entry["doc"]["page"],
                                "model": "",
                                "prompt": "",
                            }

        dataset_dict[split] = Dataset.from_generator(gen, features=features)

    return DatasetDict({"train": dataset_dict["train"], "eval": dataset_dict["dev"], "test": dataset_dict["test"]})


# arxiv qa
def load_arxivqa() -> DatasetDict:
    with open("data/arxiv_qa/arxivqa.jsonl", "r") as fr:
        arxiv_qa = [json.loads(line.strip()) for line in fr]

    dataset_dict = {
        "query": [],
        "image": [],
        "image_filename": [],
        "options": [],
        "answer": [],
        "page": [],
        "model": [],
        "prompt": [],
        "source": [],
    }

    features = Features(
        {
            "query": Value("string"),
            "image": Image(),
            "image_filename": Value("string"),
            "options": Value("string"),
            "answer": Value("string"),
            "page": Value("string"),
            "model": Value("string"),
            "prompt": Value("string"),
            "source": Value("string"),
        }
    )

    for qa in tqdm(arxiv_qa, desc="Processing entries", total=len(arxiv_qa)):
        dataset_dict["query"].append(qa["question"])
        dataset_dict["image"].append(PILImage.open("data/arxiv_qa/" + qa["image"]))
        dataset_dict["image_filename"].append(qa["image"])
        dataset_dict["options"].append(qa["options"])
        dataset_dict["answer"].append(qa["label"])
        dataset_dict["page"].append("")
        dataset_dict["model"].append("gpt4V")
        dataset_dict["prompt"].append("")
        dataset_dict["source"].append("arxiv_qa")

        # close the image
        dataset_dict["image"][-1].close()

    dataset = Dataset.from_dict(dataset_dict, features=features)

    # split the dataset into train, validation and test
    dataset = dataset.train_test_split(test_size=0.05, seed=42)

    return DatasetDict({"train": dataset["train"], "test": dataset["test"]})


def main(dataset: Annotated[str, typer.Argument(help="The dataset to generate.")]):
    if dataset == "docvqa_eval":
        ds_dict = load_docvqa_eval()
        ds_dict["test"].push_to_hub("coldoc/docvqa_test", split="test")
        ds_dict["eval"].push_to_hub("coldoc/docvqa_eval", split="eval")

    elif dataset == "docvqa_train":
        ds_dict = load_docvqa_train()
        ds_dict["train"].push_to_hub("coldoc/docvqa_train", split="train")

    elif dataset == "infovqa_train":
        ds_dict = load_docvqa_train()
        ds_dict["train"].push_to_hub("coldoc/infovqa_train", split="train")

    elif dataset == "infovqa_eval":
        ds_dict = load_infovqa_eval()
        ds_dict["test"].push_to_hub("coldoc/infovqa_test", split="test")
        ds_dict["eval"].push_to_hub("coldoc/infovqa_eval", split="eval")

    elif dataset == "tabfquad_test":
        ds_dict = load_tabfquad_retrieving()
        ds_dict.push_to_hub("coldoc/tabfquad_retrieving_test", split="test")

    elif dataset == "tatdqa":
        ds_dict = load_tatdqa()
        ds_dict["train"].push_to_hub("coldoc/tatdqa_train", split="train")
        ds_dict["eval"].push_to_hub("coldoc/tatdqa_train", split="eval")
        ds_dict["test"].push_to_hub("coldoc/tatdqa_test", split="test")
    elif dataset == "arxivqa":
        ds_dict = load_arxivqa()
        ds_dict["train"].push_to_hub("coldoc/arxivqa_train", split="train")
        ds_dict["test"].push_to_hub("coldoc/arxivqa_test", split="test")
    else:
        raise ValueError(f"Dataset {dataset} not supported.")

    print("Dataset successfully generated and pushed to the hub.")


if __name__ == "__main__":
    typer.run(main)
