import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Subset

import numpy as np

class MNIST(object):
    def __init__(self, batch_size, use_gpu, num_workers, is_shuffle, n_samples=None):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        pin_memory = True if use_gpu else False

        trainset = torchvision.datasets.MNIST(root='/home/zhangxj/workspace/datasets/mnist', train=True, download=False, transform=transform)
        
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=is_shuffle,
            num_workers=num_workers, pin_memory=pin_memory,
        )
        
        testset = torchvision.datasets.MNIST(root='/home/zhangxj/workspace/datasets/mnist', train=False, download=False, transform=transform)
        
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory,
        )

        self.trainloader = trainloader
        self.testloader = testloader
        self.num_classes = 10

class CIFAR10(object):
    def __init__(self, batch_size, use_gpu, num_workers, is_shuffle, n_samples=None):
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        pin_memory = True if use_gpu else False
        root_dir = "/home/zhangxj/workspace/datasets/cifar10"

        full_trainset = torchvision.datasets.CIFAR10(root=root_dir, train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=False)
        
        full_testset = torchvision.datasets.CIFAR10(root=root_dir, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]), download=False)

        if n_samples is not None:
            train_indices = np.random.choice(len(full_trainset), n_samples, replace=False)
            trainset = Subset(full_trainset, train_indices)
        else:
            trainset = full_trainset
        
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=is_shuffle,
            num_workers=num_workers, pin_memory=pin_memory)
        
        testloader = torch.utils.data.DataLoader(
            full_testset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=pin_memory)
        
        self.trainloader = trainloader
        self.testloader = testloader
        self.num_classes = 10


__factory = {
    'mnist': MNIST,
    'cifar10': CIFAR10,
}

def create(name, batch_size, use_gpu, num_workers, is_shuffle, n_samples=None):
    if name not in __factory.keys():
        raise KeyError("Unknown dataset: {}".format(name))
    return __factory[name](batch_size, use_gpu, num_workers, is_shuffle, n_samples=n_samples)
