import itertools

"""Configs for Office-31 experiments."""


def get_dataset_config_office31(src_task, trg_task, subsample_setting):
    return {
        "name": "Office31",
        "mods": [],
        "src_task": src_task,
        "trg_task": trg_task,
        "subsample_setting": subsample_setting,
    }


def get_algorithm_config(
    algorithm, extra_hparams=None, extra_discriminator_hparams=None
):
    # Common configs of all algorithms
    config = {
        "name": algorithm,
        "hparams": {
            "da_network": {
                "feature_extractor": {
                    "name": "ResNet",
                    "hparams": {
                        "feature_dim": 256,
                        "pretrained": True,
                        "freeze_bn": False,
                        "resnet18": False,
                        "resnet_dropout": 0.0,
                        "fc_lr_factor": 10.0,
                        "fc_wd_factor": 1.0,
                        "tlib": True,
                        # "pool_layer": True,
                    },
                },
                "classifier": {
                    "name": "LogLossClassifier",
                    "hparams": {
                        "num_hidden": None,
                        "special_init": True,
                    },
                },
            },
            "discriminator": {
                "hparams": {
                    "num_hidden": 1024,
                    "depth": 3,
                    "spectral": False,
                    "history_size": 0,
                    "relu": True,
                    "dropout": True,
                    "batch_norm": True,
                }
            },
            "ema_momentum": None,
            "fx_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.001,
                    "momentum": 0.9,
                    "weight_decay": 1e-3,
                    "nesterov": True,
                },
            },
            "cls_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.01,
                    "momentum": 0.9,
                    "weight_decay": 1e-3,
                    "nesterov": True,
                },
            },
            "disc_opt": {
                "name": "SGD",
                "kwargs": {
                    "lr": 0.01,
                    "momentum": 0.9,
                    "weight_decay": 1e-3,
                    "nesterov": True,
                },
            },
            "cls_weight": 1.0,
            "use_random": False,
            # "random_dim": 1024,
            # "z_dim": 256,
            "cls_trg_weight": 0.0,
            "alignment_weight": 1.0,
            "alignment_w_steps": 5000,
            "cdan_coeff": True,
            "cls_trg_weight_anneal": True,
            "disc_steps": 1,
            "l2_weight": 0.0,
            "lr_type": "inv",
            "vat_z": False,
            "use_vat": False,
            "num_steps": 10000,
            # "num_steps": 1,
        },
    }

    if extra_hparams is not None:
        config["hparams"].update(extra_hparams)

    if extra_discriminator_hparams is not None:
        config["hparams"]["discriminator"]["hparams"].update(
            extra_discriminator_hparams
        )

    return config


def register_experiments(registry):
    # Algorithm configs format:
    # nickname, algorithm_name, extra_hparams, extra_discriminator_hparams
    algorithms = [
        ("source_only", "ERM", None, None),
        ("dann", "DANN_NS", {"alignment_weight": 1.0}, None),
        ("cdan", "CDAN", {"alignment_weight": 1.0}, None),
    ]

    # iwdan_extra_hparams = {
    #     "alignment_weight": 1.0,
    #     "iw_update_period": 1000,
    #     "importance_weighting": {"ma": 0.5},
    # }
    # algorithms.extend(
    #     [
    #         ("iwdan", "IWDAN", iwdan_extra_hparams, None),
    #         ("iwcdan", "IWCDAN", iwdan_extra_hparams, None),
    #     ]
    # )

    algorithms.append((f"sdann_4", "SDANN", {"alignment_weight": 1.0}, {"beta": 4.0}))

    vada_extra_hparams = {
        "alignment_weight": 1.0,
        "cls_vat_src_weight": 0.0,
        "cls_vat_trg_weight": 0.1,
        "vat_xi": 1e-6,
        "vat_radius": 3.5,
    }
    algorithms.append(("vada", "VADA", vada_extra_hparams, None))

    pct_extra_hparams = {
        "nav_t": 1.0,
        "s_par": 0.5,  # [0, 0.5, 1]
        "beta": 0.0,  # 0.0001 OfficeHome
        "eps": 1e-6,
        "lr_gamma": 0.0002,
        "trade_off": 1.0,
    }

    sentry_hparams = {
        "src_weight": 1.0,
        "unsup_weight": 0.1,
        "ent_weight": 1.0,
        "committee_size": 3,
    }
    algorithms.extend([("sentry", "SENTRY", sentry_hparams, None)])

    algorithms.extend(
        [
            (
                "asa",
                "DANN_SUPP_ABS",
                {"alignment_weight": 1.0},
                {"history_size": 1000},
            ),
            (
                "casa",
                "CDAN_SUPP_ABS_E",
                {
                    "alignment_weight": 1.5,
                    "cls_trg_weight": 0.1,
                    "cls_vat_src_weight": 0.0,
                    "cls_vat_trg_weight": 0.2,
                    "vat_xi": 1e-6,
                    "vat_radius": 15.,
                    "use_vat": True,
                },
                {"history_size": 500},
            ),
            ("pct", "PCT", pct_extra_hparams, None),
        ]
    )

    seeds = list(range(1, 6))
    tasks = ["A", "D", "W"]
    subsample_settings = ["normal", "sub_s", "sub_t"]
    src_trg_pairs = [(i, j) for i, j in itertools.product(tasks, tasks) if i != j]
    assert len(src_trg_pairs) == 6

    for (seed, dataset_pair, subsample_setting, algorithm) in itertools.product(
        seeds, src_trg_pairs, subsample_settings, algorithms
    ):
        dataset_config = get_dataset_config_office31(*dataset_pair, subsample_setting)
        training_config = {
            "seed": seed,
            "num_steps": 10000,
            "batch_size": 36,
            "num_workers": 4,
            "eval_period": 500,
            "log_period": 200,
            "eval_bn_update": True,
            "save_model": False,
            "save_period": 1,
            "disc_eval_period": 4,
        }
        (
            alg_nickname,
            algorithm_name,
            extra_hparams,
            extra_discriminator_hparams,
        ) = algorithm

        algorithm_config = get_algorithm_config(
            algorithm_name, extra_hparams, extra_discriminator_hparams
        )

        experiment_name = f"office31/{subsample_setting}/{dataset_pair[0]}2{dataset_pair[1]}/seed_{seed}/{alg_nickname}"
        experiment_config = {
            "dataset": dataset_config,
            "algorithm": algorithm_config,
            "training": training_config,
        }

        registry.register(experiment_name, experiment_config)
