"""
For some odd reason, the huggingface imagenet dataset class is super slow.
In order to hopefully fix that, I will convert the dataset from the arrow thing that hf uses
to using simple image files.
"""

from datasets import load_dataset
from tqdm import tqdm
import os
from .json_stuff import dict_to_json


class ImageNetDatasetConverter:
    def __init__(
        self,
        existing_cache_dir: str,
        slice_name: str,
        output_folder: str,
        max_num_samples=5000,
        size: tuple = (224, 224),
    ):
        """
        output_folder:
            images/
            labels.json
        """
        assert (
            slice_name in ["train", "validation"],
            f"Invalid slice name for ImageNetDataset: {slice_name}",
        )
        assert os.path.exists(existing_cache_dir) == True
        assert os.path.exists(output_folder) == True

        self.dataset = load_dataset(
            "imagenet-1k",
            cache_dir=existing_cache_dir,
        )[slice_name]

        self.output_folder = output_folder
        self.max_num_samples = max_num_samples
        self.size = size
        self.image_folder = f"{output_folder}/images"
        os.system(f"mkdir -p {self.image_folder}")

    def convert(self):
        labels = []

        num_samples = (
            self.max_num_samples
            if self.max_num_samples is not None
            else len(self.dataset)
        )
        for idx in tqdm(range(num_samples), desc="Converting dataset"):

            image_path = os.path.join(self.image_folder, f"{idx}.jpg")

            if not os.path.exists(image_path):
                item = self.dataset[idx]
                image = item["image"].resize(self.size)
                label = item["label"]

                labels.append(label)

                try:
                    image.save(image_path)
                except OSError:
                    image = image.convert("RGB")
                    image.save(image_path)

            else:
                item = self.dataset[idx]
                label = item["label"]
                labels.append(label)

        print(f"Saving labels...")
        dict_to_json(
            dictionary=labels, filename=os.path.join(self.output_folder, "labels.json")
        )
        print("Done!")
