from typing import Callable, Literal, Optional
import math
from dataclasses import dataclass, asdict, field

import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import Sampler, Dataset
import pytorch_lightning as pl
# import pl_bolts.datamodules plb_dm

from ..loading import (
    DataLoaderConfig,
    load,
)
# from ..utils.data_loading import DataLoaderConfig, load
from ..lib.dirs import get_dataset_dir
from ..lib.dataset_accessor import LightningDataAccessor, DatasetStage
from ..wrappers.sample_wrapper import DatasetWrapper


Cifar10ClassLabelType = Literal[
    "plane", "car", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]
CIFAR10_CLASS_LABELS: tuple[Cifar10ClassLabelType, ...] = (
    "plane", "car", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
)

# DEFAULT_TRAIN_TRANSFORMS = transforms.Compose([
DEFAULT_TRAIN_TRANSFORMS = [
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    # transforms.ToTensor(),
    # cifar10_normalization(),
]
DEFAULT_TEST_TRANSFORMS = [
    # transforms.ToTensor(),
    # cifar10_normalization(),
]

NORMALIZATION_TRANSFORMS = {
    "cifar10": transforms.Normalize(
        torch.tensor([0.4914, 0.4822, 0.4465]),
        torch.tensor([0.2023, 0.1994, 0.2010])
    ),
    "cifar100": transforms.Normalize(
        torch.tensor([0.5071, 0.4865, 0.4409]),
        torch.tensor([0.2673, 0.2564, 0.2762])
    ),
}


CifarType = Literal["cifar10", "cifar100"]

@dataclass
class CifarDataConfig:
    cifar_type: CifarType
    loader_config: DataLoaderConfig
    transforms_train: list[torch.nn.Module] = field(default_factory=lambda: (
        list(DEFAULT_TRAIN_TRANSFORMS)
    ))
    transforms_test: list[torch.nn.Module] = field(default_factory=lambda: (
        list(DEFAULT_TEST_TRANSFORMS)
    ))
    sampler_constructor: Optional[Callable[[Dataset], Sampler]] = None
    to_tensor: bool = True
    normalize: bool = True


# Based on https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/datamodules.html
class CifarData(LightningDataAccessor):

    def __init__(self, config: CifarDataConfig) -> None:
        super().__init__()

        self.data_dir = str(get_dataset_dir(config.cifar_type))
        self.transforms_train = config.transforms_train
        self.transforms_test = config.transforms_test
        self.loader_config = config.loader_config
        self.dataset_constructor = (
            CIFAR10 if config.cifar_type == "cifar10" else CIFAR100
        )
        self.sampler_constructor = config.sampler_constructor
        self.norm_tensor_transforms = []
        if config.to_tensor:
            self.norm_tensor_transforms.append(transforms.ToTensor())
        if config.normalize:
            self.norm_tensor_transforms.append(
                NORMALIZATION_TRANSFORMS[config.cifar_type]
            )

    def prepare_data(self) -> None:
        # Download the dataset
        self.dataset_constructor(self.data_dir, train=True, download=True)
        self.dataset_constructor(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit" or stage is None:
            # TODO: use the test transforms for the validation set
            cifar_full = self.dataset_constructor(
                self.data_dir, train=True, transform=transforms.Compose([
                    *self.transforms_train,
                    *self.norm_tensor_transforms,
                ]),
            )
            n_training_samples = len(cifar_full)
            split = int(math.floor(0.9 * n_training_samples))
            train_split, val_split= random_split(
                cifar_full, [split, n_training_samples - split]
            )
            # val_split.transform = self.transforms_test
            self.set_dataset("train", DatasetWrapper(train_split))
            self.set_dataset("val", DatasetWrapper(val_split))
            # return self.dataset_train

        if stage == "test" or stage is None:
            self.set_dataset("test", DatasetWrapper(self.dataset_constructor(
                self.data_dir, train=False, transform=transforms.Compose([
                    *self.transforms_test,
                    *self.norm_tensor_transforms,
                ]),
            )))
            # return self.dataset_test
        
        if stage is not None and stage not in ["fit", "test"]:
            raise ValueError(f"Invalid stage '{stage}'")

    def train_dataloader(self) -> DataLoader:
        dataset = self.get_dataset("train")
        return load(
            dataset,
            train=True,
            config=self._get_loader_config(dataset),
        )

    def val_dataloader(self) -> DataLoader:
        dataset = self.get_dataset("val")
        return load(
            dataset,
            train=False,
            config=self._get_loader_config(dataset),
        )

    def test_dataloader(self) -> DataLoader:
        dataset = self.get_dataset("test")
        return load(
            dataset,
            train=False,
            config=self._get_loader_config(dataset),
        )

    def _get_loader_config(self, dataset: Dataset) -> DataLoaderConfig:
        if self.sampler_constructor:
            sampler = self.sampler_constructor(dataset)
            return DataLoaderConfig(**{
                **asdict(self.loader_config),
                "sampler": sampler,
            })
        else:
            return self.loader_config
