from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

import torch

try:
    import torch.distributed as dist
except Exception:  # pragma: no cover
    dist = None

try:
    from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
except Exception:  # pragma: no cover
    ColumnParallelLinear = None
    RowParallelLinear = None

from quant_layerwise.hadamard import BlockRandomHadamard


def _resize_kv_caches(model: torch.nn.Module, new_batch_size: int):
    """Resize KV caches in all attention layers to accommodate larger batch sizes.

    The model pre-allocates KV caches with max_batch_size (default 32).
    For batched Hessian computation, we may need larger batches.
    """
    for layer in model.layers:
        attn = layer.attention
        old_cache_k = attn.cache_k
        old_cache_v = attn.cache_v

        # Only resize if needed
        if old_cache_k.shape[0] >= new_batch_size:
            continue

        # Get cache dimensions
        _, max_seq_len, n_kv_heads, head_dim = old_cache_k.shape

        # Create new larger caches
        attn.cache_k = torch.zeros(
            (new_batch_size, max_seq_len, n_kv_heads, head_dim),
            device=old_cache_k.device,
            dtype=old_cache_k.dtype,
        )
        attn.cache_v = torch.zeros(
            (new_batch_size, max_seq_len, n_kv_heads, head_dim),
            device=old_cache_v.device,
            dtype=old_cache_v.dtype,
        )

        # Free old caches
        del old_cache_k, old_cache_v


class ActivationCache:
    """Cache hidden states at transformer block boundaries to avoid quadratic recomputation.

    Instead of running the full model for each layer's Hessian computation,
    we cache activations and only run the target block.
    """

    def __init__(
        self,
        model: torch.nn.Module,
        dataset: torch.Tensor,
        seqlen: int,
        nsamples: int,
        device: torch.device,
        dtype: torch.dtype | None = None,
        batch_size: int = 1,
    ):
        """Initialize cache by computing embeddings for all samples.

        Args:
            model: Transformer model
            dataset: Token IDs shaped (total_len,) or (nsamples, seqlen)
            seqlen: Sequence length
            nsamples: Number of samples to cache
            device: Device for cached tensors
            dtype: Data type for cached tensors (default: auto-detect from model)
            batch_size: Batch size for processing (used to resize KV caches if needed)
        """
        self.model = model
        self.seqlen = seqlen
        self.nsamples = nsamples
        self.device = device
        # Auto-detect dtype from model if not specified
        if dtype is None:
            dtype = next(model.parameters()).dtype
        self.dtype = dtype

        # Resize KV caches if batch_size exceeds model's max_batch_size
        if batch_size > 1:
            _resize_kv_caches(model, batch_size)

        # Current block index - activations are valid as input to this block
        self.current_block_idx = 0

        # Reshape dataset if needed
        if dataset.ndim == 1:
            total_len = int(dataset.shape[0])
            nseq = total_len // seqlen
            dataset = dataset[: nseq * seqlen].reshape(nseq, seqlen)

        use = min(nsamples, dataset.shape[0])
        self.nsamples = use

        # Compute and cache embeddings for all samples
        # Shape: [nsamples, seqlen, dim]
        self._cached_h = self._compute_embeddings(dataset[:use])

        # Precompute freqs_cis and mask (reused for all blocks)
        self._freqs_cis = model.freqs_cis[:seqlen].to(device)
        self._mask = None
        if seqlen > 1:
            mask = torch.full((seqlen, seqlen), float("-inf"), device=device)
            mask = torch.triu(mask, diagonal=1)
            mask = torch.hstack([
                torch.zeros((seqlen, 0), device=device),  # start_pos=0
                mask
            ])
            self._mask = mask.to(self.dtype)

    @torch.no_grad()
    def _compute_embeddings(self, tokens: torch.Tensor) -> torch.Tensor:
        """Compute token embeddings for all samples."""
        embeddings_list = []
        for i in range(tokens.shape[0]):
            batch = tokens[i:i+1].to(self.device)
            h = self.model.tok_embeddings(batch)
            embeddings_list.append(h.to(self.dtype).cpu())
        # Store on CPU to save GPU memory, move to GPU on demand
        return torch.cat(embeddings_list, dim=0)

    def get_cached_activations(self, sample_idx: int) -> torch.Tensor:
        """Get cached activations for a specific sample.

        Returns tensor of shape [1, seqlen, dim] on the target device.
        """
        return self._cached_h[sample_idx:sample_idx+1].to(self.device).to(self.dtype)

    def get_cached_activations_batch(self, start_idx: int, batch_size: int) -> torch.Tensor:
        """Get cached activations for a batch of samples.

        Args:
            start_idx: Starting sample index
            batch_size: Number of samples to retrieve

        Returns:
            Tensor of shape [batch_size, seqlen, dim] on the target device.
        """
        end_idx = min(start_idx + batch_size, self.nsamples)
        return self._cached_h[start_idx:end_idx].to(self.device).to(self.dtype)

    @torch.no_grad()
    def advance_through_block(self, block_idx: int, batch_size: int = 1):
        """Advance cached activations through a transformer block.

        Should be called after all weights in block_idx have been quantized.
        Updates the cache to hold activations that are input to block_idx+1.

        Args:
            block_idx: Which block to advance through
            batch_size: Number of samples to process in parallel (default: 1)
        """
        if block_idx != self.current_block_idx:
            raise ValueError(
                f"Cannot advance through block {block_idx}, "
                f"cache is at block {self.current_block_idx}"
            )

        layer = self.model.layers[block_idx]
        new_h_list = []

        # Process in batches for better GPU utilization
        for i in range(0, self.nsamples, batch_size):
            h = self.get_cached_activations_batch(i, batch_size)
            # Run the full block on the batch
            h_out = layer(h, start_pos=0, freqs_cis=self._freqs_cis, mask=self._mask)
            new_h_list.append(h_out.to(self.dtype).cpu())

        self._cached_h = torch.cat(new_h_list, dim=0)
        self.current_block_idx = block_idx + 1


def is_linear(module: torch.nn.Module) -> bool:
    if ColumnParallelLinear is not None and isinstance(module, ColumnParallelLinear):
        return True
    if RowParallelLinear is not None and isinstance(module, RowParallelLinear):
        return True
    return isinstance(module, torch.nn.Linear)


@dataclass
class HessianResult:
    H: torch.Tensor
    nseq: int
    ntokens: int


class RuntimeHessian:
    """Hook-based Hessian accumulator."""

    def __init__(
        self,
        module: torch.nn.Module,
        *,
        dst_rank: int = 0,
        hadamard: Optional[BlockRandomHadamard] = None,
        dtype: torch.dtype = torch.float32,
    ):
        if not is_linear(module):
            raise TypeError(f"RuntimeHessian expects a linear module, got {type(module)}")

        self.module = module
        self.dst_rank = int(dst_rank)
        self.hadamard = hadamard
        self.dtype = dtype

        n = int(module.in_features)
        self.n = n

        self._dist_enabled = bool(dist is not None and dist.is_available() and dist.is_initialized())
        self._world_size = dist.get_world_size() if self._dist_enabled else 1
        self._rank = dist.get_rank() if self._dist_enabled else 0

        self._do_gather = False
        if RowParallelLinear is not None and isinstance(module, RowParallelLinear) and self._world_size > 1:
            self._do_gather = True

        self._is_main = (self._rank == self.dst_rank)
        if self._is_main:
            self.H = torch.zeros((n, n), device=module.weight.device, dtype=self.dtype)
            self.nseq = 0
            self.ntokens = 0
        else:
            self.H = None
            self.nseq = 0
            self.ntokens = 0

        # Register hook.
        self._handle = module.register_forward_pre_hook(self._hook)

    def close(self):
        if self._handle is not None:
            self._handle.remove()
            self._handle = None

    def _maybe_gather(self, X: torch.Tensor) -> Optional[torch.Tensor]:
        """Gather RowParallelLinear inputs to dst_rank so we can form full X."""
        if not self._do_gather:
            return X

        assert dist is not None and dist.is_initialized(), "distributed must be initialized for gather"
        if self._rank == self.dst_rank:
            tensor_list = [torch.zeros_like(X) for _ in range(self._world_size)]
            dist.gather(X, tensor_list, dst=self.dst_rank)
            X = torch.cat(tensor_list, dim=-1)
            return X
        else:
            dist.gather(X, None, dst=self.dst_rank)
            return None

    def _hook(self, _module: torch.nn.Module, inputs):
        X = inputs[0]
        X = self._maybe_gather(X)
        if X is None:
            return

        if not self._is_main:
            return

        self.nseq += int(X.shape[0])

        X = X.detach().reshape(-1, X.shape[-1]).to(self.dtype)
        self.ntokens += int(X.shape[0])

        if self.hadamard is not None:
            X = self.hadamard.right(X, transpose=False)

        self.H.addmm_(X.T, X)

    def get(
        self,
        *,
        normalize: bool = True,
        normalize_by: str = "tokens",
        eps: float = 1e-12,
    ) -> HessianResult:
        if not self._is_main:
            raise RuntimeError("Only dst_rank process can call get().")
        if self.nseq <= 0:
            raise RuntimeError("No samples were accumulated; did the model run forward?")

        H = self.H
        if normalize:
            nb = str(normalize_by).strip().lower()
            if nb in ("seq", "sequences", "batch"):
                denom = max(self.nseq, 1)
            elif nb in ("token", "tokens"):
                denom = max(self.ntokens, 1)
            else:
                raise ValueError(f"normalize_by must be 'seq' or 'tokens', got {normalize_by!r}")
            H = H / float(denom)

        # Add a tiny ridge for numerical stability.
        n = H.shape[0]
        diag = torch.arange(n, device=H.device)
        H = H.clone()
        H[diag, diag] += eps

        return HessianResult(H=H, nseq=int(self.nseq), ntokens=int(self.ntokens))


@torch.no_grad()
def compute_module_hessian(
    model: torch.nn.Module,
    dataset: torch.Tensor,
    module_name: str,
    *,
    seqlen: int,
    nsamples: int,
    dst_rank: int = 0,
    hadamard_cfg: Optional[dict] = None,
    normalize: bool = True,
    normalize_by: str = "tokens",
    dtype: torch.dtype = torch.float32,
    verbose: bool = True,
):
    """Compute Hessian for one module by running `nsamples` sequences through the model.

    Args:
        model: Transformer
        dataset: token IDs shaped (total_len,) on CPU, or (nsamples, seqlen) on GPU.
        module_name: name in model.named_modules()
        seqlen: sequence length
        nsamples: number of sequences/batches (batch size 1)
        dst_rank: for distributed gathering
        hadamard_cfg: if enabled, compute Hessian in rotated basis
        normalize: return average vs sum
    """
    mods = dict(model.named_modules())
    if module_name not in mods:
        raise KeyError(f"Module '{module_name}' not found in model")

    module = mods[module_name]

    had = None
    if hadamard_cfg and hadamard_cfg.get("enabled", False):
        had = BlockRandomHadamard(
            int(module.in_features),
            seed=int(hadamard_cfg.get("seed", 0)),
            block_size=int(hadamard_cfg.get("block_size", 128)),
            device=module.weight.device,
            dtype=dtype,
        )

    acc = RuntimeHessian(module, dst_rank=dst_rank, hadamard=had, dtype=dtype)

    try:
        # Ensure dataset shaped (nsamples, seqlen)
        if dataset.ndim == 1:
            total_len = int(dataset.shape[0])
            nseq = total_len // seqlen
            if nseq <= 0:
                raise ValueError("Dataset too short for requested seqlen")
            dataset2 = dataset[: nseq * seqlen].reshape(nseq, seqlen)
        else:
            dataset2 = dataset

        nseq = int(dataset2.shape[0])
        use = min(nsamples, nseq)

        if verbose:
            print(f"[hessian] module={module_name}  seqlen={seqlen}  nsamples={use}")

        # Run forward passes (batch size 1)
        for i in range(use):
            batch = dataset2[i : i + 1].to(module.weight.device)
            _ = model(batch, start_pos=0)

        out = acc.get(normalize=normalize, normalize_by=normalize_by)
        return out.H, out.nseq, out.ntokens

    finally:
        acc.close()


@torch.no_grad()
def compute_module_hessian_cached(
    model: torch.nn.Module,
    cache: ActivationCache,
    layer_id: int,
    module_name: str,
    *,
    dst_rank: int = 0,
    hadamard_cfg: Optional[dict] = None,
    normalize: bool = True,
    normalize_by: str = "tokens",
    dtype: torch.dtype = torch.float32,
    verbose: bool = True,
    batch_size: int = 1,
):
    """Compute Hessian using cached activations - only runs ONE transformer block.

    This is O(1) blocks per sample instead of O(N) blocks, making the overall
    pipeline O(N) instead of O(N^2).

    Args:
        model: Transformer model
        cache: ActivationCache with hidden states at current block boundary
        layer_id: Which transformer block (0-indexed)
        module_name: Full module name (e.g., "layers.0.attention.wq")
        dst_rank: For distributed gathering
        hadamard_cfg: If enabled, compute Hessian in rotated basis
        normalize: Return average vs sum
        normalize_by: "seq" or "tokens"
        dtype: Data type for Hessian computation
        verbose: Print progress
        batch_size: Number of samples to process in parallel (default: 1)

    Returns:
        (H, nseq, ntokens) tuple
    """
    if layer_id != cache.current_block_idx:
        raise ValueError(
            f"Cache is at block {cache.current_block_idx}, "
            f"but requested layer_id={layer_id}"
        )

    mods = dict(model.named_modules())
    if module_name not in mods:
        raise KeyError(f"Module '{module_name}' not found in model")

    module = mods[module_name]
    layer = model.layers[layer_id]

    had = None
    if hadamard_cfg and hadamard_cfg.get("enabled", False):
        had = BlockRandomHadamard(
            int(module.in_features),
            seed=int(hadamard_cfg.get("seed", 0)),
            block_size=int(hadamard_cfg.get("block_size", 128)),
            device=module.weight.device,
            dtype=dtype,
        )

    acc = RuntimeHessian(module, dst_rank=dst_rank, hadamard=had, dtype=dtype)

    try:
        if verbose:
            print(f"[hessian-cached] module={module_name}  block={layer_id}  nsamples={cache.nsamples}  batch_size={batch_size}")

        # Run forward passes through ONLY this block (not the full model)
        # Process in batches for better GPU utilization
        for i in range(0, cache.nsamples, batch_size):
            h = cache.get_cached_activations_batch(i, batch_size)
            # Run the transformer block - this triggers the hook on the target module
            _ = layer(h, start_pos=0, freqs_cis=cache._freqs_cis, mask=cache._mask)

        out = acc.get(normalize=normalize, normalize_by=normalize_by)
        return out.H, out.nseq, out.ntokens

    finally:
        acc.close()


# ==============================================================================
# Qronos Statistics Computation
# ==============================================================================

class DualActivationCapture:
    """Capture activations from both unquantized and quantized models."""

    def __init__(
        self,
        module_unquant: torch.nn.Module,
        module_quant: torch.nn.Module,
        dtype: torch.dtype = torch.float32,
    ):
        self.dtype = dtype
        self.X_unquant: Optional[torch.Tensor] = None
        self.X_quant: Optional[torch.Tensor] = None

        # Register hooks
        self._handle_unquant = module_unquant.register_forward_pre_hook(self._hook_unquant)
        self._handle_quant = module_quant.register_forward_pre_hook(self._hook_quant)

    def _hook_unquant(self, _module, inputs):
        X = inputs[0]
        self.X_unquant = X.detach().reshape(-1, X.shape[-1]).to(self.dtype)

    def _hook_quant(self, _module, inputs):
        X = inputs[0]
        self.X_quant = X.detach().reshape(-1, X.shape[-1]).to(self.dtype)

    def close(self):
        if self._handle_unquant is not None:
            self._handle_unquant.remove()
            self._handle_unquant = None
        if self._handle_quant is not None:
            self._handle_quant.remove()
            self._handle_quant = None

    def get_activations(self):
        """Return (X_unquant, X_quant) and clear."""
        X, X_hat = self.X_unquant, self.X_quant
        self.X_unquant = None
        self.X_quant = None
        return X, X_hat


class ResidualCapture:
    """Capture residual stream values (h_in for wo, h_mid for w2) during forward pass."""

    def __init__(
        self,
        layer_unquant: torch.nn.Module,
        layer_quant: torch.nn.Module,
        weight_type: str,  # "wo" or "w2"
        dtype: torch.dtype = torch.float32,
    ):
        self.weight_type = weight_type
        self.dtype = dtype

        # Residual values (skip connection inputs)
        self.R_unquant: Optional[torch.Tensor] = None
        self.R_quant: Optional[torch.Tensor] = None
        # Input to the target layer (wo or w2)
        self.X_hat: Optional[torch.Tensor] = None

        self._handles = []

        if weight_type == "wo":
            # For wo: R is input to block (x), X̂ is input to wo
            # Capture block input as residual
            self._handles.append(
                layer_unquant.register_forward_pre_hook(self._hook_residual_unquant)
            )
            self._handles.append(
                layer_quant.register_forward_pre_hook(self._hook_residual_quant)
            )
            # Capture wo input
            self._handles.append(
                layer_quant.attention.wo.register_forward_pre_hook(self._hook_xhat)
            )
        elif weight_type == "w2":
            # For w2: R is h after attention (before FFN), X̂ is input to w2
            # We need to capture h between attention and FFN
            # Hook on ffn_norm input to get h (the residual)
            self._handles.append(
                layer_unquant.ffn_norm.register_forward_pre_hook(self._hook_residual_unquant)
            )
            self._handles.append(
                layer_quant.ffn_norm.register_forward_pre_hook(self._hook_residual_quant)
            )
            # Capture w2 input
            self._handles.append(
                layer_quant.feed_forward.w2.register_forward_pre_hook(self._hook_xhat)
            )
        else:
            raise ValueError(f"Unsupported weight_type: {weight_type}, must be 'wo' or 'w2'")

    def _hook_residual_unquant(self, _module, inputs):
        x = inputs[0]
        self.R_unquant = x.detach().to(self.dtype)

    def _hook_residual_quant(self, _module, inputs):
        x = inputs[0]
        self.R_quant = x.detach().to(self.dtype)

    def _hook_xhat(self, _module, inputs):
        x = inputs[0]
        self.X_hat = x.detach().reshape(-1, x.shape[-1]).to(self.dtype)

    def close(self):
        for h in self._handles:
            h.remove()
        self._handles = []

    def get_and_clear(self):
        """Return (R_unquant, R_quant, X_hat) and clear."""
        R_u, R_q, X_h = self.R_unquant, self.R_quant, self.X_hat
        self.R_unquant = None
        self.R_quant = None
        self.X_hat = None
        return R_u, R_q, X_h


@torch.no_grad()
def compute_residual_stats_cached(
    model_unquant: torch.nn.Module,
    model_quant: torch.nn.Module,
    cache_unquant: "ActivationCache",
    cache_quant: "ActivationCache",
    layer_id: int,
    weight_type: str,  # "wo" or "w2"
    *,
    normalize: bool = True,
    normalize_by: str = "tokens",
    dtype: torch.dtype = torch.float32,
    verbose: bool = True,
    batch_size: int = 1,
):
    """Compute residual stream compensation statistics using cached activations.

    For wo/w2 layers that output to the residual stream, computes:
    - Σ_{ΔR,X̂} = E[(R - R̂) X̂^T]

    where:
    - R is the residual from unquantized model
    - R̂ is the residual from quantized model
    - X̂ is the input to wo/w2 from the quantized model

    Args:
        model_unquant: Original (unquantized) transformer model
        model_quant: Quantized transformer model
        cache_unquant: Activation cache for unquantized model
        cache_quant: Activation cache for quantized model
        layer_id: Which transformer block (0-indexed)
        weight_type: "wo" or "w2"
        normalize: Whether to normalize by count
        normalize_by: "seq" or "tokens"
        dtype: Data type for computation
        verbose: Print progress
        batch_size: Batch size for processing

    Returns:
        ResidualStats object with Σ_{ΔR,X̂}
    """
    from quant_layerwise.qronos_stats import ResidualStats, ResidualStatsAccumulator

    if layer_id != cache_unquant.current_block_idx:
        raise ValueError(
            f"Unquant cache is at block {cache_unquant.current_block_idx}, "
            f"but requested layer_id={layer_id}"
        )
    if layer_id != cache_quant.current_block_idx:
        raise ValueError(
            f"Quant cache is at block {cache_quant.current_block_idx}, "
            f"but requested layer_id={layer_id}"
        )

    layer_unquant = model_unquant.layers[layer_id]
    layer_quant = model_quant.layers[layer_id]

    # Get dimensions
    if weight_type == "wo":
        module = layer_quant.attention.wo
    elif weight_type == "w2":
        module = layer_quant.feed_forward.w2
    else:
        raise ValueError(f"Unsupported weight_type: {weight_type}")

    out_features = module.weight.shape[0]  # output dim (hidden_dim for residual)
    in_features = module.weight.shape[1]   # input dim
    device = module.weight.device

    # Create accumulator
    acc = ResidualStatsAccumulator(out_features, in_features, device=device, dtype=dtype)

    # Create residual capture hooks
    capture = ResidualCapture(layer_unquant, layer_quant, weight_type, dtype=dtype)

    try:
        if verbose:
            print(f"[residual-stats] weight={weight_type}  block={layer_id}  nsamples={cache_unquant.nsamples}  batch_size={batch_size}")

        for i in range(0, cache_unquant.nsamples, batch_size):
            # Get cached activations (input to block)
            h_unquant = cache_unquant.get_cached_activations_batch(i, batch_size)
            h_quant = cache_quant.get_cached_activations_batch(i, batch_size)

            # Run both blocks - this triggers the hooks
            _ = layer_unquant(h_unquant, start_pos=0, freqs_cis=cache_unquant._freqs_cis, mask=cache_unquant._mask)
            _ = layer_quant(h_quant, start_pos=0, freqs_cis=cache_quant._freqs_cis, mask=cache_quant._mask)

            # Get captured values
            R_unquant, R_quant, X_hat = capture.get_and_clear()

            if R_unquant is not None and R_quant is not None and X_hat is not None:
                acc.accumulate(R_unquant, R_quant, X_hat)

        return acc.get(normalize=normalize, normalize_by=normalize_by)

    finally:
        capture.close()


@torch.no_grad()
def compute_qronos_stats_cached(
    model_unquant: torch.nn.Module,
    model_quant: torch.nn.Module,
    cache_unquant: "ActivationCache",
    cache_quant: "ActivationCache",
    layer_id: int,
    module_name: str,
    *,
    normalize: bool = True,
    normalize_by: str = "tokens",
    dtype: torch.dtype = torch.float32,
    verbose: bool = True,
    batch_size: int = 1,
):
    """Compute Qronos statistics using cached activations.

    Runs both unquantized and quantized transformer blocks and computes:
    - Σ_X̂ = E[X̂ X̂^T]
    - Σ_XX̂ = E[X X̂^T]

    Args:
        model_unquant: Original (unquantized) transformer model
        model_quant: Quantized transformer model (with previously quantized layers applied)
        cache_unquant: Activation cache for unquantized model
        cache_quant: Activation cache for quantized model
        layer_id: Which transformer block (0-indexed)
        module_name: Full module name (e.g., "layers.0.attention.wq")
        normalize: Whether to normalize by count
        normalize_by: "seq" or "tokens"
        dtype: Data type for computation
        verbose: Print progress
        batch_size: Batch size for processing

    Returns:
        QronosStats object with Σ_X̂ and Σ_XX̂
    """
    from quant_layerwise.qronos_stats import QronosStats, QronosStatsAccumulator

    if layer_id != cache_unquant.current_block_idx:
        raise ValueError(
            f"Unquant cache is at block {cache_unquant.current_block_idx}, "
            f"but requested layer_id={layer_id}"
        )
    if layer_id != cache_quant.current_block_idx:
        raise ValueError(
            f"Quant cache is at block {cache_quant.current_block_idx}, "
            f"but requested layer_id={layer_id}"
        )
    if cache_unquant.nsamples != cache_quant.nsamples:
        raise ValueError(
            f"Cache sample count mismatch: {cache_unquant.nsamples} vs {cache_quant.nsamples}"
        )

    mods_unquant = dict(model_unquant.named_modules())
    mods_quant = dict(model_quant.named_modules())

    if module_name not in mods_unquant:
        raise KeyError(f"Module '{module_name}' not found in unquant model")
    if module_name not in mods_quant:
        raise KeyError(f"Module '{module_name}' not found in quant model")

    module_unquant = mods_unquant[module_name]
    module_quant = mods_quant[module_name]
    layer_unquant = model_unquant.layers[layer_id]
    layer_quant = model_quant.layers[layer_id]

    n_features = module_unquant.in_features
    device = module_unquant.weight.device

    # Create accumulator
    acc = QronosStatsAccumulator(n_features, device=device, dtype=dtype)

    # Create dual capture hooks
    capture = DualActivationCapture(module_unquant, module_quant, dtype=dtype)

    try:
        if verbose:
            print(f"[qronos-cached] module={module_name}  block={layer_id}  nsamples={cache_unquant.nsamples}  batch_size={batch_size}")

        for i in range(0, cache_unquant.nsamples, batch_size):
            # Get cached activations
            h_unquant = cache_unquant.get_cached_activations_batch(i, batch_size)
            h_quant = cache_quant.get_cached_activations_batch(i, batch_size)

            # Run both blocks - this triggers the hooks
            _ = layer_unquant(h_unquant, start_pos=0, freqs_cis=cache_unquant._freqs_cis, mask=cache_unquant._mask)
            _ = layer_quant(h_quant, start_pos=0, freqs_cis=cache_quant._freqs_cis, mask=cache_quant._mask)

            # Get captured activations and accumulate
            X, X_hat = capture.get_activations()
            if X is not None and X_hat is not None:
                acc.accumulate(X, X_hat)

        return acc.get(normalize=normalize, normalize_by=normalize_by)

    finally:
        capture.close()
