from __future__ import annotations

from pathlib import Path

import inferno
import torch
import torchvision
from torch.utils.data import random_split
from torchvision.transforms import v2 as transforms

from .dataset import LightningDataset


class MNIST(LightningDataset):
    """MNIST digit classification dataset.

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        super().__init__(
            input_shape=torch.Size((1, 28, 28)),
            num_classes=10,
            train_and_validation_set_size=60000,
            batch_size=batch_size,
            test_set_size=10000,
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download train and test set
        torchvision.datasets.MNIST(self.data_dir, train=True, download=True)
        torchvision.datasets.MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            data_full = torchvision.datasets.MNIST(
                self.data_dir,
                train=True,
                transform=self.transform,
                target_transform=self.target_transform,
            )
            self.data_train, self.data_val = random_split(
                data_full,
                (self.train_set_size, self.validation_set_size),
                generator=torch.Generator().manual_seed(
                    42
                ),  # Fixed train / validation split.
            )
            self.data_train.transform = transforms.Compose(
                [
                    self.data_augmentation_transform,
                    self.transform,
                ]
            )

        if stage == "test" or stage is None:
            self.data_test = torchvision.datasets.MNIST(
                self.data_dir,
                train=False,
                transform=self.transform,
                target_transform=self.target_transform,
            )


class MNISTC(LightningDataset):
    """Corrupted MNIST image classification dataset.

    From [MNIST-C: A Robustness Benchmark for Computer Vision](https://arxiv.org/abs/1906.02337).

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param corruptions: List of corruptions to apply to the data.
    :param shift_severity: Severity of the corruption to apply.
        Must be an integer between 1 and 5.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        corruptions: list[str] = [
            "brightness",
            "canny_edges",
            "dotted_line",
            "fog",
            "glass_blur",
            "impulse_noise",
            "motion_blur",
            "rotate",
            "scale",
            "shear",
            "shot_noise",
            "spatter",
            "stripe",
            "translate",
            "zigzag",
        ],
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        self.corruptions = corruptions

        super().__init__(
            input_shape=torch.Size((1, 28, 28)),
            num_classes=10,
            train_and_validation_set_size=0,
            batch_size=batch_size,
            test_set_size=10000 * len(corruptions),
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download dataset
        inferno.datasets.MNISTC(
            self.data_dir,
            transform=self.transform,
            target_transform=self.target_transform,
            corruptions=self.corruptions,
            download=True,
        )

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            return

        if stage == "test" or stage is None:
            self.data_test = inferno.datasets.MNISTC(
                self.data_dir,
                transform=self.transform,
                target_transform=self.target_transform,
                corruptions=self.corruptions,
            )

    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError
