from typing import List
from datasets import Dataset, DatasetDict, Features, Image, Value, load_dataset
import pytesseract
from custom_colbert.tesseract import to_extracted_words
import typer
from typing import Annotated
import huggingface_hub
import time
from tqdm import tqdm


def get_ocr_from_ds(ds: Dataset) -> List[str]:
    """
    Get the OCR text from the dataset
    """
    dataset_dict = {}

    times = []

    def gen():
        for data in tqdm(ds):
            if data["chunk_type"] == "text":
               pass

            else:
                image = data["chunk_image"]
                start = time.time()
                ocr = pytesseract.image_to_data(image, output_type="dict")
                end = time.time()
                times.append(end - start)

                extracted_words = to_extracted_words(ocr)
                data["text_description"] = " ".join([word.text for word in extracted_words])

            # yield all data from the original dataset and the new ocr_chunk
            yield {
                **data
            }

    dataset_dict["test"] = Dataset.from_generator(gen, features=ds.features)

    if len(times) == 0:
        print("No OCR text extracted")
        return DatasetDict(dataset_dict)

    print(f"OCR extraction time: {sum(times)}")
    print(f"Average OCR extraction time: {sum(times) / len(times)}")

    return DatasetDict(dataset_dict)


def main(
    # model_name: Annotated[str, typer.Option(help="model name to use for evaluation")],
    dataset_name: Annotated[str, typer.Option(help="dataset on hugging face to evaluate")] = "",
    collection_name: Annotated[str, typer.Option(help="collection name to use for evaluation")] = "",
    # text_only: Annotated[bool, typer.Option(help="If True, only text chunks will be used for evaluation")] = False,
    # ocr_only: Annotated[bool, typer.Option(help="If True, only ocr chunks will be used for evaluation")] = False,
):
    print(f"Dataset name: {dataset_name}")
    print(f"Collection name: {collection_name}")
    if dataset_name != "":
        ds = load_dataset(dataset_name)["test"]
        print(f"Dataset loaded: {ds}")

        new_ds = get_ocr_from_ds(ds)

        new_ds["test"].push_to_hub(dataset_name + "_ocr_chunk", split="test")

    elif collection_name != "":
        collection = huggingface_hub.get_collection(collection_name)
        datasets = collection.items
        print(datasets)
        for dataset in datasets:
            print(f"\n---------------------------\nEvaluating {dataset.item_id}")
            ds = load_dataset(dataset.item_id)["test"]
            print(f"Dataset loaded: {ds}")
            new_ds = get_ocr_from_ds(ds)

            new_ds["test"].push_to_hub(dataset.item_id + "_ocr_chunk", split="test")

    dataset_dict = get_ocr_from_ds(ds)
    print(dataset_dict)


if __name__ == "__main__":
    typer.run(main)


# python scripts/baselines/tesseract.py --collection_name "coldoc/syntheticdocqa-test-6655c59cfda461267c0d9ac8"
# python scripts/baselines/tesseract.py --dataset_name "coldoc/existing-datasets-test-6655c5e0504da7ec0c14253c"
