import torch
from torchvision import datasets
from torchvision import transforms
import mnistm
def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform_svhn = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    transform_mnist = transforms.Compose([
                    transforms.Scale(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])])
    transform_usps = transforms.Compose([
        transforms.Scale(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])])
           
    '''
    transform_mnist = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3,1,1)),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    '''
    
    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_svhn, split='train')
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_mnist, train=True)
    usps = datasets.USPS(root=config.mnist_path, download=True, transform=transform_usps, train=True)

    svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_svhn, split='test')
    mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_mnist, train=False)
    usps_test = datasets.USPS(root=config.mnist_path, download=True, transform=transform_usps, train=False)

    mnistm_loader = torch.utils.data.DataLoader(
        mnistm.MNISTM(
            "./mnistm",
            train=True,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.Resize(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]
            ),
        ),
        batch_size=config.batch_size,
        shuffle=True,
    )
    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=True,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)
    usps_loader = torch.utils.data.DataLoader(dataset=usps,
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.num_workers)


    svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
                                              batch_size=config.batch_size,
                                              shuffle=False,
                                              num_workers=config.num_workers)

    mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=config.num_workers)

    usps_test_loader = torch.utils.data.DataLoader(dataset=usps_test,
                                               batch_size=config.batch_size,
                                               shuffle=False,
                                               num_workers=config.num_workers)
    mnistm_test_loader = torch.utils.data.DataLoader(
        mnistm.MNISTM(
            "./mnistm",
            train=False,
            download=True,
            transform=transforms.Compose(
                [
                    transforms.Resize(config.image_size),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                ]
            ),
        ),
        batch_size=config.batch_size,
        shuffle=True,
    )

    return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader, usps_loader, usps_test_loader,mnistm_loader,mnistm_test_loader
