# from typing import Literal
from dataclasses import dataclass
from typing import Union, cast

import numpy as np
import pytorch_lightning as pl

from utils.data import DataConfig, create_obj2d_dataset, DATASET_PARAMETERS
from datasets.objects_2d import (
    RandomForegroundsConfig,
    RandomBackgroundsConfig,
)


# TviDatasetType = Literal["obj2d"]

@dataclass
class ITDataConfig(DataConfig):
    # dataset: TviDatasetType = "obj2d"
    dataset: str# = "obj2d"
    # Used to determine the structure of the dataset, i.e. which transforms
    # and classes to use
    config_seed: int# = 6802
    # Used to draw samples from the dataset
    n_classes: int# = 50
    # Number of transforms per dataset
    n_transforms: int# = 3

@dataclass
class ExpandedDataConfig:
    transforms_1: list[str]
    transforms_2: list[str]
    features_1: list[int]
    features_2: list[int]

@dataclass
class DatasetCombination:
    data: dict[str, pl.LightningDataModule]
    data_keys: list[str]
    expanded_config: ExpandedDataConfig

FULL_DATASET_NAME = "full"

def create_datasets(
    config: ITDataConfig,
) -> DatasetCombination:
    if config.dataset == "obj2d":
        create_dataset = create_obj2d_dataset
    else:
        raise ValueError(f"Invalid dataset config: ${config.dataset}")
    dataset_params = DATASET_PARAMETERS[config.dataset]
    rng = np.random.default_rng(config.config_seed)

    # Subsample transforms
    n_transforms = config.n_transforms
    all_transforms = list(rng.choice(
        dataset_params.transforms,
        2 * n_transforms,
        replace=False,
    ))
    transforms_1 = all_transforms[:n_transforms]
    transforms_2 = all_transforms[n_transforms:]

    # Subsample foreground objects
    n_classes = config.n_classes
    fg_indices = list(rng.choice(
        dataset_params.class_objects, 2 * n_classes, replace=False
    ))
    g1_indices = fg_indices[:n_classes]
    g2_indices = fg_indices[n_classes:]

    t1_transforms = {fg_idx: transforms_1 for fg_idx in fg_indices}
    assert(set(t1_transforms.keys()) == set(fg_indices))
    mixed_transforms = {
        **{fg_idx: transforms_1 for fg_idx in g1_indices},
        **{fg_idx: transforms_2 for fg_idx in g2_indices},
    }
    assert(set(mixed_transforms.keys()) == set(fg_indices))
    t2_transforms = {fg_idx: transforms_2 for fg_idx in fg_indices}
    assert(set(t2_transforms.keys()) == set(fg_indices))

    data_configs = {
        "t1": (t1_transforms, fg_indices),
        "mixed": (mixed_transforms, fg_indices),
        "t2": (t2_transforms, fg_indices),
        FULL_DATASET_NAME: (all_transforms, fg_indices),
    }
    datasets: dict[str, pl.LightningDataModule] = {
        f"d_{config_id}": create_dataset(config, *data_config)
        for config_id, data_config in data_configs.items()
    }

    return DatasetCombination(
        data=datasets,
        data_keys=list(data_configs.keys()),
        expanded_config=ExpandedDataConfig(
            transforms_1=transforms_1,
            transforms_2=transforms_2,
            features_1=g1_indices,
            features_2=g2_indices,
        ),
    )
