"""Configs for VisDA17 experiments."""
import itertools


# def get_weighting_config_class_pareto(alpha, reverse, seed):
def get_weighting_config_class_dirichet(alpha, seed):
    return {
        "name": "class_dirichlet",
        "kwargs": {"alpha": alpha, "seed": seed},
    }


def get_dataset_config_visda17_dirichlet_target_imbalance(alpha, seed=None):
    return {
        "name": "VisDA17",
        "val_fraction": 0.15,
        "mods": [],
        "source": {
            "index": 0,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
        "target": {
            "index": 1,
            "weighting": get_weighting_config_class_dirichet(alpha, seed=seed),
            "subsample": True,
        },
    }


def get_dataset_config_visda17_target_balance():
    return {
        "name": "VisDA17",
        "val_fraction": 0.15,
        "mods": [],
        "source": {
            "index": 0,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
        "target": {
            "index": 1,
            "weighting": {
                "name": "class_uniform",
                "kwargs": dict(),
            },
            "subsample": True,
        },
    }


def get_algorithm_config(
    algorithm, extra_hparams=None, extra_discriminator_hparams=None
):
    # Common configs of all algorithms
    config = {
        "name": algorithm,
        "hparams": {
            "da_network": {
                "feature_extractor": {
                    "name": "ResNet",
                    "hparams": {
                        "feature_dim": 256,
                        "pretrained": True,
                        "freeze_bn": False,
                        "resnet18": False,
                        "resnet_dropout": 0.0,
                        "fc_lr_factor": 1.0,
                        "fc_wd_factor": 1.0,
                    },
                },
                "classifier": {
                    "name": "LogLossClassifier",
                    "hparams": {
                        "num_hidden": None,
                        "special_init": True,
                    },
                },
            },
            "discriminator": {
                "hparams": {
                    "num_hidden": 1024,
                    "depth": 3,
                    "spectral": False,
                    "history_size": 0,
                }
            },
            "ema_momentum": None,
            "fx_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.001,
                    "momentum": 0.9,
                    "weight_decay": 0.001,
                    "nesterov": True,
                },
            },
            "fx_lr_decay_start": 0,
            "fx_lr_decay_steps": 25000,
            "fx_lr_decay_factor": 0.05,
            "cls_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.01,
                    "momentum": 0.9,
                    "weight_decay": 0.001,
                    "nesterov": True,
                },
            },
            "cls_weight": 1.0,
            "use_random": False,
            "random_dim": 1024,
            "cls_trg_weight": 0.0,
            "alignment_weight": None,
            "alignment_w_steps": 5000,
            "cls_trg_weight_anneal": False,
            "disc_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.005,
                    "momentum": 0.9,
                    "weight_decay": 0.001,
                    "nesterov": True,
                },
            },
            "disc_steps": 1,
            "l2_weight": 0.0,
            "lr_type": "decay",
            "vat_z": False,
            "use_vat": False,
            "num_steps": 25000,
        },
    }

    if extra_hparams is not None:
        config["hparams"].update(extra_hparams)

    if extra_discriminator_hparams is not None:
        config["hparams"]["discriminator"]["hparams"].update(
            extra_discriminator_hparams
        )

    return config


def register_experiments(registry):
    # Algorithm configs format:
    # nickname, algorithm_name, extra_hparams, extra_discriminator_hparams
    algorithms = [
        ("source_only", "ERM", None, None),
        ("dann", "DANN_NS", {"alignment_weight": 0.1}, None),
        ("cdan", "CDAN", {"alignment_weight": 0.1}, None),
    ]

    iwdan_extra_hparams = {
        "alignment_weight": 0.1,
        "iw_update_period": 4000,
        "importance_weighting": {"ma": 0.5},
    }

    algorithms.extend(
        [
            ("iwdan", "IWDAN", iwdan_extra_hparams, None),
            ("iwcdan", "IWCDAN", iwdan_extra_hparams, None),
        ]
    )

    algorithms.append((f"sdann_4", "SDANN", {"alignment_weight": 0.1}, {"beta": 4.0}))

    vada_extra_hparams = {
        "alignment_weight": 0.1,
        "cls_vat_src_weight": 0.0,
        "cls_vat_trg_weight": 0.1,
        "vat_xi": 1e-6,
        "vat_radius": 3.5,
    }
    algorithms.append(("vada", "VADA", vada_extra_hparams, None))

    algorithms.extend(
        [
            (
                "asa",
                "DANN_SUPP_SQ",
                {"alignment_weight": 0.1},
                {"history_size": 1000},
            ),
            (
                "casa",
                "CDAN_SUPP_SQ_E",
                {
                    "alignment_weight": 0.2,
                    "cls_vat_src_weight": 0.0,
                    "cls_vat_trg_weight": 0.1,
                    "vat_xi": 1e-6,
                    "vat_radius": 3.5,
                    "cls_trg_weight": 0.05,
                    "use_vat": False,
                },
                {"history_size": 1000},
            ),
        ]
    )

    sentry_hparams = {
        "src_weight": 1.0,
        "unsup_weight": 0.1,
        "ent_weight": 1.0,
        "committee_size": 3,
    }

    algorithms.extend([("sentry", "SENTRY", sentry_hparams, None)])

    pct_extra_hparams = {
        "nav_t": 1.0,
        "s_par": 0.5,
        "beta": 0.0,
        "lr_gamma": 0.0002,
        "trade_off": 1.0,
    }
    algorithms.append(
        (
            "pct",
            "PCT",
            pct_extra_hparams,
            None,
        )
    )

    for imbalance_alpha in [0.5, 1.0, 3.0, 10.0, None]:

        for seed in range(1, 6):
            dataset_config = get_dataset_config_visda17_dirichlet_target_imbalance(
                imbalance_alpha, seed=seed
            )

            dataset_config_balance = get_dataset_config_visda17_target_balance()

            training_config = {
                "seed": seed,
                "num_steps": 25000,
                "batch_size": 64,
                "num_workers": 4,
                "eval_period": 5000,
                "log_period": 1000,
                "eval_bn_update": True,
                "save_model": False,
                "save_period": 1,
                "disc_eval_period": 4,
            }

            for (
                alg_nickname,
                algorithm_name,
                extra_hparams,
                extra_discriminator_hparams,
            ) in algorithms:
                algorithm_config = get_algorithm_config(
                    algorithm_name, extra_hparams, extra_discriminator_hparams
                )

                if imbalance_alpha is None:
                    experiment_name = (
                        f"visda17/resnet50/seed_{seed}/s_alpha_None/{alg_nickname}"
                    )
                    experiment_config = {
                        "dataset": dataset_config_balance,
                        "algorithm": algorithm_config,
                        "training": training_config,
                    }
                    registry.register(experiment_name, experiment_config)
                else:
                    experiment_name = f"visda17/resnet50/seed_{seed}/s_alpha_{int(imbalance_alpha * 10):03d}/{alg_nickname}"
                    experiment_config = {
                        "dataset": dataset_config,
                        "algorithm": algorithm_config,
                        "training": training_config,
                    }
                    registry.register(experiment_name, experiment_config)
