import h5py
import io
import json
import numpy as np
import os
import torch

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision import transforms

# classes : https://gist.github.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57


def ImageNetDataLoaders(h5_path, batch_size, classes=[], imsize=64,
                        shuffle=True, workers=10, distributed=False):

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_transform = transforms.Compose([
                                        transforms.RandomResizedCrop(224 if imsize < 256 else 256),
                                        transforms.Resize(imsize),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        normalize,
                                        ])

    val_transform = transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224 if im_size <= 224 else 256),
                                        transforms.Resize(imsize),
                                        transforms.ToTensor(),
                                        normalize,
                                        ])

    train_ds = ImageNetDataset(h5_path, "train", train_transform, classes)
    val_ds = ImageNetDataset(h5_path, "val", val_transform, classes)

    if distributed:
        sampler = DistributedSampler(train_ds)
        shuffle = False
    else:
        sampler = None

    train_dl = DataLoader(
        train_ds, batch_size=batch_size, shuffle=shuffle,
        num_workers=workers, pin_memory=True, sampler=sampler, drop_last=True)

    val_dl = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False,
        num_workers=workers, pin_memory=True, drop_last=True)

    return train_dl, val_dl


class ImageNetDataset(Dataset):
    def __init__(self, h5_path, split, transform=transforms.Compose([transforms.ToTensor()]), classes=[]):
        self.h5_path = h5_path     # Path to ilsvrc2012.hdf5
        self.split = split
        self.transform = transform
        self.classes = classes

        assert os.path.exists(self.h5_path), f"ImageNet h5 file path does not exist! Given: {self.h5_path}"
        assert self.split in ["train", "val", "test"], f"split must be 'train' or 'val' or 'test'! Given: {self.split}"

        self.N_TRAIN = 1281167
        self.N_VAL = 50000
        self.N_TEST = 100000

        if self.split in ['train', 'val']:
            if len(self.classes) > 0:
                class_idxs_dict = json.load(open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "imagenet.json")))[self.split]
                self.class_idxs = sorted([i for c in self.classes for i in class_idxs_dict[str(c)]])
                del class_idxs_dict
                self.n = len(self.class_idxs)
            else:
                if self.split == 'train':
                    self.n  = self.N_TRAIN
                elif self.split == 'val':
                    self.n = self.N_VAL
        else:
            self.n = self.N_TEST

        self.h5_data = None

        print(f"Dataset length: {self.__len__()}")

    def __len__(self):
        return self.n

    def __getitem__(self, idx):

        # Get class idx
        if len(self.classes) > 0 and self.split in ['train', 'val']:
            idx = self.class_idxs[idx]

        # Correct idx
        if self.split == 'val':
            idx += self.N_TRAIN
        elif self.split == 'test':
            idx += self.N_TRAIN + self.N_VAL

        # Read h5 file
        if self.h5_data is None:
            self.h5_data = h5py.File(self.h5_path, mode='r')

        # Extract info
        image = self.transform(Image.open(io.BytesIO(self.h5_data['encoded_images'][idx])).convert('RGB'))
        target = torch.from_numpy(self.h5_data['targets'][idx])[0].long() if self.split != 'test' else None

        return image, target


# json.dump( data, open( "file_name.json", 'w' ) )


# IMGS_PER_CLASS_TRAIN = [
#     1300, 1300, 1300, 1300, 1300, 1150, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1025, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1156, 1300, 1300, 1300,  738, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  732, 1300, 1300,
#     1300,  755, 1300, 1300, 1300, 1300, 1300, 1300, 1206, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1070, 1300,  936, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     860, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  772,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1273, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300,  969, 1300, 1258, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  954, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1218, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  977, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1136, 1290, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  754, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1272, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1238, 1300, 1300, 1300, 1300, 1300, 1206, 1300, 1300,
#     1300, 1300, 1118, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1159, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,  976,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1282, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300,  891, 1067,  986, 1300,  908, 1300,
#     1300, 1300, 1300, 1254, 1300, 1300, 1300, 1194, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1141, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1165,
#     969, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1117, 1300, 1300, 1300, 1300, 1300,
#     1266, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1071, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1267, 1300, 1300, 1120, 1300, 1004, 1300, 1283,
#     1199, 1300, 1300, 1300, 1292, 1300, 1299, 1300, 1300, 1300, 1084,
#     889, 1300, 1300, 1155, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1162, 1300,
#     1034, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1117, 1300, 1300, 1300, 1300, 1300, 1253, 1300,
#     1300, 1300, 1157, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1216,
#     1300, 1300, 1259, 1300, 1133, 1300, 1300, 1300, 1300, 1300, 1300,
#     1180, 1300, 1160, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1136,
#     1300, 1300, 1300, 1300, 1137, 1300, 1300, 1300, 1300, 1187, 1222,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1152, 1153, 1300, 1300,
#     1300, 1300, 1300, 1155, 1300, 1300, 1300, 1300, 1300, 1300, 1270,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1250, 1300, 1300, 1300, 1300, 1211,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1053, 1156, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1239, 1300, 1300, 1300, 1300, 1125, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1029, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1154, 1300, 1300, 1149, 1300, 1300, 1300, 1300, 1300, 1300,
#     1149, 1055, 1300, 1154, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300,  962, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1029, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1153, 1300, 1217, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1193, 1300, 1053,
#     1300, 1300, 1300, 1300, 1300, 1300, 1249, 1176, 1300, 1300, 1300,
#     931, 1300, 1300, 1300, 1282, 1300, 1300, 1207, 1300, 1247, 1300,
#     1300, 1209, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1045,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1097, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1186, 1300,
#     1300, 1300, 1272,  980, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1264, 1300, 1300, 1300, 1236, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1213,
#     1300, 1005, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1285, 1300, 1300, 1069, 1300, 1300, 1300, 1300,
#     1062, 1300, 1300, 1300, 1137, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300,
#     1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300, 1300]

# IMGS_PER_CLASS_VAL = [50] * 1000
