# diagnostics/jacobian_diag.py
"""
Jacobian and sensitivity diagnostics for the denoiser or encoder networks.

Provides:
 - numeric finite-difference jacobian estimate (cheap for small dims)
 - torch.autograd.functional.jacobian usage (may be heavy)
 - simple norms: Frobenius norm, spectral norm (approx via power iteration)

Use with small batches / low-dim latents to avoid heavy compute.
"""
import torch
import numpy as np

def frobenius_norm(mat):
    return torch.sqrt((mat ** 2).sum())

def approx_spectral_norm(J, n_iter=10):
    """
    Approximate spectral norm using power iteration on matrix-like object J (2D Tensor).
    """
    v = torch.randn(J.shape[1], device=J.device)
    v = v / (v.norm() + 1e-9)
    for _ in range(n_iter):
        w = J.mv(v)
        w_norm = w.norm()
        if w_norm.item() == 0:
            return 0.0
        v = J.t().mv(w)
        v = v / (v.norm() + 1e-9)
    sigma = w.norm().item()
    return sigma

def jacobian_via_autograd(func, x):
    """
    Compute full Jacobian using torch.autograd.functional.jacobian.
    func: callable x -> y (y shape [D_out])
    x: tensor [D_in] (requires_grad=True)
    Returns: Jacobian [D_out, D_in]
    """
    x = x.requires_grad_(True)
    J = torch.autograd.functional.jacobian(lambda z: func(z), x, create_graph=False)
    return J

def finite_diff_jacobian(func, x, eps=1e-3):
    """
    Finite difference Jacobian approximation for small-scale testing.
    x: torch.Tensor [D] (1D)
    func: mapping from x to y torch.Tensor 1D
    Returns: numpy array [D_out, D_in]
    """
    x = x.clone().detach()
    y0 = func(x).detach().cpu().numpy()
    D_out = y0.shape[0]
    D_in = x.numel()
    J = np.zeros((D_out, D_in), dtype=float)
    flat_x = x.view(-1)
    for i in range(D_in):
        xp = flat_x.clone().detach()
        xp[i] += eps
        yp = func(xp.view_as(x)).detach().cpu().numpy()
        J[:, i] = (yp - y0) / eps
    return J
