from datasets import Dataset, IterableDataset,DatasetDict, Features, Image, Value, \
    load_dataset, concatenate_datasets  
from typing import Annotated, cast
import typer
from pathlib import Path
import PIL.Image as PILImage
from tqdm import tqdm
import random

random.seed(42)

def shorten_image_path(image_path: str) -> str:
    """
    Shorten the image path to make it more readable.
    """
    full_path = Path(image_path)
    enclosing_dir = full_path.parent.parent
    return full_path.relative_to(enclosing_dir).as_posix()



def add_all_pages(dataset : Dataset, sample_size : int = 900) -> Dataset:
    """
    Add all pages to the dataset
    """
    # Get all the pages from base directory
    base_dir = Path(dataset[0]['image_filename']).parent.parent
    print(f"Base directory is {base_dir}")  
    pages = list(base_dir.glob('**/*.jpg'))
    pages = [str(page) for page in pages]

    print(f"Found {len(pages)} pages")

    existing_pages = list(set(dataset['image_filename']))

    pages = [page for page in pages if page not in existing_pages]

    print(f"Adding {len(pages)} pages to the dataset")

    pages = random.sample(pages, sample_size) if len(pages) > sample_size else pages 
    # Add the pages to the dataset

    features = Features(
        {
            "query": Value("string"),
            "image": Image(),
            "image_filename": Value("string"),
            "answer": Value("string"),
            "page": Value("string"),
            "model": Value("string"),
            "prompt": Value("string"),
            "source": Value("string"),
        }
    )


    def gen():
        with tqdm(total=len(pages)) as pbar:
            for image_path in pages:
                pbar.set_description(f"Processing {shorten_image_path(image_path)}")

                pil_image = PILImage.open(image_path)

                yield {
                    "query": None,
                    "image": pil_image,
                    "image_filename": image_path,
                    "answer": None,
                    "page": Path(image_path).stem.split("_")[-1],
                    "model": None,
                    "prompt": None,
                    "source": "pdf",
                }
                pbar.update(1)

    ds = cast(IterableDataset, Dataset.from_generator(gen, features=features))

    dataset = concatenate_datasets([dataset, ds])

    print(f"Dataset now has {len(dataset)} examples")

    return dataset
    

def main(dataset_name : Annotated[str, typer.Argument(help= "dataset on hugging face to evaluate")], 
        sample_size : Annotated[int, typer.Option(help="Number of pages to add to the dataset")] = 900):

    ds = load_dataset(dataset_name)['test']

    ds = add_all_pages(ds, sample_size=sample_size)

    print(f"Finished adding all pages to {dataset_name}")

    ds = DatasetDict({'test': ds})  
    ds.push_to_hub(dataset_name.replace('_queries_only', ''))
    
if __name__ == "__main__":
    typer.run(main)
