from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple

import torch

from utils.logs import SharedLogger


class HybridInverseMixin:
    """Utilities to create and maintain (H_panel, G_CC)."""

    logger = SharedLogger.get_logger("HybridInverseMixin")

    # =========================
    # Small data container
    # =========================
    @dataclass
    class _FactorResult:
        L: Optional[torch.Tensor]   # Cholesky factor if success
        lam: float                  # damping actually used (0.0 if none)
        attempts: int               # number of scaled-damping attempts tried
        method: str                 # "chol", "rayleigh+chol", or "fail"

    # =========================
    # Tiny utilities (shared)
    # =========================
    @staticmethod
    @torch.inference_mode()
    def _symmetrize(H: torch.Tensor) -> torch.Tensor:
        """Symmetrize a matrix or batch of matrices."""
        if H.ndim == 2:
            # Single matrix: use .T
            return 0.5 * (H + H.T)
        elif H.ndim == 3:
            # Batched matrices: transpose last two dimensions
            return 0.5 * (H + H.transpose(-2, -1))
        else:
            # Higher dimensional: transpose last two dimensions
            return 0.5 * (H + H.transpose(-2, -1))

    # =========================
    # Factorization helpers
    # =========================
    @staticmethod
    @torch.inference_mode()
    def _try_cholesky_scaled_damping(H: torch.Tensor,
                                     diag_orig: torch.Tensor,
                                     *,
                                     damp_scale: float,
                                     max_damp_scale: float,
                                     min_diag: float) -> "_FactorResult":
        """Try Cholesky with exponentially scaled λ until success or limit."""
        mean_diag = float(diag_orig.mean().item()) if diag_orig.numel() else 1.0
        base = max(mean_diag, 1e-8)
        lam = damp_scale * base
        attempts = 0
        while lam <= max_damp_scale * base:
            attempts += 1
            H.diagonal().copy_(diag_orig + lam)
            try:
                L = torch.linalg.cholesky(H)
                if torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal() < min_diag).any():
                    lam *= 10.0
                    continue
                return HybridInverseMixin._FactorResult(L=L, lam=lam, attempts=attempts, method="chol")
            except RuntimeError:
                lam *= 10.0
        return HybridInverseMixin._FactorResult(L=None, lam=lam, attempts=attempts, method="fail")

    @staticmethod
    @torch.inference_mode()
    def _try_cholesky_scaled_damping_batched(A: torch.Tensor,
                                             diag_orig: torch.Tensor,
                                             *,
                                             damp_scale: float,
                                             max_damp_scale: float,
                                             min_diag: float) -> "_FactorResult":
        """Try batched Cholesky with exponentially scaled λ until success or limit.
        
        This is the batch version of _try_cholesky_scaled_damping, designed for
        _batched_cholesky_solve_with_fallback. For single matrices, use the
        original _try_cholesky_scaled_damping function.
        """
        B, k, _ = A.shape
        mean_diag = diag_orig.mean(dim=1, keepdim=True).clamp_min(1e-8)  # (B, 1)
        base = mean_diag.clamp_min(1e-8)
        lam = damp_scale * base
        attempts = 0
        while lam.max() <= max_damp_scale * base.max():
            attempts += 1
            # Add damping to diagonal: diag_orig (B, k) + lam (B, 1) -> (B, k)
            A.diagonal(dim1=1, dim2=2).copy_(diag_orig + lam)
            try:
                L = torch.linalg.cholesky(HybridInverseMixin._symmetrize(A))
                if not (torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal(dim1=1, dim2=2) < min_diag).any()):
                    return HybridInverseMixin._FactorResult(L=L, lam=lam.mean().item(), attempts=attempts, method="chol_batched")
            except RuntimeError:
                pass
            lam = lam * 10.0
        return HybridInverseMixin._FactorResult(L=None, lam=lam.mean().item(), attempts=attempts, method="fail")

    # (renamed from estimate_lambda_min)
    @staticmethod
    @torch.inference_mode()
    def _estimate_lambda_min(H: torch.Tensor, iters: int = 40) -> float:
        # Lightweight Lanczos-style Rayleigh iteration on GPU/CPU
        n = H.size(0)
        v = torch.randn(n, 1, device=H.device, dtype=H.dtype)
        v = v / (v.norm() + 1e-12)
        ray = None
        for _ in range(iters):
            w = H @ v
            ray = (v.T @ w).squeeze()
            v = w - ray * v
            nv = v.norm()
            if nv <= 1e-10:
                break
            v = v / nv
        return float(ray)

    @staticmethod
    @torch.inference_mode()
    def _try_cholesky_with_rayleigh(H: torch.Tensor,
                                    diag_orig: torch.Tensor,
                                    *,
                                    min_diag: float) -> "_FactorResult":
        """Shift by (−λ_min + ε) and try a single Cholesky."""
        ray = HybridInverseMixin._estimate_lambda_min(H)
        diag_mean = float(diag_orig.mean()) if diag_orig.numel() else 1.0
        eps = 1e-6 * max(1.0, diag_mean)
        lam = max(1e-6, -ray + eps)
        H.diagonal().copy_(diag_orig + lam)
        try:
            L = torch.linalg.cholesky(H)
            if torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal() < min_diag).any():
                raise RuntimeError("Ill-conditioned L after Rayleigh shift")
            return HybridInverseMixin._FactorResult(L=L, lam=lam, attempts=1, method="rayleigh+chol")
        except Exception:
            return HybridInverseMixin._FactorResult(L=None, lam=lam, attempts=1, method="fail")

    @staticmethod
    @torch.inference_mode()
    def _try_cholesky_with_rayleigh_batched(A: torch.Tensor,
                                            diag_orig: torch.Tensor,
                                            *,
                                            min_diag: float) -> "_FactorResult":
        """Shift by (−λ_min + ε) and try batched Cholesky.
        
        This is the batch version of _try_cholesky_with_rayleigh, designed for
        _batched_cholesky_solve_with_fallback. For single matrices, use the
        original _try_cholesky_with_rayleigh function.
        """
        B, k, _ = A.shape
        # Estimate λ_min for each batch element and use the most conservative shift
        ray_list = []
        for b in range(B):
            ray = HybridInverseMixin._estimate_lambda_min(A[b])
            ray_list.append(ray)
        ray_min = min(ray_list)  # Most conservative shift
        diag_mean = diag_orig.mean(dim=1, keepdim=True).clamp_min(1e-8)
        eps = 1e-6 * diag_mean
        lam = torch.clamp(-ray_min + eps, min=1e-6)
        # Add shift to diagonal: diag_orig (B, k) + lam (scalar) -> (B, k)
        A.diagonal(dim1=1, dim2=2).copy_(diag_orig + lam)
        try:
            L = torch.linalg.cholesky(HybridInverseMixin._symmetrize(A))
            if not (torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal(dim1=1, dim2=2) < min_diag).any()):
                return HybridInverseMixin._FactorResult(L=L, lam=lam.mean().item(), attempts=1, method="rayleigh+chol_batched")
        except Exception:
            pass
        return HybridInverseMixin._FactorResult(L=None, lam=lam.mean().item(), attempts=1, method="fail")

    # =========================
    # Column solvers (shared)
    # =========================
    @staticmethod
    @torch.inference_mode()
    def _solve_inverse_columns_from_L(L: torch.Tensor,
                                      d: int,
                                      C: torch.Tensor,
                                      *,
                                      dtype: torch.dtype,
                                      device: torch.device) -> torch.Tensor:
        """Solve H X = I_C given a Cholesky factor L (H = L L^T)."""
        m = C.numel()
        I_C = torch.zeros(d, m, dtype=dtype, device=device)
        I_C[C, torch.arange(m, device=device)] = 1
        Y = torch.linalg.solve_triangular(L, I_C, upper=False)
        X = torch.linalg.solve_triangular(L.T, Y, upper=True)
        return X

    @staticmethod
    @torch.inference_mode()
    def _iterative_inverse_columns(H: torch.Tensor, C: torch.Tensor,
                                   device, dtype,
                                   max_iterations: int, tolerance: float) -> Optional[torch.Tensor]:
        """
        Memory-efficient iterative solve for very large matrices using CG.
        Solves H X = I_C column by column.
        """
        try:
            n = H.shape[0]; m = len(C)
            cols = torch.zeros(n, m, dtype=dtype, device=device)
            I_C = torch.zeros(n, m, dtype=dtype, device=device)
            I_C[C, torch.arange(m, device=device)] = 1.0
            for i in range(m):
                b = I_C[:, i]
                x = torch.zeros(n, dtype=dtype, device=device)
                r = b.clone(); p = r.clone()
                rr = torch.dot(r, r)
                for _ in range(min(max_iterations, n)):
                    Ap = H @ p
                    denom = torch.dot(p, Ap)
                    if denom.abs() < 1e-30:
                        break
                    alpha = rr / denom
                    x = x + alpha * p
                    r = r - alpha * Ap
                    rr_new = torch.dot(r, r)
                    if torch.sqrt(rr_new) < tolerance:
                        break
                    beta = rr_new / rr if rr.abs() > 0 else 0.0
                    p = r + beta * p
                    rr = rr_new
                cols[:, i] = x
            if torch.isnan(cols).any() or torch.isinf(cols).any():
                return None
            return cols
        except Exception as e:
            HybridInverseMixin.logger.debug(f"Iterative solve failed: {e}")
            return None

    @staticmethod
    @torch.inference_mode()
    def _pinv_inverse_columns(H: torch.Tensor, C: torch.Tensor) -> Optional[torch.Tensor]:
        try:
            H_inv = torch.linalg.pinv(H)
            cols = H_inv[:, C]
            if torch.isnan(cols).any() or torch.isinf(cols).any():
                return None
            return cols
        except Exception:
            return None

    # =========================
    # Batched utilities
    # =========================
    @staticmethod
    @torch.inference_mode()
    def _iterative_solve_batched_single(A: torch.Tensor, B: torch.Tensor,
                                        *, max_iterations: int = 200,
                                        tolerance: float = 1e-6) -> Optional[torch.Tensor]:
        """
        Conjugate gradient solve for A X = B for a single small system with multiple RHS.
        A: (k, k), B: (k, r) -> X: (k, r)
        """
        try:
            k, r = B.shape
            X = torch.zeros(k, r, dtype=B.dtype, device=B.device)
            # Solve column-by-column to keep it simple and numerically safe
            I = torch.arange(k, device=A.device)
            # Lightweight SPD guard (small relative jitter if needed)
            # Not modifying A in-place to avoid side effects
            diag_mean = A.diagonal().abs().mean().clamp_min(1e-12)
            Aj = HybridInverseMixin._symmetrize(A) + (1e-12 * diag_mean) * torch.eye(k, device=A.device, dtype=A.dtype)
            for j in range(r):
                b = B[:, j]
                x = torch.zeros(k, dtype=B.dtype, device=B.device)
                rvec = b - Aj @ x
                p = rvec.clone()
                rr = torch.dot(rvec, rvec)
                iters = 0
                while iters < max_iterations and rr > (tolerance ** 2):
                    Ap = Aj @ p
                    denom = torch.dot(p, Ap)
                    if denom.abs() < 1e-30:
                        break
                    alpha = rr / denom
                    x = x + alpha * p
                    rvec = rvec - alpha * Ap
                    rr_new = torch.dot(rvec, rvec)
                    if rr_new <= (tolerance ** 2):
                        rr = rr_new
                        break
                    beta = rr_new / rr
                    p = rvec + beta * p
                    rr = rr_new
                    iters += 1
                X[:, j] = x
            if torch.isnan(X).any() or torch.isinf(X).any():
                return None
            return X
        except Exception:
            return None
    # =========================
    # Orchestrator (shared)
    # =========================
    @staticmethod
    @torch.inference_mode()
    def _solve_inverse_columns_robust(H_in: torch.Tensor, C: torch.Tensor, *,
                                      dtype_panel: torch.dtype,
                                      device: torch.device,
                                      damp_scale: float,
                                      max_damp_scale: float,
                                      min_diag: float,
                                      max_iterative_iterations: int,
                                      iterative_tolerance: float,
                                      ) -> Tuple[Optional[torch.Tensor], dict]:
        """
        Try (1) naive Cholesky, (2) scaled-damping loop, (3) Rayleigh+Cholesky, (4) iterative, (5) pinv.
        Returns (cols, meta).
        """
        meta = {"used": None, "lam": 0.0, "attempts": 0}
        H = H_in.to(device=device, dtype=dtype_panel).contiguous()
        d = H.shape[0]
        diag_orig = H.diagonal().clone()

        # (1) Try naive Cholesky first (like in _batched_cholesky_solve_with_fallback)
        try:
            L = torch.linalg.cholesky(HybridInverseMixin._symmetrize(H))
            if not (torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal() < min_diag).any()):
                meta.update({"used": "chol", "lam": 0.0, "attempts": 0})
                cols = HybridInverseMixin._solve_inverse_columns_from_L(L, d, C, dtype=dtype_panel, device=device)
                return cols, meta
        except (Exception, RuntimeError):
            pass

        # (2) scaled-damping loop
        fr = HybridInverseMixin._try_cholesky_scaled_damping(H, diag_orig,
                                                             damp_scale=damp_scale,
                                                             max_damp_scale=max_damp_scale,
                                                             min_diag=min_diag)
        if fr.L is not None:
            meta.update({"used": fr.method, "lam": fr.lam, "attempts": fr.attempts})
            cols = HybridInverseMixin._solve_inverse_columns_from_L(fr.L, d, C, dtype=dtype_panel, device=device)
            return cols, meta

        # (2) Rayleigh shift + chol
        fr = HybridInverseMixin._try_cholesky_with_rayleigh(H, diag_orig, min_diag=min_diag)
        if fr.L is not None:
            meta.update({"used": fr.method, "lam": fr.lam, "attempts": fr.attempts})
            cols = HybridInverseMixin._solve_inverse_columns_from_L(fr.L, d, C, dtype=dtype_panel, device=device)
            return cols, meta

        # (3) iterative
        cols = HybridInverseMixin._iterative_inverse_columns(H, C, device, dtype_panel,
                                                             max_iterations=max_iterative_iterations,
                                                             tolerance=iterative_tolerance)
        if cols is not None:
            meta.update({"used": "iterative", "lam": float('nan'), "attempts": 0})
            return cols, meta

        # (4) pinv
        cols = HybridInverseMixin._pinv_inverse_columns(H, C)
        if cols is not None:
            meta.update({"used": "pinv", "lam": float('nan'), "attempts": 0})
            return cols.to(device=device), meta

        return None, meta

    @staticmethod
    @torch.inference_mode()
    def _batched_cholesky_solve_with_fallback(A: torch.Tensor,
                                              BT: torch.Tensor,
                                              *,
                                              damp_scale: float = 1e-2,
                                              max_damp_scale: float = 10.0,
                                              min_diag: float = 1e-8,
                                              max_iterative_iterations: int = 200,
                                              iterative_tolerance: float = 1e-6) -> torch.Tensor:
        """
        Solve A X = BT for multiple systems in batch with robust fallback.
        Shapes:
          A:  (B, k, k) SPD (or near-SPD)
          BT: (B, k, out)
        Returns:
          X:  (B, k, out)
        """
        # (1) Try naive batched Cholesky first
        try:
            L = torch.linalg.cholesky(A)
            if not (torch.isnan(L).any() or torch.isinf(L).any()):
                HybridInverseMixin.logger.debug(f"Batched Cholesky succeeded directly")
                return torch.cholesky_solve(BT, L, upper=False)
            else:
                HybridInverseMixin.logger.debug(f"Batched Cholesky failed validation: NaN={torch.isnan(L).any()}, Inf={torch.isinf(L).any()}")
        except Exception as e:
            HybridInverseMixin.logger.debug(f"Batched Cholesky failed with exception: {e}")
        
        try:
            L = torch.linalg.cholesky(HybridInverseMixin._symmetrize(A))
            if not (torch.isnan(L).any() or torch.isinf(L).any()):
                HybridInverseMixin.logger.debug(f"Batched Cholesky succeeded with symmetrized matrix")
                return torch.cholesky_solve(BT, L, upper=False)
            else:
                HybridInverseMixin.logger.debug(f"Symmetrized batched Cholesky failed validation: NaN={torch.isnan(L).any()}, Inf={torch.isinf(L).any()}")
        except Exception as e:
            HybridInverseMixin.logger.debug(f"Symmetrized batched Cholesky failed with exception: {e}")


        # (2) Try scaled damping on the whole batch
        Bn, k, out = BT.shape
        diag_orig = A.diagonal(dim1=1, dim2=2).clone()  # (B, k)
        
        HybridInverseMixin.logger.debug(f"Trying scaled damping fallback")
        fr = HybridInverseMixin._try_cholesky_scaled_damping_batched(
            A, diag_orig, damp_scale=damp_scale, max_damp_scale=max_damp_scale, min_diag=min_diag
        )
        if fr.L is not None:
            HybridInverseMixin.logger.debug(f"Scaled damping succeeded with λ={fr.lam:.3e}")
            return torch.cholesky_solve(BT, fr.L, upper=False)

        # (3) Try Rayleigh shift on the whole batch
        HybridInverseMixin.logger.debug(f"Trying Rayleigh shift fallback")
        fr = HybridInverseMixin._try_cholesky_with_rayleigh_batched(A, diag_orig, min_diag=min_diag)
        if fr.L is not None:
            HybridInverseMixin.logger.debug(f"Rayleigh shift succeeded with λ={fr.lam:.3e}")
            return torch.cholesky_solve(BT, fr.L, upper=False)

        # (4) Try iterative solve on the whole batch
        try:
            X_list = []
            for b in range(Bn):
                Xb = HybridInverseMixin._iterative_solve_batched_single(
                    A[b], BT[b],
                    max_iterations=max_iterative_iterations,
                    tolerance=iterative_tolerance
                )
                if Xb is None:
                    break
                X_list.append(Xb)
            if len(X_list) == Bn:
                return torch.stack(X_list, dim=0)
        except Exception:
            pass

        # (5) Final fallback: pinv on the whole batch
        try:
            A_sym = HybridInverseMixin._symmetrize(A)
            X_list = []
            for b in range(Bn):
                Xb = torch.linalg.pinv(A_sym[b]) @ BT[b]
                X_list.append(Xb)
            return torch.stack(X_list, dim=0)
        except Exception:
            # (5) Ultimate fallback: per-batch element processing
            X_list = []
            for b in range(Bn):
                Ab = HybridInverseMixin._symmetrize(A[b]).contiguous()
                BTb = BT[b]
                # Try scaled damping
                diag_orig_b = Ab.diagonal().clone()
                fr = HybridInverseMixin._try_cholesky_scaled_damping(
                    Ab, diag_orig_b, damp_scale=damp_scale, max_damp_scale=max_damp_scale, min_diag=min_diag
                )
                if fr.L is not None:
                    Xb = torch.cholesky_solve(BTb, fr.L, upper=False)
                    X_list.append(Xb)
                    continue
                # Try Rayleigh + chol
                fr = HybridInverseMixin._try_cholesky_with_rayleigh(
                    Ab, diag_orig_b, min_diag=min_diag
                )
                if fr.L is not None:
                    Xb = torch.cholesky_solve(BTb, fr.L, upper=False)
                    X_list.append(Xb)
                    continue
                # Iterative fallback (CG per RHS)
                Xb = HybridInverseMixin._iterative_solve_batched_single(Ab, BTb)
                if Xb is not None:
                    X_list.append(Xb)
                    continue
                # Final fallback: pinv
                Xb = torch.linalg.pinv(Ab) @ BTb
                X_list.append(Xb)

            return torch.stack(X_list, dim=0)
    # =========================
    # Public API – refactored
    # =========================
    @staticmethod
    @torch.inference_mode()
    def build_panel_and_inverse(
        H: torch.Tensor,
        C: torch.Tensor,
        *,
        dtype_panel: torch.dtype = torch.float32,
        mode: str = "gpu",                    # "gpu" | "cpu"
        damp_scale: float = 1e-2,             # λ = damp_scale · mean(diag(H))
        min_diag: float = 1e-8,
        max_damp_scale: float = 10.0,
        max_iterative_iterations: int = 100,
        iterative_tolerance: float = 1e-6,
    ) -> torch.Tensor | None:
        """
        Returns cols — columns of inverse Hessian H^{-1}[:, C] (fp32).
        No full d×d inverse is formed. Adds diagonal jitter via robust path.
        """
        log = HybridInverseMixin.logger
        log.debug(f"Start inverse build: H shape {tuple(H.shape)}, |C|={C.numel()}")

        if C.numel() == 0:
            log.debug("Empty candidate set; skipping inverse build")
            return None

        # Keep a CPU fp32 master, symmetrize once
        H_cpu = H.to(device="cpu", dtype=dtype_panel).detach().contiguous()
        H_cpu = HybridInverseMixin._symmetrize(H_cpu)

        device = H.device if mode == "gpu" else torch.device("cpu")

        cols, meta = HybridInverseMixin._solve_inverse_columns_robust(
            H_cpu.to(device=device),
            C,
            dtype_panel=dtype_panel,
            device=device,
            damp_scale=damp_scale,
            max_damp_scale=max_damp_scale,
            min_diag=min_diag,
            max_iterative_iterations=max_iterative_iterations,
            iterative_tolerance=iterative_tolerance,
        )
        if cols is None:
            raise ValueError("All inverse computation methods failed.")
        log.debug(f"Inverse columns via {meta['used']} (λ={meta['lam']:.3e}, attempts={meta['attempts']})")
        return cols.clone().to(torch.float32)

    # =========================
    # Existing utilities
    # =========================
    @staticmethod
    @torch.inference_mode()
    def rank_k_downdate(G: torch.Tensor, H_panel: torch.Tensor, P_loc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Schur complement update after pruning block P (local indices)."""
        device = G.device
        m = G.shape[0]
        mask = torch.ones(m, dtype=torch.bool, device=device)
        mask[P_loc] = False
        R_loc = mask.nonzero(as_tuple=False).squeeze(1)

        G_PP = G[P_loc][:, P_loc]
        G_RP = G[R_loc][:, P_loc]
        G_PR = G_RP.T
        G_RR = G[R_loc][:, R_loc]
        G_RR = G_RR - G_RP @ torch.linalg.solve(G_PP, G_PR)
        return G_RR.contiguous(), H_panel[:, R_loc].contiguous(), R_loc


# ---------------------------------------------------------------------------
# 1. Hybrid storage helpers                                                  #
# ---------------------------------------------------------------------------
class HybridInverseMixin1:
    """Utilities to create and maintain (H_panel, G_CC)."""
    # -----------------------------------------------------------------------
    #  HybridInverseMixin – single-damp robust inverse extraction
    # -----------------------------------------------------------------------
    logger = SharedLogger.get_logger("HybridInverseMixin")

    @staticmethod 
    @torch.inference_mode()
    def build_panel_and_inverse(
        H: torch.Tensor,
        C: torch.Tensor,
        *,
        dtype_panel: torch.dtype = torch.float32,
        mode: str = "gpu",                      # "gpu" | "cpu"
        damp_scale: float = 1e-2,               # λ = damp_scale · mean(diag(H))
        min_diag: float = 1e-8,
        max_damp_scale: float = 10.0,
        max_iterative_iterations: int = 100,     # Max iterations for iterative solver
        iterative_tolerance: float = 1e-6,      # Tolerance for iterative solver
    ) -> torch.Tensor | None:
        """
        Returns cols — columns of inverse Hessian H^(-1)[:, C] (fp32).
        No full d×d inverse is formed. Adds diagonal jitter λI before Cholesky.
        Falls back to sub-block solves/pinv paths if needed.
        
        Returns:
            cols: Columns of inverse Hessian H^(-1)[:, C], or None if computation fails
        """
        HybridInverseMixin.logger.debug(f"Start inverse build: H shape {H.shape}, |C|={C.numel()}")
        
        device = H.device if mode == "gpu" else torch.device("cpu")

        # Keep a CPU fp32 master copy we can move as needed
        H = H.to(device="cpu", dtype=dtype_panel).detach().contiguous()
        d, m = H.shape[0], C.numel()
        if m == 0:
            HybridInverseMixin.logger.debug("Empty candidate set; skipping inverse build")
            return None

        # Enforce symmetry
        H = 0.5 * (H + H.T)

        # Move H to target device for factorization
        H = H.to(device=device, dtype=dtype_panel)
        
        # ---- Scaled damping loop
        diag_orig = H.diagonal().clone()
        mean_diag = float(diag_orig.mean().item()) if diag_orig.numel() > 0 else 1.0
        base = max(mean_diag, 1e-8)
        current_damp = damp_scale * base
        chol_attempts = 0
        success = False
        while current_damp <= max_damp_scale * base:
            chol_attempts += 1
            H.diagonal().copy_(diag_orig + current_damp)
            try:
                L = torch.linalg.cholesky(H)
                if torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal() < min_diag).any():
                    current_damp *= 10.0
                    continue
                HybridInverseMixin.logger.debug(f"Cholesky ok with λ={current_damp:.3e} (attempt {chol_attempts})")
                success = True
                break
            except RuntimeError:
                current_damp *= 10.0
        if not success:
            HybridInverseMixin.logger.debug(f"Cholesky failed after {chol_attempts} attempts; trying Rayleigh iteration")
            ray = HybridInverseMixin.estimate_lambda_min(H)
            HybridInverseMixin.logger.debug(f"Rayleigh iteration found λ_min={ray:.3e}")
            diag_mean = float(H.diagonal().mean())
            eps = 1e-6 * max(1.0, diag_mean)
            lam = max(1e-6, -ray + eps)    # just enough to be SPD
            current_damp = lam
            HybridInverseMixin.logger.debug(f"Trying with λ={current_damp:.3e}")
            H.diagonal().copy_(diag_orig + current_damp)
            try:
                L = torch.linalg.cholesky(H)
                if torch.isnan(L).any() or torch.isinf(L).any() or (L.diagonal() < min_diag).any():
                    raise Exception("Cholesky failed after trying Rayleigh iteration. Falling back to iterative.")
                HybridInverseMixin.logger.debug(f"Cholesky ok with λ={current_damp:.3e}")
            except Exception as e:
                HybridInverseMixin.logger.debug(f"Cholesky failed: {e}; trying iterative fallback")
                # Try memory-efficient iterative approach first
                try:
                    HybridInverseMixin.logger.debug("Trying memory-efficient iterative solve")
                    cols = HybridInverseMixin._iterative_solve(H, C, device, dtype_panel, max_iterative_iterations, iterative_tolerance)
                    if cols is None:
                        raise Exception("iterative solver returned None")
                    HybridInverseMixin.logger.debug("Iterative solve successful")
                    H = H.cpu()
                    return cols.clone()
                except Exception as e_iter:
                    HybridInverseMixin.logger.debug(f"Iterative fallback failed: {e_iter}; trying pinv fallback")
                    # Clear some memory before trying pinv
                    if device.type == "cuda":
                        torch.cuda.empty_cache()
                        import gc
                        gc.collect()
                    # Try pinv fallback
                    try:
                        H_inv = torch.linalg.pinv(H)
                        cols = H_inv[:, C]
                        if torch.isnan(cols).any() or torch.isinf(cols).any():
                            HybridInverseMixin.logger.debug("pinv produced NaN/Inf")
                            raise Exception("pinv produced invalid results")
                        HybridInverseMixin.logger.debug("pinv fallback successful")
                        del H_inv
                        H = H.cpu()
                        return cols.clone()
                    except Exception as e_pinv:
                        HybridInverseMixin.logger.debug(f"pinv fallback failed: {e_pinv}")
                        H = H.cpu() if H is not None else None
                        if device.type == "cuda":
                            torch.cuda.empty_cache()
                            import gc
                            gc.collect()
                        raise ValueError(f"All inverse computation methods failed. Cholesky: failed after {chol_attempts} attempts, Rayleigh: {e}, iterative: {e_iter}, and pinv: {e_pinv}")

        # Solve H · X = I_C via triangular solves; build exact sub-inverse
        I_C = torch.zeros(d, m, dtype=dtype_panel, device=device)
        I_C[C, torch.arange(m, device=device)] = 1
        Y = torch.linalg.solve_triangular(L, I_C, upper=False)
        cols = torch.linalg.solve_triangular(L.T, Y, upper=True)
        del I_C, Y, L
        if torch.isnan(cols).any() or torch.isinf(cols).any():
            HybridInverseMixin.logger.debug("Cholesky triangular solve produced NaN/Inf; this should not happen after successful Cholesky")
            # release H
            H = H.cpu()
            raise ValueError("Cholesky triangular solve produced NaN/Inf despite successful factorization")
        else:
            # release H
            H = H.cpu()
            return cols.clone()

    # --------------------------------------------------------------------
    @staticmethod
    @torch.inference_mode()
    def _iterative_solve(H: torch.Tensor, C: torch.Tensor, device, dtype, max_iterations: int = 100, tolerance: float = 1e-6) -> torch.Tensor | None:
        """
        Memory-efficient iterative solve for very large matrices using conjugate gradient.
        Solves H X = I_C column by column.
        """
        try:
            n = H.shape[0]
            m = len(C)
            cols = torch.zeros(n, m, dtype=dtype, device=device)
            
            # Create identity matrix for right-hand side
            I_C = torch.zeros(n, m, dtype=dtype, device=device)
            I_C[C, torch.arange(m, device=device)] = 1.0
            
            # Solve each column using conjugate gradient
            for i in range(m):
                b = I_C[:, i]
                x = torch.zeros(n, dtype=dtype, device=device)
                
                # Simple conjugate gradient implementation
                r = b.clone()
                p = r.clone()
                
                for _ in range(min(max_iterations, n)):  # Max iterations
                    Ap = H @ p
                    alpha = torch.dot(r, r) / torch.dot(p, Ap)
                    x = x + alpha * p
                    r_new = r - alpha * Ap
                    
                    if torch.norm(r_new) < tolerance:  # Convergence check
                        break
                        
                    beta = torch.dot(r_new, r_new) / torch.dot(r, r)
                    p = r_new + beta * p
                    r = r_new
                
                cols[:, i] = x
            
            if torch.isnan(cols).any() or torch.isinf(cols).any():
                return None
            
            return cols
            
        except Exception as e:
            HybridInverseMixin.logger.debug(f"Iterative solve failed: {e}")
            return None

    # --------------------------------------------------------------------
    @staticmethod
    @torch.inference_mode()
    def estimate_lambda_min(H: torch.Tensor, iters: int = 40) -> float:
        # Lightweight Lanczos-style Rayleigh iteration on GPU.
        n = H.size(0)
        v = torch.randn(n, 1, device=H.device, dtype=H.dtype)
        v = v / (v.norm() + 1e-12)
        ray = None
        for _ in range(iters):
            w = H @ v
            ray = (v.T @ w).squeeze()
            # shift toward the worst direction
            v = w - ray * v
            nv = v.norm()
            if nv <= 1e-10:
                break
            v = v / nv
        return float(ray)


    # --------------------------------------------------------------------
    @staticmethod
    @torch.inference_mode()
    def rank_k_downdate(G: torch.Tensor, H_panel: torch.Tensor, P_loc: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Schur complement update after pruning block P (local indices)."""
        device = G.device
        m = G.shape[0]
        mask = torch.ones(m, dtype=torch.bool, device=device)
        mask[P_loc] = False
        R_loc = mask.nonzero(as_tuple=False).squeeze(1)

        G_PP = G[P_loc][:, P_loc]
        G_RP = G[R_loc][:, P_loc]
        G_PR = G_RP.T
        G_RR = G[R_loc][:, R_loc]
        G_RR = G_RR - G_RP @ torch.linalg.solve(G_PP, G_PR)
        return G_RR.contiguous(), H_panel[:, R_loc].contiguous(), R_loc

