# from typing import Literal
from dataclasses import dataclass
from typing import Optional, Union

import numpy as np
import pytorch_lightning as pl

from utils.data import DataConfig, create_obj2d_dataset, DATASET_PARAMETERS
from datasets.objects_2d import (
    RandomForegroundsConfig,
    RandomBackgroundsConfig,
    TRANSFORMS_WITH_NONE,
    OBJECTS,
)


# TviDatasetType = Literal["obj2d"]

@dataclass
class CTDataConfig(DataConfig):
    # dataset: TviDatasetType = "obj2d"
    dataset: str# = "obj2d"
    # Used to determine the structure of the dataset, i.e. which transforms
    # and classes to use
    config_seed: int# = 934
    # Used to draw samples from the dataset
    n_classes: int# = 30
    # Number of transforms per dataset
    # n_transforms: int# = 3
    transforms: Optional[list[str]] = None
    with_random: bool = False


N_RANDOM_BACKGROUNDS = 800

@dataclass
class CTDatasets:
    data: dict[str, pl.LightningDataModule]
    data_keys: list[str]
    objects: list[int]
    transforms: list[str]

def create_transforms_datasets(
    config: CTDataConfig,
) -> CTDatasets:
    if config.dataset == "obj2d":
        create_dataset = create_obj2d_dataset
    else:
        raise ValueError(f"Invalid dataset config: ${config.dataset}")
    all_transforms = list(TRANSFORMS_WITH_NONE.keys())
    all_objects = list(OBJECTS)
    rng = np.random.default_rng(config.config_seed)

    transform_names = (
        config.transforms if config.transforms is not None
        else all_transforms
    )

    # Subsample foreground objects
    n_classes = config.n_classes
    objects = list(rng.choice(
        all_objects, n_classes, replace=False
    ))
    print("objects:", objects)

    data_configs: dict[str, Union[
        tuple[list[str], list[int]],
        tuple[list[str], RandomForegroundsConfig, RandomBackgroundsConfig]
    ]]
    if "none" in transform_names:
        data_configs = {
            "rw_none": ([], objects),
        }
        if config.with_random:
            data_configs["rand_none"] = ([], *get_random_configs(
                config.n_classes,
                ["none"],
                config.config_seed,
                config.img_size,
            ))
    else:
        data_configs = {}
    for transform_name in transform_names:
        data_configs[f"rw_{transform_name}"] = ([transform_name], objects)
        if config.with_random:
            data_configs[f"rand_{transform_name}"] = (
                [transform_name],
                *get_random_configs(
                    config.n_classes,
                    [transform_name],
                    config.config_seed,
                    config.img_size,
                )
            )
    datasets: dict[str, pl.LightningDataModule] = {
        f"d_{transform_name}": create_dataset(config, *data_config)
        for transform_name, data_config in data_configs.items()
    }

    return CTDatasets(
        data=datasets,
        data_keys=list(data_configs.keys()),
        objects=objects,
        transforms=transform_names,
    )

def get_random_configs(
    n_classes: int,
    transforms: list[str],
    config_seed: int,
    img_size: int,
) -> tuple[RandomForegroundsConfig, RandomBackgroundsConfig]:
    return (
        RandomForegroundsConfig(
            n_classes=n_classes,
            transforms=transforms,
            seed=config_seed + 26,
            img_size=img_size,
        ),
        RandomBackgroundsConfig(
            n_backgrounds=N_RANDOM_BACKGROUNDS,
            seed=config_seed + 140,
            img_size=img_size,
        ),
    )
