from typing import Annotated

import typer
from datasets import Dataset, DatasetDict, load_dataset
from dotenv import load_dotenv

from custom_colbert.utils.image_utils import scale_to_max_dimension

load_dotenv(override=True)

N_TEST_EXAMPLES = 100


def validate_dataset(dataset: Dataset) -> DatasetDict:
    """
    Manual validation of a dataset.
    This function will display the image, query and answer for each row in the dataset and
    ask the user if they want to keep the row. The kept rows will be saved in a new dataset.

    Args:

    dataset: The dataset to validate

    Returns:

    DatasetDict: A new dataset containing the validated rows

    """
    features = dataset.features
    dataset_dict = {key: [] for key in features.keys()}
    print(f"Validating {len(dataset)} rows")

    for data in dataset.shuffle(seed=42):
        print(f"**Image:** {data['image_filename']}")
        print(f"**Query:** {data['query']}")
        print(f"**Answer:** {data['answer']}")
        image = scale_to_max_dimension(data["image"], 1024)
        image.show()

        user_input = input("\nKeep this row? (y/n): ").strip().lower()
        # write input to a text file
        with open(f"validation_backup.txt", "a") as f:
            f.write(f"{data['query']}\n{user_input}\n\n")

        if user_input == "y":
            for key in data.keys():
                dataset_dict[key].append(data[key])
        print(f"No of rows kept: {len(dataset_dict['query'])}\n-------------------\n")

        if len(dataset_dict["query"]) >= N_TEST_EXAMPLES or user_input == "stop":
            print("Stopping validation as we have enough rows")
            break
    # add a test split
    dataset = {"test": Dataset.from_dict(dataset_dict, features=features)}
    return DatasetDict(dataset)


def main(dataset_name: Annotated[str, typer.Argument(help="The name of the dataset to validate")]):

    try:
        ds = load_dataset(f"coldoc/syntheticDocQA_{dataset_name}_test_raw")
        validated_dataset = validate_dataset(ds["test"])
        validated_dataset.push_to_hub(f"coldoc/syntheticDocQA_{dataset_name}_test")
    except KeyError:
        print(f"Dataset {dataset_name} does not exist")
        return


if __name__ == "__main__":
    typer.run(main)

# python scripts/generate_data/validate_test_set.py healthcare_industry
# python scripts/generate_data/validate_test_set.py government_reports
# python scripts/generate_data/validate_test_set.py artificial_intelligence
# python scripts/generate_data/validate_test_set.py energy
