import torch


def _tensor_size(t):
    return t.numel()

def tv_loss(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6):
    """
    x: [B, C, H, W]
    weight (optional): [B, 1, H, W] or [B, H, W], typically alpha in [0,1]
        If provided, TV is computed with per-edge weights so background contributes little.
    """
    B, C, H, W = x.shape

    dx = x[:, :, 1:, :] - x[:, :, :-1, :]   # [B,C,H-1,W]
    dy = x[:, :, :, 1:] - x[:, :, :, :-1]   # [B,C,H,W-1]

    if weight is None:
        # original behavior (unweighted)
        count_h = _tensor_size(dx)
        count_w = _tensor_size(dy)
        h_tv = (dx * dx).sum()
        w_tv = (dy * dy).sum()
        return 2.0 * (h_tv / (count_h + eps) + w_tv / (count_w + eps)) / B

    # normalize weight shape to [B,1,H,W]
    if weight.dim() == 3:
        weight = weight.unsqueeze(1)
    elif weight.dim() == 2:
        # unlikely, but handle [H,W]
        weight = weight.unsqueeze(0).unsqueeze(0)
    assert weight.dim() == 4, f"weight must be [B,1,H,W] or [B,H,W], got {weight.shape}"

    # clamp to [0,1] for safety
    w = weight.clamp(0.0, 1.0)

    # edge weights: use min of the two pixels so background (low alpha) suppresses edges
    wx = torch.minimum(w[:, :, 1:, :], w[:, :, :-1, :])  # [B,1,H-1,W]
    wy = torch.minimum(w[:, :, :, 1:], w[:, :, :, :-1])  # [B,1,H,W-1]

    # apply weights (broadcast over C)
    h_tv = (dx * dx * wx).sum()
    w_tv = (dy * dy * wy).sum()

    # normalize by effective weight mass (avoid divide-by-zero)
    count_h = wx.sum() * C
    count_w = wy.sum() * C

    return 2.0 * (h_tv / (count_h + eps) + w_tv / (count_w + eps)) / B