"""
Sample images from HuggingFace dataset and save them to a specified directory.


Args:
    dataset_name (str): Name of the HuggingFace dataset to sample images from.
    num_samples (int): Number of images to sample.
    save_dir (str): Directory where the sampled images will be saved.
"""
import os
import argparse
from datasets import load_dataset
import torchvision.transforms as T
from PIL import Image
from io import BytesIO
import tqdm

def main(n, output_dir, seed, dataset_name, split):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Preprocessing: Resize and CenterCrop
    transform = T.Compose([
        T.Lambda(lambda img: T.CenterCrop(min(img.size))(img))
    ])

    # Load and shuffle dataset
    dataset = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True)
    shuffled = dataset.shuffle(seed=seed, buffer_size=10_000)

    # Sample and save images
    count = 0
    for example in tqdm.tqdm(shuffled, total=n, desc="Sampling images"):
        try:
            image = transform(example["image"])

            filename = os.path.join(args.output_dir, f"{count:05d}.JPEG")
            image.save(filename, format="JPEG")

            count += 1
            if count >= n:
                break
        except KeyBoardInterrupt:
            print("Interrupted by user.")
            break
        except Exception as e:
            print(f"Skipped one due to error: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Sample and save images from a HuggingFace dataset.")
    parser.add_argument("--n", type=int, required=True, help="Number of images to sample and save.")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the images.")
    parser.add_argument("--seed", type=int, default=0, help="Random seed for shuffling.")
    parser.add_argument("--dataset_name", type=str, default="ILSVRC/imagenet-1k", help="Name of the HuggingFace dataset.")
    parser.add_argument("--split", type=str, default="test", help="Dataset split (e.g., 'train', 'test').")

    args = parser.parse_args()
    main(
        n=args.n, 
        output_dir=args.output_dir, 
        seed=args.seed, 
        dataset_name=args.dataset_name, 
        split=args.split
    )
