from hydra.core.config_store import ConfigStore

from .experiment import (
    CTExperimentConfig,
    CTDataConfig,
    TrainingConfig,
    ModelConfig,
    EXP_NAME,
)


N_CLASSES = 30
N_TRAIN_EPOCHS = 50
N_FINE_TUNE_EPOCHS = 20
CONFIG_SEED_1 = 45
SAMPLE_SEED_1 = 2542763

SEEDS = {
    0: (111, 694),
    1: (222, 8320),
    2: (5703, 2014),
    3: (7542, 6235),
    4: (1929, 1607),
    5: (2850, 848),
    6: (6763, 1982),
    7: (2487, 47297),
    8: (270, 6076),
    9: (1981, 65519),
}

def _get_obj2d_data_config(
    config_seed: int, sample_seed: int
) -> CTDataConfig:
    return CTDataConfig(
        dataset="obj2d",
        config_seed=config_seed,
        sampling_seed=sample_seed,
        n_classes=N_CLASSES,
        img_size=32,
        n_training_samples=50000, # 1667 per class
        n_val_samples=10000,
        n_test_samples=10000,
        batch_size=256,
    )

def set_ct_seeds(
    config: CTExperimentConfig,
    seed_id: int,
) -> CTExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def register_cross_transforms_configs() -> None:
    cs = ConfigStore.instance()

    for config_name, model_type, data_config_function in [
        ("obj2d_resnet-18", "resnet-18", _get_obj2d_data_config),
        ("obj2d_vgg-11", "vgg-11", _get_obj2d_data_config),
        ("obj2d_densenet-121", "densenet-121", _get_obj2d_data_config),
    ]:
        cs.store(
            name=config_name,
            group="ct",
            node=CTExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                ),
                fine_tuning=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                data=data_config_function(
                    config_seed=CONFIG_SEED_1,
                    sample_seed=SAMPLE_SEED_1,
                ),
                model=ModelConfig(
                    type=model_type,
                    domain="cifar",
                    num_classes=N_CLASSES,
                )
            )
        )
