import torch
from collections import namedtuple
from typing import Dict, Any


Metrics = namedtuple(
    "Metrics",
    [
        "loss",
        "full_grad_norm",
        "full_update_norm",
        "full_weight_norm",
        "param_grad_norm",
        "param_update_norm",
        "param_weight_norm",
    ],
)


def total_norm(params: Dict[str, torch.Tensor] | torch.Tensor) -> float:
    if isinstance(params, torch.Tensor):
        return torch.norm(params).item()

    total_norm_sq = 0.0
    for param in params.values():
        if param is not None:
            total_norm_sq += torch.sum(param**2).item()

    return total_norm_sq**0.5


def map_norm(params: Dict[str, torch.Tensor]) -> Dict[str, float]:
    return {name: torch.norm(param).item() for name, param in params.items()}


def compute_metrics(
    loss: float,
    params: Dict[str, torch.Tensor],
    grads: Dict[str, torch.Tensor] | None = None,
    updates: Dict[str, torch.Tensor] | None = None,
) -> Metrics:
    full_weight_norm = total_norm(params)
    param_weight_norm = map_norm(params)

    if grads is not None:
        full_grad_norm = total_norm(grads)
        param_grad_norm = map_norm(grads)
    else:
        full_grad_norm = None
        param_grad_norm = None

    if updates is not None:
        full_update_norm = total_norm(updates)
        param_update_norm = map_norm(updates)
    else:
        full_update_norm = None
        param_update_norm = None

    return Metrics(
        loss=loss,
        full_grad_norm=full_grad_norm,
        full_update_norm=full_update_norm,
        full_weight_norm=full_weight_norm,
        param_grad_norm=param_grad_norm,
        param_update_norm=param_update_norm,
        param_weight_norm=param_weight_norm,
    )
