import torch
import torch.nn as nn


class InitialLossWeighting(nn.Module):
    """
    Pre‑normalization loss weighting with one division.

    - On 1st forward: save losses L0, compute weights w_i = 1/(L0_i + eps), store them.
    - Afterwards: total loss = mean_i[ w_i * L_i ] (just a multiply + mean).
    """

    def __init__(self, n_losses: int, eps: float = 1e-6):
        super().__init__()
        if n_losses < 1:
            raise ValueError(f"n_losses must be >= 1, got {n_losses}")
        self.n_losses = n_losses
        self.eps = eps

        # buffers for initial weights and init flag
        self.register_buffer("weights", torch.ones(n_losses))
        self.register_buffer("initialized", torch.zeros(1, dtype=torch.bool))

    def forward(self, losses: torch.Tensor) -> torch.Tensor:
        """
        Args:
            losses: 1D tensor of shape (n_losses,) containing each loss L_i.
        Returns:
            A scalar: mean_i [ w_i * L_i ].
        """
        if losses.dim() != 1 or losses.size(0) != self.n_losses:
            raise ValueError(
                f"Expected a 1D tensor of length {self.n_losses}, got shape {tuple(losses.shape)}"
            )

        if not self.initialized:
            L0 = losses.detach()
            L0.clamp_min_(self.eps)
            self.weights.copy_(L0.reciprocal_())
            self.initialized.fill_(True)

        return (self.weights * losses).mean()
