# -*- coding: utf-8 -*-
"""
Created on Tue Jul 24 11:33:23 2018
"""
import argparse
import torch
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn as nn
from imagenet_dataloader import Imagenet32
import numpy as np

def dataset_loader(setname,batch_size):
    if setname == 'mnist':
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='data/mnist', train=True, download=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)
            
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('data/mnist', train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)
        
    elif setname == 'cifar10':
        num_classes = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='data/',train=True,download=False, 
                             transform=transforms.Compose([ 
                                 transforms.RandomCrop(32, padding=4),  
                                 transforms.Scale(32),  
                                 transforms.RandomHorizontalFlip(),  
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),  
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10(root='data/',train=False,download=False,
                             transform=transforms.Compose([ 
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),  
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)

    elif setname == 'cifar100':
        num_classes = 100
        train_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='data/',train=True,download=False, 
                             transform=transforms.Compose([ 
                                 transforms.RandomCrop(32, padding=4),  
                                 transforms.Scale(32),  
                                 transforms.RandomHorizontalFlip(),  
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),  
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR100(root='data/',train=False,download=False,
                             transform=transforms.Compose([ 
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),  
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)

    elif setname == 'cinic10':
        num_classes = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(root='data/cinic10/trainvalid', 
                             transform=transforms.Compose([ 
                                 transforms.RandomCrop(32, padding=4),  
                                 transforms.Scale(32),  
                                 transforms.RandomHorizontalFlip(),  
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),  
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(root='data/cinic10/test', 
                             transform=transforms.Compose([ 
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),  
                             ])),
            batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True)
    elif setname == 'imagenet32':
        num_classes = 1000
        train_loader = torch.utils.data.DataLoader(Imagenet32(root='data/imagenet/',train=True) ,
            batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True)
        test_loader = torch.utils.data.DataLoader(Imagenet32(root='data/imagenet/',train=False) ,
            batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True)

    elif setname == 'svhn':
        num_classes = 10
        train_loader = torch.utils.data.DataLoader(
            datasets.SVHN(root='data/SVHN',split='train',download=False,
                             transform=transforms.Compose([ 
                                 transforms.RandomCrop(32, padding=4),  
                                 transforms.Scale(32),  
                                 transforms.RandomHorizontalFlip(),  
                                 transforms.ToTensor(),  
                                 transforms.Normalize(mean = [x / 255.0 for x in[109.9, 109.7, 113.8]], 
                                                      std=[x / 255.0 for x in [50.1, 50.6, 50.8]]),
                            ])), 
            batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)
        test_loader = torch.utils.data.DataLoader(
            datasets.SVHN(root='data/SVHN',split='test',download=False,
                             transform=transforms.Compose([ 
                                 transforms.ToTensor(),  
                                 transforms.Normalize(mean = [x / 255.0 for x in[109.9, 109.7, 113.8]], 
                                                      std=[x / 255.0 for x in [50.1, 50.6, 50.8]]),
                            ])),  
            batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)
    return train_loader, test_loader, num_classes   
        


from scipy import misc
def savetensor2img(data):
    image = data.cpu()
    print(image.size())
    img=image.numpy()
    img=(img)*255
    img = img.transpose(2,1,0)
    img = img.astype('int')
    img = np.uint8(img)
    print(img)
    print(np.max(img),np.min(img))
    print(np.shape(img))
    misc.imsave('largelr/out.jpg', img)  # 使用misc.imsave方法将数组保存为图片

