from nesim.utils.folder import get_filenames_in_a_folder
import os
from PIL import Image
from nesim.eval.resnet import imagenet_transforms

def load_image_as_square(image_path):
    # Open the image
    img = Image.open(image_path).convert("RGB")
    
    # Find the dimensions of the image
    width, height = img.size
    
    # Determine the size of the new square
    max_dim = max(width, height)
    
    # Create a new image with a white background and the dimensions of the square
    new_img = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
    
    # Paste the original image onto the center of the new white background
    new_img.paste(img, ((max_dim - width) // 2, (max_dim - height) // 2))
    # new_img.save("test.jpg")
    # raise AssertionError("oops")
    return new_img


class SizeAnimacyDataset:
    def __init__(self, folder: str, mode = "small-big"):
        assert mode in [None, "small-big", 'animate-inanimate']
        self.mode = mode
        possible_sizes = ["Small", "Big"]
        possible_animacies = ["Animate", "Inanimate"]

        if mode == "small-big":
            self.label_names = possible_sizes
        elif mode == "animate-inanimate":
            self.label_names = possible_animacies
        else:
            self.label_names = None

        self.labels = {
            "size": [],
            "animacy": []
        }
        self.image_filenames = []
        for size in possible_sizes:
            for animacy in possible_animacies:
                sub_folder = os.path.join(
                    folder,
                    f"{size}-{animacy}"
                )
                assert os.path.exists(sub_folder), f"Invalid path: {sub_folder}"
                sub_folder_image_filenames = get_filenames_in_a_folder(sub_folder)
                self.image_filenames.extend(
                    sub_folder_image_filenames
                )
                self.labels["size"].extend(
                    [possible_sizes.index(size) for i in range(len(sub_folder_image_filenames))]
                )
                self.labels["animacy"].extend(
                    [possible_animacies.index(animacy) for i in range(len(sub_folder_image_filenames))]
                )

    def __getitem__(self, idx: int):
        
        filename = self.image_filenames[idx]
        image = imagenet_transforms(load_image_as_square(image_path=filename))
        assert os.path.exists(filename), f"Invalid filename: {filename}"

        if self.mode == None:
            return {
                "image": image,
                "size": self.labels["size"][idx],
                "animacy": self.labels["animacy"][idx]
            }
        elif self.mode == "small-big":
            return (
                image,
                self.labels['size'][idx]
            )
        elif self.mode == "animate-inanimate":
            return (
                image,
                self.labels['animacy'][idx]
            )

    def __len__(self):
        return len(self.labels["size"])
    

import os

class BigSmallDataset:
    def __init__(
        self,
        folder: str
    ):
        assert os.path.exists(folder)
        
        label_map = {
            "Small": 0,
            "Big": 1
        }
        self.filenames = []
        self.labels = []
        self.label_names = ["Small", "Big"]

        for label in label_map:
            sub_folder = os.path.join(
                folder,
                label
            )
            assert os.path.exists(sub_folder)
            filenames = get_filenames_in_a_folder(sub_folder)
            self.filenames.extend(
                filenames
            )
            self.labels.extend(
                [label_map[label] for i in range(len(filenames))]
            )
    def __getitem__(self, idx: int):

        image = imagenet_transforms(load_image_as_square(image_path=self.filenames[idx]))
        label = self.labels[idx]
        return (
            image, label
        )
    
    def __len__(self):
        return len(self.filenames)