from typing import Callable, Union
from hydra.core.config_store import ConfigStore

from .experiment import (
    TvIExperimentConfig,
    TvIDataConfig,
    TrainingConfig,
    EXP_NAME,
    ModelConfig,
)

N_CLASSES = 30
# N_TRANSFORMS = 3
IMG_SIZE = 32

N_TRAIN_EPOCHS = 100
N_FINE_TUNE_EPOCHS = 25
CONFIG_SEED_1 = 5939
SAMPLING_SEED_1 = 29402

SEEDS = {
    0: (59832, 2032),
    1: (113, 50202),
    2: (5902, 93),
    3: (26906, 8835),
    4: (78, 70887),
    5: (494, 237342),
    6: (3405, 21562),
    7: (746, 107634),
    8: (36684, 5295),
    9: (605, 34),
}

def register_transforms_vs_identity_configs() -> None:
    cs = ConfigStore.instance()

    for config_name, model_type, data_config_function in [
        ("obj2d-l-1t_resnet-18", "resnet-18", _get_obj2d_rw_l_1t_data_config),
        # ("obj2d-m-1t_resnet-18", "resnet-18", _get_obj2d_rw_m_data_config),
        ("obj2d-l-2t_resnet-18", "resnet-18", _get_obj2d_rw_l_2t_data_config),
        ("obj2d-l-3t_resnet-18", "resnet-18", _get_obj2d_rw_l_3t_data_config),
        ("obj2d-l-3t_vgg-11", "vgg-11", _get_obj2d_rw_l_3t_data_config),
        # ("obj2d-m-vgg-11", "vgg-11", _get_obj2d_rw_m_data_config),
        ("obj2d-l-3t_densenet-121", "densenet-121", _get_obj2d_rw_l_3t_data_config),
        # ("obj2d-m-densenet-121", "densenet-121", _get_obj2d_rw_m_data_config),
        ("obj2d-rand-l-3t_resnet-18", "resnet-18", _get_obj2d_rand_l_3t_data_config),
        # ("obj2d-rand-m_resnet-18", "resnet-18", _get_obj2d_rand_m_data_config),
    ]:
        cs.store(
            name=config_name,
            group="tvi",
            node=TvIExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                ),
                fine_tuning=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                data=data_config_function(
                    config_seed=CONFIG_SEED_1,
                    sampling_seed=SAMPLING_SEED_1,
                ),
                model=ModelConfig(
                    type=model_type,
                    domain="cifar",
                    num_classes=N_CLASSES,
                )
            )
            # Maybe add the random part here?
        )

def set_tvi_seeds(
    config: TvIExperimentConfig,
    seed_id: int,
) -> TvIExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def _get_data_config_constructor(
    is_large: bool, has_random_objects: bool, n_transforms: int,
) -> Callable:
    if is_large:
        n_training_samples=50000 # 1667 per class
    else:
        n_training_samples=15000 # 500 per class

    def get_obj2d_data_config(
        config_seed: int, sampling_seed: int
    ) -> TvIDataConfig:
        return TvIDataConfig(
            dataset="obj2d",
            config_seed=config_seed,
            sampling_seed=sampling_seed,
            n_classes=N_CLASSES,
            n_transforms=n_transforms,
            img_size=IMG_SIZE,
            n_training_samples=n_training_samples, # 1667 per class
            n_val_samples=10000,
            n_test_samples=10000,
            batch_size=256,
            random_1=False,
            random_2=has_random_objects,
        )
    return get_obj2d_data_config

_get_obj2d_rw_l_1t_data_config = _get_data_config_constructor(True, False, 1)
_get_obj2d_rw_l_2t_data_config = _get_data_config_constructor(True, False, 2)
_get_obj2d_rw_l_3t_data_config = _get_data_config_constructor(True, False, 3)
_get_obj2d_rand_l_1t_data_config = _get_data_config_constructor(True, True, 1)
_get_obj2d_rand_l_2t_data_config = _get_data_config_constructor(True, True, 2)
_get_obj2d_rand_l_3t_data_config = _get_data_config_constructor(True, True, 3)
# _get_obj2d_rw_m_data_config = _get_data_config_constructor(False, False)
# _get_obj2d_rand_m_data_config = _get_data_config_constructor(False, True)
