import torch
import numpy as np
import scipy
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, FGVCAircraft, DTD, OxfordIIITPet


def get_data(name, data_dir, image_size, batch_size, do_transform = True):
    """
    args:
    @ name: name of dataset to be used
    @ data_dir: where dataset are stored or to be stored
    @ image_size: which size to resample the image
    @ batch_size: training and testing batch size
    @ transform: whether to enable transform
    """
    
    if name == "DTD":
        num_classes = 47
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.DTD(
            root=data_dir, split="train", download=True, transform=transform_train)
        valset = torchvision.datasets.DTD(
            root=data_dir, split="val", download=True, transform=transform_train)
        
        trainset = torch.utils.data.ConcatDataset([trainset, valset])
        
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.DTD(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
        
    elif name == "AIRCRAFT":
        num_classes = 100
        
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.FGVCAircraft(
            root=data_dir, split="trainval", download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.FGVCAircraft(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
        
    elif name == "PET":
        num_classes = 37
        
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                #transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.OxfordIIITPet(
            root=data_dir, split="trainval", download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.OxfordIIITPet(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
        
    elif name == "CIFAR10":
        num_classes = 10
        
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
        
    elif name == "CIFAR100":
        num_classes = 100
        
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([image_size, image_size]),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])
            
        transform_test = transforms.Compose([
            transforms.Resize([image_size, image_size]),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        trainset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR100(
            root=data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
    else:
        raise ValueError("That dataset is not yet implemented!")
    
    return trainloader, testloader, num_classes