from __future__ import annotations

from pathlib import Path

import inferno
import torch
import torchvision
from torch.utils.data import random_split
from torchvision.transforms import v2 as transforms

from .dataset import LightningDataset


class CIFAR10(LightningDataset):
    """CIFAR10 image classification dataset.

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.4914, 0.4822, 0.4465),
                    std=(0.2470, 0.2435, 0.2616),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
            ]
        ),
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        persistent_workers: bool = True,
        num_workers: int = 0,
        generator: torch.Generator | None = None,
    ) -> None:

        super().__init__(
            input_shape=torch.Size((3, 32, 32)),
            num_classes=10,
            train_and_validation_set_size=50000,
            batch_size=batch_size,
            test_set_size=10000,
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir / Path("CIFAR10/raw"),
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download train and test set
        torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            data_full = torchvision.datasets.CIFAR10(
                self.data_dir,
                train=True,
                transform=self.transform,
                target_transform=self.target_transform,
            )
            self.data_train, self.data_val = random_split(
                data_full,
                (self.train_set_size, self.validation_set_size),
                generator=torch.Generator().manual_seed(
                    42
                ),  # Fixed train / validation split.
            )
            self.data_train.transform = transforms.Compose(
                [
                    self.data_augmentation_transform,
                    self.transform,
                ]
            )

        if stage == "test" or stage is None:
            self.data_test = torchvision.datasets.CIFAR10(
                self.data_dir,
                train=False,
                transform=self.transform,
                target_transform=self.target_transform,
            )


class CIFAR100(LightningDataset):
    """CIFAR100 image classification dataset.

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.5071, 0.4867, 0.4408),
                    std=(0.2675, 0.2565, 0.2761),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
            ]
        ),
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        super().__init__(
            input_shape=torch.Size((3, 32, 32)),
            num_classes=100,
            train_and_validation_set_size=50000,
            batch_size=batch_size,
            test_set_size=10000,
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir / Path("CIFAR100/raw"),
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download train and test set
        torchvision.datasets.CIFAR100(self.data_dir, train=True, download=True)
        torchvision.datasets.CIFAR100(self.data_dir, train=False, download=True)

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            data_full = torchvision.datasets.CIFAR100(
                self.data_dir,
                train=True,
                transform=self.transform,
                target_transform=self.target_transform,
            )
            self.data_train, self.data_val = random_split(
                data_full,
                (self.train_set_size, self.validation_set_size),
                generator=torch.Generator().manual_seed(
                    42
                ),  # Fixed train / validation split.
            )
            self.data_train.transform = transforms.Compose(
                [
                    self.data_augmentation_transform,
                    self.transform,
                ]
            )

        if stage == "test" or stage is None:
            self.data_test = torchvision.datasets.CIFAR100(
                self.data_dir,
                train=False,
                transform=self.transform,
                target_transform=self.target_transform,
            )


class CIFAR10C(LightningDataset):
    """Corrupted CIFAR10 image classification dataset.

    From [Benchmarking Neural Network Robustness to Common Corruptions and Perturbations](https://arxiv.org/abs/1903.12261).

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param corruptions: List of corruptions to apply to the data.
    :param shift_severity: Severity of the corruption to apply.
        Must be an integer between 1 and 5.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.5048, 0.4970, 0.4642),
                    std=(0.2427, 0.2396, 0.2574),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        corruptions: list[str] = [
            "brightness",
            "contrast",
            "defocus_blur",
            "elastic_transform",
            "fog",
            "frost",
            "gaussian_blur",
            "gaussian_noise",
            "glass_blur",
            "impulse_noise",
            "jpeg_compression",
            "motion_blur",
            "pixelate",
            "saturate",
            "shot_noise",
            "snow",
            "spatter",
            "speckle_noise",
            "zoom_blur",
        ],
        shift_severity: int = 5,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        self.corruptions = corruptions
        self.shift_severity = shift_severity

        super().__init__(
            input_shape=torch.Size((3, 32, 32)),
            num_classes=10,
            train_and_validation_set_size=0,
            batch_size=batch_size,
            test_set_size=10000 * len(corruptions),
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download dataset
        inferno.datasets.CIFAR10C(
            self.data_dir,
            transform=self.transform,
            target_transform=self.target_transform,
            corruptions=self.corruptions,
            shift_severity=self.shift_severity,
            download=True,
        )

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            return

        if stage == "test" or stage is None:
            self.data_test = inferno.datasets.CIFAR10C(
                self.data_dir,
                transform=self.transform,
                target_transform=self.target_transform,
                corruptions=self.corruptions,
                shift_severity=self.shift_severity,
            )

    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError


class CIFAR100C(LightningDataset):
    """Corrupted CIFAR100 image classification dataset.

    From [Benchmarking Neural Network Robustness to Common Corruptions and Perturbations](https://arxiv.org/abs/1903.12261).

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param corruptions: List of corruptions to apply to the data.
    :param shift_severity: Severity of the corruption to apply.
        Must be an integer between 1 and 5.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.5071, 0.4867, 0.4408),
                    std=(0.2675, 0.2565, 0.2761),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        corruptions: list[str] = [
            "brightness",
            "contrast",
            "defocus_blur",
            "elastic_transform",
            "fog",
            "frost",
            "gaussian_blur",
            "gaussian_noise",
            "glass_blur",
            "impulse_noise",
            "jpeg_compression",
            "motion_blur",
            "pixelate",
            "saturate",
            "shot_noise",
            "snow",
            "spatter",
            "speckle_noise",
            "zoom_blur",
        ],
        shift_severity: int = 5,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        self.corruptions = corruptions
        self.shift_severity = shift_severity

        super().__init__(
            input_shape=torch.Size((3, 32, 32)),
            num_classes=100,
            train_and_validation_set_size=0,
            batch_size=batch_size,
            test_set_size=10000 * len(corruptions),
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download dataset
        inferno.datasets.CIFAR100C(
            self.data_dir,
            transform=self.transform,
            target_transform=self.target_transform,
            corruptions=self.corruptions,
            shift_severity=self.shift_severity,
            download=True,
        )

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            return

        if stage == "test" or stage is None:
            self.data_test = inferno.datasets.CIFAR100C(
                self.data_dir,
                transform=self.transform,
                target_transform=self.target_transform,
                corruptions=self.corruptions,
                shift_severity=self.shift_severity,
            )

    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError
