import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.utils.data as data
import random


class FolderSubset(data.Dataset):
    def __init__(self, dataset, classes, indices):
        self.dataset = dataset
        self.classes = classes
        self.indices = indices

        self.update_classes()

    def update_classes(self):
        for i in self.indices:
            img_path, cls = self.dataset.samples[i]
            cls = self.classes.index(cls)
            self.dataset.samples[i] = (img_path, cls)

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

class FolderSubsetPercent(data.Dataset):
    def __init__(self, dataset, classes, indices, percent):
        self.dataset = dataset
        self.classes = classes
        self.indices = indices
        self.percent = percent
        self.update_classes_percent()
    def update_classes_percent(self):
        targets = self.dataset.targets
        inds = []
        for i in self.classes:
            ids = [index for index, element in enumerate(targets) if element == i]
            k = len(ids) * self.percent // 100
            inds.extend(random.sample(ids, k))
        self.indices = inds
        for i in self.indices:
            img_path, cls = self.dataset.samples[i]
            cls = self.classes.index(cls)
            self.dataset.samples[i] = (img_path, cls)
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]
    def __len__(self):
        return len(self.indices)

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

def check_split(opt):
    splits = []
    for split in ['train', 'val', 'test']:
        splits.append(torch.load('data/split/' + opt.datasplit + '-' + split))

    return splits


def check_dataset(opt):
    normalize_transform = transforms.Compose([transforms.ToTensor(),
                                              transforms.Normalize((0.485, 0.456, 0.406),
                                                                   (0.229, 0.224, 0.225))])
    train_large_transform = transforms.Compose([transforms.RandomResizedCrop(224),
                                                transforms.RandomHorizontalFlip()])
    val_large_transform = transforms.Compose([transforms.Resize(256),
                                              transforms.CenterCrop(224)])
    train_small_transform = transforms.Compose([transforms.Pad(4),
                                                transforms.RandomCrop(32),
                                                transforms.RandomHorizontalFlip()])
    noise_transform = transforms.Compose([AddGaussianNoise(0.,1.)])
    splits = check_split(opt)

    if opt.dataset in ['cub200', 'indoor', 'mit67','stanford40', 'dog', 'catvdog', 'pet', 'celeba', 'celeba-src']:
        train, val = 'train', 'test'
        if opt.datanoise:
            train_transform = transforms.Compose([train_large_transform, normalize_transform, noise_transform])
            val_transform = transforms.Compose([val_large_transform, normalize_transform, noise_transform])
        else:
            train_transform = transforms.Compose([train_large_transform, normalize_transform])
            val_transform = transforms.Compose([val_large_transform, normalize_transform])
            sets = [dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=train_transform),
                dset.ImageFolder(root=os.path.join(opt.dataroot, train), transform=val_transform),
                dset.ImageFolder(root=os.path.join(opt.dataroot, val), transform=val_transform)]
        if opt.numTrain == 100:
            sets = [FolderSubset(dataset, *split) for dataset, split in zip(sets, splits)]
        else:
            sets = [FolderSubsetPercent(sets[0], *splits[0], opt.numTrain), FolderSubset(sets[1], *splits[1]), FolderSubset(sets[2], *splits[2])]

        opt.num_classes = len(splits[0][0])

    else:
	    raise Exception('Unknown dataset')

    loaders = [torch.utils.data.DataLoader(dataset,
                                           batch_size=opt.batchSize,
                                           shuffle=True,
                                           num_workers=0) for dataset in sets]
    return loaders
