"""
Jacobian-based local anisotropy metric.

Instead of measuring k-NN spread, measure how the encoder stretches/compresses
different input directions locally via the Jacobian J = ∂f/∂x.

Singular values of J reveal:
- Which input directions are amplified (large σ) vs suppressed (small σ)
- DINO should suppress augmentation directions, preserve semantic directions
- MAE might have more uniform singular values
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from tqdm import tqdm


@dataclass
class JacobianAnisotropyMetrics:
    """Results from Jacobian-based anisotropy measurement."""
    mean_top_singular: float  # Mean of top singular value
    mean_spectral_ratio: float  # σ₁ / σ_k ratio (anisotropy)
    mean_effective_rank: float  # (Σσ)² / Σσ²
    mean_frobenius: float  # ||J||_F = sqrt(Σσ²)
    singular_values: np.ndarray  # Top-k singular values per sample


def compute_jvp_batch(
    model: torch.nn.Module,
    images: torch.Tensor,
    directions: torch.Tensor,
) -> torch.Tensor:
    """
    Compute Jacobian-vector products J @ v for multiple directions.

    Args:
        model: Encoder model
        images: (B, C, H, W) input images
        directions: (B, n_dirs, C, H, W) input directions to probe

    Returns:
        (B, n_dirs, D) where D is embedding dim - the stretched directions
    """
    B, n_dirs, C, H, W = directions.shape
    device = images.device

    results = []

    for i in range(n_dirs):
        v = directions[:, i]  # (B, C, H, W)

        # Compute JVP using autograd
        images.requires_grad_(True)

        with torch.enable_grad():
            emb = model(images)
            if isinstance(emb, dict):
                emb = emb.get("image", emb.get("features", list(emb.values())[0]))
            if isinstance(emb, tuple):
                emb = emb[0]

            # Compute directional derivative
            # d/dt f(x + t*v)|_{t=0} = J @ v
            jvp = torch.autograd.grad(
                outputs=emb,
                inputs=images,
                grad_outputs=torch.ones_like(emb),
                create_graph=False,
                retain_graph=True,
            )[0]

            # Project onto direction v: this gives us (J @ v) component
            # Actually we need a different approach - use finite differences

        images.requires_grad_(False)

    # Use finite differences instead (more stable)
    return _compute_jvp_finite_diff(model, images, directions)


def _compute_jvp_finite_diff(
    model: torch.nn.Module,
    images: torch.Tensor,
    directions: torch.Tensor,
    eps: float = 1e-3,
) -> torch.Tensor:
    """
    Compute J @ v using finite differences: (f(x + εv) - f(x - εv)) / 2ε
    """
    B, n_dirs, C, H, W = directions.shape
    device = images.device

    # Get base embeddings
    with torch.no_grad():
        base_emb = model(images)
        if isinstance(base_emb, dict):
            base_emb = base_emb.get("image", base_emb.get("features", list(base_emb.values())[0]))
        if isinstance(base_emb, tuple):
            base_emb = base_emb[0]

    D = base_emb.shape[-1]
    results = torch.zeros(B, n_dirs, D, device=device)

    for i in range(n_dirs):
        v = directions[:, i]  # (B, C, H, W)

        with torch.no_grad():
            # Forward difference
            emb_plus = model(images + eps * v)
            if isinstance(emb_plus, dict):
                emb_plus = emb_plus.get("image", emb_plus.get("features", list(emb_plus.values())[0]))
            if isinstance(emb_plus, tuple):
                emb_plus = emb_plus[0]

            # Backward difference
            emb_minus = model(images - eps * v)
            if isinstance(emb_minus, dict):
                emb_minus = emb_minus.get("image", emb_minus.get("features", list(emb_minus.values())[0]))
            if isinstance(emb_minus, tuple):
                emb_minus = emb_minus[0]

            # Central difference
            jvp = (emb_plus - emb_minus) / (2 * eps)
            results[:, i] = jvp

    return results


def estimate_jacobian_singular_values(
    model: torch.nn.Module,
    images: torch.Tensor,
    n_directions: int = 64,
    n_power_iterations: int = 10,
    seed: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Estimate top singular values of Jacobian using randomized SVD.

    Uses power iteration on J.T @ J to find top singular values.

    Args:
        model: Encoder model
        images: (B, C, H, W) input images
        n_directions: Number of random directions to probe
        n_power_iterations: Power iteration steps
        seed: Random seed

    Returns:
        singular_values: (B, k) estimated top-k singular values
        singular_vectors: (B, k, C, H, W) corresponding input directions
    """
    B, C, H, W = images.shape
    device = images.device
    input_dim = C * H * W

    torch.manual_seed(seed)

    # Initialize random directions (orthonormal via QR)
    V = torch.randn(B, n_directions, C, H, W, device=device)
    V = V / V.view(B, n_directions, -1).norm(dim=-1, keepdim=True).unsqueeze(-1).unsqueeze(-1)

    # Power iteration to find top singular vectors
    for iteration in range(n_power_iterations):
        # Compute J @ V
        JV = _compute_jvp_finite_diff(model, images, V)  # (B, n_dirs, D)

        # Compute J.T @ (J @ V) by using VJP
        # J.T @ u = gradient of (u.T @ f(x)) w.r.t. x
        V_new = _compute_vjp_batch(model, images, JV)  # (B, n_dirs, C, H, W)

        # Orthonormalize
        V_flat = V_new.view(B, n_directions, -1)  # (B, n_dirs, input_dim)

        # QR decomposition for orthonormalization
        for b in range(B):
            Q, R = torch.linalg.qr(V_flat[b].T)  # (input_dim, n_dirs)
            V_flat[b] = Q.T

        V = V_flat.view(B, n_directions, C, H, W)

    # Compute final singular values: σ = ||J @ v||
    JV = _compute_jvp_finite_diff(model, images, V)  # (B, n_dirs, D)
    singular_values = JV.norm(dim=-1)  # (B, n_dirs)

    # Sort by magnitude
    singular_values, indices = singular_values.sort(dim=-1, descending=True)

    return singular_values, V


def _compute_vjp_batch(
    model: torch.nn.Module,
    images: torch.Tensor,
    vectors: torch.Tensor,
    eps: float = 1e-3,
) -> torch.Tensor:
    """
    Compute J.T @ v (vector-Jacobian product) for multiple vectors.

    J.T @ v = gradient of (v.T @ f(x)) w.r.t. x

    Args:
        model: Encoder model
        images: (B, C, H, W) input images
        vectors: (B, n_vecs, D) vectors in embedding space

    Returns:
        (B, n_vecs, C, H, W) - the pulled-back vectors
    """
    B, n_vecs, D = vectors.shape
    C, H, W = images.shape[1:]
    device = images.device

    results = torch.zeros(B, n_vecs, C, H, W, device=device)

    for i in range(n_vecs):
        v = vectors[:, i]  # (B, D)

        images_grad = images.clone().requires_grad_(True)

        with torch.enable_grad():
            emb = model(images_grad)
            if isinstance(emb, dict):
                emb = emb.get("image", emb.get("features", list(emb.values())[0]))
            if isinstance(emb, tuple):
                emb = emb[0]

            # Compute gradient of v.T @ emb w.r.t. images
            scalar = (v * emb).sum()
            grad = torch.autograd.grad(scalar, images_grad, retain_graph=False)[0]

        results[:, i] = grad.detach()

    return results


def compute_jacobian_anisotropy(
    model: torch.nn.Module,
    images: torch.Tensor,
    n_directions: int = 32,
    n_power_iterations: int = 5,
) -> Dict[str, float]:
    """
    Compute Jacobian-based anisotropy metrics for a batch of images.

    Args:
        model: Encoder model
        images: (B, C, H, W) input images
        n_directions: Number of singular values to estimate
        n_power_iterations: Power iteration steps

    Returns:
        Dictionary with anisotropy metrics
    """
    singular_values, _ = estimate_jacobian_singular_values(
        model, images, n_directions, n_power_iterations
    )

    # Metrics
    top_sv = singular_values[:, 0]  # Largest singular value
    bottom_sv = singular_values[:, -1]  # Smallest (of estimated)

    # Spectral ratio: σ₁ / σ_k
    spectral_ratio = top_sv / (bottom_sv + 1e-8)

    # Effective rank: (Σσ)² / Σσ²
    sv_sum = singular_values.sum(dim=-1)
    sv_sq_sum = (singular_values ** 2).sum(dim=-1)
    effective_rank = sv_sum ** 2 / (sv_sq_sum + 1e-8)

    # Frobenius norm (estimated): sqrt(Σσ²)
    frobenius = sv_sq_sum.sqrt()

    return {
        "top_singular": top_sv.mean().item(),
        "spectral_ratio": spectral_ratio.mean().item(),
        "effective_rank": effective_rank.mean().item(),
        "frobenius": frobenius.mean().item(),
        "singular_values": singular_values.cpu().numpy(),
    }


def compute_jacobian_anisotropy_dataset(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    n_samples: int = 100,
    n_directions: int = 32,
    n_power_iterations: int = 5,
    device: str = "cuda",
    verbose: bool = True,
) -> JacobianAnisotropyMetrics:
    """
    Compute Jacobian anisotropy over a dataset.

    Args:
        model: Encoder model
        dataloader: DataLoader yielding (images, ...) batches
        n_samples: Number of images to analyze
        n_directions: Singular values to estimate per image
        n_power_iterations: Power iteration steps
        device: Device to use
        verbose: Print progress

    Returns:
        JacobianAnisotropyMetrics dataclass
    """
    model = model.to(device).eval()

    all_top_sv = []
    all_spectral_ratio = []
    all_effective_rank = []
    all_frobenius = []
    all_singular_values = []

    n_processed = 0

    for batch in tqdm(dataloader, disable=not verbose, desc="Jacobian anisotropy"):
        if isinstance(batch, (list, tuple)):
            images = batch[0]
        else:
            images = batch

        images = images.to(device)
        B = images.shape[0]

        if n_processed + B > n_samples:
            images = images[:n_samples - n_processed]
            B = images.shape[0]

        if B == 0:
            break

        metrics = compute_jacobian_anisotropy(
            model, images, n_directions, n_power_iterations
        )

        all_top_sv.append(metrics["top_singular"])
        all_spectral_ratio.append(metrics["spectral_ratio"])
        all_effective_rank.append(metrics["effective_rank"])
        all_frobenius.append(metrics["frobenius"])
        all_singular_values.append(metrics["singular_values"])

        n_processed += B

        if n_processed >= n_samples:
            break

    return JacobianAnisotropyMetrics(
        mean_top_singular=np.mean(all_top_sv),
        mean_spectral_ratio=np.mean(all_spectral_ratio),
        mean_effective_rank=np.mean(all_effective_rank),
        mean_frobenius=np.mean(all_frobenius),
        singular_values=np.concatenate(all_singular_values, axis=0),
    )


def quick_jacobian_test(
    model: torch.nn.Module,
    device: str = "cuda",
) -> Dict[str, float]:
    """
    Quick test of Jacobian anisotropy on random images.
    """
    model = model.to(device).eval()

    # Random test images
    images = torch.randn(4, 3, 224, 224, device=device)
    images = (images - images.min()) / (images.max() - images.min())

    return compute_jacobian_anisotropy(model, images, n_directions=16, n_power_iterations=3)


if __name__ == "__main__":
    # Quick test
    import timm

    print("Loading model...")
    model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)

    print("Running Jacobian anisotropy test...")
    results = quick_jacobian_test(model, device="cuda" if torch.cuda.is_available() else "cpu")

    print(f"Top singular value: {results['top_singular']:.4f}")
    print(f"Spectral ratio: {results['spectral_ratio']:.4f}")
    print(f"Effective rank: {results['effective_rank']:.4f}")
    print(f"Frobenius norm: {results['frobenius']:.4f}")
