import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
import torch.nn.functional as F


def load_data(n_batch_train, n_batch_test, dataset, path_data):
    
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
#    transform_train = transforms.Compose([
#        transforms.ToTensor(),
#        #transforms.Resize
#        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    
    if dataset == "cifar10":
    
        trainset = torchvision.datasets.CIFAR10(
                root=path_data, train=True, download=True, transform=transform_train)        
        testset = torchvision.datasets.CIFAR10(
            root=path_data, train=False, download=True, transform=transform_test)


    else:
        
        trainset = torchvision.datasets.CIFAR100(
                root=path_data, train=True, download=True, transform=transform_train)
        testset = torchvision.datasets.CIFAR100(
            root=path_data, train=False, download=True, transform=transform_test)
        
    
    
    trainloader = torch.utils.data.DataLoader(
                trainset, batch_size=n_batch_train, shuffle=True, num_workers=1)
    val_loader = torch.utils.data.DataLoader(
            testset, batch_size=n_batch_test, shuffle=False, num_workers=1)   
    

    return trainloader, val_loader
