import torchvision
from torchvision import datasets
from lightly.transforms import (
    SimCLRTransform,
    BYOLTransform,
    BYOLView1Transform,
    BYOLView2Transform,
    DINOTransform,
    VICRegTransform,
)

try:
    import utils
    from corrupted import CorruptedDataModule
except ImportError:
    from . import utils
    from .corrupted import CorruptedDataModule


class CIFAR10DataModule(CorruptedDataModule):
    num_classes = 10
    image_size = 32

    def define_train_dataset(self):
        return datasets.CIFAR10(root=self.data_dir, train=True, download=True)

    def define_val_dataset(self):
        return datasets.CIFAR10(root=self.data_dir, train=False, download=True)

    def val_transform(self):
        return torchvision.transforms.ToTensor()


class CIFAR100DataModule(CorruptedDataModule):
    num_classes = 100
    image_size = 32

    def define_train_dataset(self):
        return datasets.CIFAR100(root=self.data_dir, train=True, download=True)

    def define_val_dataset(self):
        return datasets.CIFAR100(root=self.data_dir, train=False, download=True)

    def val_transform(self):
        return torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )


class CIFAR10SimCLRDataModule(CIFAR10DataModule):

    def original_train_transform(self):
        return SimCLRTransform(
            input_size=self.image_size,
            cj_strength=0.5,
            gaussian_blur=0.0,
            normalize=None,
        )


class CIFAR100SimCLRDataModule(CIFAR100DataModule):

    def original_train_transform(self):
        return SimCLRTransform(
            input_size=self.image_size,
            cj_strength=0.5,
            gaussian_blur=0.0,
            normalize=None,
        )


class CIFAR10BYOLDataModule(CIFAR10DataModule):

    def original_train_transform(self):
        return BYOLTransform(
            view_1_transform=BYOLView1Transform(
                input_size=self.image_size, gaussian_blur=0.0, normalize=None
            ),
            view_2_transform=BYOLView2Transform(
                input_size=self.image_size, gaussian_blur=0.0, normalize=None
            ),
        )


class CIFAR100BYOLDataModule(CIFAR100DataModule):

    def original_train_transform(self):
        return BYOLTransform(
            view_1_transform=BYOLView1Transform(
                input_size=self.image_size, gaussian_blur=0.0, normalize=None
            ),
            view_2_transform=BYOLView2Transform(
                input_size=self.image_size, gaussian_blur=0.0, normalize=None
            ),
        )


class CIFAR10VICRegDataModule(CIFAR10DataModule):

    def original_train_transform(self):
        return VICRegTransform(
            input_size=self.image_size,
            cj_strength=0.5,
            gaussian_blur=0.0,
            normalize=None,
        )


class CIFAR100VICRegDataModule(CIFAR100DataModule):

    def original_train_transform(self):
        return VICRegTransform(
            input_size=self.image_size,
            cj_strength=0.5,
            gaussian_blur=0.0,
            normalize=None,
        )


class CIFAR10DINODataModule(CIFAR10DataModule):

    def original_train_transform(self):
        return DINOTransform(
            global_crop_size=self.image_size,
            n_local_views=0,
            cj_strength=0.5,
            gaussian_blur=(0, 0, 0),
        )


class CIFAR100DINODataModule(CIFAR100DataModule):

    def original_train_transform(self):
        return DINOTransform(
            global_crop_size=self.image_size,
            n_local_views=0,
            cj_strength=0.5,
            gaussian_blur=(0, 0, 0),
        )
