from typing import Optional
from hydra.core.config_store import ConfigStore

from .experiment import (
    RIExperimentConfig,
    RIDataConfig,
    TrainingConfig,
    ModelConfig,
    EXP_NAME,
)
from ..cross_transforms.config import SEEDS


N_CLASSES = 30
N_TRAIN_EPOCHS = 50
N_FINE_TUNE_EPOCHS = 20
CONFIG_SEED_1 = 45
SAMPLING_SEED_1 = 2542763
TRANSFORMS=["none", "move", "scale", "rot", "v_flip", "h_flip", "col_jitter", "blur", "sharpen"]
# TRANSFORMS=["none", "move", "scale", "col_jitter"]
# N_DISTANCE_SAMPLES = 4096
N_DISTANCE_SAMPLES = 4096
# N_DISTANCE_SAMPLES = 1024
N_OBJECT_SAMPLES = 10

def _get_obj2d_data_config(
    config_seed: int, sampling_seed: int
) -> RIDataConfig:
    return RIDataConfig(
        dataset="obj2d",
        config_seed=config_seed,
        sampling_seed=sampling_seed,
        n_classes=N_CLASSES,
        img_size=32,
        n_training_samples=50000, # 1667 per class
        n_val_samples=10000,
        n_test_samples=10000,
        batch_size=256,
        # object_sampling_seed=93,
        # n_object_samples=100,
        transforms=TRANSFORMS,
        with_random=False,
    )

def set_ri_seeds(
    config: RIExperimentConfig,
    seed_id: int,
) -> RIExperimentConfig:
    config_seed, sampling_seed = SEEDS[seed_id]
    config.data.config_seed = config_seed
    config.data.sampling_seed = sampling_seed
    return config

def register_ri_configs() -> None:
    cs = ConfigStore.instance()
    for config_name, model_type, data_config_function, dist_metric in [
        ("obj2d_resnet-18_l2", "resnet-18", _get_obj2d_data_config, "l2"),
        ("obj2d_resnet-18_cka", "resnet-18", _get_obj2d_data_config, "linear_cka"),
        # ("obj2d_vgg-11", "vgg-11", _get_obj2d_data_config),
        # ("obj2d_densenet-121", "densenet-121", _get_obj2d_data_config),
    ]:
        cs.store(
            name=config_name,
            group="ri",
            node=RIExperimentConfig(
                exp_name=[EXP_NAME, config_name],
                training=TrainingConfig(
                    max_epochs=N_TRAIN_EPOCHS,
                    save_checkpoints=True,
                    train=False,
                    eval=False,
                ),
                data=data_config_function(
                    # Will be overwritten
                    config_seed=CONFIG_SEED_1,
                    sampling_seed=SAMPLING_SEED_1,
                ),
                model=ModelConfig(
                    domain="cifar",
                    type=model_type,
                    num_classes=N_CLASSES,
                ),
                eval_invariance="no_prior_res",
                n_distance_samples=N_DISTANCE_SAMPLES,
                n_object_samples=N_OBJECT_SAMPLES,
                dist_metric=dist_metric,
            )
        )
