# diagnostics/jacobian_diag.py
"""
Jacobian and sensitivity diagnostics for neural models.

Provides:
  - jacobian_autograd(func, x): compute full Jacobian using torch.autograd.functional.jacobian
  - finite_difference_jacobian(func, x, eps): finite-difference Jacobian (cheap for small dims)
  - frobenius_norm, approx_spectral_norm: matrix norms helpers
  - layerwise_jacobian_stats(model, sample_input): compute jacobian stats for model outputs wrt inputs
  - power_iteration_spectral_norm(A, n_iter=20): power-iteration for spectral norm of 2D tensor

Notes:
  - These operations can be expensive. Use on small batch sizes or low-dim inputs.
  - For very large models use Hutchinson-based trace / implicit methods (not implemented here).
"""
from typing import Callable, Tuple
import torch
import numpy as np

def jacobian_autograd(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor) -> torch.Tensor:
    """
    Compute full Jacobian J = dy/dx using torch.autograd.functional.jacobian.
    func: x -> y (y is 1D or tensor). x should be a 1D or batch tensor.
    Returns J as a tensor shape [out_dim, in_dim].
    """
    x = x.detach().requires_grad_(True)
    J = torch.autograd.functional.jacobian(lambda inp: func(inp), x, create_graph=False)
    # If y is vector and x is vector, jacobian returns [out_dim, in_dim]
    return J

def finite_difference_jacobian(func: Callable[[torch.Tensor], torch.Tensor], x: torch.Tensor, eps: float = 1e-3) -> np.ndarray:
    """
    Finite difference Jacobian approximation.
    Returns numpy array shape [out_dim, in_dim].
    Suitable only for small dims.
    """
    x0 = x.detach().cpu().numpy()
    y0 = func(torch.from_numpy(x0).to(x.device)).detach().cpu().numpy()
    out_dim = np.prod(y0.shape)
    in_dim = x0.size
    J = np.zeros((out_dim, in_dim), dtype=float)
    flat_x = x0.ravel()
    for i in range(in_dim):
        xp = flat_x.copy()
        xp[i] += eps
        yp = func(torch.from_numpy(xp.reshape(x0.shape)).to(x.device)).detach().cpu().numpy()
        J[:, i] = (yp.ravel() - y0.ravel()) / eps
    return J

def frobenius_norm(mat: torch.Tensor) -> float:
    """Frobenius norm of a matrix-like tensor."""
    return float(torch.norm(mat).cpu().item())

def power_iteration_spectral_norm(A: torch.Tensor, n_iter: int = 20) -> float:
    """
    Approximate spectral norm (largest singular value) of 2D tensor A via power iteration.
    A: torch.Tensor [m, n]
    Returns scalar spectral norm.
    """
    if A.dim() != 2:
        raise ValueError("power_iteration_spectral_norm expects 2D matrix")
    device = A.device
    m, n = A.shape
    v = torch.randn(n, device=device)
    v = v / (v.norm() + 1e-12)
    for _ in range(n_iter):
        w = A.matmul(v)
        w_norm = w.norm() + 1e-12
        v = A.t().matmul(w) 
        v = v / (v.norm() + 1e-12)
    sigma = float(w.norm().cpu().item())
    return sigma

def layerwise_jacobian_stats(model: torch.nn.Module, sample_input: torch.Tensor, output_selector: Callable[[torch.Tensor], torch.Tensor] = None, use_autograd: bool = True):
    """
    Compute per-layer or overall Jacobian statistics for the mapping sample_input -> model(sample_input).
    - model: a torch.nn.Module
    - sample_input: torch.Tensor (single example or small batch)
    - output_selector: optional func(y) to select which part of model output to analyze (e.g., a single logit)
    - use_autograd: if True use torch.autograd.functional.jacobian, else use finite diff (slower)

    Returns dict with:
      'jacobian' -> full Jacobian tensor (or None if too large)
      'frobenius' -> Frobenius norm
      'spectral' -> spectral norm approx (power iteration)
    """
    model.eval()
    def forward_fn(x):
        with torch.no_grad():
            y = model(x)
        # If user wants to analyze a scalar output, project here
        if output_selector is not None:
            return output_selector(y)
        return y

    x = sample_input.detach().requires_grad_(True)
    if use_autograd:
        try:
            J = torch.autograd.functional.jacobian(lambda z: forward_fn(z), x, create_graph=False)
        except RuntimeError:
            # fallback to finite diff
            J = None
    else:
        J = None

    stats = {"jacobian": None, "frobenius": None, "spectral": None}
    if J is not None:
        # flatten to 2D matrix
        Jmat = J.reshape((int(np.prod(J.shape[:-x.dim()])), int(np.prod(x.shape))))
        stats["jacobian"] = Jmat
        stats["frobenius"] = frobenius_norm(Jmat)
        stats["spectral"] = power_iteration_spectral_norm(Jmat)
    else:
        # compute finite-diff approximation for small sizes (warning printed to user)
        try:
            import numpy as np
            J_np = finite_difference_jacobian(lambda z: forward_fn(z).detach(), x, eps=1e-3)
            Jmat = torch.from_numpy(J_np).float()
            stats["jacobian"] = Jmat
            stats["frobenius"] = float(torch.norm(Jmat).cpu().item())
            stats["spectral"] = power_iteration_spectral_norm(Jmat)
        except Exception:
            # give up if dims too large
            stats["jacobian"] = None
    return stats
