from __future__ import annotations

import torch


def pearson_corr_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Negative Pearson correlation loss.

    pred/target: (batch, num_masks)

    Why this loss:
    - the ablation outputs are only meaningful up to an affine transform per example
      (scale/shift), so correlation is the right invariant objective.
    """
    pred = pred - pred.mean(dim=1, keepdim=True)
    target = target - target.mean(dim=1, keepdim=True)

    cov = (pred * target).mean(dim=1)
    pred_std = pred.pow(2).mean(dim=1).sqrt()
    target_std = target.pow(2).mean(dim=1).sqrt()
    corr = cov / (pred_std * target_std + 1e-8)
    return (-corr).mean()
