import argparse
import json
import os
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from datasets import Dataset
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma3Config


def detect_dataset_type(ds: Dataset) -> str:
    """
    Detect dataset type by checking for key fields.

    Returns:
        "progressive" if embedding and stage_index exist
        "prefix_tuning" if prefix_embedding exists
        "unknown" otherwise
    """
    if len(ds) == 0:
        return "unknown"
    # Check first record for field presence
    first_record = ds[0]
    if "prefix_embedding" in first_record:
        return "prefix_tuning"
    elif "embedding" in first_record and "stage_index" in first_record:
        return "progressive"
    return "unknown"


def load_dataset(dataset_path: str) -> Tuple[Dataset, str]:
    """
    Load dataset and detect its type.

    Returns:
        Tuple of (dataset, dataset_type) where dataset_type is "progressive" or "prefix_tuning"
    """
    ds = Dataset.load_from_disk(dataset_path)
    dataset_type = detect_dataset_type(ds)
    return ds, dataset_type


def load_progressive_dataset(dataset_path: str) -> Dataset:
    """Load progressive checkpoint dataset (deprecated, use load_dataset instead)."""
    ds, _ = load_dataset(dataset_path)
    return ds


def filter_records(
    ds: Dataset,
    sample_id: Optional[int] = None,
    stage_index: Optional[int] = None,
    dataset_type: str = "progressive",
) -> List[Dict[str, Any]]:
    """
    Filter records by sample_id and/or stage_index.

    Args:
        ds: Dataset to filter
        sample_id: Optional sample_id filter
        stage_index: Optional stage_index filter (ignored for prefix_tuning datasets)
        dataset_type: Type of dataset ("progressive" or "prefix_tuning")
    """
    # Remove columns that may not exist in all dataset types
    columns_to_remove = []
    for col in ["orig_embedding", "initialization_embedding", "initialization_prefix_embedding"]:
        if col in ds.column_names:
            columns_to_remove.append(col)
    if columns_to_remove:
        ds = ds.remove_columns(columns_to_remove)

    rows: List[Dict[str, Any]] = []
    for i in tqdm(range(len(ds)), desc="Filtering records"):
        r = ds[i]
        if sample_id is not None and int(r.get("sample_id", -1)) != int(sample_id):
            continue
        # Only filter by stage_index for progressive datasets
        if dataset_type == "progressive" and stage_index is not None and int(r.get("stage_index", -1)) != int(stage_index):
            continue
        rows.append(r)
    return rows


def collate_stages_by_sample(
    rows: List[Dict[str, Any]],
    dataset_type: str = "progressive",
) -> Dict[int, List[Dict[str, Any]]]:
    """
    Group rows by sample_id and sort by stage_index (for progressive) or keep single entry (for prefix_tuning).

    For prefix_tuning datasets, each sample has only one entry, so we create a single-item list.
    """
    by_sid: Dict[int, List[Dict[str, Any]]] = {}
    for r in rows:
        sid = int(r.get("sample_id", -1))
        if sid not in by_sid:
            by_sid[sid] = []
        by_sid[sid].append(r)
    # For progressive datasets, sort by stage_index
    # For prefix_tuning, each sample should have only one entry (no stages)
    if dataset_type == "progressive":
        for sid in by_sid:
            by_sid[sid].sort(key=lambda x: int(x.get("stage_index", 0)))
    return by_sid


def extract_attention_mass_for_all_seq_lengths(
    attentions: tuple,
    num_compression_tokens: int,
    target_seq_lengths: List[int],
    block_size: int = 16,
    block_threshold: int = 0,
) -> Dict[int, Dict[int, float]]:
    """
    Extract attention mass percent from full attention map for all target sequence lengths at once.

    For very long sequences (target_seq_len > block_threshold), attention mass is averaged in
    non-overlapping blocks over sequence length with block size = block_size. The returned dict will
    contain one entry per block (keyed by the block's max target_seq_len). For
    target_seq_len <= block_threshold, values are returned for each individual length.

    Args:
        attentions: Tuple of attention tensors, one per layer [batch_size, num_heads, seq_len, seq_len]
        num_compression_tokens: Number of compression tokens
        target_seq_lengths: List of target sequence lengths (input tokens, excluding compression tokens)

    Returns:
        Dictionary mapping target_seq_len to layer_index to attention_mass_percent
    """
    if not attentions:
        return {}
    if num_compression_tokens < 1:
        raise ValueError("num_compression_tokens must be >= 1")
    if block_size < 1:
        raise ValueError("block_size must be >= 1")
    if block_threshold < 0:
        raise ValueError("block_threshold must be >= 0")

    num_layers = len(attentions)
    total_seq_len = attentions[0].shape[-1]
    if num_compression_tokens > total_seq_len:
        raise ValueError(f"num_compression_tokens ({num_compression_tokens}) exceeds total_seq_len ({total_seq_len})")

    # Compute per-layer compression attention per query position without materializing
    # the full [layers, seq_len, seq_len] tensor.
    #
    # For each layer, we compute:
    #   comp_attn[q] = sum_k attn(q -> k) for k in [0, num_compression_tokens)
    # resulting in [seq_len] per layer.
    compression_attention_per_layer = torch.stack(
        [attn_layer.mean(dim=1)[0, :, :num_compression_tokens].sum(dim=-1) for attn_layer in attentions],
        dim=0,
    )  # [num_layers, seq_len]

    # Prefix means let us answer all effective lengths in O(1) each.
    # prefix_mean[:, e] = mean(compression_attention_per_layer[:, : e + 1], dim=-1)
    prefix_sums = compression_attention_per_layer.cumsum(dim=-1).cpu()  # [num_layers, seq_len]

    results: Dict[int, Dict[int, float]] = {}
    # Build processing items: individual lengths up to block_threshold, then block_size-length blocks thereafter.
    # Blocks are keyed by their max target_seq_len to keep x-axis monotonic for plotting.
    lengths = sorted(set(int(x) for x in target_seq_lengths))
    block_start_min = block_threshold + 1

    blocks: Dict[int, List[int]] = {}
    items: List[Dict[str, Any]] = []
    for t in lengths:
        if t <= block_threshold:
            items.append({"kind": "single", "lengths": [t], "key": t})
            continue
        block_start = block_start_min + ((t - block_start_min) // block_size) * block_size
        blocks.setdefault(block_start, []).append(t)

    for block_start in sorted(blocks.keys()):
        members = blocks[block_start]
        items.append({"kind": "block", "lengths": members, "key": max(members)})

    for item in tqdm(items, desc="Save results attention mass all seq lengths"):
        members = item["lengths"]
        effective_lengths = [num_compression_tokens + t for t in members]
        # Filter to valid lengths: effective_lengths are 1-indexed (total sequence length including prefix)
        # indices are 0-indexed, so max valid effective_length is total_seq_len, max index is total_seq_len - 1
        # But we need to be careful: if e = total_seq_len, then idx = total_seq_len - 1 is valid
        # So we allow e in [1, total_seq_len] which gives idx in [0, total_seq_len - 1]
        max_valid_idx = prefix_sums.shape[1] - 1  # 0-indexed max index
        effective_lengths = [e for e in effective_lengths if 1 <= e <= max_valid_idx + 1]
        if not effective_lengths:
            continue

        # Compute per-length means: prefix_sums[:, e-1] / e, then average across lengths.
        # e is 1-indexed effective length, so idx = e - 1 is 0-indexed
        # Ensure all indices are valid: idx in [0, max_valid_idx]
        idx_list = []
        valid_effective_lengths = []
        for e in effective_lengths:
            idx_val = e - 1
            if 0 <= idx_val <= max_valid_idx:
                idx_list.append(idx_val)
                valid_effective_lengths.append(e)

        if not idx_list:
            continue

        idx = torch.tensor(idx_list, dtype=torch.long)
        denom = torch.tensor(valid_effective_lengths, dtype=prefix_sums.dtype).unsqueeze(0)
        per_length_means = prefix_sums[:, idx] / denom  # [num_layers, num_lengths]
        mean_comp_attn = per_length_means.mean(dim=1)  # [num_layers]

        layer_vals = (mean_comp_attn * 100.0).tolist()
        results[int(item["key"])] = {layer_idx: layer_vals[layer_idx] for layer_idx in range(num_layers)}

    return results


def compute_attention_mass_for_stages(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    stages: List[Dict[str, Any]],
    device: torch.device,
    attention_block_size: int = 16,
    target_seq_lengths_override: Optional[List[int]] = None,
    dataset_type: str = "progressive",
) -> Tuple[Dict[int, Dict[int, float]], Optional[tuple], Optional[str], Optional[int]]:
    """
    Compute attention mass percent for all stages.

    Uses the longest sequence to compute attention once, then extracts attention mass
    for each stage from the full attention map (due to causal attention mask).

    For prefix_tuning datasets, computes attention across different sequence lengths
    using the same prefix embedding.

    Args:
        model: Language model
        tokenizer: Tokenizer
        stages: List of stage records, sorted by stage_index (progressive) or single entry (prefix_tuning)
        device: Device to run on
        dataset_type: Type of dataset ("progressive" or "prefix_tuning")

    Returns:
        Tuple of (results_dict, attentions, text, num_compression_tokens) where:
        - results_dict: Dictionary mapping seq_len to layer_index to attention_mass_percent
        - attentions: Tuple of attention tensors or None
        - text: Input text string or None
        - num_compression_tokens: Number of compression tokens or None
    """
    if not stages:
        return {}, None, None, None

    # Get the stage record (for prefix_tuning, there's only one; for progressive, use the longest)
    if dataset_type == "prefix_tuning":
        stage_record = stages[0]
        # For prefix tuning, extract sequence length from tokenized text
        text = stage_record.get("text", "")
        if not isinstance(text, str) or text.strip() == "":
            return {}, None, None, None
        # Tokenize to get actual sequence length
        enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
        max_seq_len = enc["input_ids"].shape[1]
        # Extract prefix embedding
        embedding = stage_record.get("prefix_embedding")
        if embedding is None:
            return {}, None, None, None
        # Get number of virtual tokens
        num_compression_tokens = int(stage_record.get("num_virtual_tokens", 1))
    else:
        # Progressive: find the longest sequence
        longest_stage = max(stages, key=lambda s: int(s.get("stage_seq_len", 0)))
        max_seq_len = int(longest_stage.get("stage_seq_len", -1))
        if max_seq_len < 1:
            return {}, None, None, None
        # Extract compression embeddings
        embedding = longest_stage.get("embedding")
        if embedding is None:
            return {}, None, None, None
        # Get number of compression tokens
        num_compression_tokens = int(longest_stage.get("num_compression_tokens", 1))
        # Get text
        text = longest_stage.get("text", "")
        if not isinstance(text, str) or text.strip() == "":
            return {}, None, None, None

    # Convert to tensor
    if isinstance(embedding, list):
        compression_embeddings = torch.tensor(embedding, dtype=torch.float32)
    else:
        compression_embeddings = torch.tensor(embedding, dtype=torch.float32)

    # Get model's hidden size for validation
    if isinstance(model.config, Gemma3Config):
        model_hidden_size = model.config.text_config.hidden_size
    else:
        model_hidden_size = model.config.hidden_size

    # For prefix tuning, embeddings are stored as PEFT module state and may need special handling
    if dataset_type == "prefix_tuning":
        # Prefix tuning embeddings are stored as PEFT parameters which may have different shapes
        # We need to use PEFT to properly convert them to the format needed for attention computation
        try:
            from peft import PrefixTuningConfig, TaskType, get_peft_model
        except ImportError:
            raise ImportError("peft is required for prefix tuning visualization. Install it (e.g. `uv add peft`).")

        # Create a PEFT model to properly handle prefix embeddings
        peft_config = PrefixTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            num_virtual_tokens=num_compression_tokens,
        )
        peft_model = get_peft_model(model, peft_config).to(device)

        # Find the prefix embedding parameter in the PEFT model
        prefix_param_name = None
        for name, param in peft_model.named_parameters():
            if param.requires_grad and param.ndim == 2 and param.shape[0] == num_compression_tokens:
                prefix_param_name = name
                break

        if prefix_param_name is None:
            raise ValueError("Could not find prefix embedding parameter in PEFT model.")

        # Load the saved prefix embedding into the PEFT model parameter
        original_shape = compression_embeddings.shape
        print(f"Original prefix embedding shape: {original_shape}, num_virtual_tokens: {num_compression_tokens}")

        # Reshape to match PEFT parameter shape (which may be different from model hidden_size)
        target_param = dict(peft_model.named_parameters())[prefix_param_name]
        target_shape = target_param.shape
        print(f"PEFT parameter shape: {target_shape}")

        # Reshape the embedding to match the PEFT parameter shape
        if compression_embeddings.shape != target_shape:
            total_elements = compression_embeddings.numel()
            if total_elements == target_param.numel():
                compression_embeddings = compression_embeddings.reshape(target_shape)
                print(f"Reshaped prefix embedding from {original_shape} to {target_shape}")
            else:
                raise ValueError(
                    f"Cannot reshape prefix embedding from {original_shape} (total: {total_elements}) "
                    f"to PEFT parameter shape {target_shape} (total: {target_param.numel()}). "
                    f"Element counts don't match."
                )

        # Set the PEFT parameter
        with torch.no_grad():
            target_param.data = compression_embeddings.to(device).to(target_param.dtype)

        # Now we'll use the PEFT model for forward pass, which will properly handle the prefix embeddings
        use_peft_model = True
    else:
        # Progressive: reshape compression embeddings to correct shape: [num_compression_tokens, hidden_size]
        # Handle different possible shapes the embedding might be stored in
        original_shape = compression_embeddings.shape
        print(
            f"Original embedding shape: {original_shape}, model hidden_size: {model_hidden_size}, num_compression_tokens: {num_compression_tokens}"
        )
        use_peft_model = False

        # Progressive training: handle shape reshaping
        if len(original_shape) == 1:
            # Flattened: reshape to [num_compression_tokens, hidden_size]
            total_elements = compression_embeddings.numel()
            if total_elements % model_hidden_size == 0:
                num_tokens_from_shape = total_elements // model_hidden_size
                compression_embeddings = compression_embeddings.reshape(num_tokens_from_shape, model_hidden_size)
                if num_tokens_from_shape != num_compression_tokens:
                    print(
                        f"Warning: Reshaped embedding from {original_shape} to [{num_tokens_from_shape}, {model_hidden_size}], but expected {num_compression_tokens} tokens. Using {num_tokens_from_shape}."
                    )
                    num_compression_tokens = num_tokens_from_shape
            else:
                raise ValueError(
                    f"Cannot reshape embedding from shape {original_shape} (total elements: {total_elements}) "
                    f"to [num_tokens, hidden_size={model_hidden_size}]. Total elements must be divisible by hidden_size."
                )
        elif len(original_shape) == 2:
            # Should be [num_compression_tokens, hidden_size] or [hidden_size, num_compression_tokens]
            if original_shape[1] == model_hidden_size:
                # Already correct shape [num_compression_tokens, hidden_size]
                if original_shape[0] != num_compression_tokens:
                    print(
                        f"Warning: Embedding has {original_shape[0]} tokens but expected {num_compression_tokens}. Using {original_shape[0]}."
                    )
                    num_compression_tokens = original_shape[0]
            elif original_shape[0] == model_hidden_size:
                # Transpose if it's [hidden_size, num_compression_tokens]
                compression_embeddings = compression_embeddings.transpose(0, 1)
                if compression_embeddings.shape[0] != num_compression_tokens:
                    print(
                        f"Warning: After transpose, embedding has {compression_embeddings.shape[0]} tokens but expected {num_compression_tokens}. Using {compression_embeddings.shape[0]}."
                    )
                    num_compression_tokens = compression_embeddings.shape[0]
            elif original_shape[0] == num_compression_tokens and original_shape[1] != model_hidden_size:
                # [num_compression_tokens, wrong_hidden_size] - model mismatch
                raise ValueError(
                    f"Model hidden size mismatch: embedding has hidden_size={original_shape[1]}, "
                    f"but model has hidden_size={model_hidden_size}. "
                    f"This embedding was created with a different model. "
                    f"Please use the same model checkpoint that was used to create the embeddings."
                )
            elif original_shape[1] == num_compression_tokens and original_shape[0] != model_hidden_size:
                # [wrong_hidden_size, num_compression_tokens] - transpose and check
                compression_embeddings = compression_embeddings.transpose(0, 1)
                if compression_embeddings.shape[1] != model_hidden_size:
                    raise ValueError(
                        f"Model hidden size mismatch: embedding has hidden_size={original_shape[0]}, "
                        f"but model has hidden_size={model_hidden_size}. "
                        f"This embedding was created with a different model. "
                        f"Please use the same model checkpoint that was used to create the embeddings."
                    )
            elif original_shape[0] * original_shape[1] == num_compression_tokens * model_hidden_size:
                # Reshape if dimensions are swapped but total elements match
                compression_embeddings = compression_embeddings.reshape(num_compression_tokens, model_hidden_size)
            else:
                # Try to reshape based on total elements
                total_elements = compression_embeddings.numel()
                if total_elements % model_hidden_size == 0:
                    num_tokens_from_shape = total_elements // model_hidden_size
                    compression_embeddings = compression_embeddings.reshape(num_tokens_from_shape, model_hidden_size)
                    if num_tokens_from_shape != num_compression_tokens:
                        print(
                            f"Warning: Reshaped embedding from {original_shape} to [{num_tokens_from_shape}, {model_hidden_size}], but expected {num_compression_tokens} tokens. Using {num_tokens_from_shape}."
                        )
                        num_compression_tokens = num_tokens_from_shape
                else:
                    raise ValueError(
                        f"Cannot reshape embedding from shape {original_shape} to match model hidden_size={model_hidden_size}. "
                        f"Expected shape [num_compression_tokens={num_compression_tokens}, hidden_size={model_hidden_size}] "
                        f"or total elements divisible by {model_hidden_size}. "
                        f"This may indicate the embedding was created with a different model."
                    )
        elif len(original_shape) == 3:
            # Remove batch dimension if present: [1, num_tokens, hidden_size] or [1, hidden_size, num_tokens]
            if original_shape[0] == 1:
                compression_embeddings = compression_embeddings.squeeze(0)
                # Recursively handle the 2D case
                if compression_embeddings.shape[1] == model_hidden_size:
                    pass  # Already correct
                elif compression_embeddings.shape[0] == model_hidden_size:
                    compression_embeddings = compression_embeddings.transpose(0, 1)
                else:
                    total_elements = compression_embeddings.numel()
                    if total_elements % model_hidden_size == 0:
                        num_tokens_from_shape = total_elements // model_hidden_size
                        compression_embeddings = compression_embeddings.reshape(num_tokens_from_shape, model_hidden_size)
                        if num_tokens_from_shape != num_compression_tokens:
                            print(
                                f"Warning: Reshaped embedding from {original_shape} to [{num_tokens_from_shape}, {model_hidden_size}], but expected {num_compression_tokens} tokens. Using {num_tokens_from_shape}."
                            )
                            num_compression_tokens = num_tokens_from_shape
                    else:
                        raise ValueError(
                            f"Cannot reshape embedding from shape {original_shape} to match model hidden_size={model_hidden_size}."
                        )
            else:
                raise ValueError(f"Unexpected 3D embedding shape: {original_shape}. Expected [1, num_tokens, hidden_size].")
        else:
            raise ValueError(f"Unexpected embedding shape: {original_shape}. Expected 1D, 2D, or 3D tensor.")

        # Final validation for progressive: ensure shape is [num_compression_tokens, hidden_size]
        if compression_embeddings.shape != (num_compression_tokens, model_hidden_size):
            # Check if this might be a model mismatch
            if compression_embeddings.shape[1] != model_hidden_size:
                raise ValueError(
                    f"Embedding hidden size mismatch: embedding has hidden_size={compression_embeddings.shape[1]}, "
                    f"but model has hidden_size={model_hidden_size}. "
                    f"This suggests the embedding was created with a different model. "
                    f"Please use the same model checkpoint that was used to create the embeddings. "
                    f"Embedding shape: {compression_embeddings.shape}, expected: [{num_compression_tokens}, {model_hidden_size}], "
                    f"original shape: {original_shape}."
                )
            else:
                raise ValueError(
                    f"After reshaping, embedding shape is {compression_embeddings.shape}, "
                    f"but expected [{num_compression_tokens}, {model_hidden_size}]. "
                    f"Original shape was {original_shape}."
                )

    # Compute attention once for the longest sequence
    print(f"Computing attention for longest sequence (length={max_seq_len})...")
    if use_peft_model:
        print(f"Using PEFT model for prefix tuning with {num_compression_tokens} virtual tokens")
    else:
        print(f"Compression embeddings shape: {compression_embeddings.shape}, num_compression_tokens: {num_compression_tokens}")

    model_to_use = peft_model if use_peft_model else model
    model_to_use.eval()
    with torch.no_grad():
        # Tokenize text
        enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)

        if use_peft_model:
            # For PEFT, we can use input_ids directly - PEFT will handle prefix embeddings internally
            # Forward pass with attention outputs
            outputs = model_to_use(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True,
            )
        else:
            # Progressive: manually concatenate compression embeddings
            # Get input embeddings
            input_embeddings_layer = model.get_input_embeddings()
            input_text_embeds = input_embeddings_layer(input_ids)

            # Concatenate compression embeddings with input text embeddings
            compression_embeddings = compression_embeddings.to(device).to(input_text_embeds.dtype)
            # Add batch dimension: [1, num_compression_tokens, hidden_size]
            compression_embeddings = compression_embeddings.unsqueeze(0)
            input_embeds = torch.cat([compression_embeddings, input_text_embeds], dim=1)

            # Extend attention mask to include compression tokens
            comp_attention = torch.ones(
                (attention_mask.shape[0], num_compression_tokens), device=device, dtype=attention_mask.dtype
            )
            extended_attention_mask = torch.cat([comp_attention, attention_mask], dim=1)

            # Forward pass with attention outputs
            outputs = model_to_use(
                inputs_embeds=input_embeds,
                attention_mask=extended_attention_mask,
                output_attentions=True,
            )

        # Extract attention weights
        attentions = outputs.attentions

        # Check if attentions are available
        if attentions is None:
            raise ValueError(
                "Attention weights are None. The model may not support output_attentions. "
                "Try setting model.set_attn_implementation('eager') before loading."
            )

    if target_seq_lengths_override is not None:
        target_seq_lengths = list(target_seq_lengths_override)
    else:
        if dataset_type == "prefix_tuning":
            # For prefix tuning, use sequence lengths from 1 to max_seq_len
            target_seq_lengths = list(range(1, max_seq_len + 1))
        else:
            # Get all unique sequence lengths from stages
            target_seq_lengths = sorted(
                set(int(s.get("stage_seq_len", -1)) for s in stages if int(s.get("stage_seq_len", -1)) > 0)
            )

    # Extract attention mass for all sequence lengths at once
    print(f"Extracting attention mass for {len(target_seq_lengths)} sequence lengths...")
    results = extract_attention_mass_for_all_seq_lengths(
        attentions=attentions,
        num_compression_tokens=num_compression_tokens,
        target_seq_lengths=target_seq_lengths,
        block_size=attention_block_size,
    )

    return results, attentions, text, num_compression_tokens


def plot_attention_hijacking_heatmap(
    results: Dict[int, Dict[int, float]],
    sample_id: Optional[int],
    output_path: str,
    title_prefix: str = "Attention Hijacking: Compression Token Attention Mass %",
    colorbar_label: str = "Attention Mass % on Compression Tokens",
):
    """
    Plot heatmap of attention mass percent vs sequence length vs layer.

    Args:
        results: Dictionary mapping stage_seq_len to layer_index to attention_mass_percent
        sample_id: Optional sample ID for title
        output_path: Path to save the plot
        title_prefix: Prefix for the plot title
        colorbar_label: Label for the colorbar
    """
    if not results:
        print("No results to plot")
        return

    # Collect all sequence lengths and layer indices
    seq_lengths = sorted(results.keys())
    all_layer_indices = set()
    for seq_len_data in results.values():
        all_layer_indices.update(seq_len_data.keys())
    layer_indices = sorted(all_layer_indices)

    # Build heatmap matrix: rows = layers, cols = sequence lengths
    heatmap_data = np.zeros((len(layer_indices), len(seq_lengths)))

    for col_idx, seq_len in enumerate(seq_lengths):
        layer_data = results[seq_len]
        for row_idx, layer_idx in enumerate(layer_indices):
            if layer_idx in layer_data:
                heatmap_data[row_idx, col_idx] = layer_data[layer_idx]

    # Create heatmap
    plt.figure(figsize=(max(8, len(seq_lengths) * 0.8), max(6, len(layer_indices) * 0.5)))
    ax = sns.heatmap(
        heatmap_data,
        xticklabels=seq_lengths,
        yticklabels=[f"Layer {idx}" for idx in layer_indices],
        cmap="viridis",
        annot=True,
        fmt=".1f",
        cbar_kws={"label": colorbar_label},
        vmin=0,
        vmax=100,
    )
    # Flip layers axis: lower layer indices appear at the bottom.
    ax.invert_yaxis()
    plt.xlabel("Sequence Length", fontsize=12)
    plt.ylabel("Layer", fontsize=12)
    title = title_prefix
    if sample_id is not None:
        title += f" (Sample {sample_id})"
    plt.title(title, fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved heatmap to: {output_path}")


def plot_attention_weights_heatmap(
    attentions: tuple,
    num_compression_tokens: int,
    tokenizer: AutoTokenizer,
    text: str,
    sample_id: Optional[int],
    output_path: str,
    layer_idx: Optional[int] = None,
    head_idx: Optional[int] = None,
):
    """
    Plot full attention weights heatmap (seq_len x seq_len) for a specific layer and optionally head.

    Args:
        attentions: Tuple of attention tensors, one per layer [batch_size, num_heads, seq_len, seq_len]
        num_compression_tokens: Number of compression tokens
        tokenizer: Tokenizer for token labels
        text: Input text for token labels
        sample_id: Optional sample ID for title
        output_path: Path to save the plot
        layer_idx: Layer index to plot (if None, averages across all layers)
        head_idx: Head index to plot (if None, averages across all heads)
    """
    if not attentions:
        print("No attention weights to plot")
        return

    num_layers = len(attentions)
    batch_size, num_heads, seq_len, _ = attentions[0].shape

    # Validate: if head_idx is specified, layer_idx must also be specified
    if head_idx is not None and layer_idx is None:
        raise ValueError("head_idx requires layer_idx to be specified")

    # Select layer and head
    if layer_idx is None:
        # Average across all layers
        attention_weights = torch.stack([attn.mean(dim=1) for attn in attentions], dim=0).mean(dim=0)
        layer_label = "All Layers (avg)"
        head_label = " (avg across heads)"
    else:
        if layer_idx < 0 or layer_idx >= num_layers:
            raise ValueError(f"layer_idx {layer_idx} out of range [0, {num_layers})")
        if head_idx is not None:
            # Specific head in specific layer
            if head_idx < 0 or head_idx >= num_heads:
                raise ValueError(f"head_idx {head_idx} out of range [0, {num_heads})")
            attention_weights = attentions[layer_idx][0, head_idx, :, :]
            layer_label = f"Layer {layer_idx}"
            head_label = f", Head {head_idx}"
        else:
            # Average across heads in specific layer
            attention_weights = attentions[layer_idx].mean(dim=1)  # Average across heads
            layer_label = f"Layer {layer_idx}"
            head_label = " (avg across heads)"

    # Convert to numpy
    if attention_weights.ndim == 3:
        # Has batch dimension: [batch_size, seq_len, seq_len]
        attention_matrix = attention_weights[0].cpu().numpy()
    else:
        # Already 2D: [seq_len, seq_len]
        attention_matrix = attention_weights.cpu().numpy()

    # Get token labels
    enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt", add_special_tokens=False)
    token_ids = enc["input_ids"][0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(token_ids)

    # Create labels: compression tokens + text tokens
    compression_labels = [f"Comp_{i}" for i in range(num_compression_tokens)]
    token_labels = compression_labels + tokens

    # Ensure labels match attention matrix dimensions
    # The attention matrix includes compression tokens, so seq_len = num_compression_tokens + text_tokens
    expected_text_tokens = seq_len - num_compression_tokens
    if len(tokens) != expected_text_tokens:
        # Truncate or pad tokens to match
        if len(tokens) > expected_text_tokens:
            tokens = tokens[:expected_text_tokens]
        else:
            tokens = tokens + ["<pad>"] * (expected_text_tokens - len(tokens))
        token_labels = compression_labels + tokens

    # Truncate labels if sequence is too long (for readability)
    max_display_tokens = 50
    if len(token_labels) > max_display_tokens:
        # Show first few and last few tokens
        display_labels = token_labels[: max_display_tokens // 2] + ["..."] + token_labels[-max_display_tokens // 2 :]
        display_matrix = np.concatenate(
            [
                attention_matrix[: max_display_tokens // 2, :],
                np.zeros((1, seq_len)),
                attention_matrix[-max_display_tokens // 2 :, :],
            ],
            axis=0,
        )
        display_matrix = np.concatenate(
            [
                display_matrix[:, : max_display_tokens // 2],
                np.zeros((display_matrix.shape[0], 1)),
                display_matrix[:, -max_display_tokens // 2 :],
            ],
            axis=1,
        )
    else:
        display_labels = token_labels
        display_matrix = attention_matrix

    # Create heatmap
    fig_size = max(10, len(display_labels) * 0.3)
    plt.figure(figsize=(fig_size, fig_size))
    sns.heatmap(
        display_matrix,
        xticklabels=display_labels if len(display_labels) <= 30 else False,
        yticklabels=display_labels if len(display_labels) <= 30 else False,
        cmap="viridis",
        cbar_kws={"label": "Attention Weight"},
        vmin=0,
        vmax=1,
    )
    plt.xlabel("Key Position", fontsize=12)
    plt.ylabel("Query Position", fontsize=12)
    title = f"Attention Weights: {layer_label}{head_label}"
    if sample_id is not None:
        title += f" (Sample {sample_id})"
    plt.title(title, fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved attention weights heatmap to: {output_path}")


def average_attention_mass_results(
    results_list: List[Dict[int, Dict[int, float]]],
) -> Dict[int, Dict[int, float]]:
    if not results_list:
        return {}

    sums: Dict[int, Dict[int, float]] = {}
    counts: Dict[int, Dict[int, int]] = {}
    for res in results_list:
        for seq_len, layer_map in res.items():
            for layer_idx, val in layer_map.items():
                sums.setdefault(seq_len, {}).setdefault(layer_idx, 0.0)
                counts.setdefault(seq_len, {}).setdefault(layer_idx, 0)
                sums[seq_len][layer_idx] += float(val)
                counts[seq_len][layer_idx] += 1

    out: Dict[int, Dict[int, float]] = {}
    for seq_len in sorted(sums.keys()):
        out[seq_len] = {}
        for layer_idx in sorted(sums[seq_len].keys()):
            c = counts[seq_len][layer_idx]
            if c > 0:
                out[seq_len][layer_idx] = sums[seq_len][layer_idx] / c
    return out


def compute_average_attention_mass_per_layer(
    results: Dict[int, Dict[int, float]],
    num_layers: int,
) -> List[float]:
    """
    Compute average attention mass per layer by averaging across all sequence lengths.

    Args:
        results: Dictionary mapping seq_len to layer_index to attention_mass_percent
        num_layers: Number of layers in the model

    Returns:
        List of average attention mass per layer (one float per layer)
    """
    layer_sums: Dict[int, float] = {}
    layer_counts: Dict[int, int] = {}

    for seq_len, layer_map in results.items():
        for layer_idx, val in layer_map.items():
            if 0 <= layer_idx < num_layers:
                layer_sums.setdefault(layer_idx, 0.0)
                layer_counts.setdefault(layer_idx, 0)
                layer_sums[layer_idx] += float(val)
                layer_counts[layer_idx] += 1

    # Compute averages for each layer
    avg_per_layer = []
    for layer_idx in range(num_layers):
        if layer_idx in layer_counts and layer_counts[layer_idx] > 0:
            avg_per_layer.append(layer_sums[layer_idx] / layer_counts[layer_idx])
        else:
            avg_per_layer.append(0.0)

    return avg_per_layer


def compute_attention_mass_for_original_sequence(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    text: str,
    device: torch.device,
    target_seq_lengths: List[int],
    num_layers: int,
) -> List[float]:
    """
    Compute average attention mass per layer for original sequence without compression embeddings.

    For each target sequence length, we compute attention and extract the attention mass
    on the first token. Then we average across all sequence lengths to get one value per layer.

    Args:
        model: Language model
        tokenizer: Tokenizer
        text: Input text
        device: Device to run on
        target_seq_lengths: List of target sequence lengths to evaluate
        num_layers: Number of layers in the model

    Returns:
        List of average attention mass per layer (one float per layer)
        (This represents average attention mass on the first token, averaged across all sequence lengths)
    """
    if not text or not isinstance(text, str) or text.strip() == "":
        return [0.0] * num_layers

    model.eval()
    with torch.no_grad():
        # Tokenize text
        enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        max_seq_len = input_ids.shape[1]

        # Forward pass with attention outputs
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
        )

        # Extract attention weights
        attentions = outputs.attentions
        if attentions is None:
            return [0.0] * num_layers

        # For original sequence, we'll compute attention mass on the first token
        # This is analogous to compression token attention, but for the first text token
        layer_sums: Dict[int, float] = {}
        layer_counts: Dict[int, int] = {}

        for target_seq_len in target_seq_lengths:
            if target_seq_len < 1 or target_seq_len > max_seq_len:
                continue

            # For each layer, compute attention mass on first token from positions 1 to target_seq_len-1
            # Exclude self-attention (position 0 attending to position 0) to avoid outlier
            for layer_idx in range(num_layers):
                # Get attention for this layer: [batch_size, num_heads, seq_len, seq_len]
                attn_layer = attentions[layer_idx]
                # Average across heads and batch: [seq_len, seq_len]
                attn_avg = attn_layer.mean(dim=1)[0]  # [seq_len, seq_len]

                # Attention mass on first token (position 0) from positions 1 to target_seq_len-1
                # Exclude self-attention to avoid outlier when only BOS token is present
                layer_sums.setdefault(layer_idx, 0.0)
                layer_counts.setdefault(layer_idx, 0)
                if target_seq_len > 1:
                    # Sum attention from positions 1 to target_seq_len-1 to key position 0
                    attention_on_first_token = attn_avg[1:target_seq_len, 0].sum().item()
                    # Normalize by number of query positions (excluding position 0)
                    attention_mass_percent = (attention_on_first_token / (target_seq_len - 1)) * 100.0

                    layer_sums[layer_idx] += attention_mass_percent
                    layer_counts[layer_idx] += 1
                else:
                    pass
                    # do not consider bos token itself

        # Compute averages for each layer
        avg_per_layer = []
        for layer_idx in range(num_layers):
            if layer_idx in layer_counts and layer_counts[layer_idx] > 0:
                avg_per_layer.append(layer_sums[layer_idx] / layer_counts[layer_idx])
            else:
                avg_per_layer.append(0.0)

    return avg_per_layer


def compute_attention_mass_for_original_sequence_per_seq_len(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    text: str,
    device: torch.device,
    target_seq_lengths: List[int],
) -> Dict[int, Dict[int, float]]:
    """
    Compute attention mass for original sequence per sequence length (for visualization).

    Returns results in the same format as compression embeddings: Dict[seq_len, Dict[layer_idx, attention_mass_percent]]

    Args:
        model: Language model
        tokenizer: Tokenizer
        text: Input text
        device: Device to run on
        target_seq_lengths: List of target sequence lengths to evaluate

    Returns:
        Dictionary mapping seq_len to layer_index to attention_mass_percent
    """
    if not text or not isinstance(text, str) or text.strip() == "":
        return {}

    model.eval()
    with torch.no_grad():
        # Tokenize text
        enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        max_seq_len = input_ids.shape[1]

        # Forward pass with attention outputs
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True,
        )

        # Extract attention weights
        attentions = outputs.attentions
        if attentions is None:
            return {}

        num_layers = len(attentions)
        results: Dict[int, Dict[int, float]] = {}

        for target_seq_len in target_seq_lengths:
            if target_seq_len < 1 or target_seq_len > max_seq_len:
                continue

            layer_attention_mass: Dict[int, float] = {}
            for layer_idx in range(num_layers):
                # Get attention for this layer: [batch_size, num_heads, seq_len, seq_len]
                attn_layer = attentions[layer_idx]
                # Average across heads and batch: [seq_len, seq_len]
                attn_avg = attn_layer.mean(dim=1)[0]  # [seq_len, seq_len]

                # Attention mass on first token (position 0) from positions 1 to target_seq_len-1
                # Exclude self-attention (position 0 attending to position 0) to avoid outlier
                if target_seq_len > 1:
                    # Sum attention from positions 1 to target_seq_len-1 to key position 0
                    attention_on_first_token = attn_avg[1:target_seq_len, 0].sum().item()
                    # Normalize by number of query positions (excluding position 0)
                    attention_mass_percent = (attention_on_first_token / (target_seq_len - 1)) * 100.0
                else:
                    # If only one token, no other tokens to attend to BOS
                    attention_mass_percent = 0.0

                layer_attention_mass[layer_idx] = attention_mass_percent

            if layer_attention_mass:
                results[target_seq_len] = layer_attention_mass

    return results


def save_attention_mass_cache(
    cache_data: Dict[str, Any],
    output_dir: str,
    sample_id: Optional[int] = None,
):
    """
    Save attention mass cache to JSON file.

    Args:
        cache_data: Dictionary containing attention mass data
        output_dir: Directory to save cache file
        sample_id: Optional sample ID (if None, saves as average)
    """
    if sample_id is not None:
        cache_filename = f"attention_mass_cache_sample_{sample_id}.json"
    else:
        cache_filename = "attention_mass_cache_avg.json"

    cache_path = os.path.join(output_dir, cache_filename)
    with open(cache_path, "w") as f:
        json.dump(cache_data, f, indent=2)
    print(f"Saved attention mass cache to: {cache_path}")


def print_attention_mass_summary(
    all_compression_attention_mass: List[List[float]],
    all_original_attention_mass: List[List[float]],
    num_layers: int,
):
    """
    Print summary of average attention mass across all samples.

    Args:
        all_compression_attention_mass: List of attention mass per layer for each sample (compression embeddings)
        all_original_attention_mass: List of attention mass per layer for each sample (original sequence, BOS token)
        num_layers: Number of layers in the model
    """
    if not all_compression_attention_mass and not all_original_attention_mass:
        print("\nNo attention mass data to summarize.")
        return

    if len(all_compression_attention_mass) != len(all_original_attention_mass):
        print(
            f"\nWarning: Mismatch in number of samples: compression={len(all_compression_attention_mass)}, original={len(all_original_attention_mass)}"
        )

    num_samples_compression = len(all_compression_attention_mass) if all_compression_attention_mass else 0
    num_samples_original = len(all_original_attention_mass) if all_original_attention_mass else 0
    num_samples = max(num_samples_compression, num_samples_original)
    if num_samples == 0:
        return

    # Compute averages across samples for each layer
    avg_compression_per_layer = [0.0] * num_layers
    avg_original_per_layer = [0.0] * num_layers

    for layer_idx in range(num_layers):
        if all_compression_attention_mass:
            compression_sum = sum(sample[layer_idx] for sample in all_compression_attention_mass if layer_idx < len(sample))
            avg_compression_per_layer[layer_idx] = (
                compression_sum / num_samples_compression if num_samples_compression > 0 else 0.0
            )
        if all_original_attention_mass:
            original_sum = sum(sample[layer_idx] for sample in all_original_attention_mass if layer_idx < len(sample))
            avg_original_per_layer[layer_idx] = original_sum / num_samples_original if num_samples_original > 0 else 0.0

    # Print summary
    print("\n" + "=" * 80)
    print("ATTENTION MASS SUMMARY (averaged across all samples)")
    print("=" * 80)
    print(f"Number of samples: {num_samples}")
    print(f"Number of layers: {num_layers}")
    print("\nAverage Attention Mass per Layer:")
    print("-" * 80)
    print(f"{'Layer':<10} {'Compression Embeddings (%)':<30} {'BOS Token Original (%)':<30}")
    print("-" * 80)
    for layer_idx in range(num_layers):
        comp_val = avg_compression_per_layer[layer_idx]
        orig_val = avg_original_per_layer[layer_idx]
        print(f"{layer_idx:<10} {comp_val:<30.2f} {orig_val:<30.2f}")

    # Overall averages
    overall_avg_compression = sum(avg_compression_per_layer) / num_layers
    overall_avg_original = sum(avg_original_per_layer) / num_layers
    print("-" * 80)
    print(f"{'Overall Avg':<10} {overall_avg_compression:<30.2f} {overall_avg_original:<30.2f}")

    # Compute correlation between compression and original attention mass across layers
    if all_compression_attention_mass and all_original_attention_mass:
        # Convert to numpy arrays for correlation computation
        compression_array = np.array(avg_compression_per_layer)
        original_array = np.array(avg_original_per_layer)
        # Compute Pearson correlation coefficient
        if len(compression_array) > 1 and np.std(compression_array) > 0 and np.std(original_array) > 0:
            correlation = np.corrcoef(compression_array, original_array)[0, 1]
            print("-" * 80)
            print(f"Correlation (across layers): {correlation:.4f}")
        else:
            print("-" * 80)
            print("Correlation (across layers): N/A (insufficient variance)")

    print("=" * 80)


def main():
    parser = argparse.ArgumentParser(description="Visualize attention hijacking with compression tokens")
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to progressive_prefixes or prefix_tuning_prefixes dataset",
    )
    parser.add_argument(
        "--model_checkpoint",
        type=str,
        default=None,
        help="Model checkpoint path (if not provided, will try to infer from dataset)",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Directory to save figures (default: inferred from dataset_path)",
    )
    parser.add_argument(
        "--min_seq_length",
        type=int,
        default=1,
        help="Filter out samples whose max sequence length is < this value.",
    )
    parser.add_argument(
        "--attention_block_size",
        type=int,
        default=16,
        help="Block size for averaging attention for long sequences (target_seq_len > 100).",
    )
    parser.add_argument(
        "--save_per_sample_heatmaps",
        action="store_true",
        help="Save individual heatmaps for each sample.",
    )
    parser.add_argument(
        "--save_attention_weights_heatmaps",
        action="store_true",
        help="Save full attention weights heatmaps (seq_len x seq_len) for each sample.",
    )
    parser.add_argument(
        "--attention_layer_idx",
        type=int,
        default=None,
        help="Layer index to plot for attention weights heatmap (if not provided, averages across all layers).",
    )
    parser.add_argument(
        "--attention_head_idx",
        type=int,
        default=None,
        help="Head index to plot for attention weights heatmap (if not provided, averages across all heads).",
    )

    args = parser.parse_args()

    # Determine output directory
    output_dir = args.output_dir
    if output_dir is None:
        # Try to infer from dataset path
        dataset_path = args.dataset_path
        if (
            "artifacts/experiments" in dataset_path
            or "artifacts/experiments_progressive" in dataset_path
            or "artifacts/experiments_prefix_tuning" in dataset_path
        ):
            exp_dir = os.path.dirname(dataset_path)
            output_dir = os.path.join(exp_dir, "attention_visualizations")
        else:
            output_dir = "attention_visualizations"
    os.makedirs(output_dir, exist_ok=True)

    # Load dataset and detect type
    print(f"Loading dataset from: {args.dataset_path}")
    ds, dataset_type = load_dataset(args.dataset_path)
    print(f"Detected dataset type: {dataset_type}")

    if dataset_type == "unknown":
        raise ValueError("Could not detect dataset type. Expected 'progressive' or 'prefix_tuning' dataset.")

    # Filter records
    rows = filter_records(ds, sample_id=None, dataset_type=dataset_type)
    if not rows:
        raise ValueError("No records found with given filters.")

    # Group by sample
    by_sid = collate_stages_by_sample(rows, dataset_type=dataset_type)

    # Determine model checkpoint
    model_checkpoint = args.model_checkpoint
    if model_checkpoint is None:
        # Try to infer from dataset
        if rows:
            model_checkpoint = rows[0].get("model_checkpoint", "")
            if not model_checkpoint:
                raise ValueError(
                    "model_checkpoint not provided and cannot be inferred from dataset. " "Please provide --model_checkpoint."
                )
        else:
            raise ValueError("No rows found to infer model_checkpoint from.")

    print(f"Using model checkpoint: {model_checkpoint}")

    # Load model and tokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading model on device: {device}")
    # Set attention implementation to 'eager' to enable output_attentions
    try:
        model = AutoModelForCausalLM.from_pretrained(
            model_checkpoint,
            attn_implementation="eager",
        ).to(device)
    except TypeError:
        # Fallback for older transformers versions that don't support attn_implementation
        print("Warning: attn_implementation parameter not supported, loading model without it...")
        model = AutoModelForCausalLM.from_pretrained(model_checkpoint).to(device)
        # Try to set attention implementation after loading
        try:
            model.set_attn_implementation("eager")
        except (AttributeError, ValueError):
            print("Warning: Could not set attention implementation to 'eager'. Attention outputs may not be available.")
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    if tokenizer.pad_token is None and tokenizer.eos_token is not None:
        tokenizer.pad_token = tokenizer.eos_token

    # Process each sample
    if args.min_seq_length < 1:
        raise ValueError("--min_seq_length must be >= 1")

    # Initialize variables for summary
    all_compression_attention_mass: List[List[float]] = []
    all_original_attention_mass: List[List[float]] = []
    num_layers = model.config.num_hidden_layers

    # Average heatmaps over all samples, limiting to the minimum max sequence length across samples.
    eligible_by_sid: Dict[int, List[Dict[str, Any]]] = {}
    per_sample_max = []
    for _sid, stages in by_sid.items():
        if dataset_type == "prefix_tuning":
            # For prefix tuning, get sequence length from tokenized text
            stage_record = stages[0]
            text = stage_record.get("text", "")
            if not isinstance(text, str) or text.strip() == "":
                continue
            enc = tokenizer(text, truncation=True, padding=False, return_tensors="pt")
            max_len = enc["input_ids"].shape[1]
        else:
            # Progressive: use stage_seq_len
            max_len = max((int(s.get("stage_seq_len", -1)) for s in stages), default=-1)
        if max_len >= args.min_seq_length:
            eligible_by_sid[_sid] = stages
            per_sample_max.append(max_len)

    if not per_sample_max:
        raise ValueError(
            f"No samples with max sequence length >= {args.min_seq_length} found. "
            "Lower --min_seq_length or check the dataset."
        )

    min_max_len = min(per_sample_max)
    print(
        f"\nAveraging over {len(eligible_by_sid)} samples; " f"using target_seq_len in [{args.min_seq_length}, {min_max_len}]"
    )
    target_seq_lengths_override = list(range(args.min_seq_length, min_max_len + 1))

    all_results: List[Dict[int, Dict[int, float]]] = []
    all_original_results: List[Dict[int, Dict[int, float]]] = []
    for sample_id, stages in eligible_by_sid.items():
        stage_count = len(stages)
        stage_label = "stages" if dataset_type == "progressive" else "entry"
        print(f"\nProcessing sample {sample_id} with {stage_count} {stage_label}...")
        results, attentions, text, num_compression_tokens = compute_attention_mass_for_stages(
            model=model,
            tokenizer=tokenizer,
            stages=stages,
            device=device,
            attention_block_size=args.attention_block_size,
            target_seq_lengths_override=target_seq_lengths_override,
            dataset_type=dataset_type,
        )
        if results:
            all_results.append(results)
            # Save per-sample heatmap if flag is set
            if args.save_per_sample_heatmaps:
                output_path = os.path.join(output_dir, f"attention_hijacking_sample_{sample_id}.png")
                plot_attention_hijacking_heatmap(
                    results=results,
                    sample_id=sample_id,
                    output_path=output_path,
                )
            # Save attention weights heatmap if flag is set
            if (
                args.save_attention_weights_heatmaps
                and attentions is not None
                and text is not None
                and num_compression_tokens is not None
            ):
                output_path = os.path.join(output_dir, f"attention_weights_sample_{sample_id}.png")
                plot_attention_weights_heatmap(
                    attentions=attentions,
                    num_compression_tokens=num_compression_tokens,
                    tokenizer=tokenizer,
                    text=text,
                    sample_id=sample_id,
                    output_path=output_path,
                    layer_idx=args.attention_layer_idx,
                    head_idx=args.attention_head_idx,
                )

            # Compute and save attention mass cache for compression embeddings
            if text is not None:
                avg_attention_mass_compression = compute_average_attention_mass_per_layer(
                    results=results,
                    num_layers=num_layers,
                )

                # Compute attention mass for original sequence (without compression)
                print(f"Computing attention mass for original sequence (sample {sample_id})...")
                # Compute per-sequence-length results for visualization
                original_results = compute_attention_mass_for_original_sequence_per_seq_len(
                    model=model,
                    tokenizer=tokenizer,
                    text=text,
                    device=device,
                    target_seq_lengths=target_seq_lengths_override,
                )
                if original_results:
                    all_original_results.append(original_results)

                # Compute averaged per-layer for summary
                avg_attention_mass_original = compute_attention_mass_for_original_sequence(
                    model=model,
                    tokenizer=tokenizer,
                    text=text,
                    device=device,
                    target_seq_lengths=target_seq_lengths_override,
                    num_layers=num_layers,
                )

                # Collect for summary
                all_compression_attention_mass.append(avg_attention_mass_compression)
                all_original_attention_mass.append(avg_attention_mass_original)

                # Save to cache
                cache_data = {
                    "sample_id": sample_id,
                    "num_layers": num_layers,
                    "target_seq_lengths": target_seq_lengths_override,
                    "avg_attention_mass_per_layer_compression": avg_attention_mass_compression,
                    "avg_attention_mass_per_layer_original": avg_attention_mass_original,
                }
                save_attention_mass_cache(cache_data, output_dir, sample_id=sample_id)

    avg_results = average_attention_mass_results(all_results)
    output_path = os.path.join(output_dir, "attention_hijacking_avg.png")
    plot_attention_hijacking_heatmap(
        results=avg_results,
        sample_id=None,
        output_path=output_path,
    )

    # Average and plot original sequence attention mass
    if all_original_results:
        avg_original_results = average_attention_mass_results(all_original_results)
        output_path_original = os.path.join(output_dir, "attention_hijacking_avg_original.png")
        plot_attention_hijacking_heatmap(
            results=avg_original_results,
            sample_id=None,
            output_path=output_path_original,
            title_prefix="Attention Hijacking: BOS Token Attention Mass % (Original Sequence)",
            colorbar_label="Attention Mass % on BOS Token",
        )

    # Print summary of average attention mass
    if all_compression_attention_mass or all_original_attention_mass:
        print_attention_mass_summary(
            all_compression_attention_mass,
            all_original_attention_mass,
            num_layers,
        )

    print(f"\nAll visualizations saved to: {output_dir}")


if __name__ == "__main__":
    main()
