from functools import partial

import torch.nn.functional as F
from typing import List, Dict, Tuple
import logging
from llava_hooks import register_llava_hooks
from transformers import AutoProcessor

import torch

def ablate_head_llava(layer, head_idx, scheme="zero", patching_cache=None):
    def hook(module, input, output):
        # output: [batch, seq, d_model]
        out = output[0].clone() if isinstance(output, tuple) else output.clone()
        batch, seq, d_model = out.shape
        # Usually d_model = n_heads * d_head
        # Split head dimension
        value = patching_cache[f"llama_layer{layer}_attn"]   # shape [batch, seq, n_heads, d_head] or [batch, n_heads, d_head]
        # Reshape to [batch, seq, n_heads, d_head]
        if value.ndim == 4:
            out = out.view(batch, seq, value.shape[2], value.shape[3])
        elif value.ndim == 3:
            # Compatible with the case without seq dimension
            out = out.view(batch, -1, value.shape[1], value.shape[2])
        else:
            raise RuntimeError(f"Unexpected cache shape: {value.shape}")
        
        if scheme == "mean" and patching_cache is not None:
            # **Use each sample's own cache mean**
            # [batch, seq, n_heads, d_head] → [batch, d_head]
            if value.ndim == 4:
                # [batch, seq, n_heads, d_head] -> mean over seq
                mean_vec = value[:, :, head_idx, :].mean(dim=1)  # [batch, d_head]
            elif value.ndim == 3:
                # [batch, n_heads, d_head] -> mean over all positions
                mean_vec = value[:, head_idx, :].mean(dim=0, keepdim=True).expand(batch, -1)  # [batch, d_head]
            else:
                raise RuntimeError(f"Unexpected cache shape: {value.shape}")
            # Ablate in batch
            out[:, :, head_idx, :] = mean_vec[:, None, :]
            # Compatible with batch=1 reduction
            if mean_vec.ndim == 1:
                mean_vec = mean_vec.unsqueeze(0)
        else:
            # Zero the head of each sample in batch
            out[:, :, head_idx, :] = 0.
        # Restore to flatten
        out = out.view(batch, seq, -1)
        return (out,) + output[1:] if isinstance(output, tuple) else out
    return hook


def ablate_mlp_llava(layer, scheme="zero", patching_cache=None):
    """
    Supports two ablation methods:
    - scheme="mean": ablate with the mean of patching_cache itself
    - scheme="zero": ablate directly with zeros
    patching_cache must be dict[str, Tensor[batch, ...]]
    """
    def hook(module, input, output):
        # output: [batch, seq, d_model] or [batch, d_model]
        out = output[0] if isinstance(output, tuple) else output
        batch = out.shape[0]
        value = patching_cache[f"llama_layer{layer}_mlp"]
        if scheme == "mean" and patching_cache is not None:
            if value.ndim == 3:
                mean_vec = value.mean(dim=1)  # [batch, d_model]
            elif value.ndim == 2:
                mean_vec = value  # [batch, d_model]
            else:
                raise RuntimeError(f"MLP cache shape error: {value.shape}")
            # Compatible with batch=1 reduction
            if mean_vec.ndim == 1:
                mean_vec = mean_vec.unsqueeze(0)
            out[...] = mean_vec[:, None, :].expand_as(out)
        elif scheme == "zero":
            out.zero_()
        else:
            raise ValueError(f"Unsupported ablation scheme: {scheme}")
        return (out,) + output[1:] if isinstance(output, tuple) else out
    return hook

def ablate_head_llava_plus(layer, head_idx, scheme="zero", patching_cache=None,alpha=0.0,global_indices=None):
    """
    Ablate a single head (head_idx):
      - scheme="zero": directly set this head's output to zero
      - scheme="mean": use the mean of all other heads at the same layer, weighted by alpha, to replace;
    This version distinguishes between incremental and full inference:
      - For incremental inference (out.seq == 1, but value.seq > 1), only use the latest value[-1] for statistics,
        then expand to current out.seq=1.
      - For full inference (out.seq == value.seq), for each position, take the head mean as before.

    Args:
      layer: int, layer index in Transformer
      head_idx: int, the attention head index to ablate
      scheme: "zero" or "mean"
      patching_cache: dict, to read this layer's attention output,
                     could be [batch, seq_full, n_heads, d_head] or [batch, n_heads, d_head]
    Returns:
      a forward hook that can be registered to the layer's self_attn
    """
    def hook(module, input, output):
        # Clone the output to avoid modifying the original tensor
        out = output[0].clone() if isinstance(output, tuple) else output.clone()
        batch, seq_out, d_model = out.shape
        device = out.device
        
        # Get value from cache
        cache_key = f"llama_layer{layer}_attn"
        value = patching_cache[cache_key]

        # Use only the batch cache if specified
        if global_indices is not None:
            value = value[global_indices, ...]  # select samples for this batch

        # Reshape out to [batch, seq_out, n_heads, d_head]
        if value.ndim == 4:
            seq_val, n_heads, d_head = value.shape[1], value.shape[2], value.shape[3]
            out = out.view(batch, seq_out, n_heads, d_head)
        elif value.ndim == 3:
            n_heads, d_head = value.shape[1], value.shape[2]
            out = out.view(batch, seq_out, n_heads, d_head)
            seq_val = seq_out
        else:
            raise RuntimeError(f"Unexpected cache shape: {value.shape}")

        if scheme == "mean":
            idx = head_idx
            # Original head output
            orig_head = out[:, :, idx, :].clone()
            # Indices of heads not to ablate at this layer
            not_ablate_idxs = [i for i in range(n_heads) if i != idx]

            # Compute mean vector of other heads
            if not not_ablate_idxs:
                # If only one head, use the mean of the whole layer
                if value.ndim == 4:
                    mean_all = value.mean(dim=2)
                    if seq_out == seq_val:
                        mean_vec = mean_all
                    else:
                        mean_last = mean_all[:, -1, :]
                        mean_vec = mean_last.unsqueeze(1).expand(-1, seq_out, -1)
                else:
                    mean_all = value.mean(dim=1)
                    mean_vec = mean_all.unsqueeze(1).expand(-1, seq_out, -1)
            else:
                if value.ndim == 4:
                    if seq_out == seq_val:
                        mean_vec = value[:, :, not_ablate_idxs, :].mean(dim=2)
                    else:
                        last_heads = value[:, -1, not_ablate_idxs, :]
                        mean_last = last_heads.mean(dim=1)
                        mean_vec = mean_last.unsqueeze(1).expand(-1, seq_out, -1)
                else:
                    small = value[:, not_ablate_idxs, :]
                    mean_tmp = small.mean(dim=1)
                    mean_vec = mean_tmp.unsqueeze(1).expand(-1, seq_out, -1)
            
            # Ensure mean_vec is on the same device
            mean_vec = mean_vec.to(device)
            # Weighted replacement: alpha * original value + (1 - alpha) * mean
            out[:, :, idx, :] = alpha * orig_head + (1.0 - alpha) * mean_vec

        elif scheme == "zero":
            out[:, :, head_idx, :] = 0.0
        else:
            raise ValueError(f"Unsupported ablation scheme: {scheme}")

        # Restore shape to [batch, seq_out, d_model]
        out = out.view(batch, seq_out, -1)
        return (out,) + output[1:] if isinstance(output, tuple) else out

    return hook

def dynamic_counterbalance_head_llava(layer, head_idx, faithful_heads, hallucination_heads,
                                    alpha=0.2, beta=0.9):
    """
    Dynamic faithful-hallucination head counterbalance hook function (compatible with original ablate_head_llava_plus data flow).

    Args:
      layer: layer index
      head_idx: current head index
      faithful_heads: list of faithful head indices
      hallucination_heads: list of hallucination head indices
      alpha: adjustment strength factor
      beta: smoothing coefficient for dynamic factor
    Returns:
      a registerable forward_hook function
    """
    # print(f"Layer {layer}: faithful heads = {faithful_heads}, halluc heads = {hallucination_heads}")
    # print(f"[DDEBUG prune] alpha={alpha}, beta={beta}")
    D_prev = [None]

    def hook(module, input, output):
        # 1) Clone and get [B, seq_out, d_model]
        out = output[0].clone() if isinstance(output, tuple) else output.clone()
        B, seq_out, d_model = out.shape
        device = out.device

        n_heads = getattr(module, 'num_attention_heads', None)
        head_dim = getattr(module, 'head_dim', None)
        if n_heads is None or head_dim is None:
            # fallback: split by d_model (assume evenly split)
            # If the model does not meet even split, use a more precise way
            n_heads = module.num_heads if hasattr(module, 'num_heads') else d_model
            head_dim = d_model // n_heads

        # 3) reshape to [B, seq_out, n_heads, head_dim]
        out_heads = out.view(B, seq_out, n_heads, head_dim)

        # 4) [Key] all calculations are based on out_heads
        # hallucination_vec: [B, seq_out, head_dim]
        hv = out_heads[:, :, hallucination_heads, :].mean(dim=2)
        # faithful_vec: [B, seq_out, head_dim]
        fv = out_heads[:, :, faithful_heads, :].mean(dim=2)

        # 5) Compute D and smooth
        Sh = hv.norm(dim=-1, keepdim=True)            # [B, seq_out, 1]
        Sf = fv.norm(dim=-1, keepdim=True).clamp(min=1e-4)  # [B, seq_out, 1]
        D  = Sh.div(Sf)    
        # -- New: If the saved D_prev.shape and current D.shape are different, reset -- 
        if D_prev[0] is not None and D_prev[0].shape != D.shape:
            D_prev[0] = None                            # [B, seq_out, 1]
        if D_prev[0] is None:
            D_prev[0] = D
        D_smooth = beta * D_prev[0] + (1 - beta) * D
        D_prev[0] = D_smooth.detach()

        if head_idx in faithful_heads:
            out_heads[:, :, head_idx, :] += alpha * D_smooth * fv
        elif head_idx in hallucination_heads:
            out_heads[:, :, head_idx, :] -= alpha * D_smooth * hv
            
        # 7) reshape back to [B, seq_out, d_model] and return
        out_final = out_heads.view(B, seq_out, -1)
        
        return (out_final,) + output[1:] if isinstance(output, tuple) else out_final

    return hook


def generate_and_extract_embeddings(
    model,
    processor,
    batch_samples,
    num_samples,
    device,
    layer_hidden_index=None,
    batch_size=8
):
    """
    For batch_samples (a large list of {"image": PIL.Image, "prompt": str} already loaded in memory), perform:
      - Split it into smaller mini-batches of batch_size.
      - For each mini-batch, repeat num_samples times: in each iteration,
        1) call model.generate(..., do_sample=True, num_return_sequences=1, output_hidden_states=True);
        2) extract the last token's hidden vector from decoder_hidden_states or hidden_states at layer_hidden_index.
      - Concatenate all mini-batches and all sampling rounds into a single tensor of shape (B * num_samples, hidden_size) and return.

    Return tensor shape: (len(batch_samples) * num_samples, hidden_size), located on CPU.
    """

    # model.to(device)
    model.eval()
    all_embs = []

    with torch.no_grad():
        # 1. Outer loop: num_samples sampling rounds
        for _ in range(num_samples):
            # 2. Inner loop: split batch_samples into mini-batches of batch_size
            for i in range(0, len(batch_samples), batch_size):
                mini_batch = batch_samples[i : i + batch_size]
                imgs = [s["image"] for s in mini_batch]
                prms = [s["prompt"] for s in mini_batch]

                # 2.1 Preprocessing: convert this mini-batch of images and prompts into model-acceptable tensors
                inputs = processor(text=prms, images=imgs, return_tensors="pt", padding=True)
                for k, v in inputs.items():
                    if k == "pixel_values":
                        inputs[k] = v.to(device, dtype=torch.float16)
                    else:
                        inputs[k] = v.to(device)

                # 2.2 Call generate, each time only output 1 sequence (num_return_sequences=1)
                output = model.generate(
                    **inputs,
                    do_sample=True,
                    top_p=0.9,
                    top_k=50,
                    num_return_sequences=1,          # Only generate one sequence per call, VRAM usage depends only on batch_size
                    output_hidden_states=True,       # Need hidden_states to extract embeddings
                    return_dict_in_generate=True,
                    output_attentions=False
                )

                # 2.3 Extract hidden_states from output
                #     May be in output.decoder_hidden_states or output.hidden_states
                if hasattr(output, "decoder_hidden_states") and output.decoder_hidden_states is not None:
                    nested_hs = output.decoder_hidden_states
                else:
                    nested_hs = output.hidden_states

                # 2.4 Check if "step × layer" structure or "layer list" structure
                if isinstance(nested_hs[0], tuple):
                    # “step × layer”: each element is a tuple or list of length n_layers
                    last_step_hs = nested_hs[-1]  # Last step: last_step_hs is a list/tuple of length n_layers
                    avail = len(last_step_hs)
                    idx = (avail - 2) if (layer_hidden_index is None) else min(layer_hidden_index, avail - 1)
                    hidden_states = last_step_hs[idx]  # (batch_size, seq_len, hidden_size)
                else:
                    # “layer list”: nested_hs is itself a list of length n_layers
                    avail = len(nested_hs)
                    idx = (avail - 2) if (layer_hidden_index is None) else min(layer_hidden_index, avail - 1)
                    hidden_states = nested_hs[idx]      # (batch_size, seq_len, hidden_size)

                # 2.5 Only take the hidden vector for the "last token"
                #     hidden_states: (batch_size, seq_len, hidden_size)
                last_token = hidden_states[:, -1, :].detach().cpu()  # (batch_size, hidden_size)
                all_embs.append(last_token)

                # 2.6 Free VRAM for this mini-batch
                del inputs, output, nested_hs, hidden_states, last_token
                torch.cuda.empty_cache()

    # 3. Concatenate all mini-batches into (len(batch_samples) * num_samples, hidden_size)
    E = torch.cat(all_embs, dim=0)  # CPU tensor
    return E


def compute_eigenscore(
    model,
    processor,
    batch_samples,
    num_samples,
    device,
    layer_hidden_index=None,
    k_eigen=64,
    epsilon=1e-4,
    batch_size=8
):
    """
    Faithfully reproduce the EigenScore metric from the INSIDE paper, but support batch_size chunking internally.
    Args:
      - batch_samples: list of {"image": PIL.Image, "prompt": str}, large size.
      - num_samples: number of random generations per image (INSIDE uses multiple samples per image).
      - layer_hidden_index: None for the penultimate layer, otherwise min(layer_hidden_index, n_layers-1)
      - k_eigen: number of largest eigenvalues used for the log sum.
      - epsilon: small value added to covariance matrix for stability.
      - batch_size: batch size for memory control.
    Returns:
      - a float value representing the EigenScore of batch_samples.
    Batchified EigenScore calculation:
      1. Call generate_and_extract_embeddings(...) to obtain (B * num_samples, hidden_size) tensor E_gpu (on CPU, dtype might be float16).
      2. Convert to float32, then mean-center, compute covariance, take top-k eigenvalues and sum their logs.
    """
    # 1. Extract hidden vectors (B * num_samples, hidden_size), E_raw_dtype may be float16
    E_half = generate_and_extract_embeddings(
        model=model,
        processor=processor,
        batch_samples=batch_samples,
        num_samples=num_samples,
        device=device,
        layer_hidden_index=layer_hidden_index,
        batch_size=batch_size
    )  # Returns a CPU tensor, dtype likely torch.float16

    # 2. Convert to float32
    E = E_half.float()  # shape (M, H), dtype=torch.float32
    M, H = E.shape

    if M < 2:
        return 0.0

    # 3. Mean centering on CPU
    mean_vec = E.mean(dim=0, keepdim=True)  # (1, H)
    Ec = E - mean_vec                        # (M, H)

    # 4. Compute covariance matrix (H, H)
    cov = Ec.t() @ Ec
    cov = cov / (M - 1)
    # 5. Add epsilon for numerical stability
    cov = cov + epsilon * torch.eye(H, dtype=cov.dtype)

    # 6. Compute eigenvalues of the covariance matrix (on float32)
    eigvals = torch.linalg.eigvalsh(cov)    # (H,)

    # 7. Take top-k largest eigenvalues, clamp to at least epsilon, then sum their logs
    eigvals_sorted, _ = torch.sort(eigvals, descending=True)
    topk = eigvals_sorted[: min(k_eigen, eigvals_sorted.shape[0])]
    topk = torch.clamp(topk, min=epsilon)
    score = torch.log(topk).sum().item()

    return score

def compute_improved_eigenscore(
    model,
    processor,
    batch_samples,
    num_samples,
    device,
    layer_hidden_index=None,
    k_eigen=64,
    epsilon=1e-4,
    batch_size=8,
    alpha=0.3,   # Improvement 1: spectral variance penalty weight
    gamma=1.0     # Improvement 3: spectral entropy weight
):
    """
    Compute improved EigenScore, integrating spectral variance penalty (improvement 1) and spectral entropy (improvement 3), inputs unchanged.
    Returns:
      - improved_score: improved EigenScore (float)
      - components: dict containing log_det, spectral_variance, spectral_entropy, etc.
    """
    # === 1. Extract embedding representations ===
    E_half = generate_and_extract_embeddings(
        model=model,
        processor=processor,
        batch_samples=batch_samples,
        num_samples=num_samples,
        device=device,
        layer_hidden_index=layer_hidden_index,
        batch_size=batch_size
    )

    E = E_half.float()  # (M, H)
    M, H = E.shape
    if M < 2:
        return 0.0, {"log_det": 0.0, "spectral_var": 0.0, "spectral_entropy": 0.0}

    # === 2. Covariance calculation ===
    mean_vec = E.mean(dim=0, keepdim=True)  # (1, H)
    Ec = E - mean_vec                       # (M, H)
    cov = Ec.t() @ Ec / (M - 1)             # (H, H)
    cov = cov + epsilon * torch.eye(H, dtype=cov.dtype, device=cov.device)

    # === 3. Eigenvalue spectrum analysis ===
    eigvals = torch.linalg.eigvalsh(cov)    # (H,), float32
    eigvals_clamped = torch.clamp(eigvals, min=epsilon)

    # === 4. log det (original EigenScore)
    eigvals_topk = eigvals_clamped[: min(k_eigen, H)]
    log_det = torch.log(eigvals_topk).sum()

    # === 5. Improvement 1: spectral variance (on log λ_i)
    log_lambdas = torch.log(eigvals_clamped)
    spectral_var = torch.var(log_lambdas, unbiased=False)

    # === 6. Improvement 3: spectral entropy (p_i = λ_i / ∑λ_j)
    lambda_sum = eigvals_clamped.sum()
    p_i = eigvals_clamped / lambda_sum
    spectral_entropy = -torch.sum(p_i * torch.log(p_i + 1e-8))  # add 1e-8 to avoid log(0)

    # === 7. Weighted combination
    improved_score = log_det - alpha * spectral_var + gamma * spectral_entropy

    return improved_score.item(), {
        "log_det": log_det.item(),
        "spectral_var": spectral_var.item(),
        "spectral_entropy": spectral_entropy.item()
    }

def compute_improved_eigenscore_float(*args, **kwargs):
    score, _ = compute_improved_eigenscore(*args, **kwargs)
    return score

def generate_layer_last_token_acts(
    model,
    processor: AutoProcessor,
    samples: List[Dict],
    device: str,
    batch_size: int,
    layer: int,
    do_sample: bool = True
) -> torch.Tensor:
    """
    Batch call model.generate, only cache the attention output of the last token at the specified layer,
    return shape = (len(samples), n_heads, head_dim) as a GPU Tensor.
    """
    all_last = []
    key = f"llama_layer{layer}_attn"
    for i in range(0, len(samples), batch_size):
        mini = samples[i : i + batch_size]
        imgs = [s["image"] for s in mini]
        prms = [s["prompt"] for s in mini]
        inputs = processor(text=prms, images=imgs, return_tensors="pt", padding=True)
        for k,v in inputs.items():
            inputs[k] = v.to(device, dtype=torch.float16 if k=="pixel_values" else v.dtype)

        cache = {}
        hooks = register_llava_hooks(model, cache)
        with torch.no_grad():
            _ = model.generate(
                **inputs,
                do_sample=do_sample,
                top_p=0.9, top_k=50,
                num_return_sequences=1,
                max_new_tokens=1,
            )
        for h in hooks: h.remove()

        attn = cache[key]  # could be 3D: (B, heads, dim) or 4D: (B, seq, heads, dim)
        if attn.dim() == 4:
            # Take the last token
            last = attn[:, -1, :, :]  # (B, heads, dim)
        else:
            last = attn           # already (B, heads, dim)
        all_last.append(last)
        del inputs, cache, hooks
        torch.cuda.empty_cache()

    # Concatenate all mini-batches of (b, heads, dim) into (N, heads, dim)
    return torch.cat(all_last, dim=0)  # on GPU

def compute_score_from_embeddings(
    E: torch.Tensor,
    k_eigen: int,
    epsilon: float,
    alpha: float,
    gamma: float
) -> float:
    """
    Compute improved EigenScore on GPU and return a Python float.
    Now forcibly cast all tensors to float32 to support linalg.eigvalsh.
    E: (M, D) Tensor, on GPU (might be originally float16).
    """
    # First cast to float32
    E = E.float()
    M, D = E.shape
    if M < 2:
        return 0.0

    # 1. Mean centering
    mean = E.mean(dim=0, keepdim=True)
    Ec = E - mean

    # 2. Covariance (float32)
    cov = (Ec.t() @ Ec) / (M - 1)
    cov = cov + epsilon * torch.eye(D, device=E.device, dtype=torch.float32)

    # 3. Eigenvalue spectrum
    eigvals = torch.linalg.eigvalsh(cov)  # now on float32
    eigvals = torch.clamp(eigvals, min=epsilon)

    # 4. log-det
    topk = eigvals[: min(k_eigen, D)]
    log_det = torch.log(topk).sum()

    # 5. Spectral variance
    log_lams = torch.log(eigvals)
    spectral_var = torch.var(log_lams, unbiased=False)

    # 6. Spectral entropy
    p = eigvals / eigvals.sum()
    spectral_entropy = -torch.sum(p * torch.log(p + 1e-8))

    # 7. Combination
    improved = log_det - alpha * spectral_var + gamma * spectral_entropy
    return improved.item()



def compute_head_consistency(
    model,
    processor,
    samples: List[Dict],
    device: str,
    layer: int,
    num_samples: int = 5,
    batch_size: int = 8,
    k_eigen: int = 64,
    epsilon: float = 1e-4,
    alpha: float = 0.25,
    gamma: float = 1.0,
) -> Dict[int, float]:
    """
    Perform num_samples rounds of sampling, extract the last token's attention activations at this layer in batches each round,
    and finally compute the average Improved EigenScore for each head, returning head->score.
    """
    # 1. Multiple sampling rounds to get list of (N, heads, dim)
    all_acts = []
    for _ in range(num_samples):
        acts = generate_layer_last_token_acts(
            model, processor, samples,
            device=device, batch_size=batch_size,
            layer=layer, do_sample=True
        )
        all_acts.append(acts)

    # 2. Concatenate into (num_samples * N, heads, dim)
    combined = torch.cat(all_acts, dim=0).to(device)

    # 3. Compute score for each head
    _, _, n_heads, head_dim = combined.unsqueeze(1).shape  # workaround to get heads,dim
    head2score = {}
    for h in range(combined.shape[1]):
        # shape = (num_samples*N, head_dim)
        E = combined[:, h, :]
        score = compute_score_from_embeddings(
            E, k_eigen=k_eigen, epsilon=epsilon, alpha=alpha, gamma=gamma
        )
        head2score[h] = score

    return head2score


def simple_auto_circuit_llava(
    model,
    processor,
    val_samples,
    val_cache,
    ablation_scheme="mean",
    device="cuda",
    include_mlps=False,
    num_samples=5,
    layer_hidden_index=None,
    target_layers=None,
    batch_size=8,
    k_eigen=64,
    epsilon=1e-4,
):
    """
    Perform layer-wise and head-wise (optionally also MLP) ablation experiments on the LLaVA model, returning ΔEigenScore for each (layer, head) or (layer, "mlp").
    This implementation will chunk val_samples and val_cache by batch_size to control peak VRAM usage.

    Args:
      - model: LlavaForConditionalGeneration model already loaded to device.
      - processor: corresponding AutoProcessor, to convert PIL.Image + prompt to model input.
      - val_samples: list of {"image": PIL.Image, "prompt": str, ...}, the sample list.
      - val_cache: dict, key is f"llama_layer{L}_attn" or f"llama_layer{L}_mlp", value is a CPU tensor, shape:
            - "llama_layer{L}_attn": Tensor of shape (N, seq_or_1, n_heads, d_head)
            - "llama_layer{L}_mlp": Tensor of shape (N, seq_or_1, d_model)
      - ablation_scheme: "mean" or "zero". If "mean", ablate using patching_cache's own mean; if "zero", ablate by zeroing.
      - device: e.g., "cuda".
      - include_mlps: whether to ablate MLP layers as well.
      - num_samples: number of generations per image for compute_eigenscore.
      - layer_hidden_index: used for compute_eigenscore, None for penultimate layer.
      - target_layers: list of layers to ablate, None defaults to range(5, n_layers-4).
      - batch_size: mini-batch size for splitting val_samples.
      - k_eigen: number of largest eigenvalues to sum for covariance.
      - epsilon: small value added to covariance for numerical stability.
    Returns:
      - result_tuples: list of (layer, head_idx, delta) and optionally (layer, "mlp", delta_mlp).
    """

    # model.to(device)
    model.eval()

    n_layers = model.config.text_config.num_hidden_layers
    n_heads = model.config.text_config.num_attention_heads

    if layer_hidden_index is None:
        layer_hidden_index = n_layers - 2
    if target_layers is None:
        # By default, skip the first 5 and last 4 layers
        target_layers = list(range(5, n_layers - 4))

    N = len(val_samples)  # Total number of samples

    # 1. First, compute full baseline EigenScore
    with torch.no_grad():
        score_base = compute_improved_eigenscore_float(
            model=model,
            processor=processor,
            batch_samples=val_samples,
            num_samples=num_samples,
            device=device,
            layer_hidden_index=layer_hidden_index,
            k_eigen=k_eigen,
            epsilon=epsilon,
            batch_size=batch_size,
        )
    logging.info(f"[DEBUG] Full Baseline EigenScore = {score_base:.6f}")

    result_tuples = []

    # 2. For each layer and each head, perform ablation
    for layer in target_layers:
        for head_idx in range(n_heads):
            total_delta = 0.0
            cnt_splits = 0

            # Chunk val_samples/val_cache into mini-batches by batch_size
            for start in range(0, N, batch_size):
                end = min(start + batch_size, N)
                batch_samples = val_samples[start:end]
                # Slice this layer's attention cache for this batch from val_cache
                cache_slice = {
                    f"llama_layer{layer}_attn": val_cache[f"llama_layer{layer}_attn"][start:end].to(device)
                }

                # Register hook to replace current layer/head's attention output
                h = model.model.language_model.layers[layer].self_attn.register_forward_hook(
                    ablate_head_llava(layer, head_idx, ablation_scheme, cache_slice)
                )

                # Only compute EigenScore on this mini-batch
                with torch.no_grad():
                    score_ablate_batch = compute_improved_eigenscore_float(
                        model=model,
                        processor=processor,
                        batch_samples=batch_samples,
                        num_samples=num_samples,
                        device=device,
                        layer_hidden_index=layer_hidden_index,
                        k_eigen=k_eigen,
                        epsilon=epsilon,
                        batch_size=batch_size,
                    )

                h.remove()

                delta_batch = score_ablate_batch - score_base
                total_delta += delta_batch
                cnt_splits += 1

            # Take the average Δ across mini-batches (or simply accumulate, depending on need)
            delta = total_delta / cnt_splits if cnt_splits > 0 else 0.0
            result_tuples.append((layer, head_idx, delta))
            logging.info(f"[RESULT] Layer {layer} Head {head_idx} ΔEigenScore = {delta:.6f}")

        # If MLP ablation is required for this layer
        if include_mlps:
            total_delta_mlp = 0.0
            cnt_splits_mlp = 0

            for start in range(0, N, batch_size):
                end = min(start + batch_size, N)
                batch_samples = val_samples[start:end]
                cache_slice_mlp = {
                    f"llama_layer{layer}_mlp": val_cache[f"llama_layer{layer}_mlp"][start:end].to(device)
                }

                # Register hook for MLP ablation
                h_mlp = model.model.language_model.layers[layer].mlp.register_forward_hook(
                    ablate_mlp_llava(layer, ablation_scheme, cache_slice_mlp)
                )

                with torch.no_grad():
                    score_ablate_mlp_batch = compute_improved_eigenscore_float(
                        model=model,
                        processor=processor,
                        batch_samples=batch_samples,
                        num_samples=num_samples,
                        device=device,
                        layer_hidden_index=layer_hidden_index,
                        k_eigen=k_eigen,
                        epsilon=epsilon,
                        batch_size=batch_size,
                    )

                h_mlp.remove()

                delta_mlp_batch = score_ablate_mlp_batch - score_base
                total_delta_mlp += delta_mlp_batch
                cnt_splits_mlp += 1

            delta_mlp = total_delta_mlp / cnt_splits_mlp if cnt_splits_mlp > 0 else 0.0
            result_tuples.append((layer, "mlp", delta_mlp))
            logging.info(f"[RESULT] Layer {layer} MLP ΔEigenScore = {delta_mlp:.6f}")

    return result_tuples