from torchvision import transforms, datasets
import numpy as np

class Cifar10():
    def __init__(self, model_predict_status_arr, victim_model_predict_status_arr, data_dir='../data'):
        data_selected = model_predict_status_arr & victim_model_predict_status_arr

        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.cifar10 = datasets.CIFAR10(data_dir, train=False, download=True)
        self.data = self.cifar10.data[data_selected]
        self.labels = np.array(self.cifar10.targets)[data_selected].tolist()

        # 选取前2000个样本
        self.data = self.data[:2000]
        self.labels = self.labels[:2000]

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, item):
        img = self.data[item]
        target = self.labels[item]

        img = self.transform(img)

        return img, target

