import glob
import numpy as np
import os
import torch

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

# data_path = '/scratch/voletivi/Datasets/ut-zap50k-images-square'


class Zap50kDataset(Dataset):

    def __init__(self, data_path, split, split_fraction=0.1, transform=transforms.Compose([transforms.ToTensor()])):

        self.data_path = data_path  # "ut-zap50k-images-square" directory
        self.split = split
        self.split_fraction = split_fraction
        self.transform = transform
        self.classes = ['Boots', 'Sandals', 'Shoes', 'Slippers']

        assert os.path.exists(self.data_path), f"ut-zap50k file path does not exist! Given: {self.data_path}"
        assert self.split in ["train", "val"], f"split must be 'train' or 'val'!"

        self.files = sorted(glob.glob(os.path.join(data_path, '*', '*', '*', '*')))

        self.N_VAL = int(split_fraction * len(self.files))
        self.N_TRAIN = len(self.files) - self.N_VAL

        np.random.seed(0)
        self.val_idx = np.array(sorted(np.random.choice(len(self.files), self.N_VAL, replace=False)))
        self.train_idx = np.setdiff1d(np.arange(len(self.files)), self.val_idx)

    def __len__(self):
        if self.split == 'train':
            return self.N_TRAIN
        else:
            return self.N_VAL

    def __getitem__(self, idx):
        # Correct idx
        idx = self.train_idx[idx] if self.split == 'train' else self.val_idx[idx]
        # Extract info
        file = self.files[idx]
        image = self.transform(Image.open(file).convert('RGB'))
        target = torch.tensor(self.classes.index(file.split('/')[-4]))
        return image, target
