"""Configs for STL->CIFAR experiments."""

import itertools


def get_weighting_config_class_dirichlet(alpha, seed):
    return {
        "name": "class_dirichlet",
        "kwargs": {"alpha": alpha, "seed": seed},
    }


def get_dataset_config_cifar_stl_balance_target():
    return {
        "name": "CIFAR_STL",
        "val_fraction": 0.15,
        "mods": [],
        "source": {
            "index": 1,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
        "target": {
            "index": 0,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
    }


def get_dataset_config_cifar_stl_dirichlet_target_imbalance(alpha, seed=None):
    return {
        "name": "CIFAR_STL",
        "val_fraction": 0.15,
        "mods": [],
        "source": {
            "index": 1,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
        "target": {
            "index": 0,
            "weighting": get_weighting_config_class_dirichlet(alpha, seed=seed),
            "subsample": True,
        },
    }


def get_algorithm_config(
    algorithm, extra_hparams=None, extra_discriminator_hparams=None
):
    # Common configs of all algorithms
    config = {
        "name": algorithm,
        "hparams": {
            "da_network": {
                "feature_extractor": {
                    "name": "DeepCNN",
                    "hparams": {
                        "num_features": 192,
                        "gaussian_noise": 1.0,
                    },
                },
                "classifier": {
                    "name": "LogLossClassifier",
                    "hparams": {
                        "num_hidden": None,
                    },
                },
            },
            "discriminator": {
                "hparams": {
                    "num_hidden": 512,
                    "depth": 3,
                    "spectral": False,
                    "history_size": 0,
                }
            },
            "ema_momentum": 0.998,
            "fx_opt": {
                "name": "Adam",
                "kwargs": {
                    "lr": 1e-3,
                    "weight_decay": 0.0,
                    "amsgrad": False,
                    "betas": (0.5, 0.999),
                },
            },
            "fx_lr_decay_start": None,
            "fx_lr_decay_steps": None,
            "fx_lr_decay_factor": None,
            "cls_opt": {
                "name": "Adam",
                "kwargs": {
                    "lr": 1e-3,
                    "weight_decay": 0.0,
                    "amsgrad": False,
                    "betas": (0.5, 0.999),
                },
            },
            "cls_weight": 1.0,
            "use_random": False,
            "num_steps": 40000,
            "cls_trg_weight": 0.1,
            "alignment_weight": None,
            "alignment_w_steps": 10000,
            "cls_trg_weight_anneal": True,
            "vat_z": False,
            "use_vat": False,
            "disc_opt": {
                "name": "Adam",
                "kwargs": {
                    "lr": 1e-3,
                    "weight_decay": 0.0,
                    "amsgrad": False,
                    "betas": (0.5, 0.999),
                },
            },
            "disc_steps": 1,
            "l2_weight": 0.0,
            "lr_type": "decay",
        },
    }

    if extra_hparams is not None:
        config["hparams"].update(extra_hparams)

    if extra_discriminator_hparams is not None:
        config["hparams"]["discriminator"]["hparams"].update(
            extra_discriminator_hparams
        )

    return config


def register_experiments(registry):
    # Algorithm configs format:
    # nickname, algorithm_name, extra_hparams, extra_discriminator_hparams
    algorithms = [
        ("source_only", "ERM", None, None),
        ("dann", "DANN_NS", {"alignment_weight": 0.1}, None),
        ("cdan", "CDAN", {"alignment_weight": 0.1}, None),
    ]

    iwdan_extra_hparams = {
        "alignment_weight": 0.1,
        "iw_update_period": 5000,
        "importance_weighting": {"ma": 0.5},
    }

    algorithms.extend(
        [
            ("iwdan", "IWDAN", iwdan_extra_hparams, None),
            ("iwcdan", "IWCDAN", iwdan_extra_hparams, None),
        ]
    )

    algorithms.append((f"sdann_4", "SDANN", {"alignment_weight": 0.1}, {"beta": 4.0}))

    vada_extra_hparams = {
        "alignment_weight": 0.1,
        "cls_vat_src_weight": 0.0,
        "cls_vat_trg_weight": 0.1,
        "vat_xi": 1e-6,
        "vat_radius": 3.5,
    }
    algorithms.append(("vada", "VADA", vada_extra_hparams, None))

    algorithms.extend(
        [
            (
                "asa",
                "DANN_SUPP_ABS",
                {"alignment_weight": 0.1},
                {"history_size": 1000},
            ),
            (
                "casa",
                "CDAN_SUPP_SQ_E",
                {
                    "alignment_weight": 0.5,
                    "cls_vat_src_weight": 0.0,
                    "cls_vat_trg_weight": 0.5,
                    "vat_xi": 1e-6,
                    "vat_radius": 3.5,
                    "cls_trg_weight": 0.1,
                    "use_vat": True,
                },
                {"history_size": 1000},
            ),
        ]
    )

    sentry_hparams = {"src_weight": 1.0, "unsup_weight": 0.1, "ent_weight": 1.0}
    algorithms.extend([("sentry", "SENTRY", sentry_hparams, None)])

    pct_extra_hparams = {
        "nav_t": 1.0,
        "s_par": 0.5,
        "beta": 0.0,
        "lr_gamma": 0.0002,
        "trade_off": 1.0,
    }
    algorithms.append(
        (
            "pct",
            "PCT",
            pct_extra_hparams,
            None,
        )
    )

    for imbalance_alpha in [0.5, 1.0, 3.0, 10.0, None]:

        for seed in range(1, 6):

            stl_cifar_config = get_dataset_config_cifar_stl_dirichlet_target_imbalance(
                imbalance_alpha, seed=seed
            )

            stl_cifar_config_balance = get_dataset_config_cifar_stl_balance_target()

            training_config = {
                "seed": seed,
                "num_steps": 40000,
                "batch_size": 64,
                "num_workers": 4,
                "eval_period": 5000,
                "log_period": 2500,
                "eval_bn_update": True,
                "save_model": False,
                "save_period": 1,
                "disc_eval_period": 15,
            }

            for (
                alg_nickname,
                algorithm_name,
                extra_hparams,
                extra_discriminator_hparms,
            ) in algorithms:
                algorithm_config = get_algorithm_config(
                    algorithm_name, extra_hparams, extra_discriminator_hparms
                )
                if imbalance_alpha is None:
                    experiment_name = (
                        f"stl_cifar/deep_cnn/seed_{seed}/"
                        f"s_alpha_None/{alg_nickname}"
                    )
                    experiment_config = {
                        "dataset": stl_cifar_config_balance,
                        "algorithm": algorithm_config,
                        "training": training_config,
                    }
                    registry.register(experiment_name, experiment_config)
                else:
                    experiment_name = (
                        f"stl_cifar/deep_cnn/seed_{seed}/"
                        f"s_alpha_{int(imbalance_alpha * 10):03d}/{alg_nickname}"
                    )
                    experiment_config = {
                        "dataset": stl_cifar_config,
                        "algorithm": algorithm_config,
                        "training": training_config,
                    }
                    registry.register(experiment_name, experiment_config)
