from nesim.utils.folder import get_filenames_in_a_folder
import os
from PIL import Image

class SizeAnimacyDataset:
    def __init__(self, folder: str):
        possible_sizes = ["Small", "Big"]
        possible_animacies = ["Animate", "Inanimate"]

        self.labels = {
            "size": [],
            "animacy": []
        }
        self.image_filenames = []
        for size in possible_sizes:
            for animacy in possible_animacies:
                sub_folder = os.path.join(
                    folder,
                    f"{size}-{animacy}"
                )
                assert os.path.exists(sub_folder), f"Invalid path: {sub_folder}"
                sub_folder_image_filenames = get_filenames_in_a_folder(sub_folder)
                self.image_filenames.extend(
                    sub_folder_image_filenames
                )
                self.labels["size"].extend(
                    [possible_sizes.index(size) for i in range(len(sub_folder_image_filenames))]
                )
                self.labels["animacy"].extend(
                    [possible_animacies.index(animacy) for i in range(len(sub_folder_image_filenames))]
                )

    def __getitem__(self, idx: int):
        
        filename = self.image_filenames[idx]
        assert os.path.exists(filename), f"Invalid filename: {filename}"
        return {
            "image": Image.open(filename),
            "size": self.labels["size"][idx],
            "animacy": self.labels["animacy"][idx]
        }

    def __len__(self):
        return len(self.labels["size"])