import h5py
import io
import json
import numpy as np
import os
import torch

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms

# classes : https://gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57

# h5_path = '/scratch/voletivi/Datasets/ImageNet/ilsvrc2012.hdf5'


def ImageNetDataLoaders(h5_path, batch_size, classes=[], shuffle=True,
                        workers=10, distributed=False):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
                                        transforms.RandomResizedCrop(224),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        normalize,
                                        ])

    val_transform = transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                        ])

    train_ds = ImageNetDataset(h5_path, "train", train_transform, classes)
    val_ds = ImageNetDataset(h5_path, "val", val_transform, classes)

    if distributed:
        sampler = DistributedSampler(train_ds)
        shuffle = False
    else:
        sampler = None

    train_dl = DataLoader(
        train_ds, batch_size=batch_size, shuffle=shuffle,
        num_workers=workers, pin_memory=True, sampler=sampler, drop_last=True)

    val_dl = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True, drop_last=True)

    return train_dl, val_dl


class ImageNetDataset(Dataset):
    def __init__(self, h5_path, split, transform=transforms.Compose([transforms.ToTensor()]), classes=[]):
        self.h5_path = h5_path     # Path to ilsvrc2012.hdf5
        self.split = split
        self.transform = transform
        self.classes = classes

        assert os.path.exists(self.h5_path), f"ImageNet h5 file path does not exist! Given: {self.h5_path}"
        assert self.split in ["train", "val", "test"], f"split must be 'train' or 'val' or 'test'! Given: {self.split}"

        self.N_TRAIN = 1281167
        self.N_VAL = 50000
        self.N_TEST = 100000

        if self.split in ['train', 'val']:
            if len(self.classes) > 0:
                class_idxs_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "imagenet.json")))[self.split]
                self.class_idxs = sorted([i for c in self.classes for i in class_idxs_dict[str(c)]])
                del class_idxs_dict
                self.n = len(self.class_idxs)
            else:
                if self.split == 'train':
                    self.n  = self.N_TRAIN
                elif self.split == 'val':
                    self.n = self.N_VAL
        else:
            self.n = self.N_TEST

        self.h5_data = None

    def __len__(self):
        return self.n

    def __getitem__(self, idx):

        # Get class idx
        if len(self.classes) > 0 and self.split in ['train', 'val']:
            idx = self.class_idxs[idx]

        # Correct idx
        if self.split == 'val':
            idx += self.N_TRAIN
        elif self.split == 'test':
            idx += self.N_TRAIN + self.N_VAL

        # Read h5 file
        if self.h5_data is None:
            self.h5_data = h5py.File(self.h5_path, mode='r')

        # Extract info
        image = self.transform(Image.open(io.BytesIO(self.h5_data['encoded_images'][idx])).convert('RGB'))
        target = torch.from_numpy(self.h5_data['targets'][idx])[0].long() if self.split != 'test' else None

        return image, target


# json.dump( data, open( "file_name.json", 'w' ) )


# IMGS_PER_CLASS_TRAIN = [
#     1300, 1300, 1300, 1300, 1300, 1150, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1025, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1156, 1300, 1300, 1300,  738, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  732, 1300, 1300,
#     1300,  755, 1300, 1300, 1300, 1300, 1300, 1300, 1206, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1070, 1300,  936, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     860, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  772,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1273, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300,  969, 1300, 1258, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  954, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1218, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  977, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1136, 1290, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  754, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1272, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1238, 1300, 1300, 1300, 1300, 1300, 1206, 1300, 1300,
#     1300, 1300, 1118, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1159, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  976,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1282, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  891, 1067,  986, 1300,  908, 1300,
#     1300, 1300, 1300, 1254, 1300, 1300, 1300, 1194, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1141, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1165,
#     969, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1117, 1300, 1300, 1300, 1300, 1300,
#     1266, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1071, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1267, 1300, 1300, 1120, 1300, 1004, 1300, 1283,
#     1199, 1300, 1300, 1300, 1292, 1300, 1299, 1300, 1300, 1300, 1084,
#     889, 1300, 1300, 1155, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1162, 1300,
#     1034, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1117, 1300, 1300, 1300, 1300, 1300, 1253, 1300,
#     1300, 1300, 1157, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1216,
#     1300, 1300, 1259, 1300, 1133, 1300, 1300, 1300, 1300, 1300, 1300,
#     1180, 1300, 1160, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1136,
#     1300, 1300, 1300, 1300, 1137, 1300, 1300, 1300, 1300, 1187, 1222,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1152, 1153, 1300, 1300,
#     1300, 1300, 1300, 1155, 1300, 1300, 1300, 1300, 1300, 1300, 1270,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1250, 1300, 1300, 1300, 1300, 1211,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1053, 1156, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1239, 1300, 1300, 1300, 1300, 1125, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1029, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1154, 1300, 1300, 1149, 1300, 1300, 1300, 1300, 1300, 1300,
#     1149, 1055, 1300, 1154, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300,  962, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1029, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1153, 1300, 1217, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1193, 1300, 1053,
#     1300, 1300, 1300, 1300, 1300, 1300, 1249, 1176, 1300, 1300, 1300,
#     931, 1300, 1300, 1300, 1282, 1300, 1300, 1207, 1300, 1247, 1300,
#     1300, 1209, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1045,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1097, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1186, 1300,
#     1300, 1300, 1272,  980, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1264, 1300, 1300, 1300, 1236, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1213,
#     1300, 1005, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1285, 1300, 1300, 1069, 1300, 1300, 1300, 1300,
#     1062, 1300, 1300, 1300, 1137, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300]

# IMGS_PER_CLASS_VAL = [50] * 1000
