"""Code to load the NSD dataset stimuli"""
import h5py
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class NSDStimuli(Dataset):
    """
    Dataset for loading the stimuli in the NSD dataset
    """
    def __init__(self, hdf5_filename, transform=transforms.ToTensor()):
        self.transform = transform
        stim_hdf5 = h5py.File(hdf5_filename, 'r')
        # indexing is according to nsdID
        self.all_stimuli = stim_hdf5['imgBrick'] # (73000, 425, 425, 3)
        assert self.all_stimuli.shape[0] == 73000

    def __len__(self):
        return len(self.all_stimuli)

    def __getitem__(self, idx):
        image_array = self.all_stimuli[idx, :, :, :] # Numpy array (H, W, C)
        image = Image.fromarray(image_array) # PIL Image
        image = self.transform(image)
        ## the 0 is added as a fake label for interface purposes, please ignore it
        return image, 0