from torch.utils.data import DataLoader

from datasets import FunctionDataSet, CELEBA

__all__ = ['get_dataset']

def get_dataset(args):
    if args.dataset == 'function':
        trainset = FunctionDataSet(total_iter=args.total_iter_train, num_total_points=args.num_total_points)
        validset = FunctionDataSet(total_iter=args.total_iter_valid, num_total_points=args.num_total_points, visualize=args.visualize)

        trainloader = DataLoader(trainset, batch_size=args.batch_size, num_workers=8)
        validloader = DataLoader(validset, batch_size=1, num_workers=2)
        return trainloader, validloader
    elif args.dataset == 'celeba':
        trainset = CELEBA(root=args.root, size=args.size, train=True)
        validset = CELEBA(root=args.root, size=args.size, train=False, visualize=args.visualize)

        trainloader = DataLoader(trainset, batch_size=args.batch_size, num_workers=8, shuffle=True, drop_last=True)
        validloader = DataLoader(validset, batch_size=args.batch_size, num_workers=4, shuffle=False)
        return trainloader, validloader
    else:
        raise NotImplementedError('{} not implemented in datasets/'.format(args.dataset))
