import torch


def regularization(alpha, h, K, b, nu):
    
    n_layers = K.shape[-1]
    loss = 0
    for j in range(n_layers - 1):
        loss = loss + alpha * h * (1 / 2 * torch.norm(K[:, :, j + 1]) ** 2 +
                                   1 / 2 * torch.norm(b[:, :, j + 1]) ** 2 +
                                   1 / 2 * torch.norm(nu[:, :, j + 1]) ** 2)
    return loss
