from __future__ import annotations

import abc
from pathlib import Path

import inferno
import lightning as L
import torch
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import v2 as transforms


class LightningDataset(L.LightningDataModule, abc.ABC):
    """Dataset.

    :param input_shape: Input shape.
    :param train_and_validation_set_size: Size of the training and validation set combined.
    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param data_dir: Directory to download the dataset to.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
        This is useful when using a data loader in a loop, as it will save the overhead of
        creating and destroying worker processes. Will be ignored if
        `num_workers` is `0`.
    :param generator: Random generator used for sampling batches.
    :param **kwargs: Keyword arguments used to set additional attributes of a dataset.
    """

    def __init__(
        self,
        input_shape: tuple[int, ...],
        train_and_validation_set_size: int,
        batch_size: int,
        test_set_size: int,
        batch_size_test: int | None = None,
        train_validation_split: list[float] = [0.9, 0.1],
        transform: transforms.Transform | None = None,
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
        **kwargs,
    ) -> None:
        super().__init__()

        self.input_shape = input_shape
        self.train_and_validation_set_size = train_and_validation_set_size
        self.train_validation_split = train_validation_split
        self.train_set_size = int(
            self.train_and_validation_set_size * self.train_validation_split[0]
        )
        self.validation_set_size = int(
            self.train_and_validation_set_size * self.train_validation_split[1]
        )
        self.test_set_size = test_set_size

        self.data_dir = data_dir

        self.batch_size = batch_size
        self.batch_size_test = batch_size_test

        self.pin_memory = pin_memory
        self.num_workers = num_workers
        self.persistent_workers = False if num_workers == 0 else persistent_workers
        self.generator = generator

        self.transform = transform
        self.target_transform = target_transform
        self.data_augmentation_transform = data_augmentation_transform

        self.data_train: torch.utils.data.Dataset | None = None
        self.data_val: torch.utils.data.Dataset | None = None
        self.data_test: torch.utils.data.Dataset | None = None

        for key, value in kwargs.items():
            setattr(self, key, value)

        self.save_hyperparameters(
            {
                "dataset": self.__class__.__name__,
                "num_classes": (
                    getattr(self, "num_classes")
                    if hasattr(self, "num_classes")
                    else None
                ),
                "input_shape": list(self.input_shape),
                "train_validation_split": self.train_validation_split,
                "train_set_size": self.train_set_size,
                "validation_set_size": self.validation_set_size,
                "test_set_size": self.test_set_size,
                "batch_size": self.batch_size,
                "batch_size_test": self.batch_size_test,
            }
        )

    def train_dataloader(self):
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=self.pin_memory,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            generator=self.generator,
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_test,
            shuffle=False,
            pin_memory=self.pin_memory,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            generator=self.generator,
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_test,
            shuffle=False,
            pin_memory=self.pin_memory,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            generator=self.generator,
        )
