import io
from typing import Tuple
import h5py
from PIL import Image
import torch
from torch.utils import data
import torchvision.transforms as transforms


class HDF5Dataset(data.Dataset):
    """
        H5ImageFolder dataset class.
    """
    def __init__(self, path: str,
                 transform: transforms.Compose):
        """
        :param str path: Root directory path to h5 file of data set
        :param transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        """
        self._file_path = path
        self._data = None
        self._label = None
        self._transform = transform
        with h5py.File(self._file_path, 'r') as file:
            self._data_len = len(file['data'])
            self._label_len = len(file['label'])

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        """ Return self.data[index], self.label[index], which are an image and its label """
        if self._data is None:
            self._data = h5py.File(self._file_path, 'r')['data']
            self._label = h5py.File(self._file_path, 'r')['label']

        img = Image.open(io.BytesIO(self._data[index]))
        x = self._transform(img)
        y = self._label[index]
        return x, y

    def __len__(self) -> int:
        """ Return self.label_len, the length of classes """
        return self._label_len
