from hydra.core.config_store import ConfigStore

from .experiment import (
    IFEExperimentConfig,
    IFEDataConfig,
    TrainingConfig,
    ModelConfig,
    EXP_NAME,
)


N_CLASSES = 10
N_TRAIN_EPOCHS = 200
N_FINE_TUNE_EPOCHS = 25
CONFIG_SEED_1 = 45
SAMPLE_SEED_1 = 2542763
DEFAULT_TRANSFORMS = ["none", "move", "scale", "col_jitter"]

SEEDS = {
    0: (111, 694),
    1: (222, 8320),
    2: (9600, 2817),
    3: (7894, 303),
    4: (7754, 5486),
    5: (3058, 9525),
    6: (5351, 1807),
    7: (5159, 6003),
    8: (3877, 7693),
    9: (7334, 2649),
}

def _get_obj2d_data_config(
    config_seed: int, sample_seed: int
) -> IFEDataConfig:
    return IFEDataConfig(
        config_seed=config_seed,
        sampling_seed=sample_seed,
        n_classes=N_CLASSES,
        img_size=32,
        n_training_samples=50000, # 1667 per class
        n_val_samples=10000,
        n_test_samples=10000,
        batch_size=256,
    )

def set_ife_seeds(
    config: IFEExperimentConfig,
    seed_id: int,
) -> IFEExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def register_ife_configs() -> None:
    cs = ConfigStore.instance()

    for config_name, model_type, data_config_function in [
        ("obj2d_resnet-18", "resnet-18", _get_obj2d_data_config),
        ("obj2d_vgg-11", "vgg-11", _get_obj2d_data_config),
        ("obj2d_densenet-121", "densenet-121", _get_obj2d_data_config),
    ]:
        cs.store(
            name=config_name,
            group="ife",
            node=IFEExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                ),
                fine_tuning=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                data=data_config_function(
                    config_seed=CONFIG_SEED_1,
                    sample_seed=SAMPLE_SEED_1,
                ),
                model=ModelConfig(
                    type=model_type,
                    domain="cifar",
                    num_classes=N_CLASSES,
                ),
                transforms=DEFAULT_TRANSFORMS,
            )
        )
