# from typing import Literal
from dataclasses import dataclass
from typing import Union, cast

import numpy as np
import pytorch_lightning as pl

from utils.data import DataConfig, create_obj2d_dataset, DATASET_PARAMETERS
from datasets.objects_2d import (
    RandomForegroundsConfig,
    RandomBackgroundsConfig,
)


# TviDatasetType = Literal["obj2d"]

@dataclass
class TMDataConfig(DataConfig):
    # dataset: TviDatasetType = "obj2d"
    dataset: str
    # Used to determine the structure of the dataset, i.e. which transforms
    # and classes to use
    config_seed: int
    n_classes: int
    random: bool = False

@dataclass
class ExpandedDataConfig:
    all_transforms: list[str]
    # mismatch_transforms: list[str]
    features: dict[str, Union[
        Union[RandomForegroundsConfig, RandomBackgroundsConfig],
        Union[list[int], None],
    ]]

@dataclass
class TMData:
    data: dict[str, pl.LightningDataModule]
    # mismatch_data: dict[str, pl.LightningDataModule]
    shuffle_data: dict[str, pl.LightningDataModule]
    config_names: list[str]
    expanded_config: ExpandedDataConfig


N_MAX_TRANSFORMS = 8

def create_tm_dataset(
    config: TMDataConfig,
) -> TMData:
    if config.dataset == "obj2d":
        create_dataset = create_obj2d_dataset
    else:
        raise ValueError(f"Invalid dataset config: ${config.dataset}")
    dataset_params = DATASET_PARAMETERS[config.dataset]

    rng = np.random.default_rng(config.config_seed)
    # Subsample transforms
    n_all_transforms = N_MAX_TRANSFORMS
    all_transforms = list(rng.choice(
        dataset_params.transforms,
        n_all_transforms,
        replace=False,
    ))

    # Subsample foreground objects
    n_classes = config.n_classes
    fg_indices = list(rng.choice(
        dataset_params.class_objects, n_classes, replace=False
    ))
    if config.random:
        raise NotImplementedError()
        features_config = _get_random_configs(config, [])
    else:
        # features_config_1 = _get_random_configs(config, 69)
        features_config = {
            "foregrounds": fg_indices[:n_classes],
            "backgrounds": None,
        }

    # Create datasets with different numbers of transformations
    data_configs = {
        str(n_transforms): {
            "transforms": all_transforms[:n_transforms],
            **features_config
        }
        for n_transforms in range(1, n_all_transforms + 1)
    }
    datasets: dict[str, pl.LightningDataModule] = {
        f"d_{config_name}": create_dataset(config, **data_config)
        for config_name, data_config in data_configs.items()
    }

    # mismatch_transforms = list(rng.choice(
    #     list(set(dataset_params.transforms) - set(all_transforms)),
    #     n_all_transforms,
    #     replace=False,
    # ))
    # mismatch_configs = {
    #     str(n_transforms): (
    #         # Keep the transformations the longest that are part of the
    #         # smallest training subsets
    #         all_transforms[:n_all_transforms - n_transforms]
    #         + mismatch_transforms[:n_transforms],
    #         *features_config,
    #     )
    #     for n_transforms in range(1, n_all_transforms + 1)
    # }
    # mismatch_datasets: dict[str, pl.LightningDataModule] = {
    #     f"d_{config_name}_mm": create_dataset(config, *data_config)
    #     for config_name, data_config in mismatch_configs.items()
    # }

    shuffle_configs = {}
    for n_transforms in range(2, n_all_transforms + 1):
        subset_transforms = all_transforms[:n_transforms]
        rng.shuffle(subset_transforms)
        shuffle_configs[str(n_transforms)] = {
            "transforms": subset_transforms,
            **features_config
        }
    shuffle_datasets: dict[str, pl.LightningDataModule] = {
        f"d_sh_{config_name}": create_dataset(config, **data_config)
        for config_name, data_config in shuffle_configs.items()
    }

    return TMData(
        data=datasets,
        # mismatch_data=mismatch_datasets,
        shuffle_data=shuffle_datasets,
        config_names=list(data_configs.keys()),
        expanded_config=ExpandedDataConfig(
            all_transforms=all_transforms,
            # mismatch_transforms=mismatch_transforms,
            features=features_config,
        ),
    )

N_RANDOM_BACKGROUNDS = 1000

def _get_random_configs(
    config: TMDataConfig,
    transforms: list[str],
) -> dict[str, Union[RandomForegroundsConfig, RandomBackgroundsConfig]]:
    # raise NotImplementedError()
    return {
        "foregrounds": RandomForegroundsConfig(
            n_classes=config.n_classes,
            transforms=transforms,
            seed=config.config_seed + 302190,
            img_size=config.img_size,
        ),
        "backgrounds": RandomBackgroundsConfig(
            n_backgrounds=N_RANDOM_BACKGROUNDS,
            seed=config.config_seed + 4920,
            img_size=config.img_size,
        ),
    }
