from __future__ import annotations

from pathlib import Path

import sklearn
import torch
import torchvision
from torch.utils.data import random_split
from torchvision.transforms import v2 as transforms

from .dataset import LightningDataset


class TwoMoons(LightningDataset):
    """Two Moons classification dataset.

    :param batch_size: Batch size for training.
    :param batch_size_test: Batch size for testing.
    :param noise: Standard deviation of Gaussian noise added to the input data.
    :param train_and_validation_set_size: Size of the training and validation set.
    :param test_set_size: Size of the test set.
    :param train_validation_split: Fraction of data to use for training and validation.
    :param transform: Transform to apply to the data.
    :param target_transform: Transform to apply to the targets.
    :param data_augmentation_transform: Data augmentation to apply to the data.
        This is applied to the training set only.
    :param data_dir: Directory to download the dataset to.
    :param num_workers: How many subprocesses to use for data loading.
        `0` means that the data will be loaded in the main process.
    :param persistent_workers: If `True`, the data loader will not shutdown the worker processes
        after a dataset has been consumed.
    :param pin_memory: If `True`, the data loader will copy Tensors into device/CUDA pinned memory
        before returning them.
    :param generator: Random generator used for sampling batches.
    """

    def __init__(
        self,
        batch_size: int,
        batch_size_test: int | None = None,
        noise: float = 0.15,
        train_and_validation_set_size: int = 400,
        test_set_size: int = 1000,
        train_validation_split: list[float] = [0.5, 0.5],
        transform: transforms.Transform | None = None,
        target_transform: transforms.Transform | None = None,
        data_augmentation_transform: transforms.Transform | None = None,
        data_dir: Path = Path.cwd(),
        pin_memory: bool = True,
        num_workers: int = 0,
        persistent_workers: bool = True,
        generator: torch.Generator | None = None,
    ) -> None:

        self.noise = noise

        super().__init__(
            input_shape=torch.Size((2,)),
            num_classes=2,
            train_and_validation_set_size=train_and_validation_set_size,
            batch_size=batch_size,
            test_set_size=test_set_size,
            batch_size_test=batch_size_test,
            train_validation_split=train_validation_split,
            transform=transform,
            target_transform=target_transform,
            data_augmentation_transform=data_augmentation_transform,
            data_dir=data_dir,
            pin_memory=pin_memory,
            num_workers=num_workers,
            persistent_workers=persistent_workers,
            generator=generator,
        )

    def prepare_data(self) -> None:
        pass

    def setup(self, stage: str | None = None) -> None:
        if stage == "fit" or stage is None:
            X, y = sklearn.datasets.make_moons(
                n_samples=self.train_and_validation_set_size,
                noise=self.noise,
                random_state=9674,
            )
            self.X_mean = X.mean(axis=-2)
            self.X_std = X.std(axis=-2)
            data_full = torch.utils.data.TensorDataset(
                torch.as_tensor((X - self.X_mean) / self.X_std, dtype=torch.float32),
                torch.as_tensor(y, dtype=torch.float32),
            )
            self.data_train, self.data_val = random_split(
                data_full,
                (self.train_set_size, self.validation_set_size),
                generator=torch.Generator().manual_seed(
                    42
                ),  # Fixed train / validation split.
            )
            # self.data_train.transform = transforms.Compose(
            #     [
            #         self.data_augmentation_transform,
            #         self.transform,
            #     ]
            # ) # TODO: transform is not applied to the data currently

        if stage == "test" or stage is None:
            X, y = sklearn.datasets.make_moons(
                n_samples=self.test_set_size,
                noise=self.noise,
                random_state=2445,
            )
            self.data_test = torch.utils.data.TensorDataset(
                torch.as_tensor((X - self.X_mean) / self.X_std, dtype=torch.float32),
                torch.as_tensor(y, dtype=torch.float32),
            )
