import h5py
import io
import numpy as np
import torch
from imagenet_pretrain.batch_transforms import RandomHorizontalFlip, ToTensor, Normalize
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image


class ImageNet(Dataset):
    def __init__(self, dataset_file: str, sample_set: str, classes: list):
        self.dataset = dataset_file
        self.sample_set = sample_set
        self.classes = classes
        self.class_map = {c: i for i, c in enumerate(classes)}

        if sample_set == "train":
            self.resizedcrop = transforms.RandomResizedCrop(224)
        else:
            self.resizedcrop = transforms.Compose(
                [transforms.Resize(256), transforms.CenterCrop(224)]
            )

        self.sample_ids = []
        with h5py.File(self.dataset, "r") as dataset:
            for class_ in self.classes:
                for i in range(int(dataset[f"{sample_set}/{class_}"].shape[0])):
                    self.sample_ids.append((class_, i))

    def __getitem__(self, i):
        class_, id_ = self.sample_ids[i]
        with h5py.File(self.dataset, "r") as dataset:
            image = dataset[f"{self.sample_set}/{class_}"][id_]

        image = Image.open(io.BytesIO(image))
        if image.mode != "RGB":
            image = image.convert("RGB")
        image = self.resizedcrop(image)
        image = np.array(image)

        return image, self.class_map[class_]

    def __len__(self):
        return len(self.sample_ids)


class ImageNetDataLoader:
    """
    Wrapper for torch DataLoader.
    """

    def __init__(
        self,
        dataset_file: str,
        batch_size: int,
        shuffle: bool,
        sample_set: str,
        classes: list,
        num_workers: int = 1,
        device=torch.device("cpu"),
    ):
        assert sample_set in ["train", "val", "test"]

        dataset = ImageNet(dataset_file, sample_set, classes)
        self.dataloader = DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
        )
        self.device = device

        if sample_set == "train":
            self.transform_batch = transforms.Compose(
                [
                    RandomHorizontalFlip(inplace=True),
                    ToTensor(),
                    Normalize(
                        (0.485, 0.456, 0.406),
                        (0.229, 0.224, 0.225),
                        inplace=True,
                        device=device,
                    ),
                ]
            )
        else:
            self.transform_batch = transforms.Compose(
                [
                    ToTensor(),
                    Normalize(
                        (0.485, 0.456, 0.406),
                        (0.229, 0.224, 0.225),
                        inplace=True,
                        device=device,
                    ),
                ]
            )

    def __iter__(self):
        for sample, target in self.dataloader:
            sample, target = sample.to(self.device), target.to(self.device)
            sample = sample.permute(0, 3, 1, 2).contiguous()
            sample = self.transform_batch(sample)
            yield sample, target
