import torch

import torchvision
from torchvision import models
from torchvision.transforms import Compose, ToTensor, Normalize, Resize,Grayscale,ToPILImage,RandomCrop,RandomHorizontalFlip,RandomPerspective,RandomRotation

import torchvision.datasets as datasets
from cifar10_data import CIFARRandomLabels,CIFARSubset_balanced
from fashion_data import FashionSubset_balanced
from mnist_data import MNISTRandomLabels,MNISTSubset_balanced

from utils import train_val_dataset





######### load dataset ###################
def get_mnist_data(batch_size=100, num_samples_per_class=1000, num_classes=10):

    transforms_train = Compose([RandomRotation(10),
                                ToTensor(),
                                Normalize((0.1307,), (0.3081,))])

    transforms_test = Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])


    train_dataset = MNISTSubset_balanced(root='MNIST-data', train=True, download=True,
                                         transform=transforms_train, num_classes=num_classes,
                                         num_samples_per_class = num_samples_per_class)

    test_dataset = MNISTSubset_balanced(root='MNIST-data',
                                        download=True,
                                        train=False,
                                        transform=transforms_test)       

    val,train = train_val_dataset(train_dataset,5000)

        
    val_loader = torch.utils.data.DataLoader(dataset=val,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4)


        
    train_loader = torch.utils.data.DataLoader(dataset=train,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4)


    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=4)
     
    return train_loader,test_loader


def get_fashion_data(batch_size=100, num_samples_per_class=0, num_classes=10):


    transforms_train = Compose([ToTensor(),
                                Normalize((0.1307,), (0.3081,))])




    transforms_test = Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]) # this if for mnist


    train_dataset = FashionSubset_balanced(root='Fashion-data', train=True, download=True,
                                         transform=transforms_train, num_classes=num_classes,
                                         num_samples_per_class = num_samples_per_class)

    test_dataset = FashionSubset_balanced(root='Fashion-data',
                                        download=True,
                                        train=False,
                                        transform=transforms_test)      
        
       
        
    val,train = train_val_dataset(train_dataset,5000)

        
    val_loader = torch.utils.data.DataLoader(dataset=val,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4)


        
    train_loader = torch.utils.data.DataLoader(dataset=train,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4)


    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=4)


     
    return train_loader,test_loader


def get_cifar10_data(batch_size=100, num_samples_per_class=0, num_classes=10):


    transforms_train = Compose([RandomCrop(32, padding=4),
                                RandomHorizontalFlip(),
                                ToTensor(),
                                Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    transforms_test = Compose([ToTensor(),
                               Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

        
    train_dataset = CIFARSubset_balanced(root='Cifar10-data', train=True, download=True,
                                         transform=transforms_train, num_classes=num_classes,
                                         num_samples_per_class = num_samples_per_class)


    test_dataset = CIFARSubset_balanced(root='Cifar10-data',download=True,
                                        train=False,
                                        transform=transforms_test)  

    val,train = train_val_dataset(train_dataset,5000)

        
    val_loader = torch.utils.data.DataLoader(dataset=val,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4)


        
    train_loader = torch.utils.data.DataLoader(dataset=train,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4)


    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=4)

     
    return train_loader,test_loader



def get_cifar100_data(batch_size):

    transforms_train = Compose([RandomCrop(32, padding=4),
                                RandomHorizontalFlip(),
                                ToTensor(),
                                Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],std=[x / 255.0 for x in [63.0, 62.1, 66.7]])]) 


    transforms_test = Compose([ToTensor(),
                               Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],std=[x / 255.0 for x in [63.0, 62.1, 66.7]])]) 

        
    train_dataset = datasets.CIFAR100(root='Cifar100-data', train=True, download=False,transform=transforms_train)

    test_dataset = datasets.CIFAR100(root='Cifar100-data',download=False,train=False,transform=transforms_test)       
        

    val,train = train_val_dataset(train_dataset,5000)

        
    val_loader = torch.utils.data.DataLoader(dataset=val,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=4)


        
    train_loader = torch.utils.data.DataLoader(dataset=train,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=4)


    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=4)

     
    return train_loader,test_loader




