import torch

def _complex_from_video(x: torch.Tensor):
    # x: (B,C,T,H,W) with channels [trend, Re, Im]
    re = x[:, 1]  # (B,T,H,W)
    im = x[:, 2]
    return torch.complex(re, im)  # (B,T,H,W)

def _coherence_bcthw(Z: torch.Tensor, m_bthw: torch.Tensor, kind: str = "msc", eps: float = 1e-6):
    """
    Z: complex STFT (B,T,H,W); m_bthw: mask (B,T,H,W) in {0,1}
    Returns: coherence averaged over H (frequency) -> (B,W,W) real.
    """
    B, T, H, W = Z.shape
    m = m_bthw.to(Z.real.dtype)

    # We want per-frequency (H) cross-spectra averaged over time (T).
    # Weight each (t,h,w) by mask.
    # For autospectra: Sxx(h,w) = mean_t |Z|^2
    denom_t = m.sum(dim=1, keepdim=True).clamp_min(1.0)             # (B,1,H,W)
    Z_mask = Z * m                                                  # broadcast (real/imag)
    Sxx = (Z_mask.conj() * Z_mask).real.sum(dim=1) / denom_t.squeeze(1)   # (B,H,W)

    # Cross-spectra across covariates for each frequency h
    # shape tricks: (B,H,W) -> (B,H,W,1) & (B,H,1,W) to outer-product across W
    Zm = Z_mask                                                     # (B,T,H,W)
    denom_t2 = denom_t.squeeze(1)                                   # (B,H,W)
    # Average over time first: Z̄(b,h,w) = sum_t Z_mask / sum_t m
    Zbar = Zm.sum(dim=1) / denom_t2.clamp_min(1.0)                  # (B,H,W) complex

    # Sxy(b,h,w1,w2) = Z̄(w1)^* Z̄(w2)
    Sxy = torch.einsum('bhw,bhj->bhwj', Zbar.conj(), Zbar)          # (B,H,W,W) complex

    # Coherence per frequency
    denom = torch.sqrt(Sxx.unsqueeze(-1) * Sxx.unsqueeze(-2)).clamp_min(eps)  # (B,H,W,W)
    C = Sxy / denom
    if kind == "msc":
        C = (C.abs() ** 2).real.clamp(0.0, 1.0)                     # (B,H,W,W)
    elif kind == "real":
        C = C.real.clamp(-1.0, 1.0)                                  # (B,H,W,W)
    else:
        raise ValueError("kind must be 'msc' or 'real'")

    # Average across frequencies (uniform; optionally weight by bin widths)
    C_mean = C.mean(dim=1)  # (B,W,W)
    return C_mean

def loss_coherence_STFT_bcthw(xp, xt, m_bthw, lambda_cov=5e-3, kind="msc", offdiag_only=True, eps=1e-6):
    """
    xp, xt: (B,C,T,H,W), mask m_bthw: (B,T,H,W) in {0,1}
    Return: (B,) per-sample loss like your function.
    """
    Zp = _complex_from_video(xp)
    Zt = _complex_from_video(xt)
    Cp = _coherence_bcthw(Zp, m_bthw, kind=kind, eps=eps)  # (B,W,W)
    Ct = _coherence_bcthw(Zt, m_bthw, kind=kind, eps=eps)  # (B,W,W)

    diff = Cp - Ct
    if offdiag_only:
        K = diff.size(-1)
        I = torch.eye(K, device=diff.device, dtype=diff.dtype).unsqueeze(0)
        diff = diff * (1 - I)

    # Frobenius over W×W, mean over batch
    return lambda_cov * diff.pow(2).flatten(1).mean(dim=1)



def loss_crosscov_STFT_bcthw_fast(xp, xt, m_bthw, lambda_cov=5e-3, offdiag_only=True, eps=1e-6):
    import torch
    # xp, xt: (B,C,T,H,W), mask: (B,T,H,W) bool
    re_p, im_p = xp[:, 1], xp[:, 2]; re_t, im_t = xt[:, 1], xt[:, 2]
    mag2_p = re_p**2 + im_p**2; 
    mag2_t = re_t**2 + im_t**2
    m = m_bthw.to(mag2_p.dtype)
    den_H = m.sum(dim=2).clamp_min(1.0)
    Sp = (mag2_p * m).sum(dim=2) / den_H  # (B,T,W)
    St = (mag2_t * m).sum(dim=2) / den_H
    Mt = (den_H > 0).to(Sp.dtype)
    n_t = Mt.sum(dim=1).clamp_min(1.0)
    mu_p = (Sp * Mt).sum(dim=1) / n_t; mu_t = (St * Mt).sum(dim=1) / n_t
    Zp = (Sp - mu_p[:, None, :]) * Mt; Zt = (St - mu_t[:, None, :]) * Mt
    Cp = (Zp.transpose(1,2) @ Zp) / (Mt.transpose(1,2) @ Mt).clamp_min(1.0)
    Ct = (Zt.transpose(1,2) @ Zt) / (Mt.transpose(1,2) @ Mt).clamp_min(1.0)
    tr_p = Cp.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True).clamp_min(eps)
    tr_t = Ct.diagonal(dim1=-2, dim2=-1).sum(dim=-1, keepdim=True).clamp_min(eps)
    Cp = Cp / tr_p[..., None]; Ct = Ct / tr_t[..., None]
    diff = Cp - Ct
    if offdiag_only:
        I = torch.eye(diff.size(-1), device=diff.device, dtype=diff.dtype)[None]
        diff = diff * (1 - I)
    return lambda_cov * diff.pow(2).flatten(1).mean(dim=1)