import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch import optim
from torch.autograd import Variable
from Trace import LossScaledTrace
import time
import copy

def Dataset(train_size):
    torch.manual_seed(1)
    N_train = 50000
    train_data = datasets.FashionMNIST(
    #train_data=datasets.CIFAR10(
        root='data',
        train=True,
        transform=transforms.Compose([
                                        transforms.ToTensor(), # first, convert image to PyTorch tensor
                                        transforms.ConvertImageDtype(torch.float),
                                                          ]),
        download=True,
    )
    test_data = datasets.FashionMNIST(
    #test_data = datasets.CIFAR10(
        root='data',
        train=False,
        #transform=ToTensor()
        transform=transforms.Compose([
                                        transforms.ToTensor(),  # first, convert image to PyTorch tensor
                                        transforms.ConvertImageDtype(torch.float),
    ]),
        download=True
    )
    #indices = list(range(0, train_size * int(N_train / train_size), int(N_train / train_size)))
    #train_data = torch.utils.data.Subset(train_data, indices)
    return train_data, test_data