from __future__ import annotations
import math
import torch, torch.nn as nn, torch.nn.functional as F
from abc import ABC, abstractmethod
from typing import List, Sequence, Dict, Tuple, Union, Optional
from torch.utils.hooks import RemovableHandle
from logging import Logger as LoggerType
from utils.logs import SharedLogger
from prune.hessian_inverse import HybridInverseMixin


# ---------------------------------------------------------------------------
# 2. Layer‑agnostic hybrid OBS base class                                    #
# ---------------------------------------------------------------------------
class OBSHybridPrunerBase(ABC, HybridInverseMixin):
    def __init__(self, layer: nn.Module, block_size: int, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32):
        self.layer = layer
        self.block_size = block_size
        self.device = device
        self.logger = SharedLogger.get_logger(self.__class__.__name__)
        self.use_chunking = use_chunking
        self.chunk_size = chunk_size
        self.pruned_blocks: List[int] = []

        # candidate bookkeeping --------------------------------------------
        self.selected_blocks: List[int] = []
        self.blocks_to_global_indices: Dict[int, List[int]] = {}
        self.blocks_to_local_indices: Dict[int, List[int]] = {}

        # Hessian / inverse storage ----------------------------------------
        self._H_accum: torch.Tensor | None = None
        self.handle: RemovableHandle | None = None
        self.cols: torch.Tensor | None = None
        self.G_CC: torch.Tensor | None = None
        self.H_panel: torch.Tensor | None = None

        # calibration state -------------------------------------------------
        self.calibration_count: int = 0

    # ----------------------- calibration interface -------------------------
    @torch.inference_mode()
    def register_hook(self):
        if self._H_accum is not None:
            raise RuntimeError("Calibration already in progress.")
        def hook(module, input):
            x = input[0] if isinstance(input, tuple) else input
            self.add_batch(x)
        self.handle = self.layer.register_forward_pre_hook(hook)
    
    @torch.inference_mode()
    def remove_hook(self):
        if self.handle is None:
            raise RuntimeError("No hook registered.")
        self.handle.remove()
        self.handle = None

    @torch.inference_mode()
    def add_batch(self, x: torch.Tensor):
        x = self._preprocess_input(x)
        self.calibration_count += x.shape[0]
        
        if self.use_chunking:
            with torch.no_grad():
                hx = self._chunked_matrix_multiply(x)
                if self._H_accum is None:
                    self._H_accum = torch.zeros_like(hx, device=hx.device, dtype=hx.dtype)
                self._H_accum.add_(hx.detach())
            del hx
        else:
            # Initialize accumulator to match x.T @ x shape when not using chunking
            if self._H_accum is None:
                d = x.shape[1]
                self._H_accum = torch.zeros(d, d, device='cpu', dtype=x.dtype, pin_memory=True)
                self._H_accum = self._H_accum.to(device=self.device, dtype=x.dtype, non_blocking=True)
            self._H_accum.addmm_(x.T, x)
        
    @torch.inference_mode()
    def move_hessian_to_cpu(self):
        """Move Hessian accumulator to CPU to free GPU memory"""
        if self._H_accum is not None and 'cuda' in self._H_accum.device.type:
            # self.logger.debug(f"Moving Hessian from {self._H_accum.device} to CPU")
            self._H_accum = self._H_accum.to(device='cpu', dtype=self._H_accum.dtype, non_blocking=True)
        if self.cols is not None and 'cuda' in self.cols.device.type:
            self.cols = self.cols.to(device='cpu', dtype=self.cols.dtype, non_blocking=True)
        if self.G_CC is not None and 'cuda' in self.G_CC.device.type:
            self.G_CC = self.G_CC.to(device='cpu', dtype=self.G_CC.dtype, non_blocking=True)
        if self.H_panel is not None and 'cuda' in self.H_panel.device.type:
            self.H_panel = self.H_panel.to(device='cpu', dtype=self.H_panel.dtype, non_blocking=True)
    
    @torch.inference_mode()
    def move_hessian_to_gpu(self):
        """Move Hessian accumulator to GPU for factorization if needed"""
        if self._H_accum is not None and 'cpu' in self._H_accum.device.type:
            # Explicitly preserve dtype when moving to GPU to avoid conversion issues
            self._H_accum = self._H_accum.to(device=self.device, dtype=self._H_accum.dtype, non_blocking=True)
        if self.cols is not None and 'cpu' in self.cols.device.type:
            self.cols = self.cols.to(device=self.device, dtype=self.cols.dtype, non_blocking=True)
        if self.G_CC is not None and 'cpu' in self.G_CC.device.type:
            self.G_CC = self.G_CC.to(device=self.device, dtype=self.G_CC.dtype, non_blocking=True)
        if self.H_panel is not None and 'cpu' in self.H_panel.device.type:
            self.H_panel = self.H_panel.to(device=self.device, dtype=self.H_panel.dtype, non_blocking=True)

                # self.logger.debug(f"Moving Hessian from CPU to {device}")
                # self._H_accum = self._H_accum.to(self.device)
                # if torch.cuda.is_available():
                #     torch.cuda.synchronize()
                #     mem_after = torch.cuda.memory_allocated() / 1024**3
                #     self.logger.debug(f"GPU memory after moving Hessian to GPU: {mem_after:.3f} GB")

    @torch.inference_mode()
    def finalize_calibration(self, min_damping: float = 1e-2, max_damping: float = 10.0, 
                           max_iterative_iterations: int = 100, iterative_tolerance: float = 1e-6):
        if self._H_accum is None:
            raise RuntimeError("add_batch() must be called before finalising.")
        
        # Build C from blocks_to_global_indices in selected block order (1D tensor)
        ordered_cols: List[int] = []
        for blk in self.selected_blocks:
            cols_list = self.blocks_to_global_indices.get(blk)
            if cols_list is not None:
                ordered_cols.extend(cols_list)
        C = torch.tensor(ordered_cols, dtype=torch.long, device=self._H_accum.device)
        if C.numel() == 0:
            self._H_accum = None; self.H_panel = None; self.G_CC = None; return
        if C.numel() != C.unique().numel():
            self.logger.warning("Candidate indices C contains duplicates!")

        # Normalize by samples
        self._H_accum /= max(1, self.calibration_count)

        # Handle dead directions
        diag_idx = torch.arange(self._H_accum.size(0), device=self._H_accum.device)
        dead = torch.abs(torch.diag(self._H_accum)) <= 1e-9
        if dead.any():
            self._H_accum[diag_idx[dead], diag_idx[dead]] = 1.0
            self.zero_dead_weights(dead)

        # Build inverse columns (cols is H^{-1}[:, C], shape d×m)
        self.cols = self.build_panel_and_inverse(
            self._H_accum, C,
            damp_scale=min_damping, max_damp_scale=max_damping,
            max_iterative_iterations=max_iterative_iterations,
            iterative_tolerance=iterative_tolerance,
        )
        if self.cols is None:
            raise RuntimeError("Failed to compute inverse columns; cols is None")
        # Derive G_CC = H^{-1}[C, C] (m×m) and H_panel = H[:, C] (d×m)
        self.G_CC = self.cols[C, :].contiguous()
        self.H_panel = self._H_accum[:, C].contiguous()
        # Clear full Hessian
        del self._H_accum
        self._H_accum = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
    # ----------------------- exact OBS importance --------------------------
    @torch.inference_mode()
    def importance_all(self, return_tensor: bool = False):
        """
        OBS score per candidate block using Cholesky solves (no explicit inverses):
            score(b) = 0.5 * sum_rows w_{row,P_b}^T (H_PP^{-1}) w_{row,P_b}
                    = 0.5 * sum_rows w_{row,P_b}^T v_row,  where G_PP v = w,
            and G_PP = G_CC[b,b] = H_PP^{-1}.
        """
        if self.G_CC is None or self.H_panel is None:
            raise RuntimeError("finalize_calibration() not yet called.")

        device = self.G_CC.device
        B = len(self.selected_blocks)
        if B == 0:
            return torch.empty(0, device=device) if return_tensor else {}

        k = self.block_size * getattr(self, "kernel_elems", 1)
        out_rows = self.layer.weight.shape[0]

        # Gather W for all blocks in the selected order -> (out, B*k) -> (B, out, k)
        W2d = self.layer.weight.flatten(1).to(torch.float32)
        cols_order = []
        for blk in self.selected_blocks:
            cols_order.extend(self.blocks_to_global_indices[blk])
        W_all = W2d[:, cols_order]
        W_blocks = W_all.view(out_rows, B, k).permute(1, 0, 2).contiguous()  # (B, out, k)

        # Diagonal block G_PP for each block: (B, k, k)
        G4 = self.G_CC.view(B, k, B, k)
        G_pp = G4[torch.arange(B, device=device), :, torch.arange(B, device=device), :].contiguous()

        # Cholesky per block (batched), then solve G_pp V^T = W^T  -> V shape (B, out, k)
        # Use lower-triangular factors by default (upper=False).
        try:
            L = torch.linalg.cholesky(G_pp)                          # (B, k, k)
            # RHS needs shape (..., k, nrhs). Here per block: (B, k, out)
            RHS = W_blocks.transpose(1, 2).contiguous()              # (B, k, out)
            Vt = torch.cholesky_solve(RHS, L, upper=False)           # (B, k, out)
            V = Vt.transpose(1, 2).contiguous()                      # (B, out, k)
        except RuntimeError:
            # Fallback if G_pp not PD (should be rare if damping worked)
            V = torch.linalg.pinv(G_pp) @ W_blocks.transpose(1, 2)   # (B, k, out)
            V = V.transpose(1, 2).contiguous()                       # (B, out, k)

        # 0.5 * w^T v per row, then sum over rows and output dim
        scores = 0.5 * (W_blocks * V).sum(dim=2)                     # (B, out)
        scores = scores.sum(dim=1)                                   # (B,)
        if scores.numel() > 0:
            mean_val = scores.mean().item()
            # Only calculate std if we have more than 1 element to avoid degrees of freedom warning
            if scores.numel() > 1:
                std_val = scores.std().item()
                self.logger.debug(f"[importance chol] B={B}, k={k}, mean={mean_val:.3e}, std={std_val:.3e}")
            else:
                self.logger.debug(f"[importance chol] B={B}, k={k}, mean={mean_val:.3e}, std=N/A (single element)")
        else:
            self.logger.debug(f"[importance chol] B={B}, k={k}, no scores")
        if return_tensor:
            return scores
        return {self.selected_blocks[i]: scores[i].item() for i in range(B)}
    
    # ----------------------- small helpers used by prune routines ------------
    @torch.inference_mode()
    def _apply_obs_update_single_block(
        self,
        P_loc: torch.Tensor,
        block_cols_global: torch.Tensor,
        *,
        R_glob_override: Optional[torch.Tensor] = None,  # <— new kw-only arg
    ):
        """
        Apply exact OBS update for a (sub)block P onto survivors R.
        If R_glob_override is given, use that as the survivor set (global input-column
        indices in [0..d-1]); otherwise use R = d \ P.
        """
        device = self.G_CC.device
        dtype_w = self.layer.weight.dtype

        out_rows = self.layer.weight.shape[0]
        d = self.layer.weight.flatten(1).shape[1]

        # W_P (out × k)
        W2d = self.layer.weight.flatten(1).to(torch.float32)
        W_P = W2d[:, block_cols_global]                              # (out, k)

        # v_rows = W_P @ (G_PP^{-1})^T
        G_PP = self.G_CC[P_loc][:, P_loc]
        try:
            L = torch.linalg.cholesky(G_PP)
            Vt = torch.cholesky_solve(W_P.T.contiguous(), L, upper=False)  # (k, out)
            v_rows = Vt.T.contiguous()                                     # (out, k)
        except RuntimeError:
            v_rows = (torch.linalg.pinv(G_PP) @ W_P.T).T

        # Survivors R
        if R_glob_override is None:
            all_idx = torch.arange(self.H_panel.shape[0], device=device)
            R_mask = torch.ones(self.H_panel.shape[0], dtype=torch.bool, device=device)
            R_mask[block_cols_global] = False
            R_glob = all_idx[R_mask]
        else:
            R_glob = R_glob_override

        # # ΔW_R = - v_rows @ H_RP[R,:]^T
        # H_RP = self.H_panel[:, P_loc].to(torch.float32)              # (d, k)
        # H_RP_R = H_RP[R_glob, :]                                     # (|R|, k)
        # delta = - v_rows @ H_RP_R.T                                  # (out, |R|)
        
        # ΔW_R = - v_rows @ G_RP^T, where G_RP = H^{-1}[R,P]
        G_RP = self.cols[R_glob][:, P_loc].to(torch.float32) # (|R|, k)
        delta = - v_rows @ G_RP.T                            # (out, |R|)

        # Scatter-add onto W[:, R_glob]
        row_off = torch.arange(out_rows, device=device).unsqueeze(1) * d
        upd_idx = (row_off + R_glob.unsqueeze(0)).reshape(-1).long()
        self.layer.weight.data.view(-1)[upd_idx] += delta.reshape(-1).to(dtype_w)

        return R_glob

    @torch.inference_mode()
    def _downdate_and_bookkeep(self,
                            P_loc: torch.Tensor,                 # local panel column indices (|P|k,)
                            removed_blocks_local: Sequence[int], # local block ids
                            removed_block_ids: Sequence[int]):   # global block ids
        """
        Perform *consistent* Schur downdate of (G_CC, H_panel, cols) after removing P,
        then update selected_blocks and local index maps.

        Let panel be split as R ∪ P. With current panel inverse blocks:
            G_PP = G_CC[P,P], G_PR = G_CC[P,R], G_RP = G_CC[R,P], G_RR = G_CC[R,R].
        The conditional inverse on R is:
            G'_RR = G_RR - G_RP (G_PP^{-1}) G_PR.
        For the *columns of the full inverse* H^{-1}[:,C] restricted to panel columns,
        the new columns for R become:
            cols'[:,R] = cols[:,R] - cols[:,P] (G_PP^{-1} G_PR).
        We then drop P from all panel objects.
        """
        device = self.G_CC.device
        dtypeW = torch.float64

        m = self.G_CC.shape[0]              # panel width in columns (mk)
        mask = torch.ones(m, dtype=torch.bool, device=device)
        mask[P_loc] = False
        R_loc = mask.nonzero(as_tuple=False).squeeze(1)  # survivor panel column indices

        # ----- Build small blocks -----
        G_PP = self.G_CC[P_loc][:, P_loc].to(dtypeW)     # (p, p)
        G_PR = self.G_CC[P_loc][:, R_loc].to(dtypeW)     # (p, r)
        G_RP = self.G_CC[R_loc][:, P_loc].to(dtypeW)     # (r, p)
        G_RR = self.G_CC[R_loc][:, R_loc].to(dtypeW)     # (r, r)

        # Solve T = G_PP^{-1} G_PR  via Cholesky (robust)
        try:
            L = torch.linalg.cholesky(0.5*(G_PP+G_PP.T) + 1e-10*torch.eye(G_PP.shape[0], device=device, dtype=G_PP.dtype))
            T = torch.cholesky_solve(G_PR, L, upper=False)   # (p, r)
        except RuntimeError:
            T = torch.linalg.pinv(G_PP) @ G_PR               # fallback

        # ----- Update G_CC (Schur on panel inverse) -----
        G_RR_new = (G_RR - G_RP @ T).contiguous()           # (r, r)
        self.G_CC = G_RR_new.to(self.G_CC.dtype)

        # ----- Update H_panel (drop pruned panel columns) -----
        self.H_panel = self.H_panel[:, R_loc].contiguous()

        # ----- Update cols (columns of full inverse) consistently -----
        if self.cols is not None:
            # cols: (d × m). New survivor columns:
            #   cols'[:,R] = cols[:,R] - cols[:,P] @ T
            cols_P = self.cols[:, P_loc].to(dtypeW)         # (d, p)
            cols_R = self.cols[:, R_loc].to(dtypeW)         # (d, r)
            cols_R_new = (cols_R - cols_P @ T).to(self.cols.dtype)
            self.cols = cols_R_new.contiguous()             # (d, r)

        # ----- Drop removed blocks from selected_blocks (descending order) -----
        for b in sorted(removed_blocks_local, reverse=True):
            self.selected_blocks.pop(b)

        # ----- Remove their global-index maps -----
        for bid in removed_block_ids:
            self.blocks_to_global_indices.pop(bid, None)

        # ----- Rebuild contiguous local indices for survivors -----
        self._rebuild_local_index_map()

    @torch.inference_mode()
    def _downdate_and_bookkeep1(self, P_loc: torch.Tensor, removed_blocks_local: Sequence[int], removed_block_ids: Sequence[int]):
        """
        Perform Schur downdate of (G_CC, H_panel) and rebuild local maps after
        removing the given local block indices / block IDs.
        """
        # Schur downdate on (G_CC, H_panel)
        self.G_CC, self.H_panel, _ = self.rank_k_downdate(self.G_CC, self.H_panel, P_loc)

        # Drop from selected_blocks (descending order to keep indices stable)
        for b in sorted(removed_blocks_local, reverse=True):
            bid = self.selected_blocks[b]
            assert bid == removed_block_ids[removed_blocks_local.index(b)] or True  # tolerant
            self.selected_blocks.pop(b)

        # Remove from global-index map
        for bid in removed_block_ids:
            self.blocks_to_global_indices.pop(bid, None)

        # Rebuild contiguous local indices
        self._rebuild_local_index_map()

    @torch.inference_mode()
    def _rebuild_local_index_map(self):
        k = self.block_size * getattr(self, "kernel_elems", 1)
        new_map: Dict[int, List[int]] = {}
        offset = 0
        for blk in self.selected_blocks:
            new_map[blk] = list(range(offset, offset + k))
            offset += k
        self.blocks_to_local_indices = new_map

    @torch.inference_mode()
    def prune_lowest(self, scores: torch.Tensor | None = None) -> Tuple[int, float]:
        if scores is None:
            scores = self.importance_all(return_tensor=True)
        worst_local = int(torch.argmin(scores).item())
        worst_block = self.selected_blocks[worst_local]
        self._prune_block(worst_local, worst_block)
        self.pruned_blocks.append(worst_block)
        return worst_block, scores[worst_local].item()
    
    # ----------------------- core prune routine ----------------------------
    @torch.inference_mode()
    def _prune_block(self, local_idx: int, block_id: int):
        device = self.G_CC.device
        k = self.block_size * getattr(self, "kernel_elems", 1)

        P_loc = torch.tensor(self.blocks_to_local_indices[block_id], device=device, dtype=torch.long)
        block_cols_global = torch.tensor(self.blocks_to_global_indices[block_id][:k], device=device, dtype=torch.long)

        # 1) OBS update on survivors (all d \ P)
        _ = self._apply_obs_update_single_block(P_loc, block_cols_global)

        # 2) Zero pruned columns
        out_rows = self.layer.weight.shape[0]
        in_cols = self.layer.weight.flatten(1).shape[1]
        row_off = torch.arange(out_rows, device=device).unsqueeze(1) * in_cols
        prune_idx = (row_off + block_cols_global.unsqueeze(0)).reshape(-1).long()
        self.layer.weight.data.view(-1)[prune_idx] = 0

        # 3–4) Downdate + bookkeeping
        self._downdate_and_bookkeep(P_loc, [local_idx], [block_id])

    @torch.inference_mode()
    def prune_blocks_rank1_stream(self, B_loc: torch.Tensor, chunk_cols: int = 1):
        """
        Memory-friendly *streamed application* of the exact joint multi-block OBS:
        ΔW_R = - (W_P @ (G_PP^{-1})^T) @ H_RP[R,:]^T
        We factor once, then apply the update in column chunks of size `chunk_cols`.
        Finally we zero P and do a single rank-|P| Schur downdate.
        For any chunk size, this matches the joint path exactly (up to fp noise).
        """
        device = self.G_CC.device
        dtype_w = self.layer.weight.dtype
        out_rows = self.layer.weight.shape[0]
        in_cols = self.layer.weight.flatten(1).shape[1]
        k = self.block_size * getattr(self, "kernel_elems", 1)

        # ----- Build union P (local) and its global input columns -----
        P_cols = torch.cat([torch.arange(b*k, (b+1)*k, device=device) for b in B_loc.flatten().tolist()])  # (|B|k,)
        block_cols_union: list[int] = []
        removed_block_ids: list[int] = []
        for b in B_loc.flatten().tolist():
            bid = self.selected_blocks[b]
            removed_block_ids.append(bid)
            block_cols_union += self.blocks_to_global_indices[bid][:k]
        block_cols_union = torch.tensor(block_cols_union, device=device, dtype=torch.long)        # (|B|k,)

        # Survivors fixed: R = d \ union(P)
        all_idx = torch.arange(self.H_panel.shape[0], device=device)
        R_mask = torch.ones(self.H_panel.shape[0], dtype=torch.bool, device=device)
        R_mask[block_cols_union] = False
        R_glob = all_idx[R_mask]  # (|R|,)

        # Snapshot weights ONCE (joint uses the pre-update W_P)
        W2d0 = self.layer.weight.flatten(1).to(torch.float32)
        W_P_full = W2d0[:, block_cols_union]              # (out, |B|k)

        # Factor (G_PP^{-1})^T once via Cholesky
        G_PP = self.G_CC[P_cols][:, P_cols]
        try:
            L = torch.linalg.cholesky(G_PP)               # (|B|k, |B|k)
            solver = lambda RHS_T: torch.cholesky_solve(RHS_T, L, upper=False)  # RHS_T: (|B|k, out)
        except RuntimeError:
            GPP_pinv = torch.linalg.pinv(G_PP)
            solver = lambda RHS_T: GPP_pinv @ RHS_T

        # We’ll stream columns of P in chunks but reuse the same factor/solver
        total_cols = P_cols.numel()
        # H_RP_full = self.H_panel[:, P_cols].to(torch.float32)   # (d, n)
        # Use columns of the inverse for cross terms
        G_RP_full = self.cols[R_glob][:, P_cols].to(torch.float32)  # (|R|, n)
        row_off   = torch.arange(out_rows, device=device).unsqueeze(1) * in_cols

        for start in range(0, total_cols, chunk_cols):
            end = min(start + chunk_cols, total_cols)
            c = end - start

            # indices for this chunk in local union coords and global input cols
            S_loc  = P_cols[start:end]                 # (c,)
            S_glob = block_cols_union[start:end]       # (c,)

            # Build selection RHS E_S (n×c) and solve G_PP X = E_S
            E_S = torch.eye(total_cols, device=device, dtype=torch.float32)[:, start:end]  # (n, c)
            if 'L' in locals():
                X = torch.cholesky_solve(E_S, L, upper=False)    # (n, c) = G_PP^{-1}[:, S]
            else:
                X = torch.linalg.pinv(G_PP) @ E_S
            # NOTE: rows of G_PP align with P_cols order; if your G_PP is already in that order,
            # drop the "- P_cols[0]" offset and set E_S[S_loc, arange(c)] = 1.

            # if 'L' in locals():
            #     X = torch.cholesky_solve(E_S, L, upper=False)    # (n, c) = G_PP^{-1}[:, S]
            # else:
            #     X = torch.linalg.pinv(G_PP) @ E_S                # fallback

            # v_rows_chunk = W_P_full @ X    (out×n) @ (n×c) -> (out×c)
            v_rows_chunk = W_P_full @ X

            # ΔW_R_chunk = - v_rows_chunk @ H_RP_R_chunk^T
            # H_RP_R_chunk = H_RP_full[R_glob, start:end]          # (|R|, c)
            # delta_chunk  = - v_rows_chunk @ H_RP_R_chunk.T       # (out, |R|)
            G_RP_chunk = G_RP_full[:, start:end]                   # (|R|, c)
            delta_chunk = - v_rows_chunk @ G_RP_chunk.T

            # Scatter-add onto survivors
            upd_idx = (row_off + R_glob.unsqueeze(0)).reshape(-1).long()
            self.layer.weight.data.view(-1)[upd_idx] += delta_chunk.reshape(-1).to(dtype_w)

        # Zero all pruned columns (same as joint)
        prune_idx = (row_off + block_cols_union.unsqueeze(0)).reshape(-1).long()
        self.layer.weight.data.view(-1)[prune_idx] = 0

        # Single rank-|P| Schur downdate and bookkeeping (same as joint)
        self.G_CC, self.H_panel, _ = self.rank_k_downdate(self.G_CC, self.H_panel, P_cols)
        for b in sorted(B_loc.flatten().tolist(), reverse=True):
            bid = self.selected_blocks[b]
            self.selected_blocks.pop(b)
            self.blocks_to_global_indices.pop(bid, None)
        self._rebuild_local_index_map()

    # ----------------------- candidate bookkeeping ------------------------
    @torch.inference_mode()
    def get_selected_blocks(self) -> List[int]:
        return self.selected_blocks
    
    @torch.inference_mode()
    def set_selected_blocks(self, blocks: Sequence[int]):
        # Warn if called after calibration
        if self.cols is not None:
            import traceback
            self.logger.warning("set_selected_blocks called after finalize_calibration! Clearing caches.")
            traceback.print_stack()
        # Clear previous bookkeeping
        self.selected_blocks.clear()
        self.blocks_to_global_indices.clear()
        self.blocks_to_local_indices.clear()
        # Also clear any per-batch caches that may hold GPU tensors
        if hasattr(self, "_WiTWi"):
            try:
                for t in getattr(self, "_WiTWi", []):
                    if isinstance(t, torch.Tensor):
                        del t
            except Exception:
                pass
            self._WiTWi = []
        if hasattr(self, "_block_col_ranges"):
            try:
                del self._block_col_ranges
            except Exception:
                pass
       
        self.cols = None

        # Store blocks in deterministic order
        self.selected_blocks = sorted(map(int, blocks))

        # Build global index mapping via subclass hook
        self._compute_block_global_indices()

        # Build local index mapping contiguous across selected blocks
        k = self.block_size * getattr(self, "kernel_elems", 1)
        offset = 0
        for blk in self.selected_blocks:
            self.blocks_to_local_indices[blk] = list(range(offset, offset + k))
            offset += k

    @torch.inference_mode()
    def get_pruned_blocks(self) -> List[int]:
        return self.pruned_blocks
    
    # --- replace reset_pruner ---
    @torch.inference_mode()
    def reset_pruner(self):
        self.selected_blocks = []
        self.blocks_to_global_indices = {}
        self.blocks_to_local_indices = {}
        if self._H_accum is not None:
            del self._H_accum; self._H_accum = None
        if self.cols is not None:
            del self.cols; self.cols = None
        if self.G_CC is not None:
            del self.G_CC; self.G_CC = None
        if self.H_panel is not None:
            del self.H_panel; self.H_panel = None
        if self.handle is not None:
            self.remove_hook()
            self.handle = None
            
        # Clear any per-iteration caches that may reside on GPU
        big_attrs = ['_WiTWi', '_block_col_ranges', 'pivots', 'R_glob_cache', 'P_cols_cache']
        for attr in big_attrs:
            if hasattr(self, attr):
                try:
                    obj = getattr(self, attr)
                    if isinstance(obj, torch.Tensor):
                        del obj
                    elif isinstance(obj, list):
                        for t in obj:
                            if isinstance(t, torch.Tensor):
                                del t
                    setattr(self, attr, None)
                except Exception:
                    pass
        self.calibration_count = 0
        
    @torch.inference_mode()
    def is_pruned(self, block_id: int) -> bool:
        return block_id in self.pruned_blocks
    
    @torch.inference_mode()
    def log_memory_status(self, stage: str = ""):
        """Log current memory status for debugging"""
        if torch.cuda.is_available():
            torch.cuda.synchronize()
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            max_allocated = torch.cuda.max_memory_allocated() / 1024**3
            self.logger.debug(f"Memory status [{stage}]: allocated={allocated:.3f}GB, reserved={reserved:.3f}GB, max_allocated={max_allocated:.3f}GB")
            
            # Log sizes of major tensors
            if self._H_accum is not None:
                h_size = self._H_accum.numel() * self._H_accum.element_size() / 1024**3
                self.logger.debug(f"  _H_accum: {h_size:.3f}GB on {self._H_accum.device}")
            if self.cols is not None:
                cols_size = self.cols.numel() * self.cols.element_size() / 1024**3
                self.logger.debug(f"  cols: {cols_size:.3f}GB on {self.cols.device}")
            if self.G_CC is not None:
                g_size = self.G_CC.numel() * self.G_CC.element_size() / 1024**3
                self.logger.debug(f"  G_CC: {g_size:.3f}GB on {self.G_CC.device}")
            if self.H_panel is not None:
                panel_size = self.H_panel.numel() * self.H_panel.element_size() / 1024**3
                self.logger.debug(f"  H_panel: {panel_size:.3f}GB on {self.H_panel.device}")

    # ----------------------- helpers for subclasses ------------------------
    
    @torch.inference_mode()
    def _chunked_matrix_multiply(self, x: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            x_t = x.T
            n_cols = x_t.shape[1]
            result = torch.zeros(x_t.shape[0], x_t.shape[0], device=x.device, dtype=x.dtype)
            for i in range(0, n_cols, self.chunk_size):
                end_i = min(i + self.chunk_size, n_cols)
                chunk = x_t[:, i:end_i]
                result.add_(chunk @ chunk.T)
                del chunk
        return result
  
    @torch.inference_mode()
    def apply_joint_update_and_downdate(self, B_loc: torch.Tensor):
        r"""
        Exact multi-block OBS on all survivors R = d \ union(P):
            ΔW_R = - (W_P @ (G_PP^{-1})^T) @ H_RP[R,:]^T
        with Cholesky solves for (G_PP^{-1})^T.
        """
        device = self.G_CC.device
        dtype_w = self.layer.weight.dtype
        out_rows = self.layer.weight.shape[0]
        in_cols = self.layer.weight.flatten(1).shape[1]
        k = self.block_size * getattr(self, 'kernel_elems', 1)

        # Local and global indices for the union block P∂
        if B_loc.numel() > 1:
            P_cols = torch.cat([torch.arange(b*k, (b+1)*k, device=device) for b in B_loc.flatten().tolist()])  # (|B|k,)
        else:
            P_cols = torch.arange(B_loc.item()*k, (B_loc.item()+1)*k, device=device)

        block_cols_global: list[int] = []
        removed_block_ids: list[int] = []
        for b in B_loc.flatten().tolist():
            bid = self.selected_blocks[b]
            removed_block_ids.append(bid)
            block_cols_global += self.blocks_to_global_indices[bid][:k]
        block_cols_global = torch.tensor(block_cols_global, device=device, dtype=torch.long)

        # Survivors R = all input columns minus the pruned block columns
        all_idx = torch.arange(self.H_panel.shape[0], device=device)
        R_mask = torch.ones(self.H_panel.shape[0], dtype=torch.bool, device=device)
        R_mask[block_cols_global] = False
        R_glob = all_idx[R_mask]  # (|R|,)

        # W_P (out × |B|k)
        W2d = self.layer.weight.flatten(1).to(torch.float32)
        W_P = W2d[:, block_cols_global]  # (out, |B|k)

        # v_rows = W_P @ (G_PP^{-1})^T via Cholesky
        G_PP = self.G_CC[P_cols][:, P_cols]
        try:
            L = torch.linalg.cholesky(G_PP)                                 # (|B|k, |B|k)
            Vt = torch.cholesky_solve(W_P.T.contiguous(), L, upper=False)   # (|B|k, out)
            v_rows = Vt.T.contiguous()                                      # (out, |B|k)
        except RuntimeError:
            v_rows = (torch.linalg.pinv(G_PP) @ W_P.T).T

        # # ΔW_R = - v_rows @ H_RP[R,:]^T
        # H_RP = self.H_panel[:, P_cols].to(torch.float32)       # (d, |B|k)
        # H_RP_R = H_RP[R_glob, :]                               # (|R|, |B|k)
        # delta = - v_rows @ H_RP_R.T                            # (out, |R|)

        # ΔW_R = - v_rows @ G_RP^T, where G_RP = H^{-1}[R,P]
        G_RP = self.cols[R_glob][:, P_cols].to(torch.float32) # (|R|, |B|k)
        delta = - v_rows @ G_RP.T                            # (out, |R|)
        
        # Scatter-add onto survivors
        row_off = torch.arange(out_rows, device=device).unsqueeze(1) * in_cols
        upd_idx = (row_off + R_glob.unsqueeze(0)).reshape(-1).long()
        self.layer.weight.data.view(-1)[upd_idx] += delta.reshape(-1).to(dtype_w)

        # Zero the pruned columns
        prune_idx = (row_off + block_cols_global.unsqueeze(0)).reshape(-1).long()
        self.layer.weight.data.view(-1)[prune_idx] = 0

        # --- record pruned block ids
        self.pruned_blocks.extend(removed_block_ids)

        # Downdate + bookkeeping
        self._downdate_and_bookkeep(P_cols, B_loc.flatten().tolist(), removed_block_ids)
   
   # ------------------------------ Batch Pruning ---------------------------------
    @torch.inference_mode()
    def _obs_score_after2(self, J_loc: list[int], i_loc: int) -> torch.Tensor:
        """
        Exact OBS score of candidate block i_loc AFTER jointly removing batch J_loc.
        All indices are LOCAL.
        E'(i) = sum_out || A'_i^{-1} W_i^T ||_F^2, with
        A'_i = A_i - G_{iJ} G_{JJ}^{-1} G_{Ji}
        where A_i = G_CC[i,i] (k×k).
        """
        device     = self.G_CC.device
        dtype_work = torch.float64
        k          = self.block_size * getattr(self, "kernel_elems", 1)

        # slices/cols in local G_CC
        Si = slice(i_loc * k, (i_loc + 1) * k)
        G  = self.G_CC

        # build A_i
        A = G[Si, Si].to(dtype_work)

        # Schur downdate if J_loc non-empty
        if len(J_loc) > 0:
            P_cols = torch.cat([torch.arange(j*k, (j+1)*k, device=device) for j in J_loc])
            GBi    = G[Si, :][:, P_cols].to(dtype_work)          # (k, |J|k)
            GBB    = G[P_cols][:, P_cols].to(dtype_work)         # (|J|k, |J|k)
            try:
                Lbb = torch.linalg.cholesky(
                    GBB + 1e-10 * torch.eye(GBB.shape[0], device=device, dtype=dtype_work)
                )
                GBB_inv = torch.cholesky_inverse(Lbb)
            except RuntimeError:
                GBB_inv = torch.linalg.pinv(GBB)
            A = A - (GBi @ GBB_inv @ GBi.T)                      # A'_i  (k×k)

        # grab Wi (out × k) in GLOBAL-column coordinates for this local block
        blk_id = self.selected_blocks[i_loc]
        gcols  = torch.tensor(self.blocks_to_global_indices[blk_id], device=device, dtype=torch.long)
        W2d    = self.layer.weight.flatten(1).to(torch.float32).to(device)
        Wi     = W2d[:, gcols].to(dtype_work)                    # (out, k)

        # E'(i) = sum_out || A'^{-1} w_row^T ||_2^2
        # solve A'^T X^T = Wi^T  (or A' X^T = Wi^T since A' is symmetric)
        try:
            La = torch.linalg.cholesky(
                A + 1e-10 * torch.eye(A.shape[0], device=device, dtype=dtype_work)
            )
            XiT = torch.cholesky_solve(Wi.T.contiguous(), La, upper=False)   # (k, out)
        except RuntimeError:
            XiT = torch.linalg.solve(A, Wi.T)  # fallback

        # Frobenius: sum of squared norms across rows
        E = (XiT * XiT).sum()
        return E.to(self.G_CC.dtype)

   
   
   
    # ---------- panel helpers (use your bookkeeping) ----------

    @torch.inference_mode()
    def _panel_block_k(self) -> int:
        return self.block_size * getattr(self, "kernel_elems", 1)

    @torch.inference_mode()
    def _panel_size(self) -> int:
        # total candidate columns in panel = (#selected blocks) * k
        return len(self.selected_blocks) * self._panel_block_k()

    @torch.inference_mode()
    def _panel_slice_for_local(self, i_loc: int) -> slice:
        k = self._panel_block_k()
        return slice(i_loc * k, (i_loc + 1) * k)

    @torch.inference_mode()
    def _panel_C_cols(self, device=None) -> torch.Tensor:
        """
        Rebuild C (global column indices of candidates) in the exact order
        used to form self.cols = H^{-1}[:, C].
        """
        if device is None:
            device = self.layer.weight.device
        ordered_cols: list[int] = []
        for blk in self.selected_blocks:
            ordered_cols.extend(self.blocks_to_global_indices[blk])
        return torch.tensor(ordered_cols, device=device, dtype=torch.long)

    @torch.inference_mode()
    def _G_CC_from_cols(self) -> torch.Tensor:
        """
        Return the calibrated panel sub-inverse G_CC built at finalize_calibration().
        Do NOT rebuild from self.cols with a fresh C — that’s what caused the mismatches.
        """
        if self.G_CC is None:
            raise RuntimeError("G_CC is not available; call finalize_calibration() first.")
        return self.G_CC

    @torch.inference_mode()
    def _Wi_from_block(self, i_loc: int, *, dtype=torch.float64) -> torch.Tensor:
        """
        Wi (out × k) for local block i_loc, gathered in *global* input-column coords.
        """
        device = self.layer.weight.device
        blk_id = self.selected_blocks[i_loc]
        gcols  = torch.tensor(self.blocks_to_global_indices[blk_id],
                            device=device, dtype=torch.long)
        W2d = self.layer.weight.flatten(1).to(torch.float32)
        return W2d[:, gcols].to(dtype)

    # ---------- numeric primitive: append a k×k block to chol(G_JJ) ----------
    @torch.inference_mode()
    def _chol_append(self,
                    L11: torch.Tensor | None,
                    A21: torch.Tensor,   # (k, |J|k)  G_{iJ}
                    A22: torch.Tensor,   # (k, k)     G_{ii}
                    *,
                    base_jitter: float = 1e-10,
                    max_jitter_scale: float = 1e6) -> torch.Tensor | None:
        """
        Appends a new k×k block to the *lower* Cholesky factor of G_{JJ}.
        If L11 is None, returns chol(A22 + jitter I).
        Robust to mild indefiniteness via jitter & eigen-repair.
        Returns None if S cannot be stably factorized; caller should fall back.
        """
        device = A22.device
        dtype  = A22.dtype
        k = A22.shape[0]

        def _chol_try(M: torch.Tensor) -> torch.Tensor | None:
            # symmetry guard
            M = 0.5 * (M + M.T)
            # escalating jitter
            jitter = base_jitter
            # scale jitter by average diag to keep it relative
            diag = torch.diagonal(M)
            scale = float(diag.abs().mean().item()) if diag.numel() else 1.0
            if scale <= 0:
                scale = 1.0
            while jitter <= max_jitter_scale * scale:
                try:
                    L = torch.linalg.cholesky(M + jitter * torch.eye(M.shape[0], device=device, dtype=dtype))
                    if torch.isfinite(L).all():
                        return L
                except RuntimeError:
                    pass
                jitter *= 10.0
            # eigen-repair
            try:
                w, V = torch.linalg.eigh(M)
                # clip small/negative eigenvalues
                eps = 1e-12 * max(1.0, float(w.abs().max().item()))
                w_clipped = torch.clamp(w, min=eps)
                M_fix = (V * w_clipped) @ V.T
                # one last chol try on repaired matrix
                L = torch.linalg.cholesky(0.5 * (M_fix + M_fix.T))
                return L
            except RuntimeError:
                return None

        if L11 is None or L11.numel() == 0:
            L22 = _chol_try(A22)
            return L22  # may be None

        # Solve Y^T = L11^{-1} A21^T (triangular solve is stable)
        Yt = torch.linalg.solve_triangular(L11, A21.T.contiguous(),
                                        upper=False, left=True)
        Y  = Yt.T.contiguous()  # (k, |J|k)

        # Schur block for the new diagonal: S = A22 - Y Y^T
        S = A22 - (Y @ Y.T)
        S = 0.5 * (S + S.T)  # symmetry guard

        L22 = _chol_try(S)
        if L22 is None:
            return None

        # Stitch enlarged L
        n_old = L11.shape[0]
        L = torch.zeros((n_old + k, n_old + k), device=device, dtype=dtype)
        L[:n_old, :n_old] = L11
        L[n_old:, :n_old] = Y
        L[n_old:, n_old:] = L22
        return L

    # ---------- exact conditioned OBS score using self.cols only ----------

    @torch.inference_mode()
    def _obs_score_after(self,
                        J_loc: list[int],
                        i_loc: int,
                        *,
                        G_CC: torch.Tensor | None = None) -> torch.Tensor:
        """
        Exact OBS score of block i_loc after jointly removing blocks in J_loc.
        Uses self.cols.
        E'(i) = || A'_i^{-1} W_i^T ||_F^2  with
        A'_i = G_{ii} - G_{iJ} G_{JJ}^{-1} G_{Ji}.
        """
        dtype_work = torch.float64
        k = self._panel_block_k()

        # Build G_CC once per caller and pass it in; otherwise reconstruct here
        if G_CC is None:
            G_CC = self._G_CC_from_cols()

        # Panel slices
        Si = self._panel_slice_for_local(i_loc)

        # If J is empty: A'_i = G_{ii}
        if not J_loc:
            Aii = G_CC[Si, Si].to(dtype_work)
            Wi  = self._Wi_from_block(i_loc, dtype=dtype_work)     # (out, k)
            try:
                L = torch.linalg.cholesky(Aii + 1e-10*torch.eye(k, device=Aii.device, dtype=Aii.dtype))
                XiT = torch.cholesky_solve(Wi.T.contiguous(), L, upper=False)   # (k, out)
            except RuntimeError:
                XiT = torch.linalg.solve(Aii, Wi.T)
            E = (XiT * XiT).sum()
            return E.to(self.cols.dtype)

        # Build the current J panel index vector (|J|k,)
        J_panel = torch.cat([torch.arange(j*k, (j+1)*k, device=G_CC.device) for j in J_loc])

        # A'_i = G_{ii} - G_{iJ} G_{JJ}^{-1} G_{Ji}
        G_iJ = G_CC[Si, :][:, J_panel].to(dtype_work)    # (k, |J|k)
        G_JJ = G_CC[J_panel][:, J_panel].to(dtype_work)  # (|J|k, |J|k)
        Aii  = G_CC[Si, Si].to(dtype_work)               # (k, k)

        # Factor G_JJ once (caller may cache this across many i_loc)
        try:
            Ljj = torch.linalg.cholesky(G_JJ + 1e-10*torch.eye(G_JJ.shape[0], device=G_JJ.device, dtype=G_JJ.dtype))
            # Compute G_{iJ} G_{JJ}^{-1} via two triangular solves:
            #   solve Ljj Y^T = G_iJ^T  ⇒  Y = (G_JJ^{-1/2}) G_{Ji}
            Yt = torch.linalg.solve_triangular(Ljj, G_iJ.T.contiguous(), upper=False, left=True)
            #   then (G_{iJ} G_{JJ}^{-1} G_{Ji}) = (Y^T @ Y)
            schur = (Yt.T @ Yt)
        except RuntimeError:
            GJJ_inv = torch.linalg.pinv(G_JJ)
            schur = G_iJ @ GJJ_inv @ G_iJ.T

        Aip = Aii - schur

        # Score with Wi
        Wi  = self._Wi_from_block(i_loc, dtype=dtype_work)          # (out, k)
        try:
            L = torch.linalg.cholesky(Aip + 1e-10*torch.eye(k, device=Aip.device, dtype=Aip.dtype))
            XiT = torch.cholesky_solve(Wi.T.contiguous(), L, upper=False)       # (k, out)
        except RuntimeError:
            XiT = torch.linalg.solve(Aip, Wi.T)
        E = (XiT * XiT).sum()
        return E.to(self.cols.dtype)

    @torch.inference_mode()
    def _uncond_obs_score(self, i_loc: int, *, G_CC: torch.Tensor) -> torch.Tensor:
        """
        E(i) = || G_{ii}^{-1/2} W_i ||_F^2 (i.e., OBS score with J = ∅).
        This is a cheap k×k Cholesky solve.
        """
        dtype_work = torch.float64
        Si  = self._panel_slice_for_local(i_loc)
        k   = Si.stop - Si.start
        Gii = G_CC[Si, Si].to(dtype_work)
        Wi  = self._Wi_from_block(i_loc, dtype=dtype_work)             # (out, k)
        try:
            L = torch.linalg.cholesky(Gii + 1e-10 * torch.eye(k, device=Gii.device, dtype=Gii.dtype))
            XiT = torch.cholesky_solve(Wi.T.contiguous(), L, upper=False)  # (k, out)
        except RuntimeError:
            XiT = torch.linalg.solve(Gii, Wi.T)                            # fallback
        E = (XiT * XiT).sum()
        return E.to(self.cols.dtype)

    @torch.inference_mode()
    def _inflation_upper_factor(self,
                                i_loc: int,
                                *,
                                G_CC: torch.Tensor,
                                J_panel: torch.Tensor,
                                Ljj: torch.Tensor) -> float:
        """
        Compute the spectral inflation proxy \bar{rho}_c(J):
            \bar{rho} = ||G_{cJ}||_F^2 / (lambda_min(G_{cc}) * lambda_min(G_{JJ}))
        using the current Cholesky Ljj of G_{JJ} and tiny k×k eig/safe-Cholesky for G_{cc}.
        Returns +inf if the bound is not trustworthy (e.g., numerically non-PD).
        """
        if J_panel.numel() == 0:
            return float('inf')  # no conditioning ⇒ no inflation factor needed

        device = G_CC.device
        dtype  = torch.float64

        Si  = self._panel_slice_for_local(i_loc)
        Gcc = G_CC[Si, Si].to(dtype)
        GcJ = G_CC[Si, :][:, J_panel].to(dtype)     # (k, |J|k)

        # lambda_min(G_cc): robust via tiny eigvalsh (k is small)
        try:
            eigs_cc = torch.linalg.eigvalsh(0.5 * (Gcc + Gcc.T))
            lam_min_cc = float(eigs_cc.min().item())
        except RuntimeError:
            lam_min_cc = 0.0

        # lambda_min(G_JJ) from Cholesky diag: min diag(L)^2
        if Ljj is None or Ljj.numel() == 0:
            return float('inf')
        with torch.no_grad():
            dL = torch.diagonal(Ljj).to(dtype)
            lam_min_JJ = float((dL * dL).min().item())

        if lam_min_cc <= 0.0 or lam_min_JJ <= 0.0:
            return float('inf')

        num = float((GcJ * GcJ).sum().item())   # ||G_{cJ}||_F^2
        rho_bar = num / (lam_min_cc * lam_min_JJ)
        # If numerical noise makes it negative, clip to 0; if huge, it will be filtered outside.
        if rho_bar < 0.0:
            rho_bar = 0.0
        return rho_bar

    # ---------- greedy batch certifier with bound-based skipping ----------
    @torch.inference_mode()
    def certify_batch(self,
                    scores,
                    *,
                    candidates: torch.Tensor,
                    max_try: int,
                    budget: int | None = None,
                    rel_worsen_tol: float = 0.15,   # stop if best_E worsens >15% vs first pick
                    min_chol_diag: float = 1e-6,    # stop if min diag(L_JJ) < this (ill-conditioning)
                    time_budget_ms: int | None = None,
                    device=None) -> list[int]:
        """
        Greedy exact certification with incremental Cholesky on G_{JJ}.
        Returns up to min(max_try, budget, len(candidates)) LOCAL indices,
        with early-stop based on (a) score worsening, (b) Cholesky conditioning, (c) time budget.
        """
        import time
        if device is None:
            device = self.layer.weight.device

        cand = candidates.to(device=device, dtype=torch.long).unique(sorted=False)
        if cand.numel() == 0:
            return []

        limit = max_try
        if budget is not None:
            limit = min(limit, int(budget))
        limit = min(limit, int(cand.numel()))
        if limit <= 0:
            return []

        G_CC = self._G_CC_from_cols()
        k = self._panel_block_k()
        remaining = cand.tolist()
        J: list[int] = []

        Ljj = None
        J_panel = torch.empty(0, device=G_CC.device, dtype=torch.long)

        first_best_E = None
        t0 = time.time()

        while len(J) < limit and remaining:
            # time budget check
            if time_budget_ms is not None:
                if (time.time() - t0) * 1000.0 >= time_budget_ms:
                    break

            best_i, best_E = None, None

            if len(J) == 0:
                # First pick: no conditioning, exact scores for all
                for i_loc in remaining:
                    E_i = self._obs_score_after([], i_loc, G_CC=G_CC)
                    if (best_E is None) or (E_i < best_E):
                        best_E, best_i = E_i, i_loc

                # record first score for relative-worsen stopping later
                first_best_E = float(best_E.item())

                # initialize Ljj with the chosen block
                Si  = self._panel_slice_for_local(best_i)
                Aii = G_CC[Si, Si]
                Ljj = self._chol_append(None,
                                        torch.empty(0, 0, device=Aii.device, dtype=Aii.dtype),
                                        Aii)
                if Ljj is None:
                    # super-rare: cannot factor first block robustly
                    break
                # conditioning guard
                if float(torch.diagonal(Ljj).min().item()) < min_chol_diag:
                    break
                J_panel = torch.arange(Si.start, Si.stop, device=G_CC.device)

            else:
                # With conditioning: reuse Ljj
                for i_loc in remaining:
                    Si  = self._panel_slice_for_local(i_loc)
                    A21 = G_CC[Si, :][:, J_panel]   # (k, |J|k)
                    A22 = G_CC[Si, Si]              # (k, k)
                    Yt = torch.linalg.solve_triangular(Ljj, A21.T.contiguous(),
                                                    upper=False, left=True)
                    S  = (A22 - (Yt.T @ Yt)).to(torch.float64)
                    Wi = self._Wi_from_block(i_loc, dtype=torch.float64)
                    try:
                        Ls  = torch.linalg.cholesky(S + 1e-10*torch.eye(S.shape[0], device=S.device, dtype=S.dtype))
                        XiT = torch.cholesky_solve(Wi.T.contiguous(), Ls, upper=False)
                    except RuntimeError:
                        XiT = torch.linalg.solve(S, Wi.T)
                    E_i = (XiT * XiT).sum().to(self.cols.dtype)
                    if (best_E is None) or (E_i < best_E):
                        best_E, best_i = E_i, i_loc

                # early-stop: relative worsening vs first pick
                if first_best_E is not None:
                    rel = float(best_E.item()) / max(first_best_E, 1e-30)
                    if rel > (1.0 + rel_worsen_tol):
                        break

                # append best_i to factor
                Si   = self._panel_slice_for_local(best_i)
                A21  = G_CC[Si, :][:, J_panel]
                A22  = G_CC[Si, Si]
                Ljj  = self._chol_append(Ljj, A21, A22)
                if Ljj is None:
                    break
                # conditioning guard after append
                if float(torch.diagonal(Ljj).min().item()) < min_chol_diag:
                    # undo append by stopping here; we keep the previous J
                    break
                J_panel = torch.cat([J_panel, torch.arange(Si.start, Si.stop, device=G_CC.device)])

            # accept and shrink
            J.append(best_i)
            remaining.remove(best_i)

        # trim to limit for safety
        if len(J) > limit:
            J = J[:limit]
        return J
    
    # ---------- helpers: block view for G_CC ----------
    @torch.inference_mode()
    def _GCC_block_view(self, G_CC: torch.Tensor) -> torch.Tensor:
        """
        View G_CC (shape (B*k, B*k)) as a 4D block tensor (B, k, B, k) without copying.
        """
        k = self._panel_block_k()
        Bk = G_CC.shape[0]
        assert Bk % k == 0, f"G_CC side {Bk} is not divisible by k={k}"
        B = Bk // k
        return G_CC.view(B, k, B, k).contiguous()


    # ---------- vectorized scorer used by the certifier ----------
    @torch.inference_mode()
    def _score_many_given_Ljj(self,
                              G_CC: torch.Tensor,
                              Ljj: torch.Tensor | None,
                              J_blocks: torch.Tensor,          # 1D long tensor of LOCAL block ids in J (may be empty)
                              cand_loc: torch.Tensor           # 1D long tensor of LOCAL block ids to score
                              ) -> torch.Tensor:
        """
        Return conditioned OBS scores for many candidates at once.
        """
        device = G_CC.device
        dtype_work = torch.float64
        k = self._panel_block_k()

        # Block view and diagonal blocks for fast A_ii gathering
        G4 = self._GCC_block_view(G_CC)                         # (B, k, B, k)
        B = G4.shape[0]
        G_diag = G4.diagonal(offset=0, dim1=0, dim2=2).permute(2, 0, 1).contiguous()  # (B, k, k)

        # Gather Aii for the candidate set -> (Bchunk, k, k)
        Aii = G_diag.index_select(0, cand_loc).to(dtype_work)   # (Bchunk, k, k)

        # Gather Wi for all candidates in one shot: build W_blocks = (B, out, k)
        W2d = self.layer.weight.flatten(1).to(torch.float32)    # (out, d)
        panel_cols = self._panel_C_cols(device=W2d.device)      # (B*k,)
        out_rows = W2d.shape[0]
        W_all = W2d[:, panel_cols]                              # (out, B*k)
        W_blocks = W_all.view(out_rows, B, k).permute(1, 0, 2).contiguous()  # (B, out, k)
        Wi = W_blocks.index_select(0, cand_loc)                 # (Bchunk, out, k)
        WiT = Wi.transpose(1, 2).contiguous().to(dtype=dtype_work, device=Aii.device)    # (Bchunk, k, out)

        # If J is empty: score = ||Aii^{-1} Wi^T||_F^2 per candidate (batched)
        if (Ljj is None) or (J_blocks.numel() == 0):
            XiT = HybridInverseMixin._batched_cholesky_solve_with_fallback(Aii, WiT)
            scores = (XiT * XiT).sum(dim=(1, 2))
            return scores.to(G_CC.dtype)

        # With conditioning on J:
        tmp = G4.index_select(0, cand_loc)                    # (Bchunk, k, B, k)
        G_iJ_full = tmp.index_select(2, J_blocks)             # (Bchunk, k, |J|, k)
        Jk = J_blocks.numel() * k

        A21 = G_iJ_full.reshape(-1, k, Jk).contiguous()       # (Bchunk, k, Jk)
        A21T = A21.transpose(1, 2).contiguous()               # (Bchunk, Jk, k)

        Ljj_b = Ljj.to(dtype=dtype_work, device=A21T.device).unsqueeze(0).expand(A21T.shape[0], -1, -1)              # (Bchunk, Jk, Jk)
        Yt = torch.linalg.solve_triangular(Ljj_b, A21T, upper=False, left=True)  # (Bchunk, Jk, k)

        Y = Yt.transpose(1, 2).contiguous()                   # (Bchunk, k, Jk)
        YYt = Y @ Y.transpose(1, 2)                           # (Bchunk, k, k)
        S = (Aii - YYt)                                       # (Bchunk, k, k)

        XiT = HybridInverseMixin._batched_cholesky_solve_with_fallback(S, WiT)  # (Bchunk, k, out)
        scores = (XiT * XiT).sum(dim=(1, 2))
        return scores.to(G_CC.dtype)


    # ---------- certifier: chunked, vectorized, with robust early stops ----------
    @torch.inference_mode()
    def certify_batch_chunked(self,
                            scores,
                            *,
                            candidates: torch.Tensor,
                            max_try: int,
                            chunk_size: int = 512,
                            budget: int | None = None,
                            rel_worsen_tol: float = 0.15,   # stop if best_E worsens >15% vs first pick
                            min_chol_diag: float = 1e-6,    # stop if min diag(L_JJ) < this (ill-conditioning)
                            device=None) -> list[int]:
        """
        Greedy exact certification with *vectorized* scoring in chunks and an incremental
        Cholesky factor on G_{JJ}. Returns LOCAL indices J (subset of `candidates`) such that
        selecting them in one batch matches the greedy single-add order (up to early stop).
        """
        if device is None:
            device = self.layer.weight.device

        cand = candidates.to(device=device, dtype=torch.long).unique(sorted=False)
        if cand.numel() == 0:
            return []

        # Limit for this batch
        limit = min(int(cand.numel()), int(max_try))
        if budget is not None:
            limit = min(limit, int(budget))
        if limit <= 0:
            return []

        # Build panel inverse once
        G_CC = self._G_CC_from_cols()
        k = self._panel_block_k()

        remaining: list[int] = cand.tolist()
        J: list[int] = []
        J_blocks = torch.empty(0, device=device, dtype=torch.long)  # LOCAL block ids in J
        Ljj: torch.Tensor | None = None

        first_best_E: float | None = None

        while len(J) < limit and remaining:
            # -------- score remaining candidates --------
            if (Ljj is None) or (J_blocks.numel() == 0):
                # Unconditioned: one pass
                rem_tensor = torch.tensor(remaining, device=device, dtype=torch.long)
                scores_rem = self._score_many_given_Ljj(G_CC, None, J_blocks, rem_tensor)
                idx_min = int(torch.argmin(scores_rem).item())
                best_i = remaining[idx_min]
                best_E_t = scores_rem[idx_min]
            else:
                # Conditioned: score in chunks
                best_i = None
                best_E_t = None
                nrem = len(remaining)
                for start in range(0, nrem, chunk_size):
                    end = min(start + chunk_size, nrem)
                    cand_chunk = torch.tensor(remaining[start:end], device=device, dtype=torch.long)
                    scores_chunk = self._score_many_given_Ljj(G_CC, Ljj, J_blocks, cand_chunk)
                    local_idx = int(torch.argmin(scores_chunk).item())
                    local_E = scores_chunk[local_idx]
                    local_i = int(cand_chunk[local_idx].item())
                    if (best_E_t is None) or (local_E < best_E_t):
                        best_E_t = local_E
                        best_i = local_i

            # Early-stop: relative worsening vs first pick
            if first_best_E is None:
                first_best_E = float(best_E_t.item())
            else:
                rel = float(best_E_t.item()) / max(first_best_E, 1e-30)
                if rel > (1.0 + rel_worsen_tol):
                    break

            # -------- append best_i to L_{JJ} --------
            G4 = self._GCC_block_view(G_CC)                   # (B, k, B, k)
            A22 = G4[best_i, :, best_i, :].contiguous()       # (k, k)
            if J_blocks.numel() == 0:
                Ljj = self._chol_append(None,
                                        torch.empty(0, 0, device=G_CC.device, dtype=G_CC.dtype),
                                        A22)
            else:
                # A21 = G_{iJ}: select dim0 by best_i, then dim2 by J_blocks
                tmp = G4.index_select(2, J_blocks)            # (B, k, |J|, k)
                G_iJ = tmp[best_i]                            # (k, |J|, k)
                A21 = G_iJ.reshape(k, -1).contiguous()        # (k, |J|k)
                Ljj = self._chol_append(Ljj, A21, A22)

            if Ljj is None:
                break
            if float(torch.diagonal(Ljj).min().item()) < min_chol_diag:
                break

            # accept and shrink
            J.append(best_i)
            J_blocks = torch.cat([J_blocks,
                                torch.tensor([best_i], device=device, dtype=torch.long)])
            remaining.remove(best_i)

        if len(J) > limit:
            J = J[:limit]
        return J, 
    
    
    @abstractmethod
    def zero_dead_weights(self, dead: torch.Tensor):
        pass
    
    @abstractmethod
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor: ...

    @abstractmethod
    def _compute_block_global_indices(self): ...

    @abstractmethod
    def _gather_candidate_weights(self) -> torch.Tensor: ...


# ---------------------------------------------------------------------------
# 3. Linear layer pruner
# ---------------------------------------------------------------------------
class HybridOBSLinearPruner(OBSHybridPrunerBase):
    def __init__(self, layer: nn.Linear, block_size: int, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32):
        super().__init__(layer, block_size, device, use_chunking, chunk_size)
        self.d = layer.weight.shape[1]
        self.kernel_elems = 1

    @torch.inference_mode()
    def _compute_block_global_indices(self):
        # Build mapping block -> global column indices; skip partial tails
        for blk in self.selected_blocks:
            s = blk * self.block_size
            e = min(s + self.block_size, self.d)
            if e - s < self.block_size:
                self.logger.warning(f"Skipping partial block idx {blk} (size {e-s} < {self.block_size})")
                continue
            rng = list(range(s, e))
            self.blocks_to_global_indices[blk] = rng

    @torch.inference_mode()
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim < 2:
            raise ValueError(f"Expected tensor with at least 2 dims, got {x.shape}")
        x = x.flatten(0, -2)
        # Optional: uncomment to stabilize further
        # x = F.layer_norm(x, x.shape[-1:])
        return x.to(dtype=torch.bfloat16, device=self.layer.weight.device)
   
    @torch.inference_mode()
    def _gather_candidate_weights(self) -> torch.Tensor:
        W = self.layer.weight
        cols = [W[:, self.blocks_to_global_indices[blk]] for blk in self.selected_blocks]
        return torch.cat(cols, dim=1).flatten()

    @torch.inference_mode()
    def zero_dead_weights(self, dead: torch.Tensor):
        W = self.layer.weight.data
        W[:, dead] = 0


# ---------------------------------------------------------------------------
# 4. Conv1d layer pruner (input-channel structured)
# ---------------------------------------------------------------------------
class HybridOBSConv1dPruner(OBSHybridPrunerBase):
    def __init__(self, layer: nn.Conv1d, block_size: int, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32):
        super().__init__(layer, block_size, device, use_chunking, chunk_size)
        self.kernel_size = layer.kernel_size[0]
        self.kernel_elems = self.kernel_size

    @torch.inference_mode()
    def _compute_block_global_indices(self):
        stride = self.block_size * self.kernel_size
        for blk in self.selected_blocks:
            s = blk * stride
            e = s + stride
            rng = list(range(s, e))
            self.blocks_to_global_indices[blk] = rng
    
    @torch.inference_mode()
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 3:
            raise ValueError(f"Expected 3D activation for Conv1d output, got shape {x.shape}")
        x = x.permute(0, 2, 1).reshape(-1, x.shape[1])
        return x.to(self.layer.weight.device, dtype=torch.bfloat16)

    @torch.inference_mode()
    def _preprocess_input_input_channel_structured(self, x: torch.Tensor) -> torch.Tensor:
        unfold = F.unfold(x.unsqueeze(-1), kernel_size=(self.kernel_size, 1), padding=(self.layer.padding[0], 0), stride=(self.layer.stride[0], 1)).transpose(1, 2)
        return unfold.flatten(0, 1).to(self.layer.weight.device, dtype=torch.float32)

    @torch.inference_mode()
    def _gather_candidate_weights(self) -> torch.Tensor:
        W_flat = self.layer.weight.flatten(1)
        cols: list[int] = []
        for blk in self.selected_blocks:
            cols.extend(self.blocks_to_global_indices[blk])
        return W_flat[:, cols].flatten()

    @torch.inference_mode()
    def _gather_candidate_weights1(self) -> torch.Tensor:
        W_flat = self.layer.weight.flatten(1)
        return W_flat[:, self.blocks_to_global_indices[self.selected_blocks]].flatten()

    @torch.inference_mode()
    def zero_dead_weights(self, dead: torch.Tensor):
        W = self.layer.weight.data
        W[:, dead, :] = 0


# ---------------------------------------------------------------------------
# 5. Conv2d *channel* pruner  (structured over input channels)
# ---------------------------------------------------------------------------
class HybridOBSConv2dPruner(OBSHybridPrunerBase):
    """
    Prunes blocks of input channels in a Conv2d layer.
    Each block is `block_size` contiguous input-channels, i.e.
    block_size · (k_h · k_w) flattened weight columns.
    """

    def __init__(self, layer: nn.Conv2d, block_size: int, device: str = "cpu", use_chunking: bool = False, chunk_size: int = 32):
        super().__init__(layer, block_size, device, use_chunking, chunk_size)
        kh, kw = layer.kernel_size
        self.spatial = kh * kw
        self.in_stride = self.block_size * self.spatial
        self.kernel_elems = self.spatial

    @torch.inference_mode()
    def _compute_block_global_indices(self):
        for blk in self.selected_blocks:
            start = blk * self.in_stride
            end = start + self.in_stride
            rng = list(range(start, end))
            self.blocks_to_global_indices[blk] = rng

    @torch.inference_mode()
    def _preprocess_input(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim != 4:
            raise ValueError(f"Expected 4D tensor for Conv2d, got {x.ndim}D")
        x = x.permute(0, 2, 3, 1).reshape(-1, x.shape[1])
        return x.to(self.layer.weight.device, dtype=torch.float32)

    @torch.inference_mode()
    def _preprocess_input_input_channel_structured(self, x: torch.Tensor) -> torch.Tensor:
        unfold = F.unfold(x, kernel_size=self.layer.kernel_size, padding=self.layer.padding, stride=self.layer.stride).transpose(1, 2)
        return unfold.flatten(0, 1).to(self.layer.weight.device, dtype=torch.bfloat16)

    @torch.inference_mode()
    def _gather_candidate_weights(self) -> torch.Tensor:
        W_flat = self.layer.weight.flatten(1)
        cols = []
        for blk in self.selected_blocks:
            cols.extend(self.blocks_to_global_indices[blk])
        return W_flat[:, cols].flatten()

    @torch.inference_mode()
    def zero_dead_weights(self, dead: torch.Tensor):
        W = self.layer.weight.data
        W[:, dead, :, :] = 0
