from cgi import test
import os
import torch
import torchvision
from torchvision import datasets, models
from torchvision.transforms import transforms

def DataLoader(args): 
    if args.task == 'mnist':
        trainset = datasets.MNIST(args.dataset, download=True, train=True,transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,  shuffle=True)
        
        testset = datasets.MNIST(args.dataset, download=True, train=False,transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
        testloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,  shuffle=True)
        return trainloader, testloader

    if args.task == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(args.img_width, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = torchvision.datasets.CIFAR10(root=args.dataset, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=32)

        testset = torchvision.datasets.CIFAR10(root=args.dataset, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=32)
        return trainloader, testloader

    if args.task == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(args.img_width, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = torchvision.datasets.CIFAR100(root=args.dataset, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=32)

        testset = torchvision.datasets.CIFAR100(root=args.dataset, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=32)
        return trainloader, testloader

    if args.task == 'tiny':
        transform_train = transforms.Compose([
            transforms.RandomCrop(args.img_width, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        trainset = datasets.ImageFolder(os.path.join(args.dataset, "train"), transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=32)

        testset = datasets.ImageFolder(os.path.join(args.dataset, "val"), transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=True, num_workers=32)
        return trainloader, testloader
    
    raise "Unknown task"