import torch

def get_hidden_l2norm(hidden_all) -> torch.Tensor:
    return torch.norm(hidden_all, p=2, dim=-1).mean()

def get_weights_l2norm(model) -> torch.Tensor:
    weights_l2norms = [
        torch.norm(param, p=2) for name, param in model.named_parameters() if 'weight' in name
    ]
    if len(weights_l2norms) == 0 : raise ValueError("No weights found in the model.")

    return sum(weights_l2norms)

class DiscountLoss(torch.nn.Module):
    def __init__(self, loss_fn, discount_factor, n_future_pred):
        super().__init__()
        self.loss_fn = loss_fn

        self.pow = torch.nn.parameter.Parameter(
            torch.pow(
                discount_factor, torch.arange(n_future_pred)
            ),
            requires_grad=False
        )

    def forward(self, outputs, labels):
        loss = self.loss_fn(outputs, labels)
        # average over batch samples, time steps and features
        if len(loss.shape) > 3:
            loss = torch.mean(loss, dim=(0, 2, 3))

            # return the discounted loss
            return torch.sum(loss*self.pow) / self.pow.sum()
        else:
            # this happens when the validation loss is calculated only
            # on the first step of the future predictions
            return torch.mean(loss)
