from __future__ import annotations

from pathlib import Path

import inferno
import torch
from torch.utils.data import random_split
from torchvision.transforms import v2 as transforms

from .dataset import LightningDataset


class TinyImageNet(LightningDataset):
    """TinyImageNet image classification dataset.

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        super().__init__(
            input_shape=torch.Size((3, 64, 64)),
            num_classes=200,
            train_and_validation_set_size=100000,
            batch_size=batch_size,
            test_set_size=10000,
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir / Path("TinyImageNet/raw"),
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download train and test set
        inferno.datasets.TinyImageNet(self.data_dir, train=True, download=True)
        inferno.datasets.TinyImageNet(self.data_dir, train=False, download=True)

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            data_full = inferno.datasets.TinyImageNet(
                self.data_dir,
                train=True,
                transform=self.transform,
                target_transform=self.target_transform,
            )
            self.data_train, self.data_val = random_split(
                data_full,
                (self.train_set_size, self.validation_set_size),
                generator=torch.Generator().manual_seed(
                    42
                ),  # Fixed train / validation split.
            )
            self.data_train.transform = transforms.Compose(
                [
                    self.data_augmentation_transform,
                    self.transform,
                ]
            )

        if stage == "test" or stage is None:
            self.data_test = inferno.datasets.TinyImageNet(
                self.data_dir,
                train=False,
                transform=self.transform,
                target_transform=self.target_transform,
            )


class TinyImageNetC(LightningDataset):
    """Corrupted TinyImageNet image classification dataset.

    From [Benchmarking Neural Network Robustness to Common Corruptions and Perturbations](https://arxiv.org/abs/1903.12261).

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param corruptions: List of corruptions to apply to the data.
    :param shift_severity: Severity of the corruption to apply.
        Must be an integer between 1 and 5.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = transforms.Compose(
            [
                transforms.ToImage(),
                transforms.ToDtype(torch.float32, scale=True),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225),
                ),
            ]
        ),
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        corruptions: list[str] = [
            "brightness",
            "contrast",
            "defocus_blur",
            "elastic_transform",
            "fog",
            "frost",
            "gaussian_blur",
            "gaussian_noise",
            "glass_blur",
            "impulse_noise",
            "jpeg_compression",
            "motion_blur",
            "pixelate",
            "saturate",
            "shot_noise",
            "snow",
            "spatter",
            "speckle_noise",
            "zoom_blur",
        ],
        shift_severity: int = 5,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        self.corruptions = corruptions
        self.shift_severity = shift_severity

        super().__init__(
            input_shape=torch.Size((3, 64, 64)),
            num_classes=200,
            train_and_validation_set_size=0,
            batch_size=batch_size,
            test_set_size=10000 * len(corruptions),
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        # Download dataset
        inferno.datasets.TinyImageNetC(
            self.data_dir,
            transform=self.transform,
            target_transform=self.target_transform,
            corruptions=self.corruptions,
            shift_severity=self.shift_severity,
            download=True,
        )

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            return

        if stage == "test" or stage is None:
            self.data_test = inferno.datasets.TinyImageNetC(
                self.data_dir,
                transform=self.transform,
                target_transform=self.target_transform,
                corruptions=self.corruptions,
                shift_severity=self.shift_severity,
            )

    def train_dataloader(self):
        raise NotImplementedError

    def val_dataloader(self):
        raise NotImplementedError
