from __future__ import annotations

import torch


def l2_weight_decay(model: torch.nn.Module, weight: float) -> torch.Tensor:
    reg = torch.tensor(0.0, device=next(model.parameters()).device)
    if weight <= 0:
        return reg
    for p in model.parameters():
        reg = reg + (p ** 2).sum()
    return weight * reg


def clip_grad_norm(module: torch.nn.Module, max_norm: float):
    if max_norm and max_norm > 0:
        torch.nn.utils.clip_grad_norm_(module.parameters(), max_norm)

