import torch
import torchvision
import numpy as np
import pickle

class MNIST_Combine(torch.utils.data.Dataset):
    def __init__(self, root, train=True):
        super().__init__()
        
        self.root=root
        with open(root, 'rb') as f:
            contents = pickle.load(f)
        if train:
            self.data = contents['train_data']
            self.targets = contents['train_label']
        else:
            self.data = contents['test_data']
            self.targets = contents['test_label']
        print("data size", self.data.shape)
    
    def __len__(self):
        return self.targets.shape[0]
    
    def __getitem__(self, idx):
        if "mnist" in self.root:
            return self.data[idx].unsqueeze(0), self.targets[idx]
        else:
            return self.data[idx], self.targets[idx]
