#!/usr/bin/env python3
"""
Experiment 2 runner: token-level faithfulness (generation perturbation).

AT2 is omitted.
"""

from __future__ import annotations

import argparse
import hashlib
import json
import os
import sys
from itertools import islice
import math
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Early CUDA mask handling: set CUDA_VISIBLE_DEVICES before importing torch.
def _early_set_cuda_visible_devices():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--cuda", type=str, default=None)
    # parse_known_args keeps the full argv for later parsing by the main parser
    args, _ = parser.parse_known_args(sys.argv[1:])
    if args.cuda and "," in args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda


_early_set_cuda_visible_devices()

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, utils

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from pathlib import Path

# ensure repo root on path
REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

import llm_attr
import llm_attr_eval
from attribution_datasets import AttributionExample
from exp.exp2 import dataset_utils as ds_utils

utils.logging.set_verbosity_error()


def _sha1_text(text: str) -> str:
    return hashlib.sha1(text.encode("utf-8")).hexdigest()


def _infer_attnlrp_spans_from_hops(
    raw_attributions: Any,
    *,
    gen_len: int,
) -> Tuple[Tuple[int, int], Tuple[int, int]]:
    if not raw_attributions:
        return (0, max(0, gen_len - 1)), (0, max(0, gen_len - 1))
    sink_span = tuple(int(x) for x in raw_attributions[0].sink_range)
    if len(raw_attributions) >= 2:
        thinking_span = tuple(int(x) for x in raw_attributions[1].sink_range)
    else:
        thinking_span = sink_span
    return sink_span, thinking_span


def _build_hop_trace_payload(
    attr_func: str,
    attr: Any,
    *,
    indices_to_explain: List[int],
) -> Optional[Dict[str, np.ndarray]]:
    """Extract per-hop vectors (postprocessed) and minimal span metadata."""
    prompt_len = int(len(getattr(attr, "prompt_tokens", []) or []))
    gen_len = int(len(getattr(attr, "generation_tokens", []) or []))
    total_len = prompt_len + gen_len
    if total_len <= 0:
        return None

    hop_vectors: List[torch.Tensor] = []
    sink_span_gen: Optional[Tuple[int, int]] = None
    thinking_span_gen: Optional[Tuple[int, int]] = None
    attnlrp_neg_handling: str = ""
    attnlrp_norm_mode: str = ""
    attnlrp_ratio_enabled: int = -1

    # IFR multi-hop variants expose projected hop vectors via metadata["ifr"]["per_hop_projected"].
    ifr_meta = (getattr(attr, "metadata", None) or {}).get("ifr") or {}
    ifr_per_hop = ifr_meta.get("per_hop_projected") or []

    if ifr_per_hop:
        hop_vectors = [torch.as_tensor(v, dtype=torch.float32) for v in ifr_per_hop]
        sink_span_gen = ifr_meta.get("sink_span_generation")
        thinking_span_gen = ifr_meta.get("thinking_span_generation")
        if sink_span_gen is not None:
            sink_span_gen = tuple(int(x) for x in sink_span_gen)
        if thinking_span_gen is not None:
            thinking_span_gen = tuple(int(x) for x in thinking_span_gen)

    elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
        meta = getattr(attr, "metadata", None) or {}
        attnlrp_neg_handling = str(meta.get("neg_handling") or "")
        attnlrp_norm_mode = str(meta.get("norm_mode") or "")
        if meta.get("ratio_enabled") is not None:
            attnlrp_ratio_enabled = int(bool(meta.get("ratio_enabled")))
        multi_hop = meta.get("multi_hop_result")
        if multi_hop is None:
            return None
        raw_attributions = getattr(multi_hop, "raw_attributions", None) or []
        if not raw_attributions:
            return None
        hop_vectors = [
            torch.as_tensor(getattr(hop, "token_importance_total"), dtype=torch.float32)
            for hop in raw_attributions
        ]
        sink_span_gen, thinking_span_gen = _infer_attnlrp_spans_from_hops(raw_attributions, gen_len=gen_len)
        sink_override = meta.get("sink_span")
        thinking_override = meta.get("thinking_span")
        if sink_override is not None:
            sink_span_gen = tuple(int(x) for x in sink_override)
        if thinking_override is not None:
            thinking_span_gen = tuple(int(x) for x in thinking_override)

    else:
        return None

    if sink_span_gen is None:
        sink_span_gen = (0, max(0, gen_len - 1))
    if thinking_span_gen is None:
        thinking_span_gen = sink_span_gen

    stacked = torch.stack([v.reshape(-1) for v in hop_vectors], dim=0)
    if stacked.shape[1] != total_len:
        raise ValueError(
            f"Hop vector length mismatch for {attr_func}: expected T={total_len}, got {stacked.shape[1]}."
        )

    return {
        "vh": stacked.detach().cpu().numpy().astype(np.float32, copy=False),
        "prompt_len": np.asarray(prompt_len, dtype=np.int64),
        "gen_len": np.asarray(gen_len, dtype=np.int64),
        "sink_span_gen": np.asarray(sink_span_gen, dtype=np.int64),
        "thinking_span_gen": np.asarray(thinking_span_gen, dtype=np.int64),
        "indices_to_explain_gen": np.asarray(indices_to_explain, dtype=np.int64),
        "attnlrp_neg_handling": np.asarray(attnlrp_neg_handling, dtype="U16"),
        "attnlrp_norm_mode": np.asarray(attnlrp_norm_mode, dtype="U16"),
        "attnlrp_ratio_enabled": np.asarray(attnlrp_ratio_enabled, dtype=np.int64),
    }


def _write_hop_trace(
    trace_dir: Path,
    *,
    example_idx: int,
    attr_func: str,
    prompt: str,
    target: Optional[str],
    payload: Dict[str, np.ndarray],
    manifest_handle,
) -> None:
    trace_dir.mkdir(parents=True, exist_ok=True)
    npz_name = f"ex_{example_idx:06d}.npz"
    npz_path = trace_dir / npz_name
    np.savez_compressed(npz_path, **payload)

    record = {
        "example_idx": int(example_idx),
        "attr_func": attr_func,
        "file": npz_name,
        "prompt_sha1": _sha1_text(prompt),
        "target_sha1": _sha1_text(target) if target is not None else None,
        "prompt_len": int(payload["prompt_len"].item()),
        "gen_len": int(payload["gen_len"].item()),
        "n_hops_plus_one": int(payload["vh"].shape[0]),
        "total_len": int(payload["vh"].shape[1]),
        "sink_span_gen": payload["sink_span_gen"].tolist(),
        "thinking_span_gen": payload["thinking_span_gen"].tolist(),
        "indices_to_explain_gen": payload["indices_to_explain_gen"].tolist(),
        "attnlrp_neg_handling": str(payload["attnlrp_neg_handling"].item()),
        "attnlrp_norm_mode": str(payload["attnlrp_norm_mode"].item()),
        "attnlrp_ratio_enabled": int(payload["attnlrp_ratio_enabled"].item()),
    }
    manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
    manifest_handle.flush()


def _parse_modes(mode_args: Any) -> List[str]:
    """Parse --mode which may be provided as multiple args and/or comma-separated."""
    if mode_args is None:
        raw_parts: List[str] = []
    elif isinstance(mode_args, str):
        raw_parts = [mode_args]
    else:
        raw_parts = [str(x) for x in mode_args]

    modes: List[str] = []
    for chunk in raw_parts:
        for part in str(chunk).split(","):
            m = part.strip()
            if m:
                modes.append(m)

    # Default to faithfulness_gen for backward compatibility.
    if not modes:
        modes = ["faithfulness_gen"]

    allowed = {"faithfulness_gen", "recovery_ruler"}
    seen: set[str] = set()
    unique: List[str] = []
    for m in modes:
        if m not in seen:
            unique.append(m)
            seen.add(m)

    unknown = [m for m in unique if m not in allowed]
    if unknown:
        raise SystemExit(f"Unsupported --mode value(s): {unknown}. Allowed: {sorted(allowed)}.")

    return unique


def _trace_run_tag(
    testing_dict: Dict[str, Any],
    *,
    modes: List[str],
    total: int,
) -> str:
    attr_func = str(testing_dict.get("attr_func") or "attr")
    parts = [attr_func]

    if attr_func in (
        "ifr_multi_hop",
        "ifr_in_all_gen",
        "ifr_multi_hop_stop_words",
        "ifr_multi_hop_both",
        "ifr_multi_hop_split_hop",
        "ft_attnlrp",
        "attnlrp_aggregated_multi_hop",
    ):
        parts.append(f"n{int(testing_dict.get('n_hops', 0))}")

    if attr_func in ("attnlrp", "ft_attnlrp", "attnlrp_aggregated_multi_hop"):
        parts.append(f"neg{str(testing_dict.get('attnlrp_neg_handling', ''))}")
        parts.append(f"norm{str(testing_dict.get('attnlrp_norm_mode', ''))}")

    if modes:
        parts.append("m" + "+".join(modes))

    parts.append(f"{int(total)}ex")
    return "_".join(parts)


def _token_importance_vector(attr: torch.Tensor) -> np.ndarray:
    """Return token importance vector w = sum_rows(attr) in shape [P+G]."""
    w = torch.nan_to_num(attr.sum(0).to(dtype=torch.float32), nan=0.0).clamp(min=0.0)
    return w.detach().cpu().numpy().astype(np.float32, copy=False)


def _build_sample_trace_payload(
    example: ds_utils.CachedExample,
    *,
    attr_list: List[torch.Tensor],
    prompt_len: int,
    user_prompt_indices: Optional[List[int]],
    keep_prompt_token_indices: Optional[List[int]],
    gold_prompt_token_indices: Optional[List[int]],
    hop_payload: Optional[Dict[str, np.ndarray]],
    faithfulness_scores: Optional[np.ndarray],
    recovery_scores: Optional[np.ndarray],
    time_attr_s: Optional[float],
    time_faith_s: Optional[float],
    time_recovery_s: Optional[float],
) -> Dict[str, np.ndarray]:
    seq_attr, row_attr, rec_attr = attr_list
    gen_len = int(seq_attr.shape[0])

    v_seq_all = _token_importance_vector(seq_attr)
    v_row_all = _token_importance_vector(row_attr)
    v_rec_all = _token_importance_vector(rec_attr)

    payload: Dict[str, np.ndarray] = {
        "v_seq_all": v_seq_all,
        "v_row_all": v_row_all,
        "v_rec_all": v_rec_all,
        "v_seq_prompt": v_seq_all[:prompt_len],
        "v_row_prompt": v_row_all[:prompt_len],
        "v_rec_prompt": v_rec_all[:prompt_len],
        "prompt_len": np.asarray(int(prompt_len), dtype=np.int64),
        "gen_len": np.asarray(int(gen_len), dtype=np.int64),
        "indices_to_explain_gen": np.asarray(list(example.indices_to_explain or []), dtype=np.int64),
    }

    if example.sink_span is not None:
        payload["sink_span_gen"] = np.asarray(list(example.sink_span), dtype=np.int64)
    if example.thinking_span is not None:
        payload["thinking_span_gen"] = np.asarray(list(example.thinking_span), dtype=np.int64)

    if user_prompt_indices is not None:
        payload["user_prompt_indices"] = np.asarray(list(user_prompt_indices), dtype=np.int64)
    if keep_prompt_token_indices is not None:
        payload["keep_prompt_token_indices"] = np.asarray(list(keep_prompt_token_indices), dtype=np.int64)
    if gold_prompt_token_indices is not None:
        payload["gold_prompt_token_indices"] = np.asarray(list(gold_prompt_token_indices), dtype=np.int64)

    if faithfulness_scores is not None:
        payload["faithfulness_scores"] = np.asarray(faithfulness_scores, dtype=np.float64)
    if recovery_scores is not None:
        payload["recovery_scores"] = np.asarray(recovery_scores, dtype=np.float64)

    if time_attr_s is not None:
        payload["time_attr_s"] = np.asarray(float(time_attr_s), dtype=np.float64)
    if time_faith_s is not None:
        payload["time_faith_s"] = np.asarray(float(time_faith_s), dtype=np.float64)
    if time_recovery_s is not None:
        payload["time_recovery_s"] = np.asarray(float(time_recovery_s), dtype=np.float64)

    if hop_payload is not None:
        for k, v in hop_payload.items():
            if k in payload:
                continue
            payload[k] = v

    return payload


def _write_sample_trace(
    trace_dir: Path,
    *,
    example_idx: int,
    attr_func: str,
    prompt: str,
    target: Optional[str],
    payload: Dict[str, np.ndarray],
    manifest_handle,
    recovery_skipped_reason: Optional[str],
) -> None:
    trace_dir.mkdir(parents=True, exist_ok=True)
    npz_name = f"ex_{example_idx:06d}.npz"
    npz_path = trace_dir / npz_name
    np.savez_compressed(npz_path, **payload)

    prompt_len = int(np.asarray(payload.get("prompt_len", 0)).item())
    gen_len = int(np.asarray(payload.get("gen_len", 0)).item())
    record: Dict[str, Any] = {
        "example_idx": int(example_idx),
        "attr_func": attr_func,
        "file": npz_name,
        "prompt_sha1": _sha1_text(prompt),
        "target_sha1": _sha1_text(target) if target is not None else None,
        "prompt_len": prompt_len,
        "gen_len": gen_len,
        "indices_to_explain_gen": payload.get("indices_to_explain_gen").tolist()
        if payload.get("indices_to_explain_gen") is not None
        else None,
        "sink_span_gen": payload.get("sink_span_gen").tolist() if payload.get("sink_span_gen") is not None else None,
        "thinking_span_gen": payload.get("thinking_span_gen").tolist()
        if payload.get("thinking_span_gen") is not None
        else None,
        "faithfulness_scores": payload.get("faithfulness_scores").tolist()
        if payload.get("faithfulness_scores") is not None
        else None,
        "recovery_scores": payload.get("recovery_scores").tolist() if payload.get("recovery_scores") is not None else None,
        "recovery_skipped_reason": recovery_skipped_reason,
        "time_attr_s": float(np.asarray(payload.get("time_attr_s")).item()) if payload.get("time_attr_s") is not None else None,
        "time_faith_s": float(np.asarray(payload.get("time_faith_s")).item()) if payload.get("time_faith_s") is not None else None,
        "time_recovery_s": float(np.asarray(payload.get("time_recovery_s")).item())
        if payload.get("time_recovery_s") is not None
        else None,
    }

    # Derived, sample-level bookkeeping (token lengths and per-sample MAS/RISE).
    record["input_len"] = int(prompt_len)

    sink_span = record.get("sink_span_gen")
    if isinstance(sink_span, list) and len(sink_span) == 2:
        try:
            start = int(sink_span[0])
            end = int(sink_span[1])
            record["output_len"] = (end - start + 1) if end >= start else None
        except Exception:
            record["output_len"] = None
    else:
        record["output_len"] = None

    thinking_span = record.get("thinking_span_gen")
    if isinstance(thinking_span, list) and len(thinking_span) == 2:
        try:
            start = int(thinking_span[0])
            end = int(thinking_span[1])
            record["cot_len"] = (end - start + 1) if end >= start else None
        except Exception:
            record["cot_len"] = None
    else:
        record["cot_len"] = None

    record["rise_seq"] = None
    record["mas_seq"] = None
    record["rise_row"] = None
    record["mas_row"] = None
    record["rise_rec"] = None
    record["mas_rec"] = None
    faith = record.get("faithfulness_scores")
    if isinstance(faith, list) and len(faith) == 3:
        try:
            record["rise_seq"] = float(faith[0][0])
            record["mas_seq"] = float(faith[0][1])
            record["rise_row"] = float(faith[1][0])
            record["mas_row"] = float(faith[1][1])
            record["rise_rec"] = float(faith[2][0])
            record["mas_rec"] = float(faith[2][1])
        except Exception:
            pass

    if payload.get("vh") is not None:
        vh = payload["vh"]
        record["n_hops_plus_one"] = int(vh.shape[0])
        record["total_len"] = int(vh.shape[1])
        record["attnlrp_neg_handling"] = str(payload.get("attnlrp_neg_handling").item()) if payload.get("attnlrp_neg_handling") is not None else ""
        record["attnlrp_norm_mode"] = str(payload.get("attnlrp_norm_mode").item()) if payload.get("attnlrp_norm_mode") is not None else ""
        record["attnlrp_ratio_enabled"] = int(payload.get("attnlrp_ratio_enabled").item()) if payload.get("attnlrp_ratio_enabled") is not None else -1

    manifest_handle.write(json.dumps(record, ensure_ascii=False) + "\n")
    manifest_handle.flush()


def _compute_faithfulness_scores(
    testing_dict: Dict[str, Any],
    *,
    attr_list: List[torch.Tensor],
    prompt_len: int,
    prompt: str,
    generation: str,
    llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
    user_prompt_indices: Optional[List[int]],
    keep_prompt_token_indices: Optional[List[int]],
) -> np.ndarray:
    attr_func = str(testing_dict.get("attr_func") or "")
    results: List[Tuple[float, float, float]] = []
    for attr in attr_list:
        attr_prompt = attr[:, :prompt_len]
        if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
            import ft_ifr_improve

            scores = ft_ifr_improve.faithfulness_test_skip_tokens(
                llm_evaluator,
                attr_prompt,
                prompt,
                generation,
                keep_prompt_token_indices=keep_prompt_token_indices,
                user_prompt_indices=user_prompt_indices,
            )
        elif user_prompt_indices is not None:
            scores = _faithfulness_test_with_user_prompt_indices(
                llm_evaluator,
                attr_prompt,
                prompt,
                generation,
                user_prompt_indices=user_prompt_indices,
            )
        else:
            scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
        results.append(scores)
    return np.asarray(results, dtype=np.float64)


def _compute_recovery_scores(
    testing_dict: Dict[str, Any],
    *,
    attr_list: List[torch.Tensor],
    prompt_len: int,
    gold_prompt_token_indices: List[int],
    llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
    keep_prompt_token_indices: Optional[List[int]],
) -> Tuple[Optional[np.ndarray], Optional[str]]:
    attr_func = str(testing_dict.get("attr_func") or "")

    if prompt_len <= 0:
        return None, "empty_prompt_len"

    gold_prompt = [int(x) for x in (gold_prompt_token_indices or [])]
    if not gold_prompt:
        return None, "empty_gold_prompt"

    if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
        import ft_ifr_improve

        keep_set = {int(x) for x in keep_prompt_token_indices}
        gold_filtered = [idx for idx in gold_prompt if int(idx) in keep_set]
        if not gold_filtered:
            return None, "empty_gold_after_keep_filter"

        scores = [
            ft_ifr_improve.evaluate_attr_recovery_skip_tokens(
                attr[:, :prompt_len],
                keep_prompt_token_indices=keep_prompt_token_indices,
                gold_prompt_token_indices=gold_prompt,
                top_fraction=0.1,
            )
            for attr in attr_list
        ]
    else:
        scores = [
            llm_evaluator.evaluate_attr_recovery(
                attr,
                prompt_len=prompt_len,
                gold_prompt_token_indices=gold_prompt,
                top_fraction=0.1,
            )
            for attr in attr_list
        ]

    return np.asarray(scores, dtype=np.float64), None


def evaluate_dataset_multi(
    args,
    dataset_name: str,
    examples: List[ds_utils.CachedExample],
    testing_dict: Dict[str, Any],
    *,
    modes: List[str],
) -> Dict[str, Any]:
    tokenizer = testing_dict["tokenizer"]
    llm_evaluator = llm_attr_eval.LLMAttributionEvaluator(testing_dict["model"], tokenizer)

    want_faith = "faithfulness_gen" in modes
    want_recovery = "recovery_ruler" in modes

    faith_results: List[np.ndarray] = []
    faith_durations: List[float] = []

    recovery_results: List[np.ndarray] = []
    recovery_attr_durations: List[float] = []
    recovery_skipped = 0

    total = min(len(examples), args.num_examples)
    iterator = islice(examples, total)

    save_traces = bool(getattr(args, "save_hop_traces", False))
    manifest_handle = None
    trace_dir: Optional[Path] = None
    if save_traces:
        model_tag = str(testing_dict.get("model_tag", "model"))
        run_tag = _trace_run_tag(testing_dict, modes=modes, total=total)
        trace_dir = Path(args.output_root) / "traces" / dataset_name / model_tag / run_tag
        trace_dir.mkdir(parents=True, exist_ok=True)
        manifest_handle = open(trace_dir / "manifest.jsonl", "w", encoding="utf-8")

    try:
        for example_idx, ex in enumerate(iterator):
            if want_recovery:
                needle_spans = (ex.metadata or {}).get("needle_spans")
                if not isinstance(needle_spans, list) or not needle_spans:
                    raise SystemExit(
                        "recovery_ruler requires RULER samples with metadata.needle_spans; "
                        f"dataset={dataset_name} has missing/empty needle_spans."
                    )
                if ex.target is None:
                    raise SystemExit(
                        "recovery_ruler requires cached targets (CoT+answer) so row/rec attribution is well-defined. "
                        f"dataset={dataset_name} has target=None; run exp/exp2/sample_and_filter.py first."
                    )

            # Determine generation/target once.
            target = ex.target
            if target is None:
                generation, full_output = llm_evaluator.response(ex.prompt)
                target = generation
                response_len = len(tokenizer(full_output).input_ids)
            else:
                response_len = len(tokenizer(llm_evaluator.format_prompt(" " + ex.prompt) + target).input_ids)

            testing_dict["batch_size"] = max(1, math.floor((testing_dict["max_input_len"] - 100) / max(1, response_len)))

            gold_prompt: Optional[List[int]] = None
            if want_recovery:
                gold_prompt = ds_utils.ruler_gold_prompt_token_indices(ex, tokenizer)

            if want_recovery and not want_faith and not save_traces:
                # Preserve recovery-only fast path when not saving traces: skip samples with empty gold.
                if not gold_prompt:
                    recovery_skipped += 1
                    continue

            time_attr_s = None
            time_faith_s = None
            time_recovery_s = None

            t0 = time.perf_counter()
            attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(testing_dict, ex, target)
            time_attr_s = time.perf_counter() - t0

            seq_attr = attr_list[0]
            prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0])  # cols=(P+G), rows=G

            if want_recovery and gold_prompt:
                recovery_attr_durations.append(float(time_attr_s))

            faith_scores = None
            if want_faith:
                t1 = time.perf_counter()
                faith_scores = _compute_faithfulness_scores(
                    testing_dict,
                    attr_list=attr_list,
                    prompt_len=prompt_len,
                    prompt=ex.prompt,
                    generation=target,
                    llm_evaluator=llm_evaluator,
                    user_prompt_indices=user_prompt_indices,
                    keep_prompt_token_indices=keep_prompt_token_indices,
                )
                time_faith_s = time.perf_counter() - t1
                faith_results.append(faith_scores)
                faith_durations.append(float(time_attr_s))

            recovery_scores = None
            recovery_skip_reason = None
            if want_recovery:
                if not gold_prompt:
                    recovery_skip_reason = "empty_gold_prompt"
                    recovery_skipped += 1
                else:
                    t2 = time.perf_counter()
                    recovery_scores, recovery_skip_reason = _compute_recovery_scores(
                        testing_dict,
                        attr_list=attr_list,
                        prompt_len=prompt_len,
                        gold_prompt_token_indices=gold_prompt,
                        llm_evaluator=llm_evaluator,
                        keep_prompt_token_indices=keep_prompt_token_indices,
                    )
                    time_recovery_s = time.perf_counter() - t2
                    if recovery_scores is None:
                        recovery_skipped += 1
                    else:
                        recovery_results.append(recovery_scores)

            if manifest_handle is not None and trace_dir is not None:
                try:
                    payload = _build_sample_trace_payload(
                        ex,
                        attr_list=attr_list,
                        prompt_len=prompt_len,
                        user_prompt_indices=user_prompt_indices,
                        keep_prompt_token_indices=keep_prompt_token_indices,
                        gold_prompt_token_indices=gold_prompt,
                        hop_payload=hop_payload,
                        faithfulness_scores=faith_scores,
                        recovery_scores=recovery_scores,
                        time_attr_s=time_attr_s,
                        time_faith_s=time_faith_s,
                        time_recovery_s=time_recovery_s,
                    )
                    _write_sample_trace(
                        trace_dir,
                        example_idx=example_idx,
                        attr_func=str(testing_dict.get("attr_func") or ""),
                        prompt=ex.prompt,
                        target=target,
                        payload=payload,
                        manifest_handle=manifest_handle,
                        recovery_skipped_reason=recovery_skip_reason,
                    )
                except Exception as exc:
                    print(f"[warn] sample trace save failed for {testing_dict.get('attr_func')} ex={example_idx}: {exc}")
    finally:
        if manifest_handle is not None:
            try:
                manifest_handle.close()
            except Exception:
                pass

    out: Dict[str, Any] = {}
    if want_faith:
        if not faith_results:
            out["faithfulness"] = None
        else:
            scores = np.stack(faith_results, axis=0)  # [N, 3, 3]
            out["faithfulness"] = {
                "mean": scores.mean(0),
                "std": scores.std(0),
                "avg_time": float(np.mean(faith_durations)) if faith_durations else 0.0,
            }
    if want_recovery:
        if not recovery_results:
            out["recovery"] = None
        else:
            scores = np.stack(recovery_results, axis=0)  # [N, 3]
            out["recovery"] = {
                "mean": scores.mean(0),
                "std": scores.std(0),
                "avg_time": float(np.mean(recovery_attr_durations)) if recovery_attr_durations else 0.0,
                "used": int(scores.shape[0]),
                "skipped": int(recovery_skipped),
            }

    return out


def _faithfulness_test_with_user_prompt_indices(
    llm_evaluator: llm_attr_eval.LLMAttributionEvaluator,
    attribution: torch.Tensor,
    prompt: str,
    generation: str,
    *,
    user_prompt_indices: List[int],
    k: int = 20, ### control the MAS steps per sample
) -> Tuple[float, float, float]:
    """Token-level MAS/RISE faithfulness via guided deletion in k perturbation steps using provided prompt indices.

    This mirrors llm_attr_eval.LLMAttributionEvaluator.faithfulness_test, but avoids
    locating the user prompt span via token-id subsequence matching (which may fail
    for some tokenizers due to non-compositional BPE merges at template boundaries).
    """

    def auc(arr: np.ndarray) -> float:
        return (arr.sum() - arr[0] / 2 - arr[-1] / 2) / max(1, (arr.shape[0] - 1))

    pad_token_id = llm_evaluator._ensure_pad_token_id()

    user_prompt = " " + prompt
    formatted_prompt = llm_evaluator.format_prompt(user_prompt)
    formatted_ids = llm_evaluator.tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=False).input_ids

    prompt_ids = formatted_ids.to(llm_evaluator.device)
    prompt_ids_perturbed = prompt_ids.clone()
    generation_ids = llm_evaluator.tokenizer(
        generation + llm_evaluator.tokenizer.eos_token,
        return_tensors="pt",
        add_special_tokens=False,
    ).input_ids.to(llm_evaluator.device)

    attr_cpu = attribution.detach().cpu()
    w = attr_cpu.sum(0)
    sorted_attr_indices = torch.argsort(w, descending=True)
    attr_sum = float(w.sum().item())

    P = int(w.numel())
    if len(user_prompt_indices) != P:
        raise ValueError(
            "user_prompt_indices length does not match prompt-side attribution length: "
            f"indices P={len(user_prompt_indices)}, attr P={P}."
        )
    if P == 0:
        return 0.0, 0.0, 0.0

    if max(user_prompt_indices) >= int(prompt_ids_perturbed.shape[1]):
        raise ValueError("user_prompt_indices contains an out-of-bounds index for formatted prompt ids.")

    if P > 0:
        steps = int(k) if k is not None else 0
        if steps <= 0:
            steps = 1
        steps = min(steps, P)
    else:
        steps = 0

    scores = np.zeros(steps + 1, dtype=np.float64)
    density = np.zeros(steps + 1, dtype=np.float64)

    scores[0] = (
        llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
    )
    density[0] = 1.0

    if attr_sum <= 0:
        density = np.linspace(1.0, 0.0, steps + 1)

    base = P // steps
    remainder = P % steps
    start = 0
    for step in range(steps):
        size = base + (1 if step < remainder else 0)
        group = sorted_attr_indices[start : start + size]
        start += size

        for idx in group:
            j = int(idx.item())
            abs_pos = int(user_prompt_indices[j])
            prompt_ids_perturbed[0, abs_pos] = pad_token_id
        scores[step + 1] = (
            llm_evaluator.compute_logprob_response_given_prompt(prompt_ids_perturbed, generation_ids).sum().cpu().detach().item()
        )
        if attr_sum > 0:
            dec = float(w.index_select(0, group).sum().item()) / attr_sum
            density[step + 1] = density[step] - dec

    min_normalized_pred = 1.0
    normalized_model_response = scores.copy()
    for i in range(len(scores)):
        normalized_pred = (normalized_model_response[i] - scores[-1]) / (abs(scores[0] - scores[-1]))
        normalized_pred = np.clip(normalized_pred, 0.0, 1.0)
        min_normalized_pred = min(min_normalized_pred, normalized_pred)
        normalized_model_response[i] = min_normalized_pred

    alignment_penalty = np.abs(normalized_model_response - density)
    corrected_scores = normalized_model_response + alignment_penalty
    corrected_scores = corrected_scores.clip(0.0, 1.0)
    corrected_scores = (corrected_scores - np.min(corrected_scores)) / (np.max(corrected_scores) - np.min(corrected_scores))

    if np.isnan(corrected_scores).any():
        corrected_scores = np.linspace(1.0, 0.0, len(scores))

    return auc(normalized_model_response), auc(corrected_scores), auc(normalized_model_response + alignment_penalty)


def load_model(model_name: str, device: str):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto" if device == "auto" else {"": int(device.split(":")[1])} if device.startswith("cuda:") else None,
        torch_dtype=torch.float16,
        attn_implementation="eager",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.eval()
    return model, tokenizer


def resolve_device(args) -> str:
    if args.cuda is not None and "," in args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
        return "auto"
    if args.cuda is not None and args.cuda.strip():
        return f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu"
    return f"cuda:{args.cuda_num}" if torch.cuda.is_available() else "cpu"


def run_attribution(
    testing_dict, example: ds_utils.CachedExample, target: Optional[str]
) -> Tuple[List[torch.Tensor], Optional[Dict[str, np.ndarray]], Optional[List[int]]]:
    model = testing_dict["model"]
    tokenizer = testing_dict["tokenizer"]
    attr_func = testing_dict["attr_func"]

    indices_to_explain = example.indices_to_explain
    if not (isinstance(indices_to_explain, list) and len(indices_to_explain) == 2):
        raise ValueError(
            "exp2 requires token-span indices_to_explain=[start_tok,end_tok]. "
            "Please re-sample or run exp/exp2/migrate_indices_to_explain_token_span.py on your cache."
        )

    llm_attributor = None
    if "IG" in attr_func:
        llm_attributor = llm_attr.LLMGradientAttribtion(model, tokenizer)
        attr = llm_attributor.calculate_IG_per_generation(
            example.prompt,
            20,
            tokenizer.eos_token_id,
            batch_size=testing_dict["batch_size"],
            target=target,
        )
    elif "perturbation" in attr_func:
        if attr_func in ("perturbation_all_fast", "perturbation_CLP_fast", "perturbation_REAGENT_fast"):
            import perturbation_fast

            llm_attributor = perturbation_fast.LLMPerturbationFastAttribution(model, tokenizer)
            if attr_func == "perturbation_all_fast":
                attr = llm_attributor.calculate_feature_ablation_segments(
                    example.prompt,
                    baseline=tokenizer.eos_token_id,
                    measure="log_loss",
                    target=target,
                    source_k=20,
                )
            elif attr_func == "perturbation_CLP_fast":
                attr = llm_attributor.calculate_feature_ablation_segments(
                    example.prompt,
                    baseline=tokenizer.eos_token_id,
                    measure="KL",
                    target=target,
                    source_k=20,
                )
            else:
                attr = llm_attributor.calculate_feature_ablation_segments_mlm(
                    example.prompt,
                    target=target,
                    source_k=20,
                )
        else:
            llm_attributor = llm_attr.LLMPerturbationAttribution(model, tokenizer)
            if attr_func == "perturbation_all":
                attr = llm_attributor.calculate_feature_ablation_sentences(
                    example.prompt, baseline=tokenizer.eos_token_id, measure="log_loss", target=target
                )
            elif attr_func == "perturbation_CLP":
                attr = llm_attributor.calculate_feature_ablation_sentences(
                    example.prompt, baseline=tokenizer.eos_token_id, measure="KL", target=target
                )
            elif attr_func == "perturbation_REAGENT":
                attr = llm_attributor.calculate_feature_ablation_sentences_mlm(example.prompt, target=target)
            else:
                raise ValueError(f"Unsupported perturbation attr_func {attr_func}")
    elif "attention" in attr_func:
        llm_attributor = llm_attr.LLMAttentionAttribution(model, tokenizer)
        llm_attributor_ig = llm_attr.LLMGradientAttribtion(model, tokenizer)
        attr = llm_attributor.calculate_attention_attribution(example.prompt, target=target)
        attr_b = llm_attributor_ig.calculate_IG_per_generation(
            example.prompt, 20, tokenizer.eos_token_id, batch_size=testing_dict["batch_size"], target=target
        )
        attr.attribution_matrix = attr.attribution_matrix * attr_b.attribution_matrix
    elif attr_func == "ifr_all_positions":
        llm_attributor = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_for_all_positions(example.prompt, target=target)
    elif attr_func == "ifr_all_positions_output_only":
        llm_attributor = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        sink_span = tuple(example.sink_span) if example.sink_span else tuple(indices_to_explain)
        attr = llm_attributor.calculate_ifr_for_all_positions_output_only(
            example.prompt,
            target=target,
            sink_span=sink_span,
        )
    elif attr_func == "ifr_multi_hop":
        llm_attributor = llm_attr.LLMIFRAttribution(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_multi_hop(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
        )
    elif attr_func == "ifr_in_all_gen":
        import ft_ifr_improve

        llm_attributor = ft_ifr_improve.LLMIFRAttributionInAllGen(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_in_all_gen(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
        )
    elif attr_func == "ifr_multi_hop_stop_words":
        import ft_ifr_improve

        llm_attributor = ft_ifr_improve.LLMIFRAttributionImproved(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_multi_hop_stop_words(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
        )
    elif attr_func == "ifr_multi_hop_both":
        import ft_ifr_improve

        llm_attributor = ft_ifr_improve.LLMIFRAttributionBoth(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_multi_hop_both(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
        )
    elif attr_func == "ifr_multi_hop_split_hop":
        import ft_ifr_improve

        llm_attributor = ft_ifr_improve.LLMIFRAttributionSplitHop(
            model,
            tokenizer,
            chunk_tokens=testing_dict["chunk_tokens"],
            sink_chunk_tokens=testing_dict["sink_chunk_tokens"],
        )
        attr = llm_attributor.calculate_ifr_multi_hop_split_hop(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
        )
    elif attr_func == "attnlrp":
        llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
        attr = llm_attributor.calculate_attnlrp_ft_hop0(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
            norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
        )
    elif attr_func in ("ft_attnlrp", "attnlrp_aggregated_multi_hop"):
        llm_attributor = llm_attr.LLMLRPAttribution(model, tokenizer)
        attr = llm_attributor.calculate_attnlrp_aggregated_multi_hop(
            example.prompt,
            target=target,
            sink_span=tuple(example.sink_span) if example.sink_span else None,
            thinking_span=tuple(example.thinking_span) if example.thinking_span else None,
            n_hops=testing_dict["n_hops"],
            neg_handling=str(testing_dict.get("attnlrp_neg_handling", "drop")),
            norm_mode=str(testing_dict.get("attnlrp_norm_mode", "norm")),
        )
    elif attr_func == "basic":
        llm_attributor = llm_attr.LLMBasicAttribution(model, tokenizer)
        attr = llm_attributor.calculate_basic_attribution(example.prompt, target=target)
    else:
        raise ValueError(f"Unsupported attr_func {attr_func}")

    seq_attr, row_attr, rec_attr = attr.get_all_token_attrs(indices_to_explain)
    hop_payload = None
    if bool(testing_dict.get("save_hop_traces", False)):
        try:
            hop_payload = _build_hop_trace_payload(attr_func, attr, indices_to_explain=indices_to_explain)
        except Exception as exc:
            print(f"[warn] hop trace extraction failed for {attr_func}: {exc}")
            hop_payload = None

    user_prompt_indices = getattr(llm_attributor, "user_prompt_indices", None)
    if isinstance(user_prompt_indices, list):
        user_prompt_indices = [int(x) for x in user_prompt_indices]
    else:
        user_prompt_indices = None

    keep_prompt_token_indices = None
    if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both"):
        try:
            import ft_ifr_improve

            keep_prompt_token_indices = ft_ifr_improve.keep_token_indices(list(attr.prompt_tokens))
        except Exception:
            keep_prompt_token_indices = None

    return [seq_attr, row_attr, rec_attr], hop_payload, user_prompt_indices, keep_prompt_token_indices


def faithfulness_generation(
    testing_dict, example: ds_utils.CachedExample, target: str, llm_evaluator
) -> Tuple[np.ndarray, Optional[Dict[str, np.ndarray]]]:
    prompt = example.prompt
    generation = target

    attr_func = str(testing_dict.get("attr_func") or "")
    attr_list, hop_payload, user_prompt_indices, keep_prompt_token_indices = run_attribution(
        testing_dict, example, target
    )
    seq_attr = attr_list[0]
    prompt_len = int(seq_attr.shape[1] - seq_attr.shape[0])  # cols=(P+G), rows=G

    results = []
    for attr in attr_list:
        # Only use prompt-side attribution, matching evaluations/faithfulness.py
        attr_prompt = attr[:, :prompt_len]
        if attr_func in ("ifr_multi_hop_stop_words", "ifr_multi_hop_both") and keep_prompt_token_indices is not None:
            import ft_ifr_improve

            scores = ft_ifr_improve.faithfulness_test_skip_tokens(
                llm_evaluator,
                attr_prompt,
                prompt,
                generation,
                keep_prompt_token_indices=keep_prompt_token_indices,
                user_prompt_indices=user_prompt_indices,
            )
        elif user_prompt_indices is not None:
            scores = _faithfulness_test_with_user_prompt_indices(
                llm_evaluator,
                attr_prompt,
                prompt,
                generation,
                user_prompt_indices=user_prompt_indices,
            )
        else:
            scores = llm_evaluator.faithfulness_test(attr_prompt, prompt, generation)
        results.append(scores)

    return np.array(results), hop_payload


def evaluate_dataset(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
    out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["faithfulness_gen"])
    faith = out.get("faithfulness")
    if not faith:
        return None
    return faith["mean"], faith["std"], faith["avg_time"]


def evaluate_dataset_recovery_ruler(args, dataset_name: str, examples: List[ds_utils.CachedExample], testing_dict):
    out = evaluate_dataset_multi(args, dataset_name, examples, testing_dict, modes=["recovery_ruler"])
    rec = out.get("recovery")
    if not rec:
        return None
    return rec["mean"], rec["std"], rec["avg_time"], rec["used"], rec["skipped"]


def main():
    parser = argparse.ArgumentParser("Experiment 2 runner (math skipped, AT2 skipped).")
    parser.add_argument("--datasets", type=str, required=True, help="Comma-separated names or paths.")
    parser.add_argument("--attr_funcs", type=str, required=True, help="Comma-separated attr funcs (no AT2).")
    parser.add_argument("--model", type=str, default=None, help="HF repo id (required unless --model_path set).")
    parser.add_argument("--model_path", type=str, default=None, help="Local path; overrides --model for loading.")
    parser.add_argument("--cuda", type=str, default=None)
    parser.add_argument("--cuda_num", type=int, default=0)
    parser.add_argument("--num_examples", type=int, default=100)
    parser.add_argument(
        "--mode",
        type=str,
        nargs="+",
        default=["faithfulness_gen"],
        help=(
            "One or more of: faithfulness_gen, recovery_ruler. "
            "Accepts comma-separated values, e.g. '--mode faithfulness_gen,recovery_ruler' "
            "or '--mode faithfulness_gen, recovery_ruler'."
        ),
    )
    parser.add_argument("--sample", type=int, default=None, help="Optional subsample before num_examples.")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--chunk_tokens", type=int, default=128)
    parser.add_argument("--sink_chunk_tokens", type=int, default=32)
    parser.add_argument("--n_hops", type=int, default=3)
    parser.add_argument(
        "--attnlrp_neg_handling",
        type=str,
        choices=["drop", "abs"],
        default="drop",
        help="FT-AttnLRP: how to handle negative values after each hop (drop=clamp>=0, abs=absolute value).",
    )
    parser.add_argument(
        "--attnlrp_norm_mode",
        type=str,
        choices=["norm", "no_norm"],
        default="norm",
        help="FT-AttnLRP: norm enables per-hop global+thinking normalization + ratios; no_norm disables all three.",
    )
    parser.add_argument("--data_root", type=str, default="exp/exp2/data", help="Filtered dataset cache directory.")
    parser.add_argument("--output_root", type=str, default="exp/exp2/output", help="Directory to store evaluation outputs.")
    parser.add_argument(
        "--save_hop_traces",
        action="store_true",
        help=(
            "Save per-sample trace artifacts (attribution vectors + per-sample metrics) under output_root/traces/. "
            "For multi-hop methods, also saves per-hop token vectors (vh)."
        ),
    )
    args = parser.parse_args()
    modes = _parse_modes(args.mode)

    if args.model_path:
        model_name = args.model_path
    elif args.model:
        model_name = args.model
    else:
        raise SystemExit("Please set --model or --model_path.")
    model_tag = args.model if args.model else Path(args.model_path).name

    datasets = [d.strip() for d in args.datasets.split(",") if d.strip()]
    attr_funcs = [a.strip() for a in args.attr_funcs.split(",") if a.strip()]

    device = resolve_device(args)
    model, tokenizer = load_model(model_name, device)

    max_input_len = {
        "llama-1B": 5500,
        "llama-3B": 4800,
        "llama-8B": 3500,
        "qwen-1.7B": 5500,
        "qwen-4B": 3500,
        "qwen-8B": 5000,
        "qwen-32B": 1500,
        "gemma-12B": 1500,
        "gemma-27B": 2000,
    }.get(args.model, 2000)

    for ds_name in datasets:
        if "recovery_ruler" in modes and ds_name == "morehopqa":
            raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not morehopqa.")
        if "recovery_ruler" in modes and ds_name.startswith("math"):
            raise SystemExit("recovery_ruler only supports RULER datasets (with needle_spans), not math.")

        # Resolve dataset (prefer prepared cache under data_root)
        cached_path = Path(args.data_root) / f"{ds_name}.jsonl"
        if cached_path.exists():
            examples = ds_utils.load_cached(cached_path, sample=args.sample, seed=args.seed)
        else:
            # allow direct cached path or raw loader
            p = Path(ds_name)
            if p.exists():
                examples = ds_utils.load_cached(p, sample=args.sample, seed=args.seed)
            else:
                hint = "please run exp/exp2/sample_and_filter.py first (or pass an explicit cached JSONL path)."
                if ds_name.startswith("math"):
                    hint = "please run exp/exp2/map_math_mine_to_exp2_cache.py first (or pass an explicit cached JSONL path)."
                raise SystemExit(f"Missing exp2 cache for '{ds_name}'. Expected {cached_path}; {hint}")

        for attr_func in attr_funcs:
            if attr_func.lower() == "at2":
                print("Skipping AT2 as requested.")
                continue

            testing_dict: Dict[str, any] = {
                "model": model,
                "model_tag": model_tag,
                "tokenizer": tokenizer,
                "attr_func": attr_func,
                "max_input_len": max_input_len,
                "chunk_tokens": args.chunk_tokens,
                "sink_chunk_tokens": args.sink_chunk_tokens,
                "n_hops": args.n_hops,
                "attnlrp_neg_handling": args.attnlrp_neg_handling,
                "attnlrp_norm_mode": args.attnlrp_norm_mode,
                "device": device,
                "batch_size": 1,
                "save_hop_traces": bool(args.save_hop_traces),
            }
            result = evaluate_dataset_multi(args, ds_name, examples, testing_dict, modes=modes)

            if "faithfulness_gen" in modes:
                faith = result.get("faithfulness")
                if not faith:
                    print(f"No faithfulness results for {ds_name} with {attr_func}.")
                else:
                    mean = faith["mean"]
                    std = faith["std"]
                    avg_time = float(faith["avg_time"])

                    out_dir = Path(args.output_root) / "faithfulness" / ds_name / model_tag
                    out_dir.mkdir(parents=True, exist_ok=True)
                    filename = f"{attr_func}_{args.num_examples}_examples.csv"
                    with open(out_dir / filename, "w") as f:
                        f.write("Method,RISE,MAS,RISE+AP\n")
                        f.write(",".join(["Seq Attr Scores Mean"] + [str(x) for x in mean[0].tolist()]) + "\n")
                        f.write(",".join(["Row Attr Scores Mean"] + [str(x) for x in mean[1].tolist()]) + "\n")
                        f.write(",".join(["Recursive Attr Scores Mean"] + [str(x) for x in mean[2].tolist()]) + "\n")
                        f.write(",".join(["Seq Attr Scores Var"] + [str(x) for x in std[0].tolist()]) + "\n")
                        f.write(",".join(["Row Attr Scores Var"] + [str(x) for x in std[1].tolist()]) + "\n")
                        f.write(",".join(["Recursive Attr Scores Var"] + [str(x) for x in std[2].tolist()]) + "\n")
                        f.write(f"Avg Sample Time (s),{avg_time}\n")
                    print(f"[{ds_name}] {attr_func} -> {out_dir/filename} (avg sample time: {avg_time:.2f}s)")

            if "recovery_ruler" in modes:
                rec = result.get("recovery")
                if not rec:
                    print(f"No recovery results for {ds_name} with {attr_func}.")
                else:
                    mean = rec["mean"]
                    std = rec["std"]
                    avg_time = float(rec["avg_time"])
                    used = int(rec["used"])
                    skipped = int(rec["skipped"])

                    out_dir = Path(args.output_root) / "recovery" / ds_name / model_tag
                    out_dir.mkdir(parents=True, exist_ok=True)
                    filename = f"{attr_func}_{args.num_examples}_examples.csv"
                    with open(out_dir / filename, "w") as f:
                        f.write("Method,Recovery@10%\n")
                        f.write(f"Seq Attr Recovery Mean,{mean[0]}\n")
                        f.write(f"Row Attr Recovery Mean,{mean[1]}\n")
                        f.write(f"Recursive Attr Recovery Mean,{mean[2]}\n")
                        f.write(f"Seq Attr Recovery Std,{std[0]}\n")
                        f.write(f"Row Attr Recovery Std,{std[1]}\n")
                        f.write(f"Recursive Attr Recovery Std,{std[2]}\n")
                        f.write(f"Examples Used,{used}\n")
                        f.write(f"Examples Skipped,{skipped}\n")
                        f.write(f"Avg Sample Time (s),{avg_time}\n")
                    print(
                        f"[{ds_name}] {attr_func} -> {out_dir/filename} "
                        f"(used={used} skipped={skipped} avg sample time: {avg_time:.2f}s)"
                    )


if __name__ == "__main__":
    main()
