import torch
from torch import Tensor


def detect_divergence(
    loss_history: list[Tensor], rthold: float = 10.0, window: int = 7
):
    """Detects divergence based on loss spikes over recent history."""
    if len(loss_history) == 0:
        return False

    last = loss_history[-1]
    if torch.isnan(last):
        return True

    if len(loss_history) < window:
        return False

    ref = torch.stack(loss_history[-window:-1])
    # torch.quantile compute the mean of both medians for even number of elements
    med = torch.quantile(ref, q=0.5)
    mad = torch.quantile(torch.abs(ref - med), q=0.5)

    diverge = last > med + rthold * mad
    return diverge.item()
