import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import os

def data_loader(a, batch_size, is_train=True):
    if a == 0:
        #imagnet                                                                                                                 
        transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ],
                                 std = [ 0.229, 0.224, 0.225 ]),
        ])

        traindir = os.path.join('~/')
        if is_train:
            traindir = os.path.join('~/')
        train = datasets.ImageFolder(traindir, transform)
        train_loader = torch.utils.data.DataLoader(
            train, batch_size=batch_size, shuffle=True, num_workers=0)
        return train_loader
    elif a == 1:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        data_set = datasets.CIFAR10(root='~/', train=is_train, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True)

        train_loader = torch.utils.data.DataLoader(
            data_set,
            batch_size=batch_size, shuffle=True,
            num_workers=12, pin_memory=True)
        return train_loader
    elif a == 2:
        data_set = datasets.MNIST('~/', train=is_train, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))
        train_loader = torch.utils.data.DataLoader(data_set, batch_size=batch_size, shuffle=True,
            num_workers=12, pin_memory=True)


        return train_loader
    elif a == 3:
        normalize = transforms.Normalize(mean=[0.5070751592371323, 0.48654887331495095, 0.4409178433670343],
                                         std=[0.2673342858792401, 0.2564384629170883, 0.27615047132568404])
        data_set = datasets.CIFAR100(root='~/', train=is_train, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True)

        train_loader = torch.utils.data.DataLoader(
            data_set,
            batch_size=batch_size, shuffle=True,
            num_workers=12, pin_memory=True)
        return train_loader

def save_model(model, saved_path, epoch):
    path = saved_path+str(epoch)
    torch.save(model.state_dict(), path)

def load_model(model, epoch):
    path = "saved/epoch"+str(epoch)
    model.load_state_dict(torch.load(path))
    model.eval()
def load_model_new(model, saved_path, epoch):
    path = saved_path+str(epoch)
    model.load_state_dict(torch.load(path))
    return model