import torch
import math


def ivon_regularization(optimizer: torch.optim.Optimizer):
    lambda_ess = optimizer.param_groups[0]["ess"]
    weight_decay = optimizer.param_groups[0]["weight_decay"]

    kl_divergence = 0.0
    for group in optimizer.param_groups:
        h = group["hess"]
        hess = h.view(-1)

        weights = []
        for p in group["params"]:
            if p is not None:
                weights.append(p.data.view(-1))
        weights = torch.cat(weights)

        variance_q = 1.0 / (lambda_ess * (hess + weight_decay))
        variance_p = 1.0 / (lambda_ess * weight_decay)

        mu_q = weights
        sigma_q = torch.sqrt(variance_q)
        sigma_p = math.sqrt(variance_p)

        kl = (
            torch.log(sigma_p / sigma_q)
            + (sigma_q**2 + mu_q**2) / (2 * sigma_p**2)
            - 0.5
        )
        kl_divergence += kl.sum().item()

    return kl_divergence


def adamw_regularization(model: torch.nn.Module, weight_decay: float = 1.0):
    l2_reg = sum(param.pow(2.0).sum() for param in model.parameters())
    l2_loss = 0.5 * weight_decay * l2_reg
    return l2_loss
