# from typing import Literal
from dataclasses import dataclass
from copy import deepcopy
from typing import cast

import numpy as np
import pytorch_lightning as pl
import torch
from torchvision import transforms
from torchvision.transforms import functional

from vis_datasets.natural.cifar import CifarData, CifarDataConfig
from vis_datasets.loading import DataLoaderConfig
from utils.data import DataConfig, create_obj2d_dataset, DATASET_PARAMETERS
from datasets.objects_2d import (
    Objects2DData, UniformBackgroundsConfig, TRANSFORMS
)


@dataclass
class IFEDataConfig(DataConfig):
    config_seed: int
    # Used to draw samples from the dataset
    n_classes: int

@dataclass
class IFEDatasets:
    data: dict[str, pl.LightningDataModule]
    data_keys: list[str]
    objects: list[int]

def create_cifar_only_dataset(
    batch_size: int,
    normalize: bool = True,
) -> pl.LightningDataModule:
    loader_config = DataLoaderConfig(
        batch_size=batch_size,
    )
    cifar_only_dataset = CifarData(CifarDataConfig(
        cifar_type="cifar10",
        loader_config=loader_config,
        normalize=normalize,
    ))
    return cifar_only_dataset

CORRELATION_VALUES = [0, 0.2, 0.4, 0.6, 0.8, 0.85, 0.9, 0.95, 1.0]
AVAILABILITY_VALUES = [0.2, 0.4, 0.6, 0.8]

CIFAR_ONLY_KEY = "cifar_only"
# CIFAR_ONLY_MODEL_KEY = "m_cifar_only"
OBJECTS_ONLY_KEY = "objects_only"
MIXED_CIFAR_KEY = "mixed_cifar_labels_cor"
MIXED_CIFAR_AVAILABILITY_KEY = "mixed_cifar_labels_av"
MIXED_OBJECTS_KEY = "mixed_object_labels"

def create_irrelevant_features_dataset(
    config: IFEDataConfig,
    transform: str,
    normalize: bool = True,
) -> IFEDatasets:
    loader_config = DataLoaderConfig(
        batch_size=config.batch_size,
    )

    dataset_params = DATASET_PARAMETERS["obj2d"]
    rng = np.random.default_rng(config.config_seed)
    # Subsample foreground objects
    n_classes = config.n_classes
    assert n_classes == 10
    objects = list(rng.choice(
        dataset_params.class_objects, n_classes, replace=False
    ))

    transforms = _create_transforms(transform, config.img_size)

    # A dataset with CIFAR images overlayed with random objects that uses
    # the CIFAR labels as targets
    mixed_cifar_labels_availability_datasets = {
        f"{MIXED_CIFAR_AVAILABILITY_KEY}{fg_av}": _create_mixed_dataset(
            config,
            transforms,
            objects,
            loader_config,
            # The CIFAR images are the backgrounds
            class_from_foreground=False,
            fg_availability=fg_av,
            normalize=normalize,
        )
        for fg_av in AVAILABILITY_VALUES
    }
    mixed_cifar_labels_correlation_datasets = {
        f"{MIXED_CIFAR_KEY}{fg_bg_correlation}": _create_mixed_dataset(
            config,
            transforms,
            objects,
            loader_config,
            # The CIFAR images are the backgrounds
            class_from_foreground=False,
            fg_bg_correlation=fg_bg_correlation,
            normalize=normalize,
        )
        for fg_bg_correlation in CORRELATION_VALUES
    }
    mixed_cifar_labels_datasets = {
        **mixed_cifar_labels_availability_datasets,
        **mixed_cifar_labels_correlation_datasets,
    }
    # A dataset with CIFAR images overlayed with random objects that uses
    # the random object types as targets
    mixed_object_labels_dataset = _create_mixed_dataset(
        config,
        transforms,
        objects,
        loader_config,
        # The objects images are the foregrounds
        class_from_foreground=True,
        normalize=normalize,
    )

    objects_only_dataset = create_obj2d_dataset(
        config,
        transforms,
        objects,
        UniformBackgroundsConfig([0, 0, 0], img_size=config.img_size),
        normalize=normalize,
    )

    datasets: dict[str, pl.LightningDataModule] = {
        # "mixed_cifar_labels": mixed_cifar_labels_dataset,
        **mixed_cifar_labels_datasets,
        MIXED_OBJECTS_KEY: mixed_object_labels_dataset,
        OBJECTS_ONLY_KEY: objects_only_dataset,
    }
    return IFEDatasets(
        data=datasets,
        data_keys=list(datasets.keys()),
        objects=objects,
    )

def _scale_down(img: torch.Tensor) -> torch.Tensor:
    return functional.affine(img, 0, [0, 0], 0.375, [0, 0])
SCALE_DOWN_TRANSFORM = transforms.Lambda(_scale_down)
def _move_to_upper_right(img: torch.Tensor) -> torch.Tensor:
    return functional.affine(img, 0, [10, -10], 1, [0, 0])
UPPER_RIGHT_CORNER = transforms.Lambda(_move_to_upper_right)

def _create_transforms(
    transform_name: str, img_size: int,
) -> list[torch.nn.Module]:
    transforms = [SCALE_DOWN_TRANSFORM]
    if transform_name != "move":
        transforms.append(UPPER_RIGHT_CORNER)
    if transform_name != "none":
        transform = TRANSFORMS[transform_name](img_size)
        transforms.append(transform)
    return cast(list[torch.nn.Module], transforms)

def _create_mixed_dataset(
    config: IFEDataConfig,
    transforms: list[torch.nn.Module],
    fg_objects: list[int],
    loader_config: DataLoaderConfig,
    class_from_foreground: bool,
    fg_bg_correlation: float = 0.0,
    fg_availability: float = 1.0,
    normalize: bool = True,
) -> Objects2DData:
    # Don't normalize and convert to tensor, the composite part of 
    cifar_dataset = CifarData(CifarDataConfig(
        cifar_type="cifar10",
        loader_config=loader_config,
        to_tensor=False,
        normalize=False,
    ))
    assert (
        fg_bg_correlation == 0
        or (not class_from_foreground and len(fg_objects) == 10)
    )
    return create_obj2d_dataset(
        config,
        transforms,
        fg_objects,
        cifar_dataset,
        class_from_foreground=class_from_foreground,
        fg_bg_correlation=fg_bg_correlation,
        fg_availability=fg_availability,
        normalize=normalize,
    )
