from typing import Union
from hydra.core.config_store import ConfigStore

from .experiment import (
    TMExperimentConfig,
    TMDataConfig,
    TrainingConfig,
    EXP_NAME,
    ModelConfig,
)

N_CLASSES = 30
IMG_SIZE = 32

N_TRAIN_EPOCHS = 100
N_FINE_TUNE_EPOCHS = 25
CONFIG_SEED_1 = 5939
SAMPLING_SEED_1 = 29402

SEEDS = {
    0: (9825, 507031),
    1: (174, 348779),
    2: (148, 7367),
    3: (1219, 901),
    4: (696, 3787),
    5: (673, 6150),
    6: (469, 4611),
    7: (6995, 1756),
    8: (4360, 264),
    9: (1750, 4345),
}

def register_transforms_mismatch_configs() -> None:
    cs = ConfigStore.instance()

    for config_name, model_type, data_config_function in [
        ("obj2d_resnet-18", "resnet-18", _get_obj2d_rw_l_data_config),
    ]:
        cs.store(
            name=config_name,
            group="tm",
            node=TMExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                ),
                quant_mismatch=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                # qual_mismatch=TrainingConfig(
                #     max_epochs=N_FINE_TUNE_EPOCHS,
                #     save_checkpoints=True,
                # ),
                order_mismatch=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                data=data_config_function(
                    config_seed=CONFIG_SEED_1,
                    sampling_seed=SAMPLING_SEED_1,
                ),
                model=ModelConfig(
                    type=model_type,
                    domain="cifar",
                    num_classes=N_CLASSES,
                )
            )
            # Maybe add the random part here?
        )

def set_tm_seeds(
    config: TMExperimentConfig,
    seed_id: int,
) -> TMExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def _get_data_config_constructor(is_large: bool, has_random_objects: bool):
    if is_large:
        n_training_samples=50000 # 1667 per class
    else:
        n_training_samples=15000 # 500 per class

    def get_obj2d_data_config(
        config_seed: int, sampling_seed: int
    ) -> TMDataConfig:
        return TMDataConfig(
            dataset="obj2d",
            config_seed=config_seed,
            sampling_seed=sampling_seed,
            n_classes=N_CLASSES,
            img_size=IMG_SIZE,
            n_training_samples=n_training_samples, # 1667 per class
            n_val_samples=10000,
            n_test_samples=10000,
            batch_size=256,
            random=has_random_objects,
        )
    return get_obj2d_data_config

_get_obj2d_rw_l_data_config = _get_data_config_constructor(True, False)
_get_obj2d_rand_l_data_config = _get_data_config_constructor(True, True)
_get_obj2d_rw_m_data_config = _get_data_config_constructor(False, False)
_get_obj2d_rand_m_data_config = _get_data_config_constructor(False, True)
