import torch
import torchvision.transforms as transforms
import random
import os
from PIL import Image


class train_dataset_from_folder(torch.utils.data.Dataset):
    def __init__(self, root_dir, num=None):
        self.root_dir = root_dir
        if num is not None:
            self.num = num
            temp_names = sorted(os.listdir(root_dir))
            self.names = random.sample(temp_names, self.num)
        else:
            self.names = sorted(os.listdir(root_dir))
        self.transforms = transforms.ToTensor()

    def reset(self):
        temp_names = sorted(os.listdir(self.root_dir))
        self.names = random.sample(temp_names, self.num)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        return self.transforms(Image.open(f'{self.root_dir}/{self.names[idx]}'))


class test_dataset_from_folder(torch.utils.data.Dataset):
    def __init__(self, root_dir, trainset):
        self.root_dir = root_dir
        temp_names = sorted(os.listdir(root_dir))
        self.names = sorted(list(set(temp_names) - set(trainset.names)))
        self.transforms = transforms.ToTensor()

    def reset(self, trainset):
        temp_names = sorted(os.listdir(self.root_dir))
        self.names = sorted(list(set(temp_names) - set(trainset.names)))

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        return self.transforms(Image.open(f'{self.root_dir}/{self.names[idx]}'))


def dataset_picker(root_dir, type, **kwargs):
    if type == 'standard':
        trainset = train_dataset_from_folder(root_dir, num=kwargs['num'])
        valset = test_dataset_from_folder(root_dir, trainset)
        return trainset, valset
    else:
        raise NotImplementedError
