#!/usr/bin/env python3
"""Case study runner for FlashTrace and attribution baselines.

Modes supported (all emit JSON + HTML under ``exp/case_study/out``):

- ``ft``: FlashTrace (current project implementation; multi-hop IFR)
- ``ifr_in_all_gen``: Experimental multi-hop IFR variant (hops over CoT+output; scheme B, aligns with exp/exp2)
- ``ifr``: IFR span-aggregate visualization (single hop; one panel)
- ``ifr_all_positions``: IFR full matrix + CAGE (Row/Recursive panels)
- ``ifr_all_positions_output_only``: IFR output-only token matrix + CAGE (Row/Recursive panels)
- ``attnlrp``: AttnLRP hop0 (reuse FT-AttnLRP span-aggregate; visualize raw hop0 vector)
- ``ft_attnlrp``: FT-AttnLRP (multi-hop aggregated AttnLRP; matches exp/exp2)
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

# Avoid torchvision dependency when importing transformers (Longformer).
os.environ.setdefault("TRANSFORMERS_NO_TORCHVISION", "1")
os.environ.setdefault("DISABLE_TRANSFORMERS_IMAGE_TRANSFORMS", "1")

def _early_set_cuda_visible_devices() -> None:
    """Set CUDA_VISIBLE_DEVICES before importing torch/transformers.

    Note: CUDA device indices are re-mapped inside the process after applying the mask.
    """

    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--cuda", type=str, default=None)
    args, _ = parser.parse_known_args(sys.argv[1:])
    cuda = args.cuda.strip() if isinstance(args.cuda, str) else ""
    if cuda and "," in cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda


if __name__ == "__main__":
    _early_set_cuda_visible_devices()

import torch

REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))


def _stub_torchvision() -> None:
    """Provide minimal torchvision stubs so Longformer imports succeed without the real package."""

    if "torchvision" in sys.modules:
        return

    from importlib.machinery import ModuleSpec

    def _mk(name: str) -> types.ModuleType:
        mod = types.ModuleType(name)
        mod.__spec__ = ModuleSpec(name, loader=None)
        return mod

    tv = _mk("torchvision")
    tv.__dict__["__path__"] = []
    submods = ["transforms", "_meta_registrations", "datasets", "io", "models", "ops", "utils"]
    for name in submods:
        mod = _mk(f"torchvision.{name}")
        sys.modules[f"torchvision.{name}"] = mod
        setattr(tv, name, mod)

    class _InterpolationMode:
        NEAREST = 0
        NEAREST_EXACT = 0
        BILINEAR = 1
        BICUBIC = 2
        LANCZOS = 3
        BOX = 4
        HAMMING = 5

    sys.modules["torchvision.transforms"].InterpolationMode = _InterpolationMode
    sys.modules["torchvision.transforms"].__all__ = ["InterpolationMode"]

    # ops + misc stub for timm/transformers imports
    ops_mod = sys.modules.get("torchvision.ops") or _mk("torchvision.ops")
    sys.modules["torchvision.ops"] = ops_mod
    setattr(tv, "ops", ops_mod)
    misc_mod = _mk("torchvision.ops.misc")
    sys.modules["torchvision.ops.misc"] = misc_mod
    setattr(ops_mod, "misc", misc_mod)

    class _FrozenBatchNorm2d:
        def __init__(self, *args, **kwargs):
            pass

    misc_mod.FrozenBatchNorm2d = _FrozenBatchNorm2d

    sys.modules["torchvision"] = tv


_stub_torchvision()


def _stub_timm() -> None:
    """Provide minimal timm stubs to avoid optional vision deps."""

    if "timm" in sys.modules:
        return

    from importlib.machinery import ModuleSpec

    def _mk(name: str) -> types.ModuleType:
        mod = types.ModuleType(name)
        mod.__spec__ = ModuleSpec(name, loader=None)
        return mod

    timm = _mk("timm")
    timm.__dict__["__path__"] = []
    sys.modules["timm"] = timm

    data_mod = _mk("timm.data")
    sys.modules["timm.data"] = data_mod
    timm.data = data_mod

    class _ImageNetInfo:
        pass

    def _infer_imagenet_subset(*args, **kwargs):
        return None

    data_mod.ImageNetInfo = _ImageNetInfo
    data_mod.infer_imagenet_subset = _infer_imagenet_subset

    layers_mod = _mk("timm.layers")
    sys.modules["timm.layers"] = layers_mod
    timm.layers = layers_mod

    create_norm_mod = _mk("timm.layers.create_norm")
    sys.modules["timm.layers.create_norm"] = create_norm_mod
    layers_mod.create_norm = create_norm_mod

    def _get_norm_layer(*args, **kwargs):
        return None

    create_norm_mod.get_norm_layer = _get_norm_layer

    classifier_mod = _mk("timm.layers.classifier")
    sys.modules["timm.layers.classifier"] = classifier_mod
    layers_mod.classifier = classifier_mod


_stub_timm()

import transformers

# Provide light stubs if Longformer classes are unavailable; IFR case study does not use them.
if not hasattr(transformers, "LongformerTokenizer"):
    class _DummyLongformerTokenizer:
        def __init__(self, *args, **kwargs):
            raise ImportError("LongformerTokenizer stubbed; install full transformers+torchvision if needed.")
    transformers.LongformerTokenizer = _DummyLongformerTokenizer

if not hasattr(transformers, "LongformerForMaskedLM"):
    class _DummyLongformerForMaskedLM:
        def __init__(self, *args, **kwargs):
            raise ImportError("LongformerForMaskedLM stubbed; install full transformers+torchvision if needed.")
    transformers.LongformerForMaskedLM = _DummyLongformerForMaskedLM

if hasattr(transformers, "__all__"):
    for _name in ["LongformerTokenizer", "LongformerForMaskedLM"]:
        if _name not in transformers.__all__:
            transformers.__all__.append(_name)

# Gemma3n stubs (transformers may attempt to import even if unused)
if "transformers.models.gemma3n.configuration_gemma3n" not in sys.modules:
    from importlib.machinery import ModuleSpec

    gemma_pkg = types.ModuleType("transformers.models.gemma3n")
    gemma_pkg.__spec__ = ModuleSpec("transformers.models.gemma3n", loader=None, is_package=True)
    sys.modules["transformers.models.gemma3n"] = gemma_pkg

    gemma_conf = types.ModuleType("transformers.models.gemma3n.configuration_gemma3n")
    gemma_conf.__spec__ = ModuleSpec("transformers.models.gemma3n.configuration_gemma3n", loader=None)

    class Gemma3nConfig:
        def __init__(self, *args, **kwargs):
            self.model_type = "gemma3n"

    class Gemma3nTextConfig(Gemma3nConfig):
        pass

    gemma_conf.Gemma3nConfig = Gemma3nConfig
    gemma_conf.Gemma3nTextConfig = Gemma3nTextConfig
    gemma_conf.__all__ = ["Gemma3nConfig", "Gemma3nTextConfig"]
    sys.modules["transformers.models.gemma3n.configuration_gemma3n"] = gemma_conf
    setattr(gemma_pkg, "configuration_gemma3n", gemma_conf)

    if hasattr(transformers, "__all__"):
        for _nm in ["Gemma3nConfig", "Gemma3nTextConfig"]:
            if _nm not in transformers.__all__:
                transformers.__all__.append(_nm)

import llm_attr
from exp.exp2 import dataset_utils as ds_utils
from evaluations.attribution_recovery import load_model

from exp.case_study import analysis, viz


def resolve_device(cuda: Optional[str], cuda_num: int) -> str:
    if cuda and isinstance(cuda, str) and "," in cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = cuda
        return "auto"
    if cuda and isinstance(cuda, str) and cuda.strip():
        try:
            idx = int(cuda)
        except Exception:
            idx = 0
        return f"cuda:{idx}" if torch.cuda.is_available() else "cpu"
    return f"cuda:{cuda_num}" if torch.cuda.is_available() else "cpu"


def load_example(dataset: str, index: int, data_root: Path) -> Tuple[ds_utils.CachedExample, str]:
    """Load a single example from a cache path or dataset name."""

    ds_path = Path(dataset)
    if ds_path.exists():
        examples = ds_utils.read_cached_jsonl(ds_path)
        dataset_name = ds_path.name
    else:
        loader = ds_utils.DatasetLoader(data_root=data_root)
        examples = loader.load(dataset)
        dataset_name = dataset

    if not examples:
        raise ValueError(f"No examples found for dataset={dataset}")

    if index < 0:
        index = len(examples) + index
    if not (0 <= index < len(examples)):
        raise IndexError(f"index {index} out of range for dataset with {len(examples)} examples")

    return examples[index], dataset_name


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser("IFR multi-hop case study")
    parser.add_argument("--dataset", type=str, default="exp/exp2/data/morehopqa.jsonl", help="Dataset name or JSONL path.")
    parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Cache root for dataset names.")
    parser.add_argument("--index", type=int, default=0, help="Sample index (supports negative for reverse).")
    parser.add_argument(
        "--mode",
        type=str,
        choices=[
            "ft",
            "ft_improve",
            "ft_split_hop",
            "ifr_in_all_gen",
            "ifr",
            "ifr_all_positions",
            "ifr_all_positions_output_only",
            "attnlrp",
            "ft_attnlrp",
        ],
        default="ft",
        help=(
            "ft = FlashTrace (multi-hop IFR); ifr = standard IFR span-aggregate; "
            "ifr_in_all_gen = multi-hop IFR over CoT+output (scheme B; exp2-aligned); "
            "ifr_all_positions = full IFR matrix + CAGE row/rec; "
            "ft_improve = FlashTrace (multi-hop IFR, stop-token soft deletion); "
            "ft_split_hop = FlashTrace (split-hop IFR over segmented thinking span); "
            "ifr_all_positions_output_only = output-only IFR matrix + CAGE row/rec; "
            "attnlrp = AttnLRP hop0 (FT-AttnLRP span-aggregate); "
            "ft_attnlrp = FT-AttnLRP (multi-hop aggregated; exp2)."
        ),
    )
    parser.add_argument("--model", type=str, default="qwen-8B", help="HF repo id (ignored if --model_path set).")
    parser.add_argument("--model_path", type=str, default=None, help="Local model path to override --model.")
    parser.add_argument("--cuda", type=str, default=None, help="CUDA spec (e.g., '0' or '0,1').")
    parser.add_argument("--cuda_num", type=int, default=0, help="Fallback GPU index when --cuda unset.")
    parser.add_argument("--n_hops", type=int, default=1, help="Number of hops for IFR multi-hop.")
    parser.add_argument("--sink_span", type=int, nargs=2, default=None, help="Optional sink span over generation tokens.")
    parser.add_argument("--thinking_span", type=int, nargs=2, default=None, help="Optional thinking span over generation tokens.")
    parser.add_argument(
        "--attnlrp_neg_handling",
        type=str,
        choices=["drop", "abs"],
        default="drop",
        help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
    )
    parser.add_argument(
        "--attnlrp_norm_mode",
        type=str,
        choices=["norm", "no_norm"],
        default="norm",
        help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
    )
    parser.add_argument("--chunk_tokens", type=int, default=128, help="IFR chunk size.")
    parser.add_argument("--sink_chunk_tokens", type=int, default=32, help="IFR sink chunk size.")
    parser.add_argument("--output_dir", type=str, default="exp/case_study/out", help="Where to write HTML/JSON artifacts.")
    return parser.parse_args()


def run_ft_multihop(
    example: ds_utils.CachedExample,
    model: Any,
    tokenizer: Any,
    *,
    n_hops: int,
    sink_span: Optional[Sequence[int]],
    thinking_span: Optional[Sequence[int]],
    chunk_tokens: int,
    sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
    """Execute FT (current multi-hop IFR) attribution for the selected example."""

    attr = llm_attr.LLMIFRAttribution(
        model,
        tokenizer,
        chunk_tokens=chunk_tokens,
        sink_chunk_tokens=sink_chunk_tokens,
    )

    sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
    thinking = (
        tuple(thinking_span)
        if thinking_span is not None
        else tuple(example.thinking_span) if example.thinking_span else None
    )

    result = attr.calculate_ifr_multi_hop(
        example.prompt,
        target=example.target,
        sink_span=sink,
        thinking_span=thinking,
        n_hops=n_hops,
    )
    debug_info: Dict[str, Any] = {
        "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
        "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
        "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
        "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
        "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
        "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
    }

    raw_vectors = []
    if result.metadata and "ifr" in result.metadata:
        raw_ifr = result.metadata["ifr"].get("raw")
        if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
            try:
                raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
            except Exception:
                raw_vectors = []
    debug_info["raw_hop_vectors"] = raw_vectors

    return result, sink, thinking, debug_info


def run_ft_multihop_improve(
    example: ds_utils.CachedExample,
    model: Any,
    tokenizer: Any,
    *,
    n_hops: int,
    sink_span: Optional[Sequence[int]],
    thinking_span: Optional[Sequence[int]],
    chunk_tokens: int,
    sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
    """Execute experimental FT (multi-hop IFR) with stop-token soft deletion."""

    import ft_ifr_improve

    attr = ft_ifr_improve.LLMIFRAttributionImproved(
        model,
        tokenizer,
        chunk_tokens=chunk_tokens,
        sink_chunk_tokens=sink_chunk_tokens,
    )

    sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
    thinking = (
        tuple(thinking_span)
        if thinking_span is not None
        else tuple(example.thinking_span) if example.thinking_span else None
    )

    result = attr.calculate_ifr_multi_hop_stop_words(
        example.prompt,
        target=example.target,
        sink_span=sink,
        thinking_span=thinking,
        n_hops=n_hops,
    )

    debug_info: Dict[str, Any] = {
        "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
        "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
        "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
        "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
        "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
        "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
    }

    raw_vectors = []
    if result.metadata and "ifr" in result.metadata:
        raw_ifr = result.metadata["ifr"].get("raw")
        if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
            try:
                raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
            except Exception:
                raw_vectors = []
    debug_info["raw_hop_vectors"] = raw_vectors

    return result, sink, thinking, debug_info


def run_ft_multihop_split_hop(
    example: ds_utils.CachedExample,
    model: Any,
    tokenizer: Any,
    *,
    n_hops: int,
    sink_span: Optional[Sequence[int]],
    thinking_span: Optional[Sequence[int]],
    chunk_tokens: int,
    sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
    """Execute experimental FT (split-hop IFR over segmented thinking span)."""

    import ft_ifr_improve

    attr = ft_ifr_improve.LLMIFRAttributionSplitHop(
        model,
        tokenizer,
        chunk_tokens=chunk_tokens,
        sink_chunk_tokens=sink_chunk_tokens,
    )

    sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
    thinking = (
        tuple(thinking_span)
        if thinking_span is not None
        else tuple(example.thinking_span) if example.thinking_span else None
    )

    result = attr.calculate_ifr_multi_hop_split_hop(
        example.prompt,
        target=example.target,
        sink_span=sink,
        thinking_span=thinking,
        n_hops=int(n_hops),
    )

    debug_info: Dict[str, Any] = {
        "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
        "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
        "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
        "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
        "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
        "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
    }

    raw_vectors = []
    if result.metadata and "ifr" in result.metadata:
        raw_ifr = result.metadata["ifr"].get("raw")
        if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
            try:
                raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
            except Exception:
                raw_vectors = []
    debug_info["raw_hop_vectors"] = raw_vectors

    return result, sink, thinking, debug_info


def run_ifr_in_all_gen(
    example: ds_utils.CachedExample,
    model: Any,
    tokenizer: Any,
    *,
    n_hops: int,
    sink_span: Optional[Sequence[int]],
    thinking_span: Optional[Sequence[int]],
    chunk_tokens: int,
    sink_chunk_tokens: int,
) -> Tuple[Any, Optional[Tuple[int, int]], Optional[Tuple[int, int]], Dict[str, Any]]:
    """Execute experimental IFR variant: multi-hop over all generation (CoT + output)."""

    import ft_ifr_improve

    attr = ft_ifr_improve.LLMIFRAttributionInAllGen(
        model,
        tokenizer,
        chunk_tokens=chunk_tokens,
        sink_chunk_tokens=sink_chunk_tokens,
    )

    sink = tuple(sink_span) if sink_span is not None else tuple(example.sink_span) if example.sink_span else None
    thinking = (
        tuple(thinking_span)
        if thinking_span is not None
        else tuple(example.thinking_span) if example.thinking_span else None
    )

    result = attr.calculate_ifr_in_all_gen(
        example.prompt,
        target=example.target,
        sink_span=sink,
        thinking_span=thinking,
        n_hops=int(n_hops),
    )

    debug_info: Dict[str, Any] = {
        "full_prompt_tokens": list(getattr(attr, "prompt_tokens", []) or []),
        "generation_tokens": list(getattr(attr, "generation_tokens", []) or []),
        "user_prompt_indices": list(getattr(attr, "user_prompt_indices", []) or []),
        "chat_prompt_indices": list(getattr(attr, "chat_prompt_indices", []) or []),
        "prompt_ids": getattr(attr, "prompt_ids", None).detach().cpu().tolist() if getattr(attr, "prompt_ids", None) is not None else None,
        "generation_ids": getattr(attr, "generation_ids", None).detach().cpu().tolist() if getattr(attr, "generation_ids", None) is not None else None,
    }

    raw_vectors = []
    if result.metadata and "ifr" in result.metadata:
        raw_ifr = result.metadata["ifr"].get("raw")
        if raw_ifr is not None and hasattr(raw_ifr, "raw_attributions"):
            try:
                raw_vectors = [r.token_importance_total.detach().cpu() for r in raw_ifr.raw_attributions]
            except Exception:
                raw_vectors = []
    debug_info["raw_hop_vectors"] = raw_vectors

    return result, sink, thinking, debug_info


def make_output_stem(dataset_name: str, index: int, mode: str) -> str:
    safe_name = dataset_name.replace("/", "_").replace(" ", "_")
    prefix = {
        "ft": "ft_case_",
        "ft_improve": "ft_improve_case_",
        "ifr": "ifr_case_",
        "ifr_all_positions": "ifr_all_positions_case_",
        "ifr_all_positions_output_only": "ifr_output_only_case_",
        "attnlrp": "attnlrp_case_",
        "ft_attnlrp": "ft_attnlrp_case_",
    }.get(mode, f"{mode}_case_")
    return f"{prefix}{safe_name}_idx{index}"


def _decode_token_ids(tokenizer: Any, ids: Sequence[int]) -> List[str]:
    """Decode each token id into a readable text piece (keeps special tokens)."""

    pieces: List[str] = []
    for tok_id in ids:
        try:
            pieces.append(
                tokenizer.decode([int(tok_id)], skip_special_tokens=False, clean_up_tokenization_spaces=False)
            )
        except Exception:
            pieces.append(str(tok_id))
    return pieces


def build_raw_tokens_from_ids(tokenizer: Any, prompt_ids: Optional[Sequence[int]], generation_ids: Optional[Sequence[int]]) -> List[str]:
    if not prompt_ids:
        prompt_ids = []
    if not generation_ids:
        generation_ids = []
    return _decode_token_ids(tokenizer, prompt_ids) + _decode_token_ids(tokenizer, generation_ids)


def build_trimmed_roles(tokens: Sequence[str], segments: Dict[str, Any]) -> List[str]:
    """Assign role labels for trimmed tokens (prompt + generation)."""

    roles = ["prompt" for _ in range(len(tokens))]
    prompt_len_tokens = segments.get("prompt_len", 0)
    for idx in range(prompt_len_tokens, len(tokens)):
        roles[idx] = "gen"
    thinking_span = segments.get("thinking_span")
    sink_span = segments.get("sink_span")
    if thinking_span is not None:
        start = prompt_len_tokens + int(thinking_span[0])
        end = prompt_len_tokens + int(thinking_span[1])
        for i in range(start, min(len(tokens), end + 1)):
            roles[i] = "think"
    if sink_span is not None:
        start = prompt_len_tokens + int(sink_span[0])
        end = prompt_len_tokens + int(sink_span[1])
        for i in range(start, min(len(tokens), end + 1)):
            roles[i] = "output"
    return roles


def build_raw_roles(
    tokens: Sequence[str],
    prompt_len_full: int,
    user_indices: Sequence[int],
    template_indices: Sequence[int],
    thinking_span_abs: Optional[Sequence[int]],
    sink_span_abs: Optional[Sequence[int]],
) -> List[str]:
    """Assign role labels for raw tokens (template + user + generation)."""

    roles = ["template" for _ in range(len(tokens))]
    user_set = set(int(i) for i in user_indices)
    tmpl_set = set(int(i) for i in template_indices)

    for i in range(min(len(tokens), prompt_len_full)):
        if i in user_set:
            roles[i] = "user"
        elif i in tmpl_set:
            roles[i] = "template"
        else:
            roles[i] = "prompt"

    for i in range(prompt_len_full, len(tokens)):
        roles[i] = "gen"

    if thinking_span_abs is not None:
        start, end = int(thinking_span_abs[0]), int(thinking_span_abs[1])
        for i in range(start, min(len(tokens), end + 1)):
            roles[i] = "think"

    if sink_span_abs is not None:
        start, end = int(sink_span_abs[0]), int(sink_span_abs[1])
        for i in range(start, min(len(tokens), end + 1)):
            roles[i] = "output"

    return roles


def extract_prompt_only_vectors(hop_vectors: Sequence[torch.Tensor], prompt_len: int) -> List[torch.Tensor]:
    """Slice hop vectors down to user-prompt tokens only (no generation tokens)."""

    if prompt_len < 0:
        raise ValueError("prompt_len must be >= 0.")

    out: List[torch.Tensor] = []
    for vec in hop_vectors:
        v = torch.as_tensor(vec, dtype=torch.float32).detach().cpu()
        if int(v.numel()) < int(prompt_len):
            raise ValueError(f"Hop vector too short for prompt-only slice: len={int(v.numel())} prompt_len={int(prompt_len)}.")
        out.append(v[:prompt_len])
    return out


def _lift_trimmed_to_full(
    trimmed: torch.Tensor,
    *,
    prompt_len_full: int,
    gen_len: int,
    user_prompt_indices: Sequence[int],
) -> torch.Tensor:
    """Lift a trimmed (user prompt + generation) vector into full token space with zeros for chat-template tokens."""

    t = torch.as_tensor(trimmed, dtype=torch.float32).detach().cpu()
    user_len = len(user_prompt_indices)
    expected = int(user_len + gen_len)
    if int(t.numel()) != expected:
        raise ValueError(f"Trimmed vector length mismatch: got {int(t.numel())}, expected {expected}.")

    total_len = int(prompt_len_full + gen_len)
    full = torch.zeros((total_len,), dtype=torch.float32)
    for j, abs_pos in enumerate(user_prompt_indices):
        full[int(abs_pos)] = t[j]
    full[int(prompt_len_full) : int(prompt_len_full + gen_len)] = t[user_len:]
    return full


def _postprocess_attnlrp_full_vector(
    raw_full: torch.Tensor,
    *,
    prompt_len_full: int,
    gen_len: int,
    user_prompt_indices: Sequence[int],
    neg_handling: str,
    norm_mode: str,
) -> torch.Tensor:
    """Mirror FT-AttnLRP hop postprocessing while preserving stripped-token normalization.

    The underlying AttnLRP implementation postprocesses the *stripped* vector (user prompt + generation):
      - NaN->0, then neg_handling ('drop' or 'abs')
      - if norm_mode=='norm': normalize by sum over stripped tokens

    For the pre-trim full view (chat template + generation), we apply the same non-negativity transform
    to the full vector and normalize using *only the stripped indices*, so overlapping token scores
    match the trimmed vectors used by the evaluation/case-study hop outputs.
    """

    v = torch.as_tensor(raw_full, dtype=torch.float32).detach().cpu()
    v = torch.nan_to_num(v, nan=0.0)

    if neg_handling == "drop":
        v = v.clamp(min=0.0)
    elif neg_handling == "abs":
        v = v.abs()
    else:
        raise ValueError(f"Unsupported neg_handling={neg_handling!r} (expected 'drop' or 'abs').")

    ratio_enabled = norm_mode == "norm"
    if not ratio_enabled:
        return v

    keep = list(int(i) for i in user_prompt_indices) + list(range(int(prompt_len_full), int(prompt_len_full + gen_len)))
    if not keep:
        return torch.zeros_like(v)

    keep_idx = torch.as_tensor(keep, dtype=torch.long)
    denom = float(v.index_select(0, keep_idx).sum().item())
    if denom <= 0.0:
        return torch.zeros_like(v)
    return v / (denom + 1e-12)


def main() -> None:
    args = parse_args()
    device = resolve_device(args.cuda, args.cuda_num)
    if torch.cuda.is_available():
        visible = os.environ.get("CUDA_VISIBLE_DEVICES")
        print(f"[info] CUDA_VISIBLE_DEVICES={visible!r} torch.cuda.device_count()={torch.cuda.device_count()} device={device}")

    model_name = args.model_path if args.model_path is not None else args.model
    # Align with exp/exp2: always use the shared fp16 loader.
    model, tokenizer = load_model(model_name, device)

    example, ds_name = load_example(args.dataset, args.index, Path(args.data_root))
    mode = args.mode

    sink_span: Optional[Tuple[int, int]] = None
    thinking_span: Optional[Tuple[int, int]] = None
    thinking_ratios: Optional[Sequence[float]] = None

    prompt_tokens_trimmed: List[str] = []
    generation_tokens_trimmed: List[str] = []
    hop_vectors_trimmed: List[torch.Tensor] = []
    hop_vectors_raw: List[torch.Tensor] = []
    prompt_len_full: Optional[int] = None
    user_prompt_indices: List[int] = []
    chat_prompt_indices: List[int] = []
    method_meta: Dict[str, Any] = {}
    raw_prompt_ids: Optional[List[int]] = None
    raw_generation_ids: Optional[List[int]] = None
    attnlrp_raw_attributions: Optional[List[Any]] = None

    if mode in ("ft", "ft_improve", "ft_split_hop", "ifr_in_all_gen"):
        if mode == "ft":
            attr_result, sink_span, thinking_span, debug_info = run_ft_multihop(
                example,
                model,
                tokenizer,
                n_hops=args.n_hops,
                sink_span=args.sink_span,
                thinking_span=args.thinking_span,
                chunk_tokens=args.chunk_tokens,
                sink_chunk_tokens=args.sink_chunk_tokens,
            )
        elif mode == "ft_improve":
            attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_improve(
                example,
                model,
                tokenizer,
                n_hops=args.n_hops,
                sink_span=args.sink_span,
                thinking_span=args.thinking_span,
                chunk_tokens=args.chunk_tokens,
                sink_chunk_tokens=args.sink_chunk_tokens,
            )
        elif mode == "ft_split_hop":
            attr_result, sink_span, thinking_span, debug_info = run_ft_multihop_split_hop(
                example,
                model,
                tokenizer,
                n_hops=args.n_hops,
                sink_span=args.sink_span,
                thinking_span=args.thinking_span,
                chunk_tokens=args.chunk_tokens,
                sink_chunk_tokens=args.sink_chunk_tokens,
            )
        elif mode == "ifr_in_all_gen":
            attr_result, sink_span, thinking_span, debug_info = run_ifr_in_all_gen(
                example,
                model,
                tokenizer,
                n_hops=args.n_hops,
                sink_span=args.sink_span,
                thinking_span=args.thinking_span,
                chunk_tokens=args.chunk_tokens,
                sink_chunk_tokens=args.sink_chunk_tokens,
            )
        else:
            raise ValueError(f"Unsupported mode={mode}")
        ifr_meta = (attr_result.metadata or {}).get("ifr") or {}
        hop_vectors_trimmed = list(ifr_meta.get("per_hop_projected") or [])
        if not hop_vectors_trimmed:
            raise RuntimeError(f"No per-hop vectors found for {mode} mode.")

        prompt_tokens_trimmed = list(attr_result.prompt_tokens)
        generation_tokens_trimmed = list(attr_result.generation_tokens)
        thinking_ratios = ifr_meta.get("thinking_ratios")

        raw_prompt_ids = debug_info.get("prompt_ids")
        if isinstance(raw_prompt_ids, list) and raw_prompt_ids and isinstance(raw_prompt_ids[0], list):
            raw_prompt_ids = raw_prompt_ids[0]
        raw_generation_ids = debug_info.get("generation_ids")
        if isinstance(raw_generation_ids, list) and raw_generation_ids and isinstance(raw_generation_ids[0], list):
            raw_generation_ids = raw_generation_ids[0]

        user_prompt_indices = list(debug_info.get("user_prompt_indices") or [])
        chat_prompt_indices = list(debug_info.get("chat_prompt_indices") or [])
        prompt_len_full = len(raw_prompt_ids) if isinstance(raw_prompt_ids, list) else None

        raw_vectors = debug_info.get("raw_hop_vectors") or []
        hop_vectors_raw = [vec.detach().cpu() if hasattr(vec, "detach") else torch.as_tensor(vec) for vec in raw_vectors]
        method_meta = {"ifr": analysis.sanitize_ifr_meta(ifr_meta)}

    elif mode == "ifr":
        # Standard IFR (single-hop span aggregate), with pre/post trim views.
        attr = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=args.chunk_tokens,
            sink_chunk_tokens=args.sink_chunk_tokens,
        )
        sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
        thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span

        if sink_span is None:
            raise ValueError("sink_span is required for IFR mode (use dataset sink_span or pass --sink_span).")
        span_result = attr.calculate_ifr_span(
            example.prompt,
            target=example.target,
            span=tuple(sink_span),
        )
        span_meta = span_result.metadata.get("ifr") if span_result.metadata else None
        aggregate = span_meta.get("aggregate") if isinstance(span_meta, dict) else None
        if aggregate is None or not hasattr(aggregate, "token_importance_total"):
            raise RuntimeError("IFR span aggregate missing from metadata; cannot render pre-trim view.")

        raw_vector = aggregate.token_importance_total.detach().cpu()
        trimmed_vector = attr._project_vector(raw_vector)
        hop_vectors_raw = [raw_vector]
        hop_vectors_trimmed = [trimmed_vector]

        prompt_tokens_trimmed = list(attr.user_prompt_tokens)
        generation_tokens_trimmed = list(attr.generation_tokens)

        raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
        raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
        user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
        chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
        prompt_len_full = len(raw_prompt_ids)

        sink_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
        think_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1]) if thinking_span else None

        meta = {
            "type": "span_aggregate",
            "ifr_view": "aggregate",
            "sink_span_generation": sink_span,
            "sink_span_absolute": sink_abs,
            "thinking_span_generation": thinking_span,
            "thinking_span_absolute": think_abs,
        }
        method_meta = {"ifr": analysis.tensor_to_list(meta)}

    elif mode == "ifr_all_positions_output_only":
        # IFR all-positions (output-only) + token-level CAGE (row/recursive) derived from the matrix.
        attr = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=args.chunk_tokens,
            sink_chunk_tokens=args.sink_chunk_tokens,
        )
        sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
        thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span

        if sink_span is None:
            raise ValueError(
                "sink_span is required for ifr_all_positions_output_only mode "
                "(use dataset sink_span or pass --sink_span)."
            )

        attr_result = attr.calculate_ifr_for_all_positions_output_only(
            example.prompt,
            target=example.target,
            sink_span=tuple(sink_span),
        )

        indices_to_explain = list(sink_span)
        _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
        row_vec = row_attr.squeeze(0).detach().cpu()
        rec_vec = rec_attr.squeeze(0).detach().cpu()

        hop_vectors_trimmed = [row_vec, rec_vec]

        prompt_tokens_trimmed = list(attr.user_prompt_tokens)
        generation_tokens_trimmed = list(attr.generation_tokens)

        raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
        raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
        user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
        chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
        prompt_len_full = len(raw_prompt_ids)

        gen_len = len(raw_generation_ids or [])
        hop_vectors_raw = [
            _lift_trimmed_to_full(
                v,
                prompt_len_full=int(prompt_len_full or 0),
                gen_len=gen_len,
                user_prompt_indices=user_prompt_indices,
            )
            for v in hop_vectors_trimmed
        ]

        ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
        ifr_meta["ifr_view"] = "all_positions_output_only (row+rec)"
        ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
        ifr_meta["indices_to_explain"] = indices_to_explain
        method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}

    elif mode == "ifr_all_positions":
        # IFR all-positions (full generation) + token-level CAGE (row/recursive) derived from the matrix.
        attr = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=args.chunk_tokens,
            sink_chunk_tokens=args.sink_chunk_tokens,
        )
        sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
        thinking_span = tuple(args.thinking_span) if args.thinking_span is not None else tuple(example.thinking_span) if example.thinking_span else sink_span

        if sink_span is None:
            raise ValueError(
                "sink_span is required for ifr_all_positions mode (use dataset sink_span or pass --sink_span)."
            )

        attr_result = attr.calculate_ifr_for_all_positions(
            example.prompt,
            target=example.target,
        )

        indices_to_explain = list(sink_span)
        _, row_attr, rec_attr = attr_result.get_all_token_attrs(indices_to_explain)
        row_vec = row_attr.squeeze(0).detach().cpu()
        rec_vec = rec_attr.squeeze(0).detach().cpu()

        hop_vectors_trimmed = [row_vec, rec_vec]

        prompt_tokens_trimmed = list(attr.user_prompt_tokens)
        generation_tokens_trimmed = list(attr.generation_tokens)

        raw_prompt_ids = attr.prompt_ids.detach().cpu().tolist()[0]
        raw_generation_ids = attr.generation_ids.detach().cpu().tolist()[0]
        user_prompt_indices = list(getattr(attr, "user_prompt_indices", []) or [])
        chat_prompt_indices = list(getattr(attr, "chat_prompt_indices", []) or [])
        prompt_len_full = len(raw_prompt_ids)

        gen_len = len(raw_generation_ids or [])
        hop_vectors_raw = [
            _lift_trimmed_to_full(
                v,
                prompt_len_full=int(prompt_len_full or 0),
                gen_len=gen_len,
                user_prompt_indices=user_prompt_indices,
            )
            for v in hop_vectors_trimmed
        ]

        ifr_meta = dict((attr_result.metadata or {}).get("ifr") or {})
        ifr_meta["ifr_view"] = "all_positions (row+rec)"
        ifr_meta["panel_titles"] = ["Row attribution", "Recursive attribution (CAGE)"]
        ifr_meta["indices_to_explain"] = indices_to_explain
        method_meta = {"ifr": analysis.tensor_to_list(ifr_meta)}

    elif mode in ("attnlrp", "ft_attnlrp"):
        # Reuse the shared LLMLRPAttribution implementations (root-level).
        attributor = llm_attr.LLMLRPAttribution(model, tokenizer)

        sink_span = tuple(args.sink_span) if args.sink_span is not None else tuple(example.sink_span) if example.sink_span else None
        thinking_span = (
            tuple(args.thinking_span)
            if args.thinking_span is not None
            else tuple(example.thinking_span) if example.thinking_span else sink_span
        )

        if mode == "attnlrp":
            # Case-study AttnLRP: reuse FT-AttnLRP logic but take hop0 (the first span-aggregate)
            # for a full, signed attribution vector (no observation masking).
            attr_result = attributor.calculate_attnlrp_ft_hop0(
                example.prompt,
                target=example.target,
                sink_span=sink_span,
                thinking_span=thinking_span,
                neg_handling=args.attnlrp_neg_handling,
                norm_mode=args.attnlrp_norm_mode,
            )
            meta = attr_result.metadata or {}
            multi_hop = meta.get("multi_hop_result")
            raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
            attnlrp_raw_attributions = list(raw_attributions)
            base_attr = raw_attributions[0] if raw_attributions else None
            if base_attr is None or not hasattr(base_attr, "token_importance_total"):
                raise RuntimeError("AttnLRP hop0 missing from multi-hop result.")

            hop0_vec = torch.as_tensor(getattr(base_attr, "token_importance_total"), dtype=torch.float32).detach().cpu()
            if hop0_vec.numel() <= 0:
                raise RuntimeError("Empty generation for AttnLRP case study.")

            # Use the actual sink span applied by hop0 (defaults to full generation when unset).
            sink_span = tuple(getattr(base_attr, "sink_range"))
            if thinking_span is None:
                thinking_span = sink_span

            hop_vectors_trimmed = [hop0_vec]
            thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])

            method_meta = {
                "attnlrp": {
                    "type": "calculate_attnlrp_multi_hop(n_hops=0) hop0 raw_attributions[0]",
                    "sink_span_generation": sink_span,
                    "thinking_span_generation": thinking_span,
                    "thinking_ratios": thinking_ratios,
                    "neg_handling": args.attnlrp_neg_handling,
                    "norm_mode": args.attnlrp_norm_mode,
                    "ratio_enabled": args.attnlrp_norm_mode == "norm",
                }
            }
        else:
            # exp2 ft_attnlrp: multi-hop aggregated AttnLRP (metadata contains per-hop vectors).
            attr_result = attributor.calculate_attnlrp_aggregated_multi_hop(
                example.prompt,
                target=example.target,
                sink_span=sink_span,
                thinking_span=thinking_span,
                n_hops=int(args.n_hops),
                neg_handling=args.attnlrp_neg_handling,
                norm_mode=args.attnlrp_norm_mode,
            )
            meta = attr_result.metadata or {}
            multi_hop = meta.get("multi_hop_result")
            if multi_hop is None:
                raise RuntimeError("FT-AttnLRP case study missing metadata.multi_hop_result.")

            raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
            attnlrp_raw_attributions = list(raw_attributions)
            hop_vectors_trimmed = [
                torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32).detach().cpu()
                for hop in raw_attributions
            ]
            thinking_ratios = list(getattr(multi_hop, "thinking_ratios", []) or [])

            method_meta = {
                "attnlrp": {
                    "type": "calculate_attnlrp_aggregated_multi_hop (exp2 ft_attnlrp)",
                    "n_hops": int(args.n_hops),
                    "sink_span_generation": sink_span,
                    "thinking_span_generation": thinking_span,
                    "thinking_ratios": thinking_ratios,
                    "neg_handling": args.attnlrp_neg_handling,
                    "norm_mode": args.attnlrp_norm_mode,
                    "ratio_enabled": args.attnlrp_norm_mode == "norm",
                }
            }

        prompt_tokens_trimmed = list(attributor.user_prompt_tokens)
        generation_tokens_trimmed = list(attributor.generation_tokens)

        raw_prompt_ids = attributor.prompt_ids.detach().cpu().tolist()[0]
        raw_generation_ids = attributor.generation_ids.detach().cpu().tolist()[0]
        user_prompt_indices = list(getattr(attributor, "user_prompt_indices", []) or [])
        chat_prompt_indices = list(getattr(attributor, "chat_prompt_indices", []) or [])
        prompt_len_full = len(raw_prompt_ids)

    else:
        raise ValueError(f"Unsupported mode={mode}")

    if not hop_vectors_trimmed:
        raise RuntimeError("No hop vectors to visualize.")

    raw_tokens = build_raw_tokens_from_ids(tokenizer, raw_prompt_ids, raw_generation_ids)

    sink_span_abs = None
    thinking_span_abs = None
    if prompt_len_full is not None and sink_span is not None:
        sink_span_abs = (prompt_len_full + sink_span[0], prompt_len_full + sink_span[1])
    if prompt_len_full is not None and thinking_span is not None:
        thinking_span_abs = (prompt_len_full + thinking_span[0], prompt_len_full + thinking_span[1])
    prompt_len_full_safe = int(prompt_len_full or 0)
    roles_raw = build_raw_roles(
        raw_tokens,
        prompt_len_full_safe,
        user_prompt_indices,
        chat_prompt_indices,
        thinking_span_abs,
        sink_span_abs,
    )

    prompt_tokens_only = list(prompt_tokens_trimmed)
    prompt_only_vectors = extract_prompt_only_vectors(hop_vectors_trimmed, len(prompt_tokens_only))

    # Ensure every method has a pre-trim full vector per panel.
    if not hop_vectors_raw:
        if mode in ("attnlrp", "ft_attnlrp") and attnlrp_raw_attributions is not None:
            gen_len = len(raw_generation_ids or [])
            expected = int((prompt_len_full_safe + gen_len) if prompt_len_full is not None else 0)
            full_vectors: List[torch.Tensor] = []
            for hop in attnlrp_raw_attributions:
                meta = getattr(hop, "metadata", None) or {}
                raw_full = meta.get("token_importance_total_with_chat_template")
                if raw_full is None:
                    full_vectors = []
                    break
                v = _postprocess_attnlrp_full_vector(
                    torch.as_tensor(raw_full, dtype=torch.float32),
                    prompt_len_full=prompt_len_full_safe,
                    gen_len=gen_len,
                    user_prompt_indices=user_prompt_indices,
                    neg_handling=args.attnlrp_neg_handling,
                    norm_mode=args.attnlrp_norm_mode,
                )
                if expected and int(v.numel()) != expected:
                    raise RuntimeError(
                        "AttnLRP full-vector length mismatch for pre-trim view: "
                        f"got {int(v.numel())}, expected {expected}."
                    )
                full_vectors.append(v)
            hop_vectors_raw = full_vectors

    if not hop_vectors_raw and prompt_len_full is not None:
        # Fallback: lift trimmed vectors back to full token space with zeros for template tokens.
        gen_len = len(raw_generation_ids or [])
        hop_vectors_raw = [
            _lift_trimmed_to_full(
                v,
                prompt_len_full=prompt_len_full_safe,
                gen_len=gen_len,
                user_prompt_indices=user_prompt_indices,
            )
            for v in hop_vectors_trimmed
        ]

    if not hop_vectors_raw:
        raise RuntimeError("Missing pre-trim vectors; cannot render required full-sequence heatmap.")

    # Lightweight debug stats to catch silent all-zero / NaN cases.
    hop_stats_raw = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in hop_vectors_raw]
    hop_stats_prompt = [analysis.vector_stats(torch.nan_to_num(v.detach().cpu(), nan=0.0)) for v in prompt_only_vectors]
    for i in range(max(len(hop_stats_raw), len(hop_stats_prompt))):
        raw_abs = hop_stats_raw[i]["abs_max"] if i < len(hop_stats_raw) else None
        prompt_abs = hop_stats_prompt[i]["abs_max"] if i < len(hop_stats_prompt) else None
        print(f"[stats] panel {i}: raw_abs_max={raw_abs} prompt_abs_max={prompt_abs}")

    hop_token_raw = analysis.package_token_hops(hop_vectors_raw)
    hop_token_prompt = analysis.package_token_hops(prompt_only_vectors)

    case_meta: Dict[str, Any] = {
        "dataset": ds_name,
        "index": args.index,
        "sink_span": sink_span,
        "thinking_span": thinking_span,
        "n_hops": args.n_hops,
        "thinking_ratios": thinking_ratios,
        "mode": mode,
        "ifr_view": method_meta.get("ifr", {}).get("ifr_view") if isinstance(method_meta.get("ifr"), dict) else None,
        "panel_titles": method_meta.get("ifr", {}).get("panel_titles") if isinstance(method_meta.get("ifr"), dict) else None,
        "attnlrp_neg_handling": args.attnlrp_neg_handling if mode in ("attnlrp", "ft_attnlrp") else None,
        "attnlrp_norm_mode": args.attnlrp_norm_mode if mode in ("attnlrp", "ft_attnlrp") else None,
        "attnlrp_ratio_enabled": (args.attnlrp_norm_mode == "norm") if mode in ("attnlrp", "ft_attnlrp") else None,
        "vector_stats_raw": hop_stats_raw,
        "vector_stats_prompt": hop_stats_prompt,
    }

    generation_text = "".join(generation_tokens_trimmed) if generation_tokens_trimmed else ""
    prompt_text = example.prompt
    record = {
        "meta": case_meta,
        "prompt": prompt_text,
        "target": example.target,
        "generation": generation_text,
        "full_all_tokens": raw_tokens,
        "raw_token_roles": roles_raw,
        "prompt_tokens": prompt_tokens_only,
        "prompt_token_roles": ["user" for _ in range(len(prompt_tokens_only))],
        "token_hops_raw": hop_token_raw,
        "token_hops_prompt": hop_token_prompt,
        "ifr_meta": method_meta.get("ifr"),
        "attnlrp_meta": method_meta.get("attnlrp"),
    }

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    stem = make_output_stem(ds_name, args.index, mode)
    json_path = out_dir / f"{stem}.json"
    html_path = out_dir / f"{stem}.html"

    with json_path.open("w", encoding="utf-8") as f:
        json.dump(record, f, ensure_ascii=False, indent=2)

    html = viz.render_case_html(
        case_meta,
        token_view_raw={
            "label": "Pre-trim token-level heatmap (full sequence with chat template)",
            "tokens": raw_tokens,
            "roles": roles_raw,
            "hops": hop_token_raw,
        },
        token_view_prompt={
            "label": "Prompt-only token-level heatmap (user prompt only)",
            "tokens": prompt_tokens_only,
            "roles": ["user" for _ in range(len(prompt_tokens_only))],
            "hops": hop_token_prompt,
        },
    )
    html_path.write_text(html, encoding="utf-8")

    print(f"[done] wrote {json_path}")
    print(f"[done] wrote {html_path}")


if __name__ == "__main__":
    main()
