# -*- coding: utf-8 -*-
"""

@author: Anonymous Author
"""

import torch
import torchvision
from torchvision import datasets
import numpy as np
from PIL import Image

class CIFAR10sub(datasets.CIFAR10):
    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=True):
        super().__init__(root, train=train,
                         transform=transform,
                         download=download)
        
        self.targets = np.array(self.targets)

        indexs = np.array(indexs)
        self.data = self.data[indexs]
        self.targets = np.array(self.targets)[indexs]
        self.indexs = indexs

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        return img, target#, self.indexs[index]