# Torch Influence is Original work Copyright 2020 Alston Lo, Juhan Bae
# Licensed under the Apache License, Version 2.0 
#  Modifications Copyright 2025 Anonymous Authors
# - Added feature verify_ihvp, check_cg_stability, cg_torch
# - Refactored function inverse_hvp, gnhvp_batch_autograd




from __future__ import annotations
from dataclasses import dataclass
import math
import numpy as np
import abc
from typing import Any, List, Optional

import numpy as np
import torch
from torch import nn
from torch.utils import data

import logging
from typing import Callable, Optional

import numpy as np
import scipy.sparse.linalg as L
import torch
from torch import nn
from torch.utils import data

from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Optional
import torch.func as F
from torch.func import functional_call  
import torch
import matplotlib.pyplot as plt

from __future__ import annotations

from typing import List, Tuple
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Sequence
import logging

import math
import time

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Subset


def make_folds(N: int, K: int, *, stratify: Optional[np.ndarray], seed: int) -> List[np.ndarray]:
    assert K >= 2, "K must be >= 2"
    rng = np.random.default_rng(seed)
    indices = np.arange(N)
    if stratify is None:
        rng.shuffle(indices)
        sizes = [N // K + (1 if i < (N % K) else 0) for i in range(K)]
        folds = []
        start = 0
        for sz in sizes:
            folds.append(indices[start : start + sz])
            start += sz
        return folds
    # stratified
    y = np.asarray(stratify)
    classes, counts = np.unique(y, return_counts=True)
    # warn (silently) if any class < K; we still distribute as possible
    # split each class into K buckets via round-robin after shuffling
    buckets = [list() for _ in range(K)]
    for c in classes:
        idx = np.where(y == c)[0]
        idx = idx.copy()
        rng.shuffle(idx)
        for j, i in enumerate(idx):
            buckets[j % K].append(i)
    folds = [np.array(sorted(b)) for b in buckets]
    return folds

def eval_on_indices(model: nn.Module, dataset, idxs: np.ndarray, loss_fn: nn.Module, metrics: Optional[Dict[str, Callable]] = None) -> Dict:
    device = next(model.parameters()).device
    pin_memory = device.type == "cuda"
    loader = DataLoader(
        Subset(dataset, idxs.tolist()),
        batch_size=128,
        shuffle=False,
        pin_memory=pin_memory,
    )
    model.eval()
    total_loss = 0.0
    total_n = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device, non_blocking=pin_memory), yb.to(device, non_blocking=pin_memory)
            out = model(xb)
            loss = loss_fn(out, yb)
            bs = xb.shape[0]
            total_loss += float(loss.item()) * bs
            total_n += bs
    result = {"loss": total_loss / max(total_n, 1), "n": total_n}
    # Metrics (optional, averaged)
    if metrics:
        metric_sums = {k: 0.0 for k in metrics}
        counts = 0
        with torch.no_grad():
            for xb, yb in loader:
                xb, yb = xb.to(device, non_blocking=pin_memory), yb.to(device, non_blocking=pin_memory)
                out = model(xb)
                for name, fn in metrics.items():
                    metric_sums[name] += float(fn(out, yb))
                counts += 1
        for name in metrics:
            result[name] = metric_sums[name] / max(counts, 1)
    return result


def parameters_to_vector(params):
    return torch.cat([p.detach().reshape(-1) for p in params])

def vector_to_parameters(vec, params) -> None:
    offset = 0
    for p in params:
        num = p.numel()
        p.data.copy_(vec[offset:offset+num].view_as(p))
        offset += num


def flatten_params(model: nn.Module) -> torch.Tensor:
    return parameters_to_vector([p.detach() for p in model.parameters()])


def set_params_from_vector(model: nn.Module, vec: torch.Tensor) -> None:
    vector_to_parameters(vec, model.parameters())


def apply_param_update(model: nn.Module, vec: torch.Tensor) -> None:
    set_params_from_vector(model, vec)



##############################################
# CLUSTERING HELPERS
##############################################

@dataclass
class ClusterStats:
    labels: np.ndarray
    sizes: np.ndarray
    means_full: np.ndarray
    within_var: np.ndarray

def _kmeans_numpy(X: np.ndarray, C: int, iters: int = 50, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    n = X.shape[0]
    centroids = X[rng.choice(n, size=C, replace=False)]
    labels = np.zeros(n, dtype=np.int64)
    for _ in range(iters):
        # assign
        dists = ((X[:, None, :] - centroids[None, :, :]) ** 2).sum(axis=2)
        new_labels = dists.argmin(axis=1)
        if np.array_equal(new_labels, labels):
            break
        labels = new_labels
        # update
        for c in range(C):
            idx = np.where(labels == c)[0]
            if len(idx) > 0:
                centroids[c] = X[idx].mean(axis=0)
    return labels

def kmeans_cluster(projected: np.ndarray, C: int, seed: int) -> np.ndarray:
    try:
        from sklearn.cluster import KMeans  # type: ignore
        km = KMeans(n_clusters=C, random_state=seed, n_init=10)
        return km.fit_predict(projected)
    except Exception:
        return _kmeans_numpy(projected, C=C, seed=seed)

def cluster_means_full(full_grads: np.ndarray, labels: np.ndarray) -> ClusterStats:
    C = int(labels.max()) + 1
    P = full_grads.shape[1]
    means = np.zeros((C, P), dtype=full_grads.dtype)
    sizes = np.zeros(C, dtype=np.int64)
    within = np.zeros(C, dtype=full_grads.dtype)
    for c in range(C):
        idx = np.where(labels == c)[0]
        sizes[c] = len(idx)
        if sizes[c] > 0:
            Gc = full_grads[idx]
            mu = Gc.mean(axis=0)
            means[c] = mu
            within[c] = np.mean(((Gc - mu) ** 2).sum(axis=1))
    return ClusterStats(labels=labels, sizes=sizes, means_full=means, within_var=within)

##############################################
# JL HELPERS
##############################################

def make_jl_matrix(out_dim: int, in_dim: int, seed: int):
    import numpy as np
    rng = np.random.default_rng(seed)
    # Achlioptas-like ±1 / sqrt(k)
    P = rng.integers(0, 2, size=(out_dim, in_dim), endpoint=False)
    P = (P * 2 - 1).astype(np.float32) / math.sqrt(out_dim)
    return torch.tensor(P, dtype=torch.float32)

def project(J, P):
    return J @ P.T

##############################################
# Torch-influence
##############################################


def _set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        _set_attr(getattr(obj, names[0]), names[1:], val)


def _del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        _del_attr(getattr(obj, names[0]), names[1:])


class BaseObjective(abc.ABC):
    """An abstract adapter that provides torch-influence with project-specific information
    about how training and test objectives are computed.

    In order to use torch-influence in your project, a subclass of this module should be
    created that implements this module's four abstract methods.
    """

    @abc.abstractmethod
    def train_outputs(self, model: nn.Module, batch: Any) -> torch.Tensor:
        """Returns a batch of model outputs (e.g., logits, probabilities) from a batch of data.

        Args:
            model: the model.
            batch: a batch of training data.

        Returns:
            the model outputs produced from the batch.
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_loss_on_outputs(self, outputs: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of the model outputs produced from a batch of data.

        Args:
            outputs: a batch of model outputs.
            batch: a batch of training data.

        Returns:
            the loss of the outputs over the batch.

        Note:
            There may be some ambiguity in how to define :meth:`train_outputs()` and
            :meth:`train_loss_on_outputs()`: what point in the forward pass deliniates
            outputs from loss function? For example, in binary classification, the
            outputs can reasonably be taken to be the model logits or normalized probabilities.

            For standard use of influence functions, both choices produce the same behaviour.
            However, if using the Gauss-Newton Hessian approximation for influence functions,
            we require that :meth:`train_loss_on_outputs()` be convex in the model
            outputs.

        See also:
            :class:`CGInfluenceModule`
            :class:`LiSSAInfluenceModule`
        """

        raise NotImplementedError()

    @abc.abstractmethod
    def train_regularization(self, params: torch.Tensor) -> torch.Tensor:
        """Returns the regularization loss at a set of model parameters.

        Args:
            params: a flattened vector of model parameters.

        Returns:
            the regularization loss.
        """

        raise NotImplementedError()

    def train_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced regularized loss of a model over a batch of data.

        This method should not be overridden for most use cases. By default, torch-influence
        takes and expects the overall training loss to be::

            outputs = train_outputs(model, batch)
            loss = train_loss_on_outputs(outputs, batch) + train_regularization(params)

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of training data.

        Returns:
            the training loss over the batch.
        """

        outputs = self.train_outputs(model, batch)
        return self.train_loss_on_outputs(outputs, batch) + self.train_regularization(params)

    @abc.abstractmethod
    def test_loss(self, model: nn.Module, params: torch.Tensor, batch: Any) -> torch.Tensor:
        """Returns the **mean**-reduced loss of a model over a batch of data.

        Args:
            model: the model.
            params: a flattened vector of the model's parameters.
            batch: a batch of test data.

        Returns:
            the test loss over the batch.
        """

        raise NotImplementedError()


class BaseInfluenceModule(abc.ABC):
    """The core module that contains convenience methods for computing influence functions.

    Args:
        model: the model of interest.
        objective: an implementation of :class:`BaseObjective`.
        train_loader: a training dataset loader.
        test_loader: a test dataset loader.
        device: the device on which operations are performed.
    """

    def __init__(
            self,
            model: nn.Module,
            objective: BaseObjective,
            train_loader: data.DataLoader,
            test_loader: data.DataLoader,
            device: torch.device
    ):
        model.eval()
        self.model = model.to(device)
        self.device = device

        self.is_model_functional = False
        self.params_names = tuple(name for name, _ in self._model_params())
        self.params_shape = tuple(p.shape for _, p in self._model_params())

        self.objective = objective
        self.train_loader = train_loader
        self.test_loader = test_loader

    @abc.abstractmethod
    def inverse_hvp(self, vec: torch.Tensor) -> torch.Tensor:
        """Computes an inverse-Hessian vector product, where the Hessian is specifically
        that of the (mean) empirical risk over the training dataset.

        Args:
            vec: a vector.

        Returns:
            the inverse-Hessian vector product.
        """

        raise NotImplementedError()

    # ====================================================
    # Interface functions
    # ====================================================

    def train_loss_grad(self, train_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) training loss over a set of training
        data points with respect to the model's flattened parameters.

        Args:
            train_idxs: the indices of the training points.

        Returns:
            the loss gradient at the training points.
        """

        return self._loss_grad(train_idxs, train=True)

    def test_loss_grad(self, test_idxs: List[int]) -> torch.Tensor:
        """Returns the gradient of the (mean) test loss over a set of test
        data points with respect to the model's flattened parameters.

        Args:
           test_idxs: the indices of the test points.

        Returns:
           the loss gradient at the test points.
        """

        return self._loss_grad(test_idxs, train=False)

    def stest(self, test_idxs: List[int]) -> torch.Tensor:
        """This function simply composes :func:`inverse_hvp` with :func:`test_loss_grad`.

        In the original influence function paper, the resulting vector was called
        :math:`\mathbf{s}_{\mathrm{test}}`.

        Args:
            test_idxs: the indices of the test points.

        Returns:
            the :math:`\mathbf{s}_{\mathrm{test}}` vector.
        """

        return self.inverse_hvp(self.test_loss_grad(test_idxs))

    def influences(
            self,
            train_idxs: List[int],
            test_idxs: List[int],
            stest: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Returns the influence scores of a set of training data points with respect to
        the (mean) test loss over a set of test data points.

        Specifically, this method returns a 1D tensor of ``len(train_idxs)`` influence scores.
        These scores estimate the following quantities:

            Let :math:`\mathcal{L}_0` be the (mean) test loss of the current model
            over the input test points. Suppose we produce a new model by (1) removing
            the ``train_idxs[i]``-th example from the training dataset and (2) retraining
            the model on this one-smaller dataset. Let :math:`\mathcal{L}` be the (mean)
            test loss of the **new** model over the input test points. Then the ``i``-th
            influence score estimates :math:`\mathcal{L} - \mathcal{L}_0`.

        Args:
            train_idxs: the indices of the training points.
            test_idxs: the indices of the test points.
            stest: this method requires the :math:`\mathbf{s}_{\mathrm{test}}` vector of
                the input test points. If not ``None``, this argument will be used taken as
                :math:`\mathbf{s}_{\mathrm{test}}`. Otherwise, :math:`\mathbf{s}_{\mathrm{test}}`
                will be computed internally with :meth:`stest`.

        Returns:
            the influence scores.
        """

        stest = self.stest(test_idxs) if (stest is None) else stest.to(self.device)

        scores = []
        for grad_z, _ in self._loss_grad_loader_wrapper(batch_size=1, subset=train_idxs, train=True):
            s = grad_z @ stest
            scores.append(s)
        return torch.tensor(scores) / len(self.train_loader.dataset)

    # ====================================================
    # Private helper functions
    # ====================================================

    # Model and parameter helpers

    def _model_params(self, with_names=True):
        assert not self.is_model_functional
        return tuple((name, p) if with_names else p for name, p in self.model.named_parameters() if p.requires_grad)

    def _model_make_functional(self):
        assert not self.is_model_functional
        params = tuple(p.detach().requires_grad_() for p in self._model_params(False))

        for name in self.params_names:
            _del_attr(self.model, name.split("."))
        self.is_model_functional = True

        return params

    def _model_reinsert_params(self, params, register=False):
        for name, p in zip(self.params_names, params):
            _set_attr(self.model, name.split("."), torch.nn.Parameter(p) if register else p)
        self.is_model_functional = not register

    def _flatten_params_like(self, params_like):
        vec = []
        for p in params_like:
            vec.append(p.view(-1))
        return torch.cat(vec)

    def _reshape_like_params(self, vec):
        pointer = 0
        split_tensors = []
        for dim in self.params_shape:
            num_param = dim.numel()
            split_tensors.append(vec[pointer: pointer + num_param].view(dim))
            pointer += num_param
        return tuple(split_tensors)

    # Data helpers

    def _transfer_to_device(self, batch):
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        elif isinstance(batch, (tuple, list)):
            return type(batch)(self._transfer_to_device(x) for x in batch)
        elif isinstance(batch, dict):
            return {k: self._transfer_to_device(x) for k, x in batch.items()}
        else:
            raise NotImplementedError()

    def _loader_wrapper(self, train, batch_size=None, subset=None, sample_n_batches=-1):
        loader = self.train_loader if train else self.test_loader
        if subset is None and sample_n_batches == -1 and batch_size is None:
            data_left = len(loader.dataset)
            for batch in loader:
                batch = self._transfer_to_device(batch)
                size = min(loader.batch_size, data_left)
                yield batch, size
                data_left -= size
            return

    # Loss and autograd helpers

    def _loss_grad_loader_wrapper(self, train, **kwargs):
        params = self._model_params(with_names=False)
        flat_params = self._flatten_params_like(params)

        for batch, batch_size in self._loader_wrapper(train=train, **kwargs):
            loss_fn = self.objective.train_loss if train else self.objective.test_loss
            loss = loss_fn(model=self.model, params=flat_params, batch=batch)
            yield self._flatten_params_like(torch.autograd.grad(loss, params)), batch_size

    def _loss_grad(self, idxs, train):
        grad = 0.0
        for grad_batch, batch_size in self._loss_grad_loader_wrapper(subset=idxs, train=train):
            grad = grad + grad_batch * batch_size
        return grad / len(idxs)

    def _hvp_at_batch(self, batch, flat_params, vec, gnh):

        def f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_loss(self.model, theta_, batch)

        def out_f(theta_):
            self._model_reinsert_params(self._reshape_like_params(theta_))
            return self.objective.train_outputs(self.model, batch)

        def loss_f(out_):
            return self.objective.train_loss_on_outputs(out_, batch)

        def reg_f(theta_):
            return self.objective.train_regularization(theta_)

        if gnh:
            y, jvp = torch.autograd.functional.jvp(out_f, flat_params, v=vec)
            hjvp = torch.autograd.functional.hvp(loss_f, y, v=jvp)[1]
            gnhvp_batch = torch.autograd.functional.vjp(out_f, flat_params, v=hjvp)[1]
            return gnhvp_batch + torch.autograd.functional.hvp(reg_f, flat_params, v=vec)[1]
        else:
            return torch.autograd.functional.hvp(f, flat_params, v=vec)[1]
    


class CGInfluenceModule(BaseInfluenceModule):
    def __init__(
        self,
        model: nn.Module,
        objective: BaseObjective,
        train_loader: data.DataLoader,
        test_loader: data.DataLoader,
        device: torch.device,
        damp: float,
        gnh: bool = False,
        **kwargs
    ):
        super().__init__(model, objective, train_loader, test_loader, device)
        self.damp = damp
        self.gnh = gnh

        self.param_names  = [n for n, _ in model.named_parameters()]
        self.param_shapes = [p.shape for _, p in model.named_parameters()]
        self.param_sizes  = [p.numel() for _, p in model.named_parameters()]
        self.total_dim    = sum(self.param_sizes)

        # freeze buffers as a dict for functional_call
        self.buffer_map = {n: b.to(device) for n, b in model.named_buffers()}

        # accepted: tol, maxiter, x0
        self.cg_kwargs = kwargs


    def _flat_to_param_dict(self, theta_flat: torch.Tensor) -> dict:
        splits = theta_flat.split(self.param_sizes)
        tensors = [t.view(s) for t, s in zip(splits, self.param_shapes)]
        return {n: t for n, t in zip(self.param_names, tensors)}

    def _flat_to_params_tuple(self, theta_flat: torch.Tensor):
        splits = theta_flat.split(self.param_sizes)
        return tuple(t.view(s) for t, s in zip(splits, self.param_shapes))

    def _current_theta_flat(self) -> torch.Tensor:
        name2p = dict(self.model.named_parameters())
        return torch.cat([name2p[n].detach().reshape(-1).to(self.device) for n in self.param_names])

    def _forward_with_flat(self, theta_flat: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        pdict = self._flat_to_param_dict(theta_flat)
        return functional_call(self.model, {**pdict, **self.buffer_map}, (x,))


    def gnhvp_batch_autograd(self, batch, flat_params, vec, gnh: bool):
        x, y = batch
        x = x.to(self.device, non_blocking=True)
        if isinstance(y, torch.Tensor):
            y = y.to(self.device, non_blocking=True)

        flat_params = flat_params.to(self.device).requires_grad_(True)
        vec = vec.to(self.device)

        def out_f(theta_flat):
            return self._forward_with_flat(theta_flat, x)  # logits

        def loss_f(logits):
            return self.objective.train_loss_on_outputs(logits, (x, y))  # scalar

        def reg_f(theta_flat):
            # zero reg, scalar on same device/dtype
            return self.objective.train_regularization(theta_flat)

        if gnh:
            # Your requested sequence
            y_out, jvp = torch.autograd.functional.jvp(out_f, (flat_params,), (vec,))
            _, hjvp    = torch.autograd.functional.hvp(loss_f, (y_out,), (jvp,))
            _, gnhvp   = torch.autograd.functional.vjp(out_f, (flat_params,), v=hjvp)
            _, reg_hvp = torch.autograd.functional.hvp(reg_f, (flat_params,), (vec,), strict=True)
            return gnhvp[0] + reg_hvp[0]
        else:
            # Exact Hessian HVP on θ
            def f(theta_flat):
                logits = self._forward_with_flat(theta_flat, x)
                return self.objective.train_loss_on_outputs(logits, (x, y)) + self.objective.train_regularization(
                    self._flat_to_params_tuple(theta_flat)
                )
            _, hvp = torch.autograd.functional.hvp(f, (flat_params,), (vec,), strict=True)
            return hvp[0]

    @torch.no_grad()
    def _cg_torch(self, matvec, b: torch.Tensor, tol=1e-5, maxiter=200, x0=None):
        """Conjugate Gradients that calls matvec(p) on the same device as b."""
        x = torch.zeros_like(b) if x0 is None else x0.clone()
        r = b - matvec(x)
        p = r.clone()
        rr = torch.dot(r.flatten(), r.flatten())
        rr_history = [torch.sqrt(rr)]
        i = 0
        while i < maxiter and torch.sqrt(rr) > tol:
            Ap = matvec(p)
            denom = torch.dot(p.flatten(), Ap.flatten())
            alpha = rr / denom
            x = x + alpha * p
            r = r - alpha * Ap
            rr_new = torch.dot(r.flatten(), r.flatten())
            beta = rr_new / rr
            p = r + beta * p
            rr = rr_new
            rr_history.append(torch.sqrt(rr))
            i += 1

        return x, i, rr, [r.item() for r in rr_history]

    def check_cg_stability(self, matvec, b: torch.Tensor, rtol1=1e-6, rtol2=1e-8, maxiter=200):
        """Check stability of CG solution by comparing results with different tolerances."""
        x1, iter1, rr1, hist1 = self._cg_torch(matvec, b, tol=rtol1, maxiter=maxiter)
        x2, iter2, rr2, hist2 = self._cg_torch(matvec, b, tol=rtol2, maxiter=maxiter)
        
        # Compute relative difference
        diff_norm = torch.norm(x1 - x2)
        x2_norm = torch.norm(x2)
        rel_diff = diff_norm / x2_norm if x2_norm > 0 else torch.tensor(float('inf'))
        
        is_stable = rel_diff < 1e-3
        
        return {
            'is_stable': is_stable,
            'relative_difference': rel_diff.item(),
            'iterations_rtol1': iter1,
            'iterations_rtol2': iter2,
            'residual_rtol1': torch.sqrt(rr1).item(),
            'residual_rtol2': torch.sqrt(rr2).item(),
            'residual_history_rtol1': hist1,
            'residual_history_rtol2': hist2,
        }
    
    def verify_ihvp(self, v: torch.Tensor, ihvp: torch.Tensor, matvec):
        """Verify the inverse-Hessian-vector product by checking if H(H⁻¹v) ≈ v."""
        Hx = matvec(ihvp)
        residual = v - Hx
        residual_norm = torch.norm(residual)
        v_norm = torch.norm(v)
        rel_error = residual_norm / v_norm if v_norm > 0 else torch.tensor(float('inf'))
        
        return {
            'residual_norm': residual_norm.item(),
            'relative_error': rel_error.item(),
            'is_accurate': rel_error < 1e-3
        }

    def inverse_hvp(self, vec: torch.Tensor, check_stability=False, verify=True, curvature_batches=16):
        # vec must be flat tensor on self.device
        vec = vec.to(self.device)

        flat_params = self._current_theta_flat().detach().to(self.device)
        flat_params = flat_params.requires_grad_(True)

        # Cache the full training set on self.device for curvature computation
        cached = []
        for batch in self.train_loader:
            batch = self._transfer_to_device(batch)
            bs = batch[0].shape[0]
            cached.append((batch, bs))

        if len(cached) == 0:
            raise RuntimeError("Cached curvature batches had zero batches.")

        # Optional: use only a seeded subset of curvature batches.
        # This enables controlled randomness (same model + same query) while keeping
        # the CG solve deterministic given the sampled curvature.
        curvature_subsample = bool(self.cg_kwargs.get("curvature_subsample", False))
        curvature_seed = self.cg_kwargs.get("curvature_seed", None)
        curvature_batch_frac = self.cg_kwargs.get("curvature_batch_frac", None)

        cached_used = cached
        if curvature_subsample:
            if curvature_batch_frac is not None:
                try:
                    frac = float(curvature_batch_frac)
                except Exception:
                    frac = None
                if frac is not None:
                    frac = max(0.0, min(1.0, frac))
                    if frac > 0.0:
                        curvature_batches = max(1, int(round(frac * len(cached))))

            try:
                n_use = int(curvature_batches)
            except Exception:
                n_use = len(cached)

            n_use = max(1, min(len(cached), n_use))

            gen = torch.Generator(device="cpu")
            if curvature_seed is None:
                gen.manual_seed(int(torch.initial_seed()) % (2**63 - 1))
            else:
                gen.manual_seed(int(curvature_seed) % (2**63 - 1))

            perm = torch.randperm(len(cached), generator=gen)
            idx = perm[:n_use].tolist()
            cached_used = [cached[i] for i in idx]

        n_cached = sum(bs for _, bs in cached_used)
        if n_cached <= 0:
            raise RuntimeError("Cached curvature batches had zero total examples.")

        def hvp_matvec(v):
            v = v.to(self.device)
            with torch.enable_grad():
                hvp = torch.zeros_like(v)
                for batch, bs in cached_used:
                    hvp_b = self.gnhvp_batch_autograd(batch, flat_params, vec=v, gnh=self.gnh)
                    hvp.add_(hvp_b, alpha=bs)
                hvp = hvp / n_cached          # IMPORTANT: divide by cached count, not full N
                hvp = hvp + self.damp * v
                return hvp.detach()
        
        # CG args
        tol = self.cg_kwargs.get('tol', 1e-5)
        maxiter = self.cg_kwargs.get('maxiter', 200)
        x0 = self.cg_kwargs.get('x0', None)
        if x0 is not None:
            x0 = x0.to(self.device)

        stability_info = None
        verification_info = None
        
        # Run stability check if requested
        if check_stability:
            stability_info = self.check_cg_stability(
                hvp_matvec, vec, 
                rtol1=min(tol*10, 1e-5),
                rtol2=tol,
                maxiter=maxiter
            )
            if not stability_info['is_stable']:
                logging.warning(f"CG solution may be unstable. Relative difference: {stability_info['relative_difference']:.6e}")
        
        # Run regular CG
        ihvp, info, rr, rr_history = self._cg_torch(
            hvp_matvec, b=vec, tol=tol, maxiter=maxiter, x0=x0
        )
        
        # Verify IHVP if requested
        if verify:
            verification_info = self.verify_ihvp(vec, ihvp, hvp_matvec)
            if not verification_info['is_accurate']:
                logging.warning(f"IHVP verification failed. Relative error: {verification_info['relative_error']:.6e}")
        with torch.no_grad():
            self._model_reinsert_params(self._reshape_like_params(flat_params), register=True)
        result = {
            'ihvp': ihvp,
            'iterations': info,
            'final_residual': torch.sqrt(rr).item(),
            'residual_history': rr_history
        }
        
        if stability_info:
            result['stability_info'] = stability_info
        
        if verification_info:
            result['verification_info'] = verification_info
            
        return result
