import torch
import numpy as np
import scipy
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
import PIL.Image
    

def get_data(name, data_dir, batch_size, do_transform = True):
    """
    args:
    @ name: name of dataset to be used
    @ data_dir: where dataset are stored or to be stored
    @ batch_size: training and testing batch size
    @ do_transform: whether to enable transform
    """
    
    if name == "dtd":
        num_classes = 47
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
            
        trainset = torchvision.datasets.DTD(
            root=data_dir, split="train", download=True, transform=transform_train)
        valset = torchvision.datasets.DTD(
            root=data_dir, split="val", download=True, transform=transform_train)
        
        trainset = torch.utils.data.ConcatDataset([trainset, valset])
        
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=8)

        testset = torchvision.datasets.DTD(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=8)
        
        
    elif name == "aircraft":
        num_classes = 100
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        trainset = torchvision.datasets.FGVCAircraft(
            root=data_dir, split="trainval", download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=8)
        
        testset = torchvision.datasets.FGVCAircraft(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=8)
        
    elif name == "pet":
        num_classes = 37
        if do_transform:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.Resize([224, 224]), # Since the image are typically too large
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ])

        transform_test = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
        trainset = torchvision.datasets.OxfordIIITPet(
            root=data_dir, split="trainval", download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=8)

        testset = torchvision.datasets.OxfordIIITPet(
            root=data_dir, split="test", download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=8)
        
    elif name == "cifar10":
        num_classes = 10
        if do_transform:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        
        if True:
            print("Rescale to [224,224]")
            if do_transform:
                transform_train.transforms[0] = transforms.Resize([224, 224])
            else:
                transform_train.transforms.insert(0, transforms.Resize([224, 224]))
            transform_test.transforms.insert(0, transforms.Resize([224, 224]))
        
        trainset = torchvision.datasets.CIFAR10(
            root=data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=8)

        testset = torchvision.datasets.CIFAR10(
            root=data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=8)
        
    elif name == "cifar100":
        num_classes = 100
        if do_transform:
            transform_train = transforms.Compose([
                    #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
        else:
            transform_train = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
                ])
            
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        
        if True:
            print("Rescale to [224,224]")
            if do_transform:
                transform_train.transforms[0] = transforms.Resize([224, 224])
            else:
                transform_train.transforms.insert(0, transforms.Resize([224, 224]))
            transform_test.transforms.insert(0, transforms.Resize([224, 224]))

        trainset = torchvision.datasets.CIFAR100(
            root=data_dir, train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=batch_size, shuffle=True, num_workers=1)

        testset = torchvision.datasets.CIFAR100(
            root=data_dir, train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=False, num_workers=1)
        
    else:
        print("That dataset is not yet implemented!")
    
    return trainloader, testloader, num_classes