"""
A wrapper for imagenet validation set, this is a simple loader
with the appropriate transforms, it does not support shuffling, nor batching
"""
import random
import os
import json
from PIL import Image

import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

IMAGENET_SL = 224


# put your imagenet validation path here
# where the directory is
# label_1/
#     ILS....JPEG
#     ILS....JPEG
# label_2/
#     ILS....JPEG
#     ILS....JPEG
#
# The labels are indexed based on 1001 classes, where 0 is the "I dont know" label


    


class ImagenetValidData():
    def prev__init__(self, data_dir="/home/DATA/ITWM/lorenzp/ImageNet/val"):
        imgnet_transform = transforms.Compose([
            transforms.Resize(IMAGENET_SL),
            transforms.CenterCrop(IMAGENET_SL),
            transforms.Lambda(lambda _: np.array(_) / 255.)])
        self.dset = ImageFolder(root=data_dir, transform=imgnet_transform)
        # shuffle the dataset
        self.idxs = list(range(len(self.dset)))
        random.Random(1).shuffle(self.idxs)
    #
    #Adapted from ImageNetKaggle
    def __init__(self, data_dir="/work/LAS/jtian-lab/data/IMN/"):  #Change this. /work/mech-ai-scratch/data/IMN
        imgnet_transform = transforms.Compose([
            transforms.Resize(IMAGENET_SL),
            transforms.CenterCrop(IMAGENET_SL),
            transforms.Lambda(lambda _: np.array(_) / 255.)])
        data_dir="/work/LAS/jtian-lab/data/IMN/"  #Hack so that data dir is always fixed.
        split = "val"
        self.samples = []
        self.targets = []
        self.transform = imgnet_transform
        self.syn_to_class = {}
        with open(os.path.join(data_dir, "imagenet_class_index.json"), "rb") as f:
                    json_file = json.load(f)
                    for class_id, v in json_file.items():
                        self.syn_to_class[v[0]] = int(class_id)
        with open(os.path.join(data_dir, "ILSVRC2012_val_labels.json"), "rb") as f:
                    self.val_to_syn = json.load(f)
        samples_dir = os.path.join(data_dir, "ILSVRC/Data/CLS-LOC", split)
        for entry in sorted(os.listdir(samples_dir)):
            syn_id = self.val_to_syn[entry]
            target = self.syn_to_class[syn_id]
            sample_path = os.path.join(samples_dir, entry)
            self.samples.append(sample_path)
            self.targets.append(target)
        
    def __len__(self):
            return len(self.samples)
    def __getitem__(self, idx):
            x = Image.open(self.samples[idx]).convert("RGB")
            if self.transform:
                x = self.transform(x)
            return x, self.targets[idx]    
        
        #self.dset = ImageFolder(root=data_dir, transform=imgnet_transform)
        # shuffle the dataset
        #self.idxs = list(range(len(self.dset)))
        #random.Random(1).shuffle(self.idxs)

    
    def get_eval_data(self, bstart, bend):
        images, labels = zip(*[self.__getitem__(idx) for idx in range(bstart, bend)])
        return np.array(images), np.array(labels) # to consider "IDK" label
