from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict

from inferno.bnn import params

if TYPE_CHECKING:
    from lightning import LightningModule


def _wsvi_hyperparameters(
    lightning_module: LightningModule,
    architecture: str,
    out_size: int,
    cov: params.FactorizedCovariance,
    num_samples_train: int,
    num_samples_test: int,
    kl_weight: float,
    lr: float,
    momentum: float,
    nesterov: bool,
    weight_decay: float,
    max_epochs: int,
    scale_mean_input_init_weight: float,
    scale_mean_input_init_bias: float,
    scale_mean_input_lr_weight: float,
    scale_mean_input_lr_bias: float,
    scale_mean_input_forward_weight: float,
    scale_mean_input_forward_bias: float,
    scale_mean_output_init_weight: float,
    scale_mean_output_init_bias: float,
    scale_mean_output_lr_weight: float,
    scale_mean_output_lr_bias: float,
    scale_mean_output_forward_weight: float,
    scale_mean_output_forward_bias: float,
    scale_cov_input_init_weight: float,
    scale_cov_input_init_bias: float,
    scale_cov_input_lr_weight: float,
    scale_cov_input_lr_bias: float,
    scale_cov_input_forward_weight: float,
    scale_cov_input_forward_bias: float,
    scale_cov_output_init_weight: float,
    scale_cov_output_init_bias: float,
    scale_cov_output_lr_weight: float,
    scale_cov_output_lr_bias: float,
    scale_cov_output_forward_weight: float,
    scale_cov_output_forward_bias: float,
    optimizer: str = "SGD",
) -> Dict[str, Any]:
    """Common hyperparameters for models trained via weight-space VI."""

    VARIATIONAL_FAMILY = {
        "DiagonalCovariance": "Mean-field",
        "KroneckerCovariance": "Kronecker",
        "LowRankCovariance": "Low-rank",
    }

    return {
        "model": lightning_module.__class__.__name__,
        "inference_method": f"Weight-space VI ({VARIATIONAL_FAMILY[cov.__class__.__name__]})",
        "architecture": architecture,
        "out_size": out_size,
        "num_trainable_parameters": sum(
            p.numel() for p in lightning_module.model.parameters() if p.requires_grad
        ),
        "num_parameters_and_buffers": sum(
            p.numel() for p in lightning_module.model.parameters()
        )
        + sum(b.numel() for b in lightning_module.model.buffers()),
        "cov": cov.__class__.__name__,
        "parametrization": lightning_module.model.parametrization.__class__.__name__,
        "num_samples_train": num_samples_train,
        "num_samples_test": num_samples_test,
        "kl_weight": kl_weight,
        "optimizer": optimizer,
        "lr": lr,
        "momentum": momentum,
        "nesterov": nesterov,
        "weight_decay": weight_decay,
        "max_epochs": max_epochs,
        "scale_mean_input_init_weight": scale_mean_input_init_weight,
        "scale_mean_input_init_bias": scale_mean_input_init_bias,
        "scale_mean_input_lr_weight": scale_mean_input_lr_weight,
        "scale_mean_input_lr_bias": scale_mean_input_lr_bias,
        "scale_mean_input_forward_weight": scale_mean_input_forward_weight,
        "scale_mean_input_forward_bias": scale_mean_input_forward_bias,
        "scale_mean_output_init_weight": scale_mean_output_init_weight,
        "scale_mean_output_init_bias": scale_mean_output_init_bias,
        "scale_mean_output_lr_weight": scale_mean_output_lr_weight,
        "scale_mean_output_lr_bias": scale_mean_output_lr_bias,
        "scale_mean_output_forward_weight": scale_mean_output_forward_weight,
        "scale_mean_output_forward_bias": scale_mean_output_forward_bias,
        "scale_cov_input_init_weight": scale_cov_input_init_weight,
        "scale_cov_input_init_bias": scale_cov_input_init_bias,
        "scale_cov_input_lr_weight": scale_cov_input_lr_weight,
        "scale_cov_input_lr_bias": scale_cov_input_lr_bias,
        "scale_cov_input_forward_weight": scale_cov_input_forward_weight,
        "scale_cov_input_forward_bias": scale_cov_input_forward_bias,
        "scale_cov_output_init_weight": scale_cov_output_init_weight,
        "scale_cov_output_init_bias": scale_cov_output_init_bias,
        "scale_cov_output_lr_weight": scale_cov_output_lr_weight,
        "scale_cov_output_lr_bias": scale_cov_output_lr_bias,
        "scale_cov_output_forward_weight": scale_cov_output_forward_weight,
        "scale_cov_output_forward_bias": scale_cov_output_forward_bias,
    }
