import pytorch_lightning as pl
from torchvision.datasets import FakeData
from torch.utils.data import DataLoader
import torchvision
from datasets.dataset_utils import DatasetWithIndices
from datasets.corrupted import CorruptedDataModule

from lightly.transforms import SimCLRTransform


class DummyDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, input_channels: int = 3):
        super().__init__()
        self.batch_size = batch_size
        self.num_classes = 10
        self.input_channels = input_channels
        self.num_train_samples = 1000
        self.image_size = 224
        self.transforms = torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
            ]
        )

    def train_dataloader(self) -> DataLoader:
        data_with_indices = DatasetWithIndices(
            FakeData(
                size=self.num_train_samples,
                image_size=(self.input_channels, self.image_size, self.image_size),
                num_classes=self.num_classes,
                transform=self.transforms,
            ),
            # if False, only index is returned
            with_labels=True,
        )
        training_loader = DataLoader(
            data_with_indices, batch_size=self.batch_size, shuffle=True
        )
        return training_loader

    def val_dataloader(self):
        val_dataset = FakeData(
            size=int(0.2 * self.num_train_samples),
            image_size=(self.input_channels, self.image_size, self.image_size),
            num_classes=self.num_classes,
            transform=self.transforms,
        )

        loader = DataLoader(
            val_dataset, batch_size=self.batch_size, shuffle=False, drop_last=False
        )
        return loader

    def test_dataloader(self):
        test_dataset = FakeData(
            size=int(0.2 * self.num_train_samples),
            image_size=(self.input_channels, self.image_size, self.image_size),
            num_classes=self.num_classes,
            transform=self.transforms,
        )
        loader = DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
        )
        return loader


class DummyCIFARSimCLR(CorruptedDataModule):
    num_classes = 10
    image_size = 32
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]

    def define_train_dataset(self):
        return FakeData(
            size=1000,
            image_size=(3, self.image_size, self.image_size),
            num_classes=self.num_classes,
        )

    def define_val_dataset(self):
        return FakeData(
            size=1000,
            image_size=(3, self.image_size, self.image_size),
            num_classes=self.num_classes,
        )

    def val_transform(self):
        return torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(self.IMAGENET_MEAN, self.IMAGENET_STD),
            ]
        )

    def original_train_transform(self):
        return SimCLRTransform(input_size=self.image_size)


class DummyImageNetSimCLR(CorruptedDataModule):
    """A FakeData-based DataModule simulating ImageNet for SimCLR workflows."""
    num_classes = 1000
    image_size = 224
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]

    def define_train_dataset(self):
        # Fake “ImageNet” with 1000 classes, 224×224 RGB images
        return FakeData(
            size=1280,
            image_size=(3, self.image_size, self.image_size),
            num_classes=self.num_classes,
        )

    def define_val_dataset(self):
        return FakeData(
            size=320,
            image_size=(3, self.image_size, self.image_size),
            num_classes=self.num_classes,
        )

    def original_train_transform(self):
        # SimCLR augmentations for 224×224 inputs
        return SimCLRTransform(input_size=self.image_size)

    def val_transform(self):
        return torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(self.IMAGENET_MEAN, self.IMAGENET_STD),
            ]
        )


if __name__ == "__main__":
    dm = DummyDataModule()
    dm.train_dataloader()
