from __future__ import annotations
import torch
import torch.nn as nn


class MaskedMSELoss(nn.Module):
    """MSE loss that ignores NaN values in targets.

    Computes loss only for valid (non-NaN) target values. This allows
    training on datasets with incomplete labels across multiple properties.

    Args:
        reduction: Reduction method ('mean', 'sum', or 'none')
    """

    def __init__(self, reduction: str = "mean"):
        super().__init__()
        if reduction not in ["mean", "sum", "none"]:
            raise ValueError(f"Unknown reduction: {reduction}")
        self.reduction = reduction

    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """Compute masked MSE loss.

        Args:
            predictions: Model predictions of shape (batch_size, n_properties)
            targets: Target values of shape (batch_size, n_properties), may contain NaNs

        Returns:
            Scalar loss computed only over valid (non-NaN) entries
        """
        # Create mask for valid (non-NaN) targets
        mask = ~torch.isnan(targets)

        # Replace NaN targets with 0 to avoid NaN propagation in computation
        # (these will be masked out anyway)
        targets_safe = torch.where(mask, targets, torch.zeros_like(targets))

        # Compute squared errors
        squared_errors = (predictions - targets_safe) ** 2

        # Apply mask to ignore NaN positions
        masked_errors = squared_errors * mask

        if self.reduction == "mean":
            # Average over all valid entries
            n_valid = mask.sum()
            if n_valid == 0:
                # Edge case: no valid targets in batch
                return torch.tensor(0.0, device=predictions.device)
            return masked_errors.sum() / n_valid

        elif self.reduction == "sum":
            return masked_errors.sum()

        elif self.reduction == "none":
            return masked_errors

        else:
            raise ValueError(f"Unknown reduction: {self.reduction}")
