import torch
import torch.nn as nn


class UncertaintyLossWeighting(nn.Module):
    """
    Homoscedastic uncertainty loss weighting (Kendall et al., 2018).

    At init we create one learnable log(sigma) per loss, initialized to 0 (i.e. sigma=1).
    In forward we take a 1D tensor of losses and return the combined scalar loss:

        L = sum_i [ 1/(2 sigma_i^2) * L_i + log sigma_i ].
    """

    def __init__(self, n_losses: int):
        super().__init__()
        if n_losses < 1:
            raise ValueError(f"n_losses must be >=1, got {n_losses}")
        self.log_sigma = nn.Parameter(torch.zeros(n_losses))
        self.initialized = False

    def forward(self, losses: torch.Tensor) -> torch.Tensor:
        """
        Args:
            losses: 1D tensor of shape (n_losses,) containing each individual loss L_i.
        Returns:
            A scalar tensor: sum_i [ 1/(2 sigma_i^2) * L_i + log sigma_i ].
        """
        if losses.dim() != 1 or losses.size(0) != self.log_sigma.size(0):
            raise ValueError(
                f"Expected losses to be a 1D tensor of length {self.log_sigma.size(0)}, "
                f"but got shape {tuple(losses.shape)}"
            )
        if torch.any(losses <= 0.0):
            raise ValueError("All losses must be positive.")

        # Initialize log_sigma at first forward pass
        if not self.initialized:
            with torch.no_grad():
                self.log_sigma.copy_(0.5 * torch.log(losses))
            self.initialized = True

        precision = torch.exp(-2.0 * self.log_sigma) * 0.5
        weighted_losses = precision * losses + self.log_sigma
        return weighted_losses.sum()
