from torchvision import transforms
import torch
import torchvision
import sys 
sys.path.append("..")
from config import opt


class CIFAR10(object):
    def __init__(self, input_size = 32, transform=None, partition=None):
        self.n_classes = 10
        train_transform = transforms.Compose([
            transforms.RandomCrop(input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
            # transforms.Normalize(
            #     (.48,.07,.02,), (.43,.77,.87,)
            # ),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ])
        # test_transform = transforms.Compose([
        #     transforms.Resize(input_size),
        #     transforms.ToTensor(),
        # ])
        self.train_dataset = torchvision.datasets.CIFAR10(
            root = opt.data_dir+'datasets',
            train = True,
            download = True,
            transform = train_transform
        )
        self.test_dataset = torchvision.datasets.CIFAR10(
            root = opt.data_dir+'datasets',
            train = False,
            download = True,
            transform = test_transform
        )
        self.train_dataset_no_aug = torchvision.datasets.CIFAR10(
            root=opt.data_dir + 'datasets',
            train=True,
            download=True,
            transform=test_transform
        )
        if transform:
            self.dataset = torchvision.datasets.CIFAR10(
                root = opt.data_dir+'datasets',
                train = True,
                download = True,
                transform = transform
            )

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

