"""
Preconditioned Conjugate Gradient (PCG) Solver.

Implements PCG with optional warm-start from previous solutions.
Used to solve (Ĝ + λI)z = b for cluster mean gradients.

Integrates with GGNOperator from ggn_ops.py for efficient IHVP computation.
Supports float64 computation for improved numerical stability.

DDP-aware: Scalar all-reduces on pAp, rz, rz_new keep all ranks synchronized.

Performance optimizations:
- dot64(): fp64 accumulation without full-vector casts
- Reduced .item() calls (GPU sync) in inner loop
- Batched verbose printing
"""

import torch
import torch.distributed as dist
from torch.backends.cuda import sdp_kernel
from typing import Optional, Callable, Tuple, Union, TYPE_CHECKING
import time

from .logging_utils import get_logger, log_dict, is_ddp, DDPState, allreduce_scalar

if TYPE_CHECKING:
    from .ggn_ops import GGNOperator, FastGGNOperator

logger = get_logger(__name__)


# Type alias for linear operators
LinearOperator = Union[Callable[[torch.Tensor], torch.Tensor], "GGNOperator", "FastGGNOperator"]

def allreduce_sum_(t: torch.Tensor) -> torch.Tensor:
    # t is a CUDA tensor (scalar or vector)
    if is_ddp():
        dist.all_reduce(t, op=dist.ReduceOp.SUM)
    return t

def allreduce_max_flag_(flag: torch.Tensor) -> torch.Tensor:
    # flag is a 0-dim bool/int tensor on CUDA
    if is_ddp():
        dist.all_reduce(flag, op=dist.ReduceOp.MAX)
    return flag

def global_dot64(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    s = dot64(a.flatten(), b.flatten()).to(torch.float64)  # 0-dim CUDA tensor
    return allreduce_sum_(s)

def global_norm(a: torch.Tensor) -> torch.Tensor:
    s = global_dot64(a, a)
    return torch.sqrt(torch.clamp(s, min=0.0))


def _compute_debug_info(
    A: "LinearOperator",
    x: torch.Tensor,
    b: torch.Tensor,
    b_norm: torch.Tensor,
    verbose: bool = True,
) -> dict:
    """
    Compute debug metrics for PCG solution quality.
    
    Returns:
        dict with rayleigh_ratio, Hx_over_lamx, cos_xb, ridge comparison metrics
    """
    from .logging_utils import rank0_print
    
    # Get damping from operator if available
    damping = getattr(A, 'damping', 5.0)
    
    # Rayleigh ratio: R(x) = x^T A x / ||x||²
    Ax = _apply_operator(A, x)
    xTAx = global_dot64(x, Ax).item()
    x_norm_sq = global_dot64(x, x).item()
    x_norm = x_norm_sq ** 0.5
    rayleigh_ratio = xTAx / (x_norm_sq + 1e-10)
    
    if verbose:
        rank0_print(f"  Rayleigh ratio: R(x) = x^T A x / ||x||² = {rayleigh_ratio:.6e}")
    
    # ||Hx|| / ||λx|| where H = A - λI (GGN without damping)
    lam = damping
    Hx = Ax - lam * x  # Remove damping contribution to get pure GGN part
    Hx_norm_sq = global_dot64(Hx, Hx).item()
    lamx_norm_sq = global_dot64(lam * x, lam * x).item()
    Hx_over_lamx = (Hx_norm_sq / (lamx_norm_sq + 1e-12)) ** 0.5
    
    if verbose:
        rank0_print(f"  ||Hx|| / ||λx|| = {Hx_over_lamx:.3f}")
    
    # Cosine similarity between solution x and RHS b
    b_norm_val = b_norm.item() if isinstance(b_norm, torch.Tensor) else b_norm
    b_norm_sq = b_norm_val ** 2
    xb_dot = global_dot64(x, b).item()
    cos_xb = xb_dot / ((x_norm_sq * b_norm_sq + 1e-12) ** 0.5)
    
    if verbose:
        rank0_print(f"  cos(x, b) = {cos_xb:.3f}")
    
    # Ridge-limit comparison: x_ridge = b / λ
    # In the ridge limit (λ >> eigenvalues of H), (H + λI)^{-1} ≈ (1/λ)I
    x_ridge = b / lam
    x_ridge_norm_sq = global_dot64(x_ridge, x_ridge).item()
    x_ridge_norm = x_ridge_norm_sq ** 0.5
    
    # Cosine similarity: cos(x, x_ridge)
    x_xridge_dot = global_dot64(x, x_ridge).item()
    cos_x_xridge = x_xridge_dot / ((x_norm_sq * x_ridge_norm_sq + 1e-12) ** 0.5)
    
    # Relative error: ||x - x_ridge|| / ||x||
    diff = x - x_ridge
    diff_norm_sq = global_dot64(diff, diff).item()
    rel_err_ridge = (diff_norm_sq ** 0.5) / (x_norm + 1e-12)
    
    if verbose:
        rank0_print(f"  Ridge limit: cos(x, b/λ) = {cos_x_xridge:.3f}, ||x - b/λ|| / ||x|| = {rel_err_ridge:.3f}")
    
    return {
        'rayleigh_ratio': rayleigh_ratio,
        'Hx_over_lamx': Hx_over_lamx,
        'cos_xb': cos_xb,
        'x_norm': x_norm,
        'rhs_norm': b_norm_val,
        'cos_x_xridge': cos_x_xridge,
        'rel_err_ridge': rel_err_ridge,
        'x_ridge_norm': x_ridge_norm,
    }


def dot64(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Compute dot product with fp64 accumulation WITHOUT casting full vectors.
    
    This is much faster than torch.dot(a.double(), b.double()) which 
    materializes two full fp64 copies of potentially 89M-element vectors.
    
    Returns a scalar tensor in fp64.
    """
    return (a * b).sum(dtype=torch.float64)


def _apply_operator(
    A: LinearOperator,
    x: torch.Tensor,
) -> torch.Tensor:
    """Apply linear operator, handling both callables and GGN operator objects."""
    if callable(A):
        return A(x)
    elif hasattr(A, 'matvec'):
        return A.matvec(x)
    else:
        raise TypeError(f"Unsupported operator type: {type(A)}")


def pcg_solve(
    A: LinearOperator,
    b: torch.Tensor,
    M_inv: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    x0: Optional[torch.Tensor] = None,
    max_iter: int = 50,
    tol: float = 1e-3,
    atol: float = 1e-8,
    device: str = "cuda",
    verbose: bool = True,
) -> Tuple[torch.Tensor, dict]:
    """
    Preconditioned Conjugate Gradient solver.
    
    Solves Ax = b using PCG with optional preconditioner M^{-1}.
    
    DDP behavior:
    - All ranks run PCG in lock-step
    - Scalar dot products (pAp, rz, rz_new) are all-reduced to stay synchronized
    - matvec A(v) internally does vector all-reduce
    
    Algorithm:
        x0 = warm_start or zeros
        r = b - A @ x0
        z = M_inv(r)
        p = z
        for k in range(max_iter):
            Ap = A @ p
            alpha = (r · z) / (p · Ap)
            x += alpha * p
            r_new = r - alpha * Ap
            if ||r_new|| < tol * ||b||: break
            z_new = M_inv(r_new)
            beta = (r_new · z_new) / (r · z)
            p = z_new + beta * p
    
    Args:
        A: Linear operator (callable or GGNOperator), computes Av for vector v
        b: Right-hand side vector
        M_inv: Optional preconditioner (callable), computes M^{-1}v
        x0: Optional initial guess (warm start)
        max_iter: Maximum number of iterations
        tol: Relative tolerance for convergence
        atol: Absolute tolerance for convergence
        device: Computation device
        verbose: Print convergence info
        
    Returns:
        x: Solution vector
        info: Dict with convergence information
    """
    # Auto-prepare GGNOperator if it has prepare_for_solve
    # Fixed guard: check for None explicitly
    ddp = is_ddp()
    if hasattr(A, 'prepare_for_solve'):
        A.prepare_for_solve()
        if ddp:
            dist.barrier()
    
    b = b.to(device)
    n = b.shape[0]
    logger.debug(f"PCG solve: n={n:,}, max_iter={max_iter}, tol={tol}, ddp={ddp}")
    
    # Wrap operator for float64 if needed
    def A_wrapped(x):
        return _apply_operator(A, x).clone()
    
    def M_inv_wrapped(x):
        if M_inv is None:
            return x.clone()
        return M_inv(x)
    
    # Initialize
    if x0 is not None:
        x = x0.to(device).clone()
        logger.debug(f"Using warm start with norm {x0.norm().item():.6e}")
        # Initial residual: r = b - A(x0)
        r = b - A_wrapped(x)
    else:
        x = torch.zeros_like(b)
        # When x0 = 0, A(x0) = 0, so r = b - 0 = b (skip useless matvec)
        r = b.clone()
    
    b_norm = global_norm(b)
    r_norm = global_norm(r) 

    conv = ((r_norm < atol) | (r_norm < tol * b_norm)).to(torch.int32)
    conv = allreduce_max_flag_(conv)
    if conv.item():
        if ddp: dist.barrier()
        logger.info(f"PCG converged immediately: residual={r_norm.item():.6e}")
        return x, {'converged': True, 'iterations': 0, 'residual': r_norm.item()}
    
    # Apply preconditioner
    logger.info(f"z norm before M_inv: {r_norm.item():.6e}")
    z = M_inv_wrapped(r)
    p = z.clone()
    # Use dot64 for fp64 accumulation without materializing full fp64 vectors
    rz = global_dot64(r, z)

    bad_rz = (rz <= 0).to(torch.int32)
    bad_rz = allreduce_max_flag_(bad_rz)
    if bad_rz.item():
        logger.warning("PCG: non-positive r^T z; disabling preconditioner for this solve.")
        z = r.clone()
        p = z.clone()
        rz = global_dot64(r, z)

    # Tuning knobs
    recompute_every = 10          # true residual refresh
    eps = 1e-30
    print_every = 1               # only print every N iterations to reduce syncs

    residuals = [r_norm]
    start_time = time.time()

    best_x = x.clone()
    best_r_norm = r_norm
    
    # Keep alpha/beta as tensors to avoid .item() syncs every iteration
    last_alpha = torch.tensor(0.0, device=device)
    last_beta = torch.tensor(0.0, device=device)
    
    # Track timings in DDP mode
    iter_times = []

    for k in range(max_iter):
        iter_start = time.time()
        
        Ap = A_wrapped(p)

        # fp64 accumulation for stability WITHOUT full vector casts
        pAp = global_dot64(p, Ap)

        bad = ((~torch.isfinite(pAp)) | (pAp <= 0)).to(torch.int32)
        bad = allreduce_max_flag_(bad)
        if bad.item():
            logger.warning(f"PCG: non-positive/invalid pAp={pAp.item():.3e} at iter {k}, restarting")
            break
        
        # rz is already fp64 from dot64, compute alpha in fp64
        alpha = (rz / pAp).to(p.dtype)
        last_alpha = alpha

        x = x + alpha * p

        # residual update (with periodic true recompute)
        if (recompute_every is not None) and (recompute_every > 0) and ((k + 1) % recompute_every == 0):
            r_new = b - A_wrapped(x)
        else:
            r_new = r - alpha * Ap

        # Only sync for r_norm (needed for convergence check)
        r_norm = global_norm(r_new)
        residuals.append(r_norm.item())

        # track best iterate
        if r_norm < best_r_norm:
            best_r_norm = r_norm
            best_x = x.clone()

        conv = ((r_norm < atol) | (r_norm < tol * b_norm)).to(torch.int32)
        conv = allreduce_max_flag_(conv)
        if conv.item():
            if ddp: 
                dist.barrier()
            elapsed = time.time() - start_time
            
            # Compute debug info
            debug_info = _compute_debug_info(A, x, b, b_norm, verbose)
            
            log_dict(logger, "PCG converged", {
                "iterations": k + 1,
                "final_residual": r_norm.item(),
                "relative_residual": r_norm.item() / (b_norm.item() + 1e-10),
                "time_seconds": elapsed,
                "residual_reduction": residuals[0].item() / (r_norm.item() + 1e-10),
                "ddp": ddp,
                **debug_info,
            })
            return x, {"converged": True, "iterations": k + 1, "residual": r_norm.item(), "residuals": residuals, "time": time.time() - start_time, **debug_info}

        # precondition
        z_new = M_inv_wrapped(r_new)

        # fp64 rz_new using dot64 (no full vector casts)
        rz_new = global_dot64(r_new, z_new)

        bad = ((~torch.isfinite(rz_new)) | (rz.abs() < 1e-15)).to(torch.int32)
        bad = allreduce_max_flag_(bad)
        if bad.item():
            logger.warning(f"PCG: rz breakdown (rz={rz.item():.3e}, rz_new={rz_new.item():.3e}) at iter {k}")
            break

        beta = (rz_new / (rz + eps)).to(p.dtype)
        last_beta = beta

        p = z_new + beta * p

        # advance
        r = r_new
        z = z_new
        rz = rz_new  # keep as fp64 tensor
        prev_r_norm = r_norm
        
        iter_time = time.time() - iter_start
        iter_times.append(iter_time)

        # Batched verbose output to reduce syncs (rank 0 only in DDP)
        if verbose and (k + 1) % print_every == 0:
            from .logging_utils import rank0_print
            rank0_print(f"PCG iter {k+1}: residual={r_norm:.6e}, rel_residual={(r_norm/(b_norm + 1e-10)):.6e}, "
                f"alpha={last_alpha.item():.6e}, beta={last_beta.item():.6e}")

    elapsed = time.time() - start_time
    if verbose:
        from .logging_utils import rank0_print
        rank0_print(f"PCG did not converge in {max_iter} iterations (residual: {r_norm:.6e})")
    
    # Compute debug info on best solution
    debug_info = _compute_debug_info(A, best_x, b, b_norm, verbose)
    
    log_dict(logger, "PCG did not converge", {
        'iterations': max_iter,
        'final_residual': r_norm,
        'relative_residual': r_norm / (b_norm + 1e-10),
        'time_seconds': elapsed,
        'residual_reduction': residuals[0] / (r_norm + 1e-10),
        'avg_iter_ms': 1000 * sum(iter_times) / len(iter_times) if iter_times else 0,
        'ddp': ddp,
        **debug_info,
    }, level='warning')
    
    return best_x, {
        'converged': False,
        'iterations': max_iter,
        'residual': r_norm,
        'residuals': residuals,
        'best_residual': best_r_norm,
        'time': elapsed,
        **debug_info,
    }


class PCGSolver:
    """
    PCG Solver class with warm-start management.
    
    Maintains a cache of previous solutions for warm-starting
    subsequent solves on similar RHS vectors.
    
    Integrates with GGNOperator from ggn_ops.py for efficient
    Gauss-Newton matrix-vector products.
    
    Performance: Automatically calls prepare_for_solve() on GGNOperator
    to enable frozen params and cached batches.
    """
    
    def __init__(
        self,
        A: LinearOperator,
        M_inv: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        max_iter: int = 50,
        tol: float = 1e-3,
        atol: float = 1e-8,
        device: str = "cuda",
    ):
        """
        Args:
            A: Linear operator (callable or GGNOperator)
            M_inv: Optional preconditioner
            max_iter: Maximum CG iterations
            tol: Relative convergence tolerance
            atol: Absolute convergence tolerance
            device: Computation device
        """
        self.A = A
        self.M_inv = M_inv
        self.max_iter = max_iter
        self.tol = tol
        self.atol = atol
        self.device = device
        
        # Cache for warm starts
        # Maps cluster_id -> (centroid, solution)
        self._solution_cache: dict = {}
        self._centroid_cache: dict = {}
        
        # Stats tracking
        self._solve_count = 0
        self._warm_start_used = 0
        self._total_iterations = 0
        self._total_time = 0.0
        
        # Track if we've prepared the operator
        self._operator_prepared = True
        
        logger.info(f"PCGSolver initialized: max_iter={max_iter}, tol={tol}")
    
    def prepare(self):
        """
        Prepare the operator for solving (cache batches, freeze params).
        
        Called automatically on first solve, but can be called manually
        for explicit control.
        """
        return
        self.operator.prepare_for_solve()
        self._operator_prepared = True
        logger.info("GGNOperator prepared for solving (batches cached, params frozen)")
        self.operator.cache_batches()
        self._operator_prepared = True
        logger.info("Operator batches cached")
    
    @classmethod
    def from_ggn_operator(
        cls,
        ggn_op: "GGNOperator",
        M_inv: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        max_iter: int = 50,
        tol: float = 1e-3,
        atol: float = 1e-8,
    ) -> "PCGSolver":
        """
        Create PCGSolver from a GGNOperator.
        
        Args:
            ggn_op: GGNOperator instance from ggn_ops.py
            M_inv: Optional preconditioner
            max_iter: Maximum iterations
            tol: Relative tolerance
            atol: Absolute tolerance
            
        Returns:
            PCGSolver instance
        """
        return cls(
            A=ggn_op,
            M_inv=M_inv,
            max_iter=max_iter,
            tol=tol,
            atol=atol,
            device=ggn_op.device,
        )
        
    def solve(
        self,
        b: torch.Tensor,
        cluster_id: Optional[int] = None,
        centroid: Optional[torch.Tensor] = None,
        use_warm_start: bool = False,
        verbose: bool = True,
        x0: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, dict]:
        """
        Solve Ax = b with optional warm-start.
        
        Args:
            b: Right-hand side vector
            cluster_id: Optional cluster ID for caching
            centroid: Optional cluster centroid for finding nearest cached solution
            use_warm_start: Whether to use warm-starting from cached centroids
            verbose: Print info
            x0: Explicit initial guess (takes precedence over centroid warm-start)
            
        Returns:
            x: Solution vector
            info: Convergence information
        """
        # # Auto-prepare operator on first solve (caches batches, freezes params)
        # if not self._operator_prepared:
        #     self.prepare()
        
        warm_start_source = None
        
        # x0 takes precedence if provided explicitly
        if x0 is not None:
            if verbose:
                print(f"  Using explicit x0 as initial guess (norm={x0.norm().item():.4f})")
            warm_start_source = "explicit_x0"
        elif use_warm_start and centroid is not None and len(self._centroid_cache) > 0:
            # Find nearest cached solution
            x0, nearest_id = self._find_nearest_solution(centroid)
            if verbose and x0 is not None:
                print(f"  Using warm-start from cluster {nearest_id}")
            warm_start_source = nearest_id
        
        # Solve (debug info now computed inside pcg_solve)
        with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            x, info = pcg_solve(
                self.A, b, self.M_inv, x0,
                max_iter=self.max_iter,
                tol=self.tol,
                atol=self.atol,
                device=self.device,
                verbose=verbose,
            )
            
        # Update statistics
        self._solve_count += 1
        if warm_start_source is not None:
            self._warm_start_used += 1
        self._total_iterations += info['iterations']
        self._total_time += info.get('time', 0.0)
        
        # Cache solution
        if cluster_id is not None:
            self._solution_cache[cluster_id] = x.detach().cpu()
            if centroid is not None:
                self._centroid_cache[cluster_id] = centroid.detach().cpu()
        
        # Log solve summary
        logger.debug(f"PCG solve #{self._solve_count}: cluster={cluster_id}, warm_start={warm_start_source}, "
                    f"iters={info['iterations']}, converged={info['converged']}")
        
        return x, info
    
    def _find_nearest_solution(
        self,
        centroid: torch.Tensor,
    ) -> Tuple[Optional[torch.Tensor], Optional[int]]:
        """
        Find the cached solution with nearest centroid.
        
        Uses cosine similarity to find most similar cluster.
        
        Args:
            centroid: Query centroid
            
        Returns:
            Tuple of (solution, cluster_id) or (None, None)
        """
        if len(self._centroid_cache) == 0:
            return None, None
        
        centroid = centroid.cpu().flatten()
        centroid_norm = centroid.norm()
        
        if centroid_norm < 1e-8:
            return None, None
        
        best_sim = -1
        best_id = None
        
        for cid, cached_centroid in self._centroid_cache.items():
            cached_centroid = cached_centroid.flatten()
            cached_norm = cached_centroid.norm()
            
            if cached_norm < 1e-8:
                continue
            
            sim = torch.dot(centroid, cached_centroid) / (centroid_norm * cached_norm)
            
            if sim > best_sim:
                best_sim = sim
                best_id = cid
        
        if best_id is not None:
            return self._solution_cache[best_id].to(self.device), best_id
        
        return None, None
    
    def clear_cache(self):
        """Clear solution cache."""
        self._solution_cache.clear()
        self._centroid_cache.clear()
        logger.debug("Cleared PCG solution cache")
    
    def get_cached_solution(self, cluster_id: int) -> Optional[torch.Tensor]:
        """Get cached solution for a cluster."""
        return self._solution_cache.get(cluster_id)
    
    @property
    def num_cached(self) -> int:
        """Number of cached solutions."""
        return len(self._solution_cache)
    
    def get_statistics(self) -> dict:
        """Get solver statistics."""
        stats = {
            'total_solves': self._solve_count,
            'warm_starts_used': self._warm_start_used,
            'warm_start_ratio': self._warm_start_used / max(1, self._solve_count),
            'total_iterations': self._total_iterations,
            'avg_iterations': self._total_iterations / max(1, self._solve_count),
            'total_time': self._total_time,
            'avg_time': self._total_time / max(1, self._solve_count),
            'cached_solutions': self.num_cached,
        }
        logger.info(f"PCG Statistics: {stats}")
        return stats


def cg_solve(
    A: LinearOperator,
    b: torch.Tensor,
    x0: Optional[torch.Tensor] = None,
    max_iter: int = 50,
    tol: float = 1e-3,
    atol: float = 1e-8,
    device: str = "cuda",
    verbose: bool = True,
) -> Tuple[torch.Tensor, dict]:
    """
    Standard Conjugate Gradient (no preconditioner).
    
    Convenience wrapper for pcg_solve without preconditioner.
    
    Args:
        A: Linear operator (callable or GGNOperator)
        b: Right-hand side vector
        x0: Optional initial guess
        max_iter: Maximum iterations
        tol: Relative tolerance
        atol: Absolute tolerance
        device: Computation device
        verbose: Print info
        
    Returns:
        x: Solution vector
        info: Convergence information
    """
    return pcg_solve(A, b, None, x0, max_iter, tol, atol, device, verbose)


def estimate_condition_number(
    A: LinearOperator,
    n: int,
    k: int = 20,
    device: str = "cuda",
) -> Tuple[float, float]:
    """
    Estimate condition number using power iteration and Lanczos.
    
    Estimates largest and smallest eigenvalues to compute
    condition number κ(A) = λ_max / λ_min.
    
    Args:
        A: Linear operator (callable or GGNOperator)
        n: Dimension
        k: Number of power iterations
        device: Computation device
        
    Returns:
        Tuple of (lambda_max, estimated_condition_number)
    """
    logger.info(f"Estimating condition number with {k} power iterations...")
    
    # Estimate largest eigenvalue via power iteration
    v = torch.randn(n, device=device, dtype=torch.float64)
    v = v / v.norm()
    
    for _ in range(k):
        Av = _apply_operator(A, v.float()).double()
        lambda_max = v.dot(Av)
        v_norm = Av.norm()
        if v_norm > 1e-10:
            v = Av / v_norm
        else:
            break
    
    lambda_max = v.dot(_apply_operator(A, v.float()).double()).item()
    
    # Simple Lanczos for eigenvalue range estimation
    alpha_list = []
    beta_list = []
    
    q = torch.randn(n, device=device, dtype=torch.float64)
    q = q / q.norm()
    q_prev = torch.zeros_like(q)
    beta = 0.0
    
    for j in range(min(k, 30)):
        if j > 0:
            q_new = _apply_operator(A, q.float()).double() - beta * q_prev
        else:
            q_new = _apply_operator(A, q.float()).double()
        
        alpha = q.dot(q_new).item()
        alpha_list.append(alpha)
        
        q_new = q_new - alpha * q
        if j > 0:
            q_new = q_new - beta * q_prev
        
        beta = q_new.norm().item()
        if beta < 1e-10:
            break
        beta_list.append(beta)
        
        q_prev = q
        q = q_new / beta
    
    # Build tridiagonal matrix and compute its eigenvalues
    m = len(alpha_list)
    if m > 1:
        T = torch.zeros(m, m, device=device, dtype=torch.float64)
        for i in range(m):
            T[i, i] = alpha_list[i]
            if i < m - 1 and i < len(beta_list):
                T[i, i+1] = beta_list[i]
                T[i+1, i] = beta_list[i]
        
        try:
            eigs = torch.linalg.eigvalsh(T)
            lambda_min = max(eigs[0].item(), 1e-10)
            lambda_max_lanczos = eigs[-1].item()
            lambda_max = max(lambda_max, lambda_max_lanczos)
            condition_number = lambda_max / lambda_min
        except:
            condition_number = lambda_max / 1e-10
            lambda_min = 1e-10
    else:
        condition_number = lambda_max / 1e-10
        lambda_min = 1e-10
    
    log_dict(logger, "Condition number estimation", {
        'largest_eigenvalue': lambda_max,
        'smallest_eigenvalue': lambda_min,
        'condition_number': condition_number,
        'method': 'power_iteration+lanczos',
        'iterations': k,
    })
    
    return lambda_max, condition_number


def create_diagonal_preconditioner(
    diag_ggn: torch.Tensor,
    damping: float = 0.0,
    min_value: float = 1e-8,
) -> Callable[[torch.Tensor], torch.Tensor]:
    """
    Create a diagonal preconditioner from diagonal GGN estimate.
    
    M^{-1} x = x / (diag + damping)
    
    Args:
        diag_ggn: Diagonal of GGN matrix
        damping: Damping value to add
        min_value: Minimum diagonal value for numerical stability
        
    Returns:
        Preconditioner function
    """
    diag_inv = 1.0 / torch.clamp(diag_ggn + damping, min=min_value)
    
    def preconditioner(x: torch.Tensor) -> torch.Tensor:
        return x * diag_inv.to(x.device)
    
    logger.info(f"Created diagonal preconditioner: diag range [{diag_ggn.min():.2e}, {diag_ggn.max():.2e}]")
    return preconditioner


def solve_ihvp(
    ggn_op: "GGNOperator",
    v: torch.Tensor,
    max_iter: int = 100,
    tol: float = 1e-6,
    preconditioner: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
    verbose: bool = True,
) -> Tuple[torch.Tensor, dict]:
    """
    Solve inverse Hessian-vector product: H^{-1} v.
    
    Convenience function that wraps PCG solver for IHVP computation
    using GGNOperator from ggn_ops.py.
    
    Args:
        ggn_op: GGNOperator instance
        v: Vector to multiply with H^{-1}
        max_iter: Maximum CG iterations
        tol: Convergence tolerance
        preconditioner: Optional preconditioner
        verbose: Print convergence info
        
    Returns:
        ihvp: H^{-1} v
        info: Convergence information
    """
    return pcg_solve(
        A=ggn_op,
        b=v,
        M_inv=preconditioner,
        x0=None,
        max_iter=max_iter,
        tol=tol,
        atol=1e-10,
        device=ggn_op.device,
        verbose=verbose,
    )


if __name__ == "__main__":
    """
    PCG Test Suite with ViT GGN Operator.
    
    Tests:
    1. Determinism: ||A(v) - A(v)|| / ||A(v)|| should be ~0
    2. Rayleigh quotient: v^T A v / v^T v (positive for SPD)
    3. Symmetry: <Av1, v2> = <v1, Av2>
    4. PCG convergence vs damping sweep
    5. IHVP accuracy verification
    """
    import numpy as np
    from torch.utils.data import TensorDataset, DataLoader
    from .vit_full import load_vit
    from .ggn_ops import GGNOperator
    # Set seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    torch.cuda.init()
     # Test PCG solver
    print("Testing PCG solver...")
    
    # Create test problem: A = diag + low-rank
    n = 1000
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # SPD matrix
    D = torch.abs(torch.randn(n, device=device)) + 0.1
    V = torch.randn(n, 10, device=device)
    V = V / V.norm(dim=0)
    
    def A_op(x):
        return D * x + V @ (V.T @ x) + 0.01 * x
    
    # Preconditioner (diagonal)
    def M_inv(x):
        return x / (D + 0.01)
    
    # True solution
    x_true = torch.randn(n, device=device)
    b = A_op(x_true)
    
    # Solve without preconditioner
    print("\nWithout preconditioner:")
    x_cg, info_cg = cg_solve(A_op, b, max_iter=100, tol=1e-6, device=device)
    print(f"  Error: {(x_cg - x_true).norm().item():.6e}")
    print(f"  Iterations: {info_cg['iterations']}")
    
    # Solve with preconditioner
    print("\nWith diagonal preconditioner:")
    x_pcg, info_pcg = pcg_solve(A_op, b, M_inv, max_iter=100, tol=1e-6, device=device)
    print(f"  Error: {(x_pcg - x_true).norm().item():.6e}")
    print(f"  Iterations: {info_pcg['iterations']}")
    
    # Test warm-start
    print("\nTesting warm-start...")
    b2 = b + 0.1 * torch.randn_like(b)  # Slightly perturbed RHS
    
    print("Cold start:")
    _, info1 = pcg_solve(A_op, b2, M_inv, max_iter=100, tol=1e-6, device=device)
    
    print("Warm start (from previous solution):")
    _, info2 = pcg_solve(A_op, b2, M_inv, x0=x_pcg, max_iter=100, tol=1e-6, device=device)
    
    print(f"\nIterations: cold={info1['iterations']}, warm={info2['iterations']}")
    
    # # Test condition number estimation
    # print("\nEstimating condition number...")
    # lambda_max, cond = estimate_condition_number(A_op, n, k=30, device=device)
    # print(f"  λ_max: {lambda_max:.6e}, κ(A): {cond:.2e}")
    
    # Test diagonal preconditioner creation
    print("\nTesting create_diagonal_preconditioner...")
    diag_estimate = D + 0.01
    M_inv_created = create_diagonal_preconditioner(diag_estimate, damping=0.0)
    x_pcg2, info_pcg2 = pcg_solve(A_op, b, M_inv_created, max_iter=100, tol=1e-6, device=device)
    print(f"  Error: {(x_pcg2 - x_true).norm().item():.6e}")
    print(f"  Iterations: {info_pcg2['iterations']}")

    def test_determinism(ggn_op: "GGNOperator", num_tests: int = 5) -> dict:
        """
        Test operator determinism: ||A(v) - A(v)|| / ||A(v)||.
        
        For deterministic operators, this should be exactly 0.
        """
        print("\n" + "=" * 60)
        print("TEST: Operator Determinism")
        print("=" * 60)
        
        n = ggn_op.num_params
        results = []
        
        for i in range(num_tests):
            # Generate random vector
            v = torch.randn(n, device=ggn_op.device)
            v = v / v.norm()
            
            # Apply operator twice
            Av1 = ggn_op.matvec(v)
            Av2 = ggn_op.matvec(v)
            
            # Compute relative difference
            diff_norm = (Av1 - Av2).norm().item()
            Av_norm = Av1.norm().item()
            rel_diff = diff_norm / (Av_norm + 1e-10)
            
            results.append(rel_diff)
            status = "✓ PASS" if rel_diff < 1e-6 else "✗ FAIL"
            print(f"  Test {i+1}: ||A(v)-A(v)||/||A(v)|| = {rel_diff:.2e} {status}")
        
        mean_diff = np.mean(results)
        max_diff = np.max(results)
        all_passed = max_diff < 1e-6
        
        print(f"\n  Summary: mean={mean_diff:.2e}, max={max_diff:.2e}")
        print(f"  {'✓ All determinism tests PASSED' if all_passed else '✗ Some determinism tests FAILED'}")
        
        return {
            'passed': all_passed,
            'mean_relative_diff': mean_diff,
            'max_relative_diff': max_diff,
            'all_diffs': results,
        }
    
    def test_rayleigh_quotient(ggn_op: "GGNOperator", num_tests: int = 10) -> dict:
        """
        Test Rayleigh quotient: R(v) = v^T A v / v^T v.
        
        For SPD matrices, R(v) > 0 for all v != 0.
        R(v) gives eigenvalue bounds: λ_min ≤ R(v) ≤ λ_max.
        """
        print("\n" + "=" * 60)
        print("TEST: Rayleigh Quotient (SPD Check)")
        print("=" * 60)
        
        n = ggn_op.num_params
        rayleigh_values = []
        
        for i in range(num_tests):
            v = torch.randn(n, device=ggn_op.device)
            v_norm_sq = torch.dot(v, v).item()
            
            Av = ggn_op.matvec(v)
            vTAv = torch.dot(v, Av).item()
            
            rayleigh = vTAv / v_norm_sq
            rayleigh_values.append(rayleigh)
            
            status = "✓ PASS" if rayleigh > 0 else "✗ FAIL"
            print(f"  Test {i+1}: R(v) = v^T A v / ||v||² = {rayleigh:.6e} {status}")
        
        all_positive = all(r > 0 for r in rayleigh_values)
        min_rayleigh = min(rayleigh_values)
        max_rayleigh = max(rayleigh_values)
        
        print(f"\n  Rayleigh range: [{min_rayleigh:.6e}, {max_rayleigh:.6e}]")
        print(f"  Estimated condition number (upper bound): {max_rayleigh / (min_rayleigh + 1e-10):.2e}")
        print(f"  {'✓ All Rayleigh tests PASSED (matrix is SPD)' if all_positive else '✗ Matrix is NOT SPD'}")
        
        return {
            'passed': all_positive,
            'min_rayleigh': min_rayleigh,
            'max_rayleigh': max_rayleigh,
            'condition_estimate': max_rayleigh / (min_rayleigh + 1e-10),
            'all_values': rayleigh_values,
        }
    
    def test_symmetry(ggn_op: "GGNOperator", num_tests: int = 5) -> dict:
        """
        Test operator symmetry: <Av1, v2> = <v1, Av2>.
        """
        print("\n" + "=" * 60)
        print("TEST: Operator Symmetry")
        print("=" * 60)
        
        n = ggn_op.num_params
        results = []
        
        for i in range(num_tests):
            v1 = torch.randn(n, device=ggn_op.device)
            v2 = torch.randn(n, device=ggn_op.device)
            v1 = v1 / v1.norm()
            v2 = v2 / v2.norm()
            
            Av1 = ggn_op.matvec(v1)
            Av2 = ggn_op.matvec(v2)
            
            lhs = torch.dot(Av1, v2).item()  # <Av1, v2>
            rhs = torch.dot(v1, Av2).item()  # <v1, Av2>
            
            rel_diff = abs(lhs - rhs) / (abs(lhs) + abs(rhs) + 1e-10)
            results.append(rel_diff)
            
            status = "✓ PASS" if rel_diff < 1e-4 else "✗ FAIL"
            print(f"  Test {i+1}: <Av1,v2>={lhs:.6e}, <v1,Av2>={rhs:.6e}, rel_diff={rel_diff:.2e} {status}")
        
        mean_diff = np.mean(results)
        max_diff = np.max(results)
        all_passed = max_diff < 1e-4
        
        print(f"\n  Summary: mean_diff={mean_diff:.2e}, max_diff={max_diff:.2e}")
        print(f"  {'✓ All symmetry tests PASSED' if all_passed else '✗ Some symmetry tests FAILED'}")
        
        return {
            'passed': all_passed,
            'mean_diff': mean_diff,
            'max_diff': max_diff,
            'all_diffs': results,
        }
    
    def test_pcg_damping_sweep(
        model,
        dataloader,
        damping_values: list,
        num_rhs: int = 3,
        diag_ggn: torch.Tensor = None,
    ) -> dict:
        """
        Sweep damping values and monitor PCG behavior.
        
        For each damping λ, monitors:
        - PCG iterations to converge
        - Relative residual
        - Symmetry check
        - Determinism check
        """
        print("\n" + "=" * 60)
        print("TEST: PCG Damping Sweep")
        print(f"λ ∈ {damping_values}")
        print("=" * 60)
        
        n_params = model.num_params
        results = {}
        
        # Generate fixed RHS vectors for consistent comparison
        torch.manual_seed(123)
        rhs_vectors = [torch.randn(n_params, device="cuda") for _ in range(num_rhs)]
        rhs_vectors = [v / v.norm() for v in rhs_vectors]
        
        print(f"\nUsing {num_rhs} random RHS vectors (normalized)")
        print("-" * 60)
        
        for damping in damping_values:
            print(f"\n>>> Damping λ = {damping}")
            
            # Create GGN operator with this damping
            ggn_op = GGNOperator(
                model=model,
                dataloader=dataloader,
                damping=damping,
                use_gnhvp=True,
                device="cuda",
            )
            
            damping_results = {
                'iterations': [],
                'residuals': [],
                'converged': [],
                'determinism': [],
                'symmetry': [],
                'rayleigh': [],
            }
            
            # Quick determinism check
            v_test = rhs_vectors[0]
            Av1 = ggn_op.matvec(v_test)
            Av2 = ggn_op.matvec(v_test)
            det_diff = (Av1 - Av2).norm().item() / (Av1.norm().item() + 1e-10)
            damping_results['determinism'].append(det_diff)
            
            # Quick symmetry check
            v1, v2 = rhs_vectors[0], rhs_vectors[1] if num_rhs > 1 else rhs_vectors[0]
            Av1 = ggn_op.matvec(v1)
            Av2 = ggn_op.matvec(v2)
            sym_lhs = torch.dot(Av1, v2).item()
            sym_rhs = torch.dot(v1, Av2).item()
            sym_diff = abs(sym_lhs - sym_rhs) / (abs(sym_lhs) + abs(sym_rhs) + 1e-10)
            damping_results['symmetry'].append(sym_diff)
            
            # Rayleigh quotient
            rayleigh = torch.dot(v_test, Av1).item() / torch.dot(v_test, v_test).item()
            damping_results['rayleigh'].append(rayleigh)
            
            # Create preconditioner for this damping
            if diag_ggn is not None:
                precond = create_diagonal_preconditioner(diag_ggn, damping=damping)
            else:
                precond = None
            
            # Solve IHVP for each RHS
            for j, rhs in enumerate(rhs_vectors):
                print(f"  RHS {j+1}: ", end="")
                
                x, info = pcg_solve(
                    A=ggn_op,
                    b=rhs,
                    M_inv=precond,
                    max_iter=100,
                    tol=1e-4,
                    atol=1e-8,
                    device="cuda",
                    verbose=True,
                )
                
                damping_results['iterations'].append(info['iterations'])
                damping_results['residuals'].append(info['residual'])
                damping_results['converged'].append(info['converged'])
                
                status = "✓" if info['converged'] else "✗"
                print(f"iters={info['iterations']:3d}, residual={info['residual']:.2e} {status}")
            
            # Summary for this damping
            avg_iters = np.mean(damping_results['iterations'])
            avg_residual = np.mean(damping_results['residuals'])
            convergence_rate = np.mean(damping_results['converged'])
            
            print(f"  Summary: avg_iters={avg_iters:.1f}, avg_residual={avg_residual:.2e}, "
                  f"converged={convergence_rate*100:.0f}%")
            print(f"  Checks: determinism={det_diff:.2e}, symmetry={sym_diff:.2e}, rayleigh={rayleigh:.2e}")
            
            results[damping] = damping_results
        
        # Print comparison table
        print("\n" + "=" * 60)
        print("SUMMARY TABLE")
        print("=" * 60)
        print(f"{'Damping':<10} {'Avg Iters':<12} {'Avg Residual':<14} {'Converged':<12} {'Rayleigh':<12} {'Sym Error':<12}")
        print("-" * 72)
        
        for damping in damping_values:
            r = results[damping]
            avg_iters = np.mean(r['iterations'])
            avg_residual = np.mean(r['residuals'])
            conv_rate = np.mean(r['converged']) * 100
            rayleigh = r['rayleigh'][0]
            sym_err = r['symmetry'][0]
            
            print(f"{damping:<10.4f} {avg_iters:<12.1f} {avg_residual:<14.2e} {conv_rate:<12.0f}% {rayleigh:<12.2e} {sym_err:<12.2e}")
        
        return results
    
    def test_ihvp_accuracy(ggn_op: "GGNOperator", num_tests: int = 3, preconditioner=None) -> dict:
        """
        Test IHVP accuracy by checking A @ (A^{-1} @ v) ≈ v.
        """
        print("\n" + "=" * 60)
        print("TEST: IHVP Accuracy (A @ A^{-1} @ v ≈ v)")
        print("=" * 60)
        
        n = ggn_op.num_params
        results = []
        
        for i in range(num_tests):
            # Random RHS
            v = torch.randn(n, device=ggn_op.device)
            v = v / v.norm()
            
            # Solve IHVP: x = A^{-1} v
            x, info = solve_ihvp(
                ggn_op=ggn_op,
                v=v,
                max_iter=30,
                tol=1e-5,
                preconditioner=preconditioner,
                verbose=True,
            )
            
            # Check: A @ x ≈ v
            Ax = ggn_op.matvec(x)
            reconstruction_error = (Ax - v).norm().item() / v.norm().item()
            results.append(reconstruction_error)
            
            status = "✓ PASS" if reconstruction_error < 0.01 else "✗ FAIL"
            print(f"  Test {i+1}: ||A(A^{{-1}}v) - v|| / ||v|| = {reconstruction_error:.6e}, "
                  f"PCG iters={info['iterations']} {status}")
        
        mean_error = np.mean(results)
        max_error = np.max(results)
        all_passed = max_error < 0.01
        
        print(f"\n  Summary: mean_error={mean_error:.2e}, max_error={max_error:.2e}")
        print(f"  {'✓ All IHVP accuracy tests PASSED' if all_passed else '✗ Some IHVP accuracy tests FAILED'}")
        
        return {
            'passed': all_passed,
            'mean_error': mean_error,
            'max_error': max_error,
            'all_errors': results,
        }
    
    # ==========================================================================
    # Run Test Suite
    # ==========================================================================
    
    with sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        print("=" * 60)
        print("PCG Test Suite with ViT GGN Operator")
        print("=" * 60)
        
        # Load ViT model
        print("\nLoading ViT model...")
        model = load_vit(pretrained=True, device="cuda")
        print(f"Model has {model.num_params:,} parameters")
        
        # Create ImageNet-like test data (random images, 224x224, 3 channels)
        print("\nCreating ImageNet-like test data...")
        torch.manual_seed(42)
        n_samples = 16
        dummy_images = torch.randn(n_samples, 3, 224, 224).cuda()
        # Normalize to ImageNet-like range
        dummy_images = (dummy_images - dummy_images.mean()) / dummy_images.std()
        dummy_images = dummy_images * 0.229 + 0.485  # Approximate ImageNet normalization
        dummy_labels = torch.randint(0, 1000, (n_samples,)).cuda()
        
        dummy_dataset = TensorDataset(dummy_images, dummy_labels)
        dummy_loader = DataLoader(dummy_dataset, batch_size=4)
        
        print(f"Test data: {n_samples} samples, batch_size=4")
        
        # Create GGN operator with moderate damping for initial tests
        print("\nCreating GGN operator (damping=0.01)...")
        ggn_op = GGNOperator(
            model=model,
            dataloader=dummy_loader,
            damping=0.01,
            use_gnhvp=True,
            device="cuda",
        )
        ggn_op.prepare_for_solve()
        # Compute diagonal GGN for preconditioning
        print("\nComputing diagonal GGN for preconditioning...")
        Minv = None
        # Minv = compute_diagonal_ggn(model, dummy_loader, device="cuda", ggn_op=ggn_op)

        # def M_inv(v: torch.Tensor) -> torch.Tensor:
        #     return v * Minv  # elementwise Jacobi
        # print(f"Diagonal GGN: min={Minv.min().item():.2e}, max={Minv.max().item():.2e}, mean={Minv.mean().item():.2e}")
        
        # Run damping sweep
        print("\n" + "#" * 60)
        print("RUNNING DAMPING SWEEP")
        print("#" * 60)
        damping_values = [1.0, 0.5, 0.1, 0.05, 0.01, 0.005, 0.001]
        # lambda_max, cond_estimate = estimate_condition_number(ggn_op, ggn_op.num_params, k=30, device="cuda")
        # print(f"\nEstimated condition number at damping=0.01: κ(A) ≈ {cond_estimate:.2e}")
        # damping_values = [max(1e-4, lambda_max / factor) for factor in [1e4, 5e3, 1e3, 5e2, 1e2, 5e1, 1e1]]
        sweep_results = test_pcg_damping_sweep(
            model=model,
            dataloader=dummy_loader,
            damping_values=damping_values,
            num_rhs=1,
            diag_ggn=Minv,
        )
        
        # Run individual tests
        print("\n" + "#" * 60)
        print("RUNNING INDIVIDUAL TESTS")
        print("#" * 60)
        
        # Create preconditioner for individual tests (skip if no diagonal GGN)
        preconditioner = None
        if Minv is not None:
            preconditioner = create_diagonal_preconditioner(Minv, damping=0.01)
        
        det_results = test_determinism(ggn_op, num_tests=5)
        ray_results = test_rayleigh_quotient(ggn_op, num_tests=10)
        sym_results = test_symmetry(ggn_op, num_tests=5)
        ihvp_results = test_ihvp_accuracy(ggn_op, num_tests=3, preconditioner=preconditioner)
        
        
        # Final summary
        print("\n" + "=" * 60)
        print("FINAL TEST SUMMARY")
        print("=" * 60)
        
        all_tests = [
            ("Determinism", det_results['passed']),
            ("Rayleigh (SPD)", ray_results['passed']),
            ("Symmetry", sym_results['passed']),
            ("IHVP Accuracy", ihvp_results['passed']),
        ]
        
        for name, passed in all_tests:
            status = "✓ PASSED" if passed else "✗ FAILED"
            print(f"  {name}: {status}")
        
        all_passed = all(p for _, p in all_tests)
        print("\n" + "-" * 60)
        if all_passed:
            print("✓ ALL TESTS PASSED")
        else:
            print("✗ SOME TESTS FAILED")
        print("=" * 60)
