import os
import pytorch_lightning as pl
from torchvision import transforms as transform_lib
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from typing import Callable
import torchvision
from typing import List, Dict
import numpy as np
import torch
from datasets.dataset_utils import DatasetWithIndices
import lightly.data as lightly_data
from lightly.transforms import (
    SimCLRTransform,
    BYOLTransform,
    DINOTransform,
    VICRegTransform,
    AIMTransform,
    MAETransform,
)
from lightly.transforms.utils import IMAGENET_NORMALIZE

from datasets.corrupted import CorruptedDataModule


class ImageNet100DataModule(CorruptedDataModule):
    num_classes = 100
    image_size = 224

    def define_train_dataset(self):
        train_path = os.path.join(self.data_dir, "train")
        return ImageFolder(train_path)

    def define_val_dataset(self):
        val_path = os.path.join(self.data_dir, "val")
        return ImageFolder(val_path)

    def val_transform(self):
        """
        The standard imagenet transforms for validation
        """

        return transform_lib.Compose(
            [
                transform_lib.Resize(self.image_size + 32),
                transform_lib.CenterCrop(self.image_size),
                transform_lib.ToTensor(),
                transform_lib.Normalize(**IMAGENET_NORMALIZE),
            ]
        )


class ImageNet100SimCLRDataModule(ImageNet100DataModule):

    def original_train_transform(self):
        return SimCLRTransform()


class ImageNet100BYOLDataModule(ImageNet100DataModule):

    def original_train_transform(self):
        return BYOLTransform()


class ImageNet100VICRegDataModule(ImageNet100DataModule):

    def original_train_transform(self):
        return BYOLTransform()


class ImageNet100DINODataModule(ImageNet100DataModule):

    def original_train_transform(self):
        return DINOTransform()


class ImageNet100MAEDataModule(ImageNet100DataModule):

    def original_train_transform(self):
        return MAETransform()