from torchvision import transforms
import torch
import torchvision
import sys
sys.path.append('..')
from config import opt


class CIFAR100_90(object):
    def __init__(self, input_size = 32, n_classes = 10):
        self.n_classes = n_classes
        transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (.48,.07,.02,), (.43,.77,.87,)
            ),
        ])
        self.train_dataset = torchvision.datasets.CIFAR100(
            root = opt.work_dir+'datasets',
            train = True,
            download = True,
            transform = transform
        )
        self.test_dataset = torchvision.datasets.CIFAR100(
            root = opt.work_dir+'datasets',
            train = False,
            download = True,
            transform = transform
        )
        forbidden = [i for i in range(90,100)]
        indexes = [i for i, value in enumerate(self.train_dataset.targets) if value not in forbidden]
        self.train_dataset.train_data = self.train_dataset.data[indexes]
        self.train_dataset.train_labels = [self.train_dataset.targets[i] for i in indexes]

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 64,
            shuffle = True,
            num_workers = 2,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 16,
            num_workers = 2,
            drop_last = False
        )

