import torch

def loss_ridge(network, lam_h:float):
    """Apply penalty on hidden layers."""
    return lam_h * sum([torch.sum(conv.weight ** 2) for conv in network.hidden_layers])

def loss_conf(network, lam_c:float):
    """Group penalty on learnable confounders."""
    z = network.conf.squeeze(1)
    return lam_c * torch.sum(torch.sqrt(torch.mean(z ** 2, dim=2)))

def loss_itv(network, lam_v:float):
    """Group sparsity on the first layer of networks."""
    base = torch.sum(torch.norm(network.gc_layer.weight, dim=[0,2]))
    extra = sum(torch.sum(torch.norm(conv.weight, dim=[0,2])) for conv in network.itv_layers)
    return lam_v * (base + extra)

def restore_parameters(model, best_model):
    for params, best_params in zip(model.parameters(), best_model.parameters()):
        params.data = best_params