import copy
import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tqdm import tqdm_notebook as tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

def dataset():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNISTの平均と標準偏差
    ])

    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    valid_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
    return train_dataset, valid_dataset
