from typing import List
from datasets import Dataset, DatasetDict, Features, Image, Value, load_dataset
import pytesseract
from custom_colbert.tesseract import to_extracted_words
import typer
from typing import Annotated
import huggingface_hub
import time


def get_ocr_from_ds(ds: Dataset) -> List[str]:
    """
    Get the OCR text from the dataset
    """
    dataset_dict = {}

    features = Features(
        {
            "image": Image(),
            "image_filename": Value("string"),
            "query": Value("string"),
            #"answer": Value("string"),
            "text_description": Value("string"),
        }
    )

    times = []

    def gen():     
        for data in ds:   
            image = data["image"]
            image_filename = data["image_filename"]
            query = data["query"]

            start = time.time()
            data = pytesseract.image_to_data(image, output_type="dict")
            end = time.time()
            times.append(end - start)

            extracted_words = to_extracted_words(data)
            page_text = " ".join([word.text for word in extracted_words])

            yield {
                "image": image,
                "image_filename": image_filename,
                "query": query,
                "text_description": page_text,
            }

    dataset_dict['test'] = Dataset.from_generator(gen, features=features)

    if len(times) == 0:
        print("No OCR text extracted")
        return DatasetDict(dataset_dict)
        
    print(f"OCR extraction time: {sum(times)}")
    print(f"Average OCR extraction time: {sum(times) / len(times)}")

    return DatasetDict(dataset_dict)


def main(
    #model_name: Annotated[str, typer.Option(help="model name to use for evaluation")],
    dataset_name: Annotated[str, typer.Option(help="dataset on hugging face to evaluate")] = "",
    collection_name: Annotated[str, typer.Option(help="collection name to use for evaluation")] = "",
    #text_only: Annotated[bool, typer.Option(help="If True, only text chunks will be used for evaluation")] = False,
    #ocr_only: Annotated[bool, typer.Option(help="If True, only ocr chunks will be used for evaluation")] = False,
):
    print(f"Dataset name: {dataset_name}")
    print(f"Collection name: {collection_name}")
    if dataset_name != "":
        ds = load_dataset(dataset_name)['test']
        print(f"Dataset loaded: {ds}")

        new_ds = get_ocr_from_ds(ds)

        new_ds["test"].push_to_hub(dataset_name + "_tesseract", split="test")

    elif collection_name != "":
        collection = huggingface_hub.get_collection(collection_name)
        datasets = collection.items
        print(datasets)
        for dataset in datasets:
            print(f"\n---------------------------\nEvaluating {dataset.item_id}")
            ds = load_dataset(dataset.item_id)['test']
            print(f"Dataset loaded: {ds}")
            new_ds = get_ocr_from_ds(ds)

            new_ds["test"].push_to_hub(dataset.item_id + "_tesseract", split="test")

    dataset_dict = get_ocr_from_ds(ds)
    print(dataset_dict)


if __name__ == "__main__":
    typer.run(main)


#python scripts/baselines/tesseract.py --collection_name "coldoc/syntheticdocqa-test-6655c59cfda461267c0d9ac8"
#python scripts/baselines/tesseract.py --dataset_name "coldoc/existing-datasets-test-6655c5e0504da7ec0c14253c"