# from typing import Literal
from dataclasses import dataclass
from typing import Union, cast

import numpy as np
import pytorch_lightning as pl

from utils.data import DataConfig, create_obj2d_dataset, DATASET_PARAMETERS
from datasets.objects_2d import (
    RandomForegroundsConfig,
    RandomBackgroundsConfig,
)


# TviDatasetType = Literal["obj2d"]

@dataclass
class TvIDataConfig(DataConfig):
    # dataset: TviDatasetType = "obj2d"
    dataset: str# = "obj2d"
    # Used to determine the structure of the dataset, i.e. which transforms
    # and classes to use
    config_seed: int# = 6802
    # Used to draw samples from the dataset
    n_classes: int# = 50
    # Number of transforms per dataset
    n_transforms: int# = 3
    random_1: bool = False
    random_2: bool = False

@dataclass
class ExpandedDataConfig:
    transforms_1: list[str]
    transforms_2: list[str]
    features_1: Union[
        tuple[RandomForegroundsConfig, RandomBackgroundsConfig],
        tuple[list[int], None],
    ]
    features_2: Union[
        tuple[RandomForegroundsConfig, RandomBackgroundsConfig],
        tuple[list[int], None],
    ]

@dataclass
class DatasetCombination:
    data: dict[str, pl.LightningDataModule]
    data_keys: list[tuple[int, int, bool]]
    expanded_config: ExpandedDataConfig


def get_item_name(
    is_data: bool, transforms_idx: int, objects_idx: int, is_random: bool
) -> str:
    prefix = "d" if is_data else "m"
    type_postfix = "rand" if is_random else "rw"
    return f"{prefix}_t{transforms_idx}_i{objects_idx}_{type_postfix}"

def create_dataset_combination(
    config: TvIDataConfig,
) -> DatasetCombination:
    if config.dataset == "obj2d":
        create_dataset = create_obj2d_dataset
    else:
        raise ValueError(f"Invalid dataset config: ${config.dataset}")
    dataset_params = DATASET_PARAMETERS[config.dataset]
    rng = np.random.default_rng(config.config_seed)

    # Subsample transforms
    n_transforms = config.n_transforms
    transform_names = list(rng.choice(
        dataset_params.transforms,
        2 * n_transforms,
        replace=False,
    ))
    transforms_1 = transform_names[:n_transforms]
    transforms_2 = transform_names[n_transforms:]

    # Subsample foreground objects
    n_classes = config.n_classes
    fg_indices = list(rng.choice(
        dataset_params.class_objects, 2 * n_classes, replace=False
    ))
    if config.random_1:
        features_config_1 = _get_random_configs(config, 69)
    else:
        # features_config_1 = _get_random_configs(config, 69)
        features_config_1 = (fg_indices[:n_classes], None)
    if config.random_2:
        features_config_2 = _get_random_configs(config, 20592)
    else:
        # features_config_2 = _get_random_configs(config, 20592)
        features_config_2 = (fg_indices[n_classes:], None)

    # print("transforms 1:", transforms_1)
    # print("transforms 2:", transforms_2)
    # print("indices 1:", fg_indices_1)
    # print("indices 2:", fg_indices_2)

    data_configs = {
        (1, 1, config.random_1): (transforms_1, *features_config_1),
        (1, 2, config.random_2): (transforms_1, *features_config_2),
        (2, 1, config.random_1): (transforms_2, *features_config_1),
        (2, 2, config.random_2): (transforms_2, *features_config_2),
    }
    datasets: dict[str, pl.LightningDataModule] = {
        get_item_name(True, *data_id): create_dataset(config, *data_config)
        for data_id, data_config in data_configs.items()
    }

    return DatasetCombination(
        data=datasets,
        data_keys=list(data_configs.keys()),
        expanded_config=ExpandedDataConfig(
            transforms_1=transforms_1,
            transforms_2=transforms_2,
            features_1=features_config_1,
            features_2=features_config_2,
        ),
    )

N_RANDOM_BACKGROUNDS = 1000

def _get_random_configs(
    config: TvIDataConfig, seed_offset: int,
) -> tuple[RandomForegroundsConfig, RandomBackgroundsConfig]:
    return (
        RandomForegroundsConfig(
            n_classes=config.n_classes,
            seed=config.config_seed + 294 + seed_offset,
            img_size=config.img_size,
        ),
        RandomBackgroundsConfig(
            n_backgrounds=N_RANDOM_BACKGROUNDS,
            seed=config.config_seed + 2012 + seed_offset,
            img_size=config.img_size,
        ),
    )
