import os

import torch
import torchvision.datasets as datasets

from torch.utils.data import Subset
import numpy as np

class MNIST:
    def __init__(self, preprocess, location=os.path.expanduser("~/data"), batch_size=128, num_workers=0):
        self.train_dataset = datasets.MNIST(root=location, download=True, train=True, transform=preprocess)

        self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        self.test_dataset = datasets.MNIST(root=location, download=True, train=False, transform=preprocess)

        self.test_loader = torch.utils.data.DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        self.test_loader_shuffle = torch.utils.data.DataLoader(self.test_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        self.classnames = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]

        n_total = len(self.test_dataset)
        n_sample = n_total // 4 # 取四分之一
        indices = np.random.choice(n_total, n_sample, replace=False)  # 随机且不重复
        subset = Subset(self.test_dataset, indices)
        self.test_loader_subset = torch.utils.data.DataLoader(subset, batch_size=batch_size, num_workers=num_workers)
        self.test_loader_subset_shuffle = torch.utils.data.DataLoader(subset, shuffle=True, batch_size=batch_size, num_workers=num_workers)
