from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict

from inferno import bnn

if TYPE_CHECKING:
    from lightning import LightningModule

import itertools


def _laplace_hyperparameters(
    lightning_module: LightningModule,
    architecture: str,
    out_size: int,
    num_samples_test: int,
    subset_of_weights: str,
    hessian_structure: str,
    pred_type: str,
    link_approx: str,
    method_prior_precision_optimization: str,
    lr: float,
    momentum: float,
    nesterov: bool,
    weight_decay: float,
    max_epochs: int,
    optimizer: str = "SGD",
) -> Dict[str, Any]:
    """Common hyperparameters for models with a post-hoc Laplace approximation."""

    SUBSET_OF_WEIGHTS = {
        "all": "Full",
        "subnetwork": "Subnet",
        "last_layer": "Last-layer",
    }
    METHOD_PRIOR_PRECISION_OPT = {"gridsearch": "GS", "marglik": "ML"}

    return {
        "model": lightning_module.__class__.__name__,
        "inference_method": f"Laplace ({SUBSET_OF_WEIGHTS[subset_of_weights]}, {METHOD_PRIOR_PRECISION_OPT[method_prior_precision_optimization]})",
        "architecture": architecture,
        "out_size": out_size,
        "num_trainable_parameters": sum(
            p.numel() for p in lightning_module.model.parameters() if p.requires_grad
        ),
        "num_parameters_and_buffers": sum(
            p.numel() for p in lightning_module.model.parameters()
        )
        + sum(b.numel() for b in lightning_module.model.buffers())
        + sum(  # Parameters defining the Hessian (i.e. inverse precision matrix) of the Laplace approximation
            [
                tens.numel()
                for tens in list(
                    itertools.chain(
                        *lightning_module.laplace_approximation.state_dict()["H"]
                    )
                )
            ]
        ),
        "parametrization": (
            lightning_module.model.parametrization.__class__.__name__
            if hasattr(lightning_module.model, "parametrization")
            else bnn.params.Standard.__name__
        ),
        "num_samples_test": num_samples_test,
        "subset_of_weights": subset_of_weights,
        "hessian_structure": hessian_structure,
        "pred_type": pred_type,
        "link_approx": link_approx,
        "method_prior_precision_optimization": method_prior_precision_optimization,
        "optimizer": optimizer,
        "lr": lr,
        "momentum": momentum,
        "nesterov": nesterov,
        "weight_decay": weight_decay,
        "max_epochs": max_epochs,
    }
