# -*- coding: utf-8 -*-
"""TinyImageNetLoader.ipynb

Automatically generated by Colaboratory.


"""

#loads images as 3*64*64 tensors 

#!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
#!unzip -q tiny-imagenet-200.zip

import torch
import torchvision
import numpy as np
import os, glob

import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import Dataset

from torchvision.io import read_image, ImageReadMode
from torchvision import transforms
from ..dataset_utils import *
from PIL import Image

class TinyImageNet200(Dataset):
    def __init__(self,data_conf):
        self.data_conf = data_conf
        self.d = self.data_conf["dimension"]
        self.num_classes = 200 
        self.f = None
        
    def __getitem__(self, index):
        
        x = self.X[index]   
        y = self.Y[index]
        
        return x, y,index 
    
    def __len__(self):

        return len(self.X)
    
    def len(self):
        return len(self.X)
    
    def get_subset(self,idcs):
        idcs = np.array(idcs) 
        X = None
        Y = None 
        if(self.X is not None):
            X = self.X[idcs] 
        if(self.Y is not None):
            Y = self.Y[idcs] 
            return CustomTensorDataset(X=X,Y=Y,num_classes = self.num_classes, d=self.d,transform=None)

    def build_dataset(self):

        # expecting path to tiny-imagenet-200 directory that contains train test and val folders
        data_dir  = self.data_conf['data_path']
        
        self.transform = transforms.Compose( [
                                              transforms.Normalize( (122.4786, 114.2755, 101.3963), 
                                                                    (70.4924, 68.5679, 71.8127)) ])

        data_transforms = {
            'train': transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]),
            'val': transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
            ]),
        }

        id_dict = {}
        for i, line in enumerate(open(f'{data_dir}/wnids.txt', 'r')):
            id_dict[line.rstrip()] = i

        # load training data
        train_filenames = glob.glob(f"{data_dir}/train/*/*/*.JPEG")
        lst_train_images = []
        lst_train_labels = []
        for img_path in train_filenames:
            image = read_image(img_path)
            if image.shape[0] == 1:
                image = read_image(img_path,ImageReadMode.RGB)
                #image = self.transform(image.type(torch.FloatTensor))
                #image = image.type(torch.FloatTensor)
                image = data_transforms['train'](image)

            label = id_dict[img_path.split('/')[-1].split('_')[0]]
            lst_train_images.append(image)
            lst_train_labels.append(label)
        
        
        # load validation data

        val_fname_cls_id = {}
        for i, line in enumerate(open(f'{data_dir}/val/val_annotations.txt', 'r')):
            a = line.split('\t')
            fname, cls_id = a[0],a[1]
            val_fname_cls_id[a[0]] = a[1]

        val_filenames = glob.glob(f"{data_dir}/val/images/*.JPEG")
        lst_val_images = [] 
        lst_val_labels = []
        for img_path in val_filenames:
            image = read_image(img_path)
            if image.shape[0] == 1:
                image = read_image(img_path,ImageReadMode.RGB)
                #image = self.transform(image.type(torch.FloatTensor))
                image = image.type(torch.FloatTensor)

            label = val_fname_cls_id[img_path.split('/')[-1]]
            lst_val_images.append(image)
            lst_val_labels.append(id_dict[label])
        
        all_images = lst_train_images + lst_val_images
        all_labels = lst_train_labels + lst_val_labels
        self.X  = torch.stack(all_images)
        self.Y = torch.Tensor(all_labels).long()

        self.X_val = torch.stack(lst_val_images)
        self.Y_val = torch.Tensor(lst_val_labels).long()

        #***** The labels for test set are not known
        # this is temporary fix, don't use test numbers
        self.X_test = self.X_val 
        self.Y_test = self.Y_val 

        self.X_train = torch.stack(lst_train_images)
        self.Y_train = torch.Tensor(lst_train_labels)

        self.idcs_std_train = np.arange(0,len(lst_train_images),1).astype(int)
        self.idcs_std_val   = np.arange(len(lst_train_images),len(all_images),1).astype(int)
        
        self.ds_std_train = CustomTensorDataset(X=self.X_train,Y=self.Y_train, num_classes = self.num_classes,d=self.d,transform=None)
        self.ds_std_val   = CustomTensorDataset(X=self.X_val,Y=self.Y_val, num_classes = self.num_classes,d=self.d,transform=None)

    def get_test_datasets(self):
        X_ = self.X_test
        Y_ = self.Y_test
        return CustomTensorDataset(X=X_,Y=Y_, num_classes = self.num_classes,d=self.d,transform=None)
