import torch
import torch.nn as nn
from torchvision import datasets, transforms

class Normalize(nn.Module):
    def __init__(self,):
        super(Normalize, self).__init__()
        self.ms = [(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)]
    def forward(self, input):
        x = input.clone()
        for i in range(x.shape[1]):
            x[:,i] = (x[:,i] - self.ms[0][i]) / self.ms[1][i]
        return x

class Normalize1(nn.Module):
    def __init__(self,):
        super(Normalize1, self).__init__()
        self.ms = [(0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)]
    def forward(self, input):
        x = input.clone()
        for i in range(x.shape[1]):
            x[:,i] = (x[:,i] - self.ms[0][i]) / self.ms[1][i]
        return x

def get_test_loaders(dir_, batch_size):
    test_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    test_dataset = datasets.CIFAR10(
        dir_, train=False, transform=test_transform, download=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=0,
    )
    return test_loader

def get_loaders(dir_, batch_size):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    num_workers = 0
    train_dataset = datasets.CIFAR10(
        dir_, train=True, transform=train_transform, download=True)
    test_dataset = datasets.CIFAR10(
        dir_, train=False, transform=test_transform, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=0,
    )
    return train_loader, test_loader