from hydra.core.config_store import ConfigStore

from .experiment import (
    ITExperimentConfig,
    ITDataConfig,
    TrainingConfig,
    EXP_NAME,
    ModelConfig,
)

N_CLASSES_PER_GROUP = 15 # per group, so total twice that many
N_TRANSFORMS = 3
IMG_SIZE = 32

N_TRAIN_EPOCHS = 100
N_FINE_TUNE_EPOCHS = 25
CONFIG_SEED_1 = 5939
SAMPLING_SEED_1 = 29402

SEEDS = {
    0: (9607, 4932),
    1: (7126, 66245),
    2: (805, 637222),
    3: (3823, 577),
    4: (5087, 8092),
    5: (6917, 5372),
    6: (3279, 5),
    7: (6982, 2509),
    8: (735, 2341),
    9: (2649, 66881),
}

def register_ti_configs() -> None:
    cs = ConfigStore.instance()

    for config_name, model_type, data_config_function in [
        ("obj2d_resnet-18", "resnet-18", get_obj2d_data_config),
    ]:
        cs.store(
            name=config_name,
            group="it",
            node=ITExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                ),
                fine_tuning=TrainingConfig(
                    max_epochs=N_FINE_TUNE_EPOCHS,
                    save_checkpoints=True,
                ),
                data=data_config_function(
                    config_seed=CONFIG_SEED_1,
                    sampling_seed=SAMPLING_SEED_1,
                ),
                model=ModelConfig(
                    type=model_type,
                    domain="cifar",
                    # There are two groups, so double the number of classes
                    num_classes=2 * N_CLASSES_PER_GROUP,
                )
            )
            # Maybe add the random part here?
        )

def set_it_seeds(
    config: ITExperimentConfig,
    seed_id: int,
) -> ITExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def get_obj2d_data_config(
    config_seed: int, sampling_seed: int
) -> ITDataConfig:
    return ITDataConfig(
        dataset="obj2d",
        config_seed=config_seed,
        sampling_seed=sampling_seed,
        n_classes=N_CLASSES_PER_GROUP,
        n_transforms=N_TRANSFORMS,
        img_size=IMG_SIZE,
        n_training_samples=50000, # 1667 per class
        n_val_samples=10000,
        n_test_samples=10000,
        batch_size=256,
    )
