"""
Gauss-Newton-vector product operations.

Implements minibatch GGN-vP operator:
    v ↦ (1/S) Σ_{i ∈ curvature subset} J_i^T H_i J_i v + λv

where J_i is the per-sample Jacobian and H_i is the loss Hessian w.r.t. logits.

For cross-entropy loss, H_i = diag(p_i) - p_i p_i^T (the Fisher information).

Uses torch.func (functorch) for efficient vectorized JVP/VJP computation.
Implements both CG and LiSSA solvers for IHVP computation.

DDP-aware: In distributed mode, each rank computes on its local shard and 
the results are all-reduced to get the global GGN-vP.
"""

import torch
import torch.distributed as dist
from torch.backends.cuda import sdp_kernel
# Disable the fused/native MHA fastpath that triggers _native_multi_head_attention
torch.backends.mha.set_fastpath_enabled(False)

# Enable TF32 for Ampere/Hopper GPUs - significant speedup with minimal numeric impact
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True  # Auto-tune convolution algorithms
torch.set_float32_matmul_precision("high")

try:
    # functorch still works and is the cleanest way to get params as a tuple (no dict plumbing)
    from functorch import make_functional_with_buffers
except Exception:
    make_functional_with_buffers = None


import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.func import functional_call, vmap, vjp, jvp, grad
from typing import Optional, Tuple, List, Callable, Dict, Any, Union
import numpy as np
from tqdm import tqdm
import math
import time
from dataclasses import dataclass
from contextlib import contextmanager
from .vit_full import ViTWithHooks
from .logging_utils import get_logger, log_tensor_stats, log_dict, is_ddp, DDPState, allreduce_tensor

logger = get_logger(__name__)


# =============================================================================
# Utility: Remove hooks for torch.func compatibility
# =============================================================================

def _remove_all_hooks(module: nn.Module) -> dict:
    """
    Remove all forward/backward hooks from a module and its children.
    
    Returns a dict that can be passed to _restore_all_hooks to restore them.
    
    This is needed because timm (and other libraries) register backward hooks
    that use BackwardHookFunction, which is not compatible with torch.func
    transforms (vjp, jvp, vmap, etc.).
    """
    saved_hooks = {}
    
    for name, child in module.named_modules():
        # Save and clear forward hooks
        if child._forward_hooks:
            saved_hooks[(name, 'forward')] = dict(child._forward_hooks)
            child._forward_hooks.clear()
        
        # Save and clear forward pre hooks
        if child._forward_pre_hooks:
            saved_hooks[(name, 'forward_pre')] = dict(child._forward_pre_hooks)
            child._forward_pre_hooks.clear()
        
        # Save and clear backward hooks
        if child._backward_hooks:
            saved_hooks[(name, 'backward')] = dict(child._backward_hooks)
            child._backward_hooks.clear()
        
        # Save and clear full backward hooks (the problematic ones)
        if hasattr(child, '_backward_pre_hooks') and child._backward_pre_hooks:
            saved_hooks[(name, 'backward_pre')] = dict(child._backward_pre_hooks)
            child._backward_pre_hooks.clear()
    
    return saved_hooks


def _restore_all_hooks(module: nn.Module, saved_hooks: dict):
    """Restore hooks saved by _remove_all_hooks."""
    name_to_module = {name: m for name, m in module.named_modules()}
    
    for (name, hook_type), hooks in saved_hooks.items():
        m = name_to_module.get(name)
        if m is None:
            continue
        
        if hook_type == 'forward':
            m._forward_hooks.update(hooks)
        elif hook_type == 'forward_pre':
            m._forward_pre_hooks.update(hooks)
        elif hook_type == 'backward':
            m._backward_hooks.update(hooks)
        elif hook_type == 'backward_pre':
            m._backward_pre_hooks.update(hooks)


@contextmanager
def disable_hooks(module: nn.Module):
    """
    Context manager to temporarily disable all hooks on a module.
    
    Usage:
        with disable_hooks(model):
            # torch.func transforms work here
            _, vjp_fn = vjp(net_fn, params)
    """
    saved = _remove_all_hooks(module)
    try:
        yield
    finally:
        _restore_all_hooks(module, saved)


# =============================================================================
# Utility functions for pytree-like operations (similar to JAX)
# =============================================================================


def _tree_zeros_like(template: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    """Create zero-filled param dict with same structure."""
    return {k: torch.zeros_like(v) for k, v in template.items()}


def _tree_dot(a: Dict[str, torch.Tensor], b: Dict[str, torch.Tensor]) -> torch.Tensor:
    """Dot product between two param dicts."""
    return sum((a[k] * b[k]).sum() for k in a.keys())


# =============================================================================
# Core GGN-vector product computation using torch.func
# =============================================================================

def gnhvp_single(
    loss_fn: Callable[[torch.Tensor], torch.Tensor],
    net_fn: Callable[[Dict[str, torch.Tensor]], torch.Tensor],
    params: Dict[str, torch.Tensor],
    v: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
    """
    Compute Gauss-Newton-Hessian vector product for a single sample.
    
    GNH v = J^T (∇²_z L) J v
    
    where J = ∂net/∂params, and ∇²_z L is the Hessian of loss w.r.t. network output.
    
    Uses the identity:
        J^T H J v = J^T (H (J v))
    
    Step 1: Compute J v via forward-mode AD (jvp)
    Step 2: Compute H (J v) via backward-mode AD on loss
    Step 3: Compute J^T (H J v) via backward-mode AD on network
    """
    # Step 1: J v = d(net(params))/dparams · v using JVP
    outputs, jv = jvp(net_fn, (params,), (v,))
    
    # Step 2: Compute grad of loss w.r.t outputs, then hvp with jv
    # For loss L(z), we need ∇²_z L · (Jv)
    # This is d/dz (∇_z L · jv) = VJP of (∇_z L) with tangent jv
    def loss_grad_fn(z):
        return torch.autograd.grad(loss_fn(z), z, create_graph=True)[0]
    
    # H @ jv via JVP on the gradient function
    _, hjv = jvp(loss_grad_fn, (outputs,), (jv,))
    
    # Step 3: J^T (H J v) via VJP
    _, vjp_fn = vjp(net_fn, params)
    gnhvp_result = vjp_fn(hjv)[0]
    
    return gnhvp_result


class GGNOperator:
    """
    Efficient Gauss-Newton-vector product operator using torch.func.
    
    Implements LinearOperator-like interface for CG solver.
    Computes true GGN-vP (not full Hessian) using vectorized operations.
    
    GGN v = (1/n) Σ_i J_i^T H_i J_i v + λv
    
    where H_i is the Hessian of loss w.r.t. logits (Fisher information for CE).
    """
    
    def __init__(
        self,
        model: ViTWithHooks,
        dataloader: DataLoader,
        damping: float = 1e-2,
        weight_decay: float = 0.01,
        device: str = "cuda",
        use_gnhvp: bool = True,
        use_mixed_precision: bool = True,
        compile_gnhvp: bool = False,
    ):
        """
        Args:
            model: ViT model with hooks
            dataloader: DataLoader for curvature subset
            damping: Damping parameter λ (added to diagonal)
            weight_decay: L2 regularization coefficient
            device: Computation device
            use_gnhvp: If True, use GN-HVP; if False, use full HVP
            use_mixed_precision: If True, use BF16/autocast for JVP/VJP (1.5-2x speedup)
            compile_gnhvp: If True, use torch.compile on gnhvp (experimental)
        """
        self.model = model
        self.dataloader = dataloader
        self.damping = damping
        self.weight_decay = weight_decay
        self.device = device
        self.use_gnhvp = use_gnhvp
        self.use_mixed_precision = use_mixed_precision
        self.compile_gnhvp = compile_gnhvp
        self.use_lm_damping = False

        # Compute number of parameters
        self.num_params = model.num_params
        
        self.model.model.eval()

        # Get parameter names and shapes for vectorization
        self._param_names = []
        self._param_shapes = []
        self._param_numels = []
        for name, p in model.model.named_parameters():
            self._param_names.append(name)
            self._param_shapes.append(p.shape)
            self._param_numels.append(p.numel())
        
        # Cache for batches
        self._cached_batches: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
        
        # Track matvec calls
        self._matvec_count = 0

        logger.info(f"GGNOperator initialized: {self.num_params:,} params, damping={damping}, "
                    f"use_gnhvp={use_gnhvp}, mixed_precision={use_mixed_precision}")
    
    def _get_param_dict(self) -> Dict[str, torch.Tensor]:
        """Get current parameters as a dict."""
        return dict(self.model.model.named_parameters())
    
    def _vector_to_param_dict(self, v: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Convert flat vector to param dict."""
        params = {}
        offset = 0
        for name, shape, numel in zip(self._param_names, self._param_shapes, self._param_numels):
            params[name] = v[offset:offset + numel].view(shape)
            offset += numel
        return params
    
    def _param_dict_to_vector(self, params: Dict[str, torch.Tensor]) -> torch.Tensor:
        """Convert param dict to flat vector."""
        return torch.cat([params[name].flatten() for name in self._param_names])
        
    def cache_batches(self):
        """
        Cache all batches as persistent contiguous tensors on GPU.
        
        This creates a single contiguous block of memory for all images/labels,
        avoiding DataLoader iteration overhead during matvec.
        """
        all_images = []
        all_labels = []
        for images, labels in self.dataloader:
            all_images.append(images)
            all_labels.append(labels)
        
        # Concatenate into persistent contiguous tensors
        self._all_images = torch.cat(all_images, dim=0).to(
            self.device, non_blocking=True
        ).contiguous()
        self._all_labels = torch.cat(all_labels, dim=0).to(
            self.device, non_blocking=True
        ).contiguous()
        
        # Store batch info for iteration
        batch_size = all_images[0].shape[0]
        self._batch_size = batch_size
        self._num_samples = self._all_images.shape[0]
        
        # Create index ranges for batches (for backwards compat with batch iteration)
        self._cached_batches = []
        for i in range(0, self._num_samples, batch_size):
            end_idx = min(i + batch_size, self._num_samples)
            self._cached_batches.append((
                self._all_images[i:end_idx],
                self._all_labels[i:end_idx],
            ))
        
        # Ensure all transfers complete
        if self.device == "cuda":
            torch.cuda.synchronize()
        
        logger.info(f"Cached {self._num_samples} samples as persistent tensors "
                    f"({len(self._cached_batches)} batches, {self._all_images.element_size() * self._all_images.numel() / 1e9:.2f} GB)")
        
    def clear_cache(self):
        """Clear cached batches and persistent tensors."""
        self._cached_batches = None
        if hasattr(self, '_all_images'):
            del self._all_images
        if hasattr(self, '_all_labels'):
            del self._all_labels
        if hasattr(self, '_frozen_params'):
            self._frozen_params = None
        if hasattr(self, '_output_buffer'):
            self._output_buffer = None
        torch.cuda.empty_cache()
        logger.debug("Cleared GGN batch cache and persistent tensors")
        
    def _get_batches(self):
        """Get batches from cache. Raises if not cached."""
        if self._cached_batches is None:
            raise RuntimeError(
                "Call cache_batches() or prepare_for_solve() before PCG. "
                "Lazy batch loading in matvec is disabled for performance."
            )
        return self._cached_batches
    
    def prepare_for_solve(self):
        """
        Prepare operator for CG/PCG solve.
        
        Must be called once before starting a solve. This:
        1. Caches batches to GPU (if not already)
        2. Freezes parameters (no clone, just detach + requires_grad)
        3. Preallocates output buffers (ping-pong for DDP safety)
        4. Computes local and global sample counts for DDP
        
        This avoids the massive overhead of cloning 89M params per matvec.
        """
        # Cache batches if not already done
        if self._cached_batches is None:
            self.cache_batches()
        
        # Freeze params ONCE - no clone, just detach and enable grads
        self._frozen_params = {
            n: p.detach()          # <-- no requires_grad_(True)
            for n, p in self.model.model.named_parameters()
        }
                
        # =======================================================================
        # DDP: Compute local and global sample counts
        # =======================================================================
        self._local_num_samples = self._num_samples
        
        if is_ddp():
            # All-reduce to get global sample count
            local_count_tensor = torch.tensor(
                self._local_num_samples, 
                device=self.device, 
                dtype=torch.long
            )
            dist.all_reduce(local_count_tensor, op=dist.ReduceOp.SUM)
            self._global_num_samples = local_count_tensor.item()
            
            logger.info(f"DDP GGN operator: local_samples={self._local_num_samples}, "
                       f"global_samples={self._global_num_samples}")
        else:
            self._global_num_samples = self._local_num_samples
        
        # =======================================================================
        # Preallocate ping-pong output buffers for DDP safety
        # Using two buffers prevents aliasing bugs in all-reduce
        # =======================================================================
        self._output_buffer_0 = torch.empty(
            self.num_params, device=self.device, dtype=torch.float32
        )
        self._output_buffer_1 = torch.empty(
            self.num_params, device=self.device, dtype=torch.float32
        )
        self._buffer_toggle = 0  # Track which buffer to use
        
        logger.debug("GGNOperator prepared for solve (ping-pong buffers allocated)")
        
        if self.use_lm_damping:
            self._diag_ggn = self.estimate_diag_hutchinson(
                num_probes=getattr(self, "lm_num_probes", 16),
                seed=getattr(self, "lm_seed", 0),
                eps=getattr(self, "lm_diag_eps", 1e-6),
                normalize_mean=True,
            )

        logger.debug("GGNOperator prepared for solve")

    @torch.no_grad()
    def estimate_diag_hutchinson(self, num_probes: int = 16, seed: int = 0,
                                eps: float = 1e-6, normalize_mean: bool = True) -> torch.Tensor:
        diag_accum = torch.zeros(self.num_params, device=self.device, dtype=torch.float32)

        # deterministic probes across ranks
        gen = torch.Generator(device="cpu")
        gen.manual_seed(seed)

        for k in range(num_probes):
            r_cpu = torch.empty(self.num_params, dtype=torch.float32, device="cpu")
            r_cpu.bernoulli_(0.5, generator=gen)
            r_cpu.mul_(2.0).sub_(1.0)  # {-1,+1}
            r = r_cpu.to(self.device, non_blocking=True)

            Ar = self.matvec_ggn_only(r)  # IMPORTANT: no damping inside
            diag_accum.add_(r * Ar)

        diag = diag_accum / float(num_probes)
        diag = diag.clamp(min=eps)

        if normalize_mean:
            diag = diag / (diag.mean() + 1e-12)

        return diag

    def _param_dict_to_vector_inplace(self, params: Dict[str, torch.Tensor], out: torch.Tensor) -> torch.Tensor:
        """Convert param dict to flat vector using preallocated buffer (no torch.cat)."""
        off = 0
        for name, numel in zip(self._param_names, self._param_numels):
            out[off:off + numel].copy_(params[name].reshape(-1))
            off += numel
        return out
    
    def _compute_gnhvp_batch(
        self,
        params_dict: Dict[str, torch.Tensor],
        v_dict: Dict[str, torch.Tensor],
        images: torch.Tensor,
        labels: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute GN-HVP for a batch using efficient vectorized operations.
        
        Uses the decomposition:
            GNH v = J^T H J v
        
        For cross-entropy, H = diag(p) - pp^T
        
        Note: Mixed precision with torch.func.vjp can cause dtype mismatches
        in autograd, so we run the full JVP-VJP chain in fp32 for stability.
        TF32 (enabled at module level) still provides speedup.
        
        Note: We disable hooks during torch.func transforms because timm (and
        other libs) register backward hooks that use BackwardHookFunction,
        which is incompatible with functorch transforms.
        """
        batch_size = images.shape[0]
        
        # Define network function with functional_call
        def net_fn(p_dict):
            return functional_call(self.model.model, p_dict, (images,))
        
        # Note: autocast + torch.func.vjp causes dtype mismatches in backward pass
        # ("expected input and grad types to match"). TF32 is still enabled for
        # matmuls, which provides most of the speedup without the instability.
        
        # Disable hooks during torch.func transforms (timm registers backward hooks
        # that are incompatible with functorch)
        with disable_hooks(self.model.model):
            # Step 1: Compute J @ v via JVP (forward-mode AD)
            logits, jv = jvp(net_fn, (params_dict,), (v_dict,))
            # jv has shape (batch_size, num_classes)
            
            # Step 2: Compute H @ (J @ v) where H is Fisher for cross-entropy
            # H = diag(p) - p p^T, so H @ z = p * z - p * (p · z)
            probs = F.softmax(logits.detach(), dim=-1)  # (batch_size, num_classes)
            
            # H @ jv = diag(p) @ jv - p @ (p^T @ jv)
            #        = p * jv - p * sum(p * jv, dim=-1, keepdim=True)
            p_dot_jv = (probs * jv).sum(dim=-1, keepdim=True)  # (batch_size, 1)
            hjv = probs * jv - probs * p_dot_jv  # (batch_size, num_classes)
            
            # Step 3: Compute J^T @ (H @ J @ v) via VJP (backward-mode AD)
            _, vjp_fn = vjp(net_fn, params_dict)
            gnhvp_dict = vjp_fn(hjv)[0]
        
        return gnhvp_dict
    
    def _ggn_matvec(self, v: torch.Tensor, add_damping: bool) -> torch.Tensor:
        """
        Compute GGN-vector product: (Ĝ + λI)v
        
        DDP behavior:
        - Each rank computes LOCAL sum over its shard (not averaged)
        - All-reduce (SUM) to get global sum
        - Divide by global_num_samples
        - Add damping locally
        
        Args:
            v: Parameter vector (num_params,)
            
        Returns:
            Gv + λv: GGN-vector product with damping
        """
        # Track timing for performance analysis
        if is_ddp():
            compute_start = time.time()
        
        self._matvec_count += 1
        v = v.to(self.device)
        
        # Verify prepare_for_solve was called (for batch caching, DDP setup, etc.)
        if not hasattr(self, '_global_num_samples'):
            raise RuntimeError(
                "matvec called without prepare_for_solve(). "
                "Call prepare_for_solve() before PCG."
            )
        
        # Use frozen params cached by prepare_for_solve() to avoid rebuilding every matvec
        # This is critical for performance and determinism
        params_dict = self._frozen_params
        # Convert to param dict format
        v_dict = self._vector_to_param_dict(v)
        
        # Accumulate GGN-vP over LOCAL batches (do NOT divide yet)
        result_dict = _tree_zeros_like(params_dict)
        local_samples = 0
        
        batches = self._get_batches()
        
        for images, labels in batches:
            batch_size = images.shape[0]
            local_samples += batch_size
            
            batch_gnhvp = self._compute_gnhvp_batch(params_dict, v_dict, images, labels)
            
            with torch.no_grad():
                for key in result_dict:
                    result_dict[key].add_(batch_gnhvp[key].detach())
        
        # =======================================================================
        # DDP: Pack into flat vector, all-reduce, then apply normalization
        # =======================================================================
        
        # Select ping-pong buffer to avoid aliasing issues
        if hasattr(self, '_buffer_toggle'):
            if self._buffer_toggle == 0:
                output_buffer = self._output_buffer_0
                self._buffer_toggle = 1
            else:
                output_buffer = self._output_buffer_1
                self._buffer_toggle = 0
        else:
            output_buffer = torch.empty(self.num_params, device=self.device, dtype=torch.float32)
        
        # Pack result into flat vector
        result = self._param_dict_to_vector_inplace(result_dict, output_buffer)
        
        if is_ddp():
            compute_time = time.time() - compute_start
            allreduce_start = time.time()
            
            # All-reduce the accumulated sums across all ranks
            dist.all_reduce(result, op=dist.ReduceOp.SUM)
            
            allreduce_time = time.time() - allreduce_start
            
            # Divide by global sample count
            result.div_(self._global_num_samples)
            
            # Log timing periodically
            if self._matvec_count % 10 == 1:
                log_dict(logger, f"DDP GGN matvec #{self._matvec_count}", {
                    'compute_ms': compute_time * 1000,
                    'allreduce_ms': allreduce_time * 1000,
                    'input_norm': v.norm().item(),
                    'output_norm': result.norm().item(),
                    'local_samples': local_samples,
                    'global_samples': self._global_num_samples,
                }, level='debug')
        else:
            # Single GPU: just divide by local sample count
            result.div_(local_samples)
        
        # Add weight decay and damping (same on all ranks)
        if add_damping:
            with torch.no_grad():
                if self.weight_decay > 0:
                    result.add_(v, alpha=self.weight_decay)
                result.add_(v, alpha=self.damping)
        
        if self._matvec_count % 10 == 1 and not is_ddp():
            log_dict(logger, f"GGN matvec #{self._matvec_count}", {
                'input_norm': v.norm().item(),
                'output_norm': result.norm().item(),
                'total_samples': local_samples,
            }, level='debug')
        
        return result.detach()
    
    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        return self._ggn_matvec(v, add_damping=True).clone()

    def matvec_ggn_only(self, v: torch.Tensor) -> torch.Tensor:
        return self._ggn_matvec(v, add_damping=False).clone()
    
    def __call__(self, v: torch.Tensor) -> torch.Tensor:
        """Alias for matvec."""
        return self.matvec(v).clone()
    
    @property
    def shape(self) -> Tuple[int, int]:
        """Shape of the operator (n, n)."""
        return (self.num_params, self.num_params)



# =============================================================================
# Utility functions for diagonal GGN and damping estimation
# =============================================================================

def compute_diagonal_ggn(
    model,  # ViTWithHooks
    dataloader: DataLoader,
    device: str = "cuda",
    ggn_op: Optional[object] = None,
) -> torch.Tensor:
    """
    Proper Jacobi preconditioner for Gauss-Newton / GGN via output-space sampling.

    Returns:
        Minv: flat (num_params,) tensor where Minv ~= 1 / diag(A),
              with A = (GGN + weight_decay*I + damping*I) matching your CG operator.

    References:
      - Chapelle & Erhan (2011): unbiased GN diagonal estimation via modified backprop.
      - For softmax CE, sample u = onehot(y~p) - p to get Cov(u)=diag(p)-pp^T.
    """
    logger.info("Computing diagonal GGN (Jacobi) via output-space sampling...")

    model.model.eval()

    # Use the same parameter ordering as your vectorization.
    named_params = [(n, p) for (n, p) in model.model.named_parameters()]
    num_params = int(getattr(model, "num_params", sum(p.numel() for _, p in named_params)))

    # Pull operator hyperparams to match CG.
    damping = float(getattr(ggn_op, "damping", 0.0)) if ggn_op is not None else 0.0
    weight_decay = float(getattr(ggn_op, "weight_decay", 0.0)) if ggn_op is not None else 0.0

    # Probes: 1–4 is usually enough; more is just expensive.
    num_probes = int(getattr(ggn_op, "diag_num_probes", 2)) if ggn_op is not None else 2
    num_probes = max(1, min(8, num_probes))

    seed = int(getattr(ggn_op, "diag_seed", 0)) if ggn_op is not None else 0
    gen = torch.Generator(device=device)
    gen.manual_seed(seed)

    # Accumulate diag in float32. (This is ~344MB for ViT-B/16; it’s fine on real GPUs.)
    diag_acc = torch.zeros(num_params, device=device, dtype=torch.float32)
    total_samples = 0

    # Build a detached parameter dict ONCE (no cloning). Works even if original params have requires_grad=False.
    # These are leaf tensors with grad enabled, used only through functional_call.
    params_dict = {}
    for n, p in named_params:
        t = p.detach().to(device)
        t.requires_grad_(True)
        params_dict[n] = t

    # Precompute offsets once.
    offsets = []
    off = 0
    for n, p in named_params:
        k = p.numel()
        offsets.append((n, off, k))
        off += k
    assert off == num_params

    # Helper: sample approximate median from a subset (avoid torch.quantile on 86M elems).
    def approx_median(x: torch.Tensor, sample_n: int = 200_000) -> torch.Tensor:
        sample_n = min(sample_n, x.numel())
        idx = torch.randint(0, x.numel(), (sample_n,), generator=gen, device=x.device)
        return x[idx].median()

    for images, _labels in tqdm(dataloader, desc="Computing diag(GGN)"):
        images = images.to(device, non_blocking=True)
        B = int(images.shape[0])
        total_samples += B

        # Forward once; reuse graph for multiple probes on same batch.
        logits = functional_call(model.model, params_dict, (images,))
        probs = F.softmax(logits, dim=-1)

        # Multiple output-space probes; retain graph except last.
        for t in range(num_probes):
            # Sample y ~ Categorical(probs) for each sample
            y = torch.multinomial(probs, num_samples=1, generator=gen).squeeze(1)

            # u = onehot(y) - p   (mean-zero, Cov(u)=diag(p)-pp^T)
            u = torch.zeros_like(probs)
            u.scatter_(1, y[:, None], 1.0)
            u.sub_(probs)

            grads = torch.autograd.grad(
                outputs=logits,
                inputs=[params_dict[n] for (n, _p) in named_params],
                grad_outputs=u,
                retain_graph=(t != num_probes - 1),
                create_graph=False,
                allow_unused=False,
            )

            # Accumulate g^2; g is the SUM over samples => unbiased for SUM diag; divide by total_samples later.
            for (g, (n, off, k)) in zip(grads, offsets):
                diag_acc[off:off + k].add_(g.detach().reshape(-1).float().square())

        # Explicitly free big tensors per batch
        del logits, probs, grads, u, y

    if total_samples == 0:
        raise RuntimeError("dataloader yielded zero samples; cannot build preconditioner.")

    # Mean per-sample diagonal of (J^T H J)
    diag_gn = diag_acc / float(total_samples * num_probes)

    # Match CG operator A = GGN + wd*I + damping*I
    diag_A = diag_gn + (weight_decay + damping)

    # Robust clipping around the (approx) median to avoid trashing CG with extreme scaling.
    eps = float(getattr(ggn_op, "diag_eps", 1e-6)) if ggn_op is not None else 1e-6
    clip_ratio = float(getattr(ggn_op, "diag_clip_ratio", 1e3)) if ggn_op is not None else 1e3

    diag_A = diag_A.clamp_min(eps)
    med = approx_median(diag_A, sample_n=200_000)
    lo = (med / clip_ratio).clamp_min(eps)
    hi = (med * clip_ratio).clamp_min(lo * 1.01)

    diag_A = diag_A.clamp(lo, hi)

    Minv = (1.0 / diag_A).to(torch.float32)

    log_dict(logger, "Diagonal A statistics (Jacobi)", {
        "total_samples": int(total_samples),
        "num_probes": int(num_probes),
        "damping": float(damping),
        "weight_decay": float(weight_decay),
        "diagA_min": float(diag_A.min().item()),
        "diagA_max": float(diag_A.max().item()),
        "diagA_mean": float(diag_A.mean().item()),
        "diagA_median_approx": float(med.item()),
        "clip_ratio": float(clip_ratio),
        "Minv_min": float(Minv.min().item()),
        "Minv_max": float(Minv.max().item()),
        "Minv_mean": float(Minv.mean().item()),
    })

    return Minv


@torch.no_grad()
def hutchinson_diag_from_matvec(
    ggn_op,                         # object with matvec(v) -> (G + lam I)v
    num_probes: int = 8,
    probe_batch: int = 1,           # keep 1 unless you implement block-matvec
    eps: float = 1e-8,
    clamp_min: float = 1e-8,
    clamp_max: Optional[float] = None,
    seed: int = 0,
    estimate: str = "G",            # "G" or "A"
) -> torch.Tensor:
    """
    Estimate diagonal of G (or A=G+lam I) using Hutchinson probes.

    Returns:
        diag: (p,) tensor on ggn_op.device
    """
    device = ggn_op.device
    p = ggn_op.num_params
    lam = float(getattr(ggn_op, "damping", 0.0))

    gen = torch.Generator(device="cpu")
    gen.manual_seed(seed)

    diag_acc = torch.zeros(p, device=device, dtype=torch.float32)

    # Note: probe_batch > 1 would require a batched matvec for efficiency.
    for k in range(num_probes):
        # Rademacher ±1 probe (float32 is fine)
        r = torch.empty(p, device=device, dtype=torch.float32)
        r.bernoulli_(0.5, generator=gen)
        r.mul_(2.0).sub_(1.0)

        Ar = ggn_op.matvec(r)  # returns (G + lam I)r

        if estimate.upper() == "G":
            Gr = Ar - lam * r
            diag_acc.add_(r * Gr)
        else:
            diag_acc.add_(r * Ar)

    diag = diag_acc / float(num_probes)

    # Stabilize
    diag = diag.abs()  # due to noise; for PSD operator diag should be >=0
    diag = diag.clamp_min(clamp_min)
    if clamp_max is not None:
        diag = diag.clamp_max(clamp_max)

    # return as float32
    return diag


def make_diag_preconditioner(diag: torch.Tensor, damping: float, eps: float = 1e-8) -> torch.Tensor:
    """
    Returns M^{-1} diagonal for A = G + damping I:
        M^{-1} = 1 / (diag(G) + damping)
    """
    return 1.0 / (diag + float(damping) + eps)


def estimate_damping(
    model: ViTWithHooks,
    dataloader: DataLoader,
    scale: float = 0.01,
    device: str = "cuda",
) -> float:
    """
    Estimate appropriate damping value based on diagonal GGN.
    
    λ = scale * mean(diag(G))
    
    Args:
        model: ViT model
        dataloader: DataLoader
        scale: Scaling factor (default 0.01)
        device: Computation device
        
    Returns:
        damping: Estimated damping value
    """
    diag_ggn = compute_diagonal_ggn(model, dataloader, device)
    damping = scale * diag_ggn.mean().item()
    
    log_dict(logger, "Damping estimation", {
        'scale': scale,
        'diag_mean': float(diag_ggn.mean().item()),
        'estimated_damping': damping,
    })
    
    return damping


# =============================================================================
# Convenience function for gradient computation
# =============================================================================

def compute_sample_gradient(
    model: ViTWithHooks,
    image: torch.Tensor,
    label: torch.Tensor,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Compute gradient of loss for a single sample.
    
    Args:
        model: ViT model
        image: Single image tensor (C, H, W) or (1, C, H, W)
        label: Single label tensor (scalar or (1,))
        device: Computation device
        
    Returns:
        grad_vec: Flattened gradient vector
    """
    image = image.to(device)
    label = label.to(device)
    
    if image.dim() == 3:
        image = image.unsqueeze(0)
    if label.dim() == 0:
        label = label.unsqueeze(0)
    
    params_dict = {
        name: p.detach().requires_grad_(True)
        for name, p in model.model.named_parameters()
    }
    
    logits = functional_call(model.model, params_dict, (image,))
    loss = F.cross_entropy(logits, label)
    
    grads = torch.autograd.grad(loss, list(params_dict.values()))
    grad_vec = torch.cat([g.flatten() for g in grads])
    
    return grad_vec.detach()


def compute_batch_gradients(
    model: ViTWithHooks,
    images: torch.Tensor,
    labels: torch.Tensor,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Compute per-sample gradients for a batch using vmap.
    
    Args:
        model: ViT model
        images: Batch of images (B, C, H, W)
        labels: Batch of labels (B,)
        device: Computation device
        
    Returns:
        grads: Matrix of shape (B, num_params)
    """
    images = images.to(device)
    labels = labels.to(device)
    
    params_dict = {
        name: p.detach()
        for name, p in model.model.named_parameters()
    }
    
    def single_loss(p_dict, img, lbl):
        img = img.unsqueeze(0)
        lbl = lbl.unsqueeze(0)
        logits = functional_call(model.model, p_dict, (img,))
        return F.cross_entropy(logits, lbl)
    
    grad_fn = vmap(grad(single_loss), in_dims=(None, 0, 0))
    batch_grads = grad_fn(params_dict, images, labels)
    
    # Flatten to matrix
    batch_size = images.shape[0]
    grads = torch.stack([
        torch.cat([batch_grads[name][i].flatten() for name in params_dict.keys()])
        for i in range(batch_size)
    ])
    
    return grads


if __name__ == "__main__":
    """
    Test suite for GGN operators.
    
    Tests verify:
    1. SPD property: v^T G v > 0 for all v != 0
    2. Symmetry: G v1 · v2 = v1 · G v2
    3. GGN correctness: Compare GN-HVP with finite difference approximation
    4. Consistency between operators
    """
    from .vit_full import load_vit
    from torch.backends.cuda import sdp_kernel
    
    def test_spd(operator, num_tests: int = 10, name: str = "GGN"):
        """Test that operator is symmetric positive definite."""
        print(f"\n=== Testing SPD property for {name} ===")
        n_params = operator.num_params
        all_passed = True
        
        for i in range(num_tests):
            # Random test vector
            v = torch.randn(n_params, device="cuda")
            v = v / v.norm()  # Normalize
            
            # Compute v^T G v (should be > 0 for SPD)
            Gv = operator.matvec(v).clone()
            vTGv = torch.dot(v, Gv).item()
            
            passed = vTGv > 0
            if not passed:
                all_passed = False
                print(f"  Test {i+1}: FAILED - v^T G v = {vTGv:.6e} (should be > 0)")
            else:
                print(f"  Test {i+1}: PASSED - v^T G v = {vTGv:.6e}")
        
        if all_passed:
            print(f"  ✓ All {num_tests} SPD tests passed for {name}")
        else:
            print(f"  ✗ Some SPD tests failed for {name}")
        
        return all_passed
    
    def test_symmetry(operator, num_tests: int = 10, name: str = "GGN", rtol: float = 1e-5, atol: float = 1e-7):
        """
        Robust symmetry test:
        |<Gv1,v2> - <v1,Gv2>| <= atol + rtol * (||Gv1|| ||v2|| + ||v1|| ||Gv2||)
        Uses float64 for the scalar dot products to reduce noise in the check.
        """
        print(f"\n=== Testing Symmetry for {name} (robust) ===")
        n_params = operator.num_params
        all_passed = True

        for i in range(num_tests):
            v1 = torch.randn(n_params, device="cuda")
            v2 = torch.randn(n_params, device="cuda")
            v1 = v1 / v1.norm()
            v2 = v2 / v2.norm()

            Gv1 = operator.matvec(v1).clone()
            Gv2 = operator.matvec(v2).clone()

            # compute scalars in float64
            lhs = torch.dot(Gv1.double(), v2.double())
            rhs = torch.dot(v1.double(), Gv2.double())
            diff = (lhs - rhs).abs()

            scale = (Gv1.norm().double() * v2.norm().double() + v1.norm().double() * Gv2.norm().double())
            passed = diff <= (atol + rtol * scale)

            if not passed:
                all_passed = False
                print(f"  Test {i+1}: FAILED - diff={diff.item():.3e}, scale={scale.item():.3e}, ratio={(diff/scale).item():.3e}")
            else:
                print(f"  Test {i+1}: PASSED - diff={diff.item():.3e}, scale={scale.item():.3e}, ratio={(diff/scale).item():.3e}")

        print("  ✓ All passed" if all_passed else "  ✗ Some failed")
        return all_passed

    
    def test_ggn_vs_finite_diff(model, images, labels, damping: float = 1e-2, eps: float = 1e-4):
        """
        Test GGN-vector product against finite difference approximation.
        
        For GGN: G = J^T H J where H is output space Hessian
        Finite diff: (∇L(θ + εv) - ∇L(θ - εv)) / (2ε) ≈ Hv
        
        Note: This tests the full Hessian, not GGN specifically.
        For GGN we verify the JVP-VJP structure instead.
        """
        print("\n=== Testing GGN correctness via structure ===")
        
        params_dict = {name: p.detach().clone() for name, p in model.model.named_parameters()}
        param_vec = torch.cat([p.flatten() for p in params_dict.values()])
        n_params = param_vec.numel()
        
        # Random direction
        v = torch.randn(n_params, device="cuda")
        v = v / v.norm() * 0.1  # Small magnitude
        
        # Convert v to dict
        v_dict = {}
        offset = 0
        for name, p in params_dict.items():
            numel = p.numel()
            v_dict[name] = v[offset:offset + numel].view(p.shape)
            offset += numel
        
        # Compute GN-HVP using JVP-VJP structure
        def net_fn(p_dict):
            return functional_call(model.model, p_dict, (images,))
        
        # Forward pass + JVP
        for p in params_dict.values():
            p.requires_grad_(True)
        
        logits, jv = jvp(net_fn, (params_dict,), (v_dict,))
        
        # Output space Hessian for cross-entropy: H = diag(p) - p p^T
        probs = F.softmax(logits.detach(), dim=-1)  # (batch, num_classes)
        
        # H @ Jv where H = diag(p) - pp^T
        p_dot_jv = (probs * jv).sum(dim=-1, keepdim=True)  # (batch, 1)
        Hjv = probs * jv - probs * p_dot_jv  # (batch, num_classes)
        
        # VJP: J^T @ (H @ J @ v)
        _, vjp_fn = vjp(net_fn, params_dict)
        gnhvp_dict = vjp_fn(Hjv)[0]
        
        gnhvp = torch.cat([gnhvp_dict[name].flatten() for name in params_dict.keys()])
        
        # Basic sanity checks
        print(f"  Input vector norm: {v.norm().item():.6e}")
        print(f"  GN-HVP result norm: {gnhvp.norm().item():.6e}")
        print(f"  JVP (Jv) norm: {jv.norm().item():.6e}")
        print(f"  Hjv norm: {Hjv.norm().item():.6e}")
        
        # Verify GGN is PSD: v^T G v >= 0
        # For GGN: v^T J^T H J v = (Jv)^T H (Jv)
        # For cross-entropy Fisher, H is PSD, so this should be >= 0
        vTGv = torch.dot(v, gnhvp).item()
        
        # Alternative: (Jv)^T H (Jv) = sum_i p_i (jv_i)^2 - (sum_i p_i jv_i)^2
        # This is Var_p(Jv) which is always >= 0
        jv_sq = (jv ** 2)  # (batch, num_classes)
        expected_vTGv = (probs * jv_sq).sum() - (probs * jv).sum(dim=-1).pow(2).sum()
        expected_vTGv = expected_vTGv.item()
        
        print(f"\n  v^T G v = {vTGv:.6e}")
        print(f"  Expected (variance formula) = {expected_vTGv:.6e}")
        print(f"  Relative error: {abs(vTGv - expected_vTGv) / (abs(expected_vTGv) + 1e-10):.6e}")
        
        passed = vTGv >= -1e-6  # Allow small numerical error
        if passed:
            print("  ✓ GGN structure test PASSED (v^T G v >= 0)")
        else:
            print("  ✗ GGN structure test FAILED (v^T G v < 0)")
        
        return passed
    
    def test_operator_consistency(op1, op2, num_tests: int = 5, name1: str = "Op1", name2: str = "Op2"):
        """Test that two operators give similar results."""
        print(f"\n=== Testing consistency: {name1} vs {name2} ===")
        n_params = op1.num_params
        all_passed = True
        tol = 0.1  # 10% relative tolerance (different implementations may vary)
        
        for i in range(num_tests):
            v = torch.randn(n_params, device="cuda")
            v = v / v.norm()
            
            Gv1 = op1.matvec(v)
            Gv2 = op2.matvec(v)
            
            norm1 = Gv1.norm().item()
            norm2 = Gv2.norm().item()
            
            # Cosine similarity
            cos_sim = torch.dot(Gv1, Gv2) / (Gv1.norm() * Gv2.norm() + 1e-10)
            cos_sim = cos_sim.item()
            
            # Relative norm difference
            rel_norm_diff = abs(norm1 - norm2) / (norm1 + norm2 + 1e-10)
            
            passed = cos_sim > 0.9 and rel_norm_diff < tol
            
            if not passed:
                all_passed = False
                print(f"  Test {i+1}: FAILED - cos_sim={cos_sim:.4f}, norms=({norm1:.4e}, {norm2:.4e})")
            else:
                print(f"  Test {i+1}: PASSED - cos_sim={cos_sim:.4f}, norms=({norm1:.4e}, {norm2:.4e})")
        
        if all_passed:
            print(f"  ✓ All {num_tests} consistency tests passed")
        else:
            print(f"  ✗ Some consistency tests failed")
        
        return all_passed
    
    # Run tests
    with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        print("=" * 60)
        print("GGN Operator Test Suite")
        print("=" * 60)
        
        # Load model
        print("\nLoading ViT model...")
        model = load_vit(pretrained=True, device="cuda")
        print(f"Model has {model.num_params:,} parameters")
        
        # Create test data
        print("\nCreating test data...")
        dummy_images = torch.randn(8, 3, 224, 224).cuda()
        dummy_labels = torch.randint(0, 1000, (8,)).cuda()
        
        from torch.utils.data import TensorDataset, DataLoader
        dummy_dataset = TensorDataset(dummy_images, dummy_labels)
        dummy_loader = DataLoader(dummy_dataset, batch_size=4)
        
        # Test GGN structure
        test_ggn_vs_finite_diff(model, dummy_images[:2], dummy_labels[:2])
        
        # Create operators with damping (ensures SPD)
        damping = 0.1  # Large damping for numerical stability in tests
        
        print("\n" + "=" * 60)
        print("Testing GGNOperator (JVP-VJP based GN-HVP)")
        print("=" * 60)
        ggn_op = GGNOperator(model, dummy_loader, damping=damping, use_gnhvp=True)
        ggn_op.prepare_for_solve()
        test_spd(ggn_op, num_tests=5, name="GGNOperator")
        test_symmetry(ggn_op, num_tests=3, name="GGNOperator")
        torch.cuda.empty_cache()
