import os
from nesim.utils.folder import get_filenames_in_a_folder
import numpy as np
from PIL import Image
from nesim.eval.resnet import imagenet_transforms

class FlocDataset:
    def __init__(self, folder: str):
        assert os.path.exists(folder)
        self.filenames = get_filenames_in_a_folder(folder=folder)
        self.filenames.sort()
        self.labels  = np.unique( [(os.path.basename(x)).split("-")[0] for x in self.filenames]).tolist()
        
    def filename_to_label_index(self, filename):
        return self.labels.index(
            (os.path.basename(filename)).split("-")[0]
        )
    
    def __getitem__(self, idx: int):
        return (
            imagenet_transforms(Image.open(self.filenames[idx]).convert("RGB")),
            self.filename_to_label_index(filename = self.filenames[idx])
        )
        # return {
        #     "Image": imagenet_transforms(Image.open(self.filenames[idx]).convert("RGB")),
        #     "label": self.filename_to_label_index(filename = self.filenames[idx])
        # }
    
    def __len__(self):
        return len(self.filenames)
