#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import re
import json
import math
import argparse
import random
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import pandas as pd

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from transformers import AutoTokenizer, AutoModelForCausalLM

# your loader
from data_utils import get_sft_dataset
tab20 = plt.cm.tab20

# -------------------------
# misc
# -------------------------
def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)

def parse_multi_layer_specs(spec: str) -> List[str]:
    """
    "0-7,8-15,16-23" -> ["0-7", "8-15", "16-23"]
    """
    return [s.strip() for s in spec.split(",") if s.strip()]


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_num_layers(model) -> int:
    # Llama-like
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return len(model.model.layers)
    # GPT-like fallback
    if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        return len(model.transformer.h)
    raise ValueError("Cannot infer num layers for this model.")


def apply_lm_head_with_norm(model, hidden: torch.Tensor) -> torch.Tensor:
    # Llama uses model.model.norm before lm_head
    if hasattr(model, "model") and hasattr(model.model, "norm"):
        hidden = model.model.norm(hidden)
    return model.lm_head(hidden)


def softmax_np(x: np.ndarray, axis: int = -1) -> np.ndarray:
    x = x - np.max(x, axis=axis, keepdims=True)
    e = np.exp(x)
    return e / (np.sum(e, axis=axis, keepdims=True) + 1e-12)


def cosine_pairwise_mean(vecs_4xd: np.ndarray) -> float:
    """
    vecs_4xd: [4,d]
    mean cosine across 6 pairs
    """
    x = vecs_4xd / (np.linalg.norm(vecs_4xd, axis=1, keepdims=True) + 1e-12)
    pairs = [(0,1),(0,2),(0,3),(1,2),(1,3),(2,3)]
    return float(np.mean([np.dot(x[i], x[j]) for i,j in pairs]))


# -------------------------
# dataset -> prompt/answer decoding
# -------------------------
LETTER4 = ["A", "B", "C", "D"]
LETTER2 = ["A", "B"]


def parse_prompt_answer_from_labels(tokenizer, input_ids: torch.Tensor, labels: torch.Tensor) -> Tuple[str, str]:
    """
    Your dataset format:
      input_ids: prompt + answer + eos + pads
      labels:    -100 on prompt, answer token ids on answer (+ eos), -100 on pads
    We reconstruct:
      prompt_text: decode(tokens where labels == -100 and attention=1 up to first supervised token)
      answer_text: decode(tokens where labels != -100 until first eos or end)
    """
    ids = input_ids.tolist()
    lbs = labels.tolist()

    # find first supervised position
    first = None
    for i, y in enumerate(lbs):
        if y != -100:
            first = i
            break
    if first is None:
        raise ValueError("All labels are -100; cannot extract answer tokens.")

    prompt_ids = ids[:first]

    # answer ids: labels != -100
    ans_ids = []
    for i in range(first, len(lbs)):
        if lbs[i] == -100:
            break
        tid = ids[i]
        ans_ids.append(tid)
        if tokenizer.eos_token_id is not None and tid == tokenizer.eos_token_id:
            break

    prompt_text = tokenizer.decode(prompt_ids, skip_special_tokens=True)
    answer_text = tokenizer.decode(ans_ids, skip_special_tokens=True).strip()

    return prompt_text, answer_text


def answer_letter_index(answer_text: str, k: int) -> int:
    """
    Robustly parse answer from supervised segment:
    - Letter: A/B/C/D (case-insensitive, may contain punctuation like "A." or spaces)
    - Digit class:
        * 4-way: "0/1/2/3" or "1/2/3/4"
        * 2-way: "0/1" or "1/2"
    Returns: index in [0, k-1]
    """
    s = str(answer_text).strip()
    if not s:
        raise ValueError("Empty answer_text")

    # 1) Try letter first
    m = re.search(r"[ABCD]", s.upper())
    if m is not None:
        ch = m.group(0)
        if k == 2:
            if ch not in ("A", "B"):
                raise ValueError(f"2-way expects A/B, got {ch} from {answer_text}")
            return 0 if ch == "A" else 1
        else:
            if ch not in ("A", "B", "C", "D"):
                raise ValueError(f"4-way expects A/B/C/D, got {ch} from {answer_text}")
            return {"A": 0, "B": 1, "C": 2, "D": 3}[ch]

    # 2) Fallback: digit parsing
    md = re.search(r"-?\d+", s)
    if md is None:
        raise ValueError(f"Cannot parse answer from: {answer_text}")

    n = int(md.group(0))

    if k == 2:
        # allow 0/1 or 1/2
        if n in (0, 1):
            return n
        if n in (1, 2):
            return n - 1
        raise ValueError(f"2-way expects 0/1 or 1/2, got {n} from {answer_text}")

    # k == 4
    # allow 0..3 or 1..4
    if 0 <= n <= 3:
        return n
    if 1 <= n <= 4:
        return n - 1
    raise ValueError(f"4-way expects 0..3 or 1..4, got {n} from {answer_text}")

def task_num_options(task_name: str) -> int:
    t = task_name.lower()
    # 2-way tasks
    if "boolq" in t or "piqa" in t or "wino" in t:
        return 2
    # 4-way tasks
    return 4


# -------------------------
# model forward helpers
# -------------------------
@torch.no_grad()
def forward_hidden_states(model, input_ids: torch.Tensor, attention_mask: torch.Tensor):
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        use_cache=False,
        return_dict=True,
    )
    return out.hidden_states  # tuple (L+1)


@torch.no_grad()
def score_answer_letter_per_layer(
    model,
    tokenizer,
    prompt_text: str,
    letters: List[str],
    max_length: int,
    device: str,
) -> np.ndarray:
    """
    Score each letter by logprob over continuation tokens for " {letter}".
    Return opt_logprobs: [L+1, K]
    """
    K = len(letters)
    cont_texts = [prompt_text + " " + l for l in letters]

    enc = tokenizer(
        cont_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
        add_special_tokens=False,   # IMPORTANT: keep alignment consistent with your SFT encoding style
    )
    input_ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)

    # prompt length (no special tokens)
    enc_prompt = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=max_length, add_special_tokens=False)
    prompt_len = int(enc_prompt["input_ids"].shape[1])

    hs_tuple = forward_hidden_states(model, input_ids, attn)  # (L+1), each [K,T,H]
    Lp1 = len(hs_tuple)

    opt_logprobs = np.zeros((Lp1, K), dtype=np.float32)

    for li, hs in enumerate(hs_tuple):
        logits = apply_lm_head_with_norm(model, hs)          # [K,T,V]
        logp = F.log_softmax(logits, dim=-1)

        for b in range(K):
            seq_len = int(attn[b].sum().item())
            s = 0.0
            # score tokens after prompt
            for t in range(prompt_len, seq_len):
                tgt = int(input_ids[b, t].item())
                idx = t - 1
                if idx < 0 or idx >= seq_len:
                    continue
                s += float(logp[b, idx, tgt].item())
            opt_logprobs[li, b] = s

    return opt_logprobs


# -------------------------
# Direction 1: option disentanglement (representation similarity among options)
# -------------------------
@torch.no_grad()
def direction1_option_disentangle(
    model, tokenizer, dataset, task_name: str, out_dir: str, limit: int, max_length: int, device: str, seed: int
):
    ensure_dir(out_dir)
    set_seed(seed)

    K = task_num_options(task_name)
    letters = LETTER2 if K == 2 else LETTER4

    num_layers = get_num_layers(model)
    Lp1 = num_layers + 1

    sum_cos = np.zeros(Lp1, dtype=np.float64)
    n = 0

    for i in range(min(limit, len(dataset))):
        batch = dataset[i]
        prompt_text, ans_text = parse_prompt_answer_from_labels(tokenizer, batch["input_ids"], batch["labels"])
        # build K continuations and forward once to get hidden states: we will reuse the same trick as scoring
        cont_texts = [prompt_text + " " + l for l in letters]
        enc = tokenizer(
            cont_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            add_special_tokens=False,
        )
        input_ids = enc["input_ids"].to(device)
        attn = enc["attention_mask"].to(device)

        hs_tuple = forward_hidden_states(model, input_ids, attn)  # (L+1), each [K,T,H]

        for li, hs in enumerate(hs_tuple):
            vecs = []
            for b in range(K):
                seq_len = int(attn[b].sum().item())
                last_idx = seq_len - 1
                vecs.append(hs[b, last_idx, :].float().cpu().numpy())
            vecs = np.stack(vecs, axis=0)  # [K,H]
            if K == 4:
                sum_cos[li] += cosine_pairwise_mean(vecs)
            else:
                # K=2: cosine between A and B only
                v = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12)
                sum_cos[li] += float(np.dot(v[0], v[1]))

        n += 1
        if n % 50 == 0:
            print(f"[direction1] {task_name}: {n}/{min(limit, len(dataset))}")

    mean_cos = sum_cos / max(n, 1)
    df = pd.DataFrame({"layer": np.arange(Lp1, dtype=int), "mean_option_cosine": mean_cos.astype(np.float32), "n": n})
    csv_path = os.path.join(out_dir, f"{task_name}_direction1_option_cosine.csv")
    df.to_csv(csv_path, index=False)

    plt.figure()
    plt.plot(df["layer"], df["mean_option_cosine"], marker="o")
    plt.title(f"Option Disentanglement (cosine) | task={task_name} | N={n}")
    plt.xlabel("Layer ID")
    plt.ylabel("Mean Option Cosine" if K == 4 else "Cosine(A,B)")
    plt.grid(True, alpha=0.3)
    fig_path = os.path.join(out_dir, f"{task_name}_direction1_option_cosine.png")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()

    print(f"[direction1] saved: {csv_path}")
    print(f"[direction1] saved: {fig_path}")


# -------------------------
# Direction 2: structure probes (simple interpretable labels from prompt text)
# -------------------------
def make_structure_labels(prompt_text: str) -> Dict[str, int]:
    s = prompt_text.lower()

    # negation in question area
    neg_words = [" not ", " except ", " never ", " none ", " least ", " cannot ", " can't "]
    neg = int(any(w in f" {s} " for w in neg_words))

    # prompt length bucket by token count (prompt only)
    # (we will fill by tokenizer outside; here placeholder)
    return {"negation": neg}


def direction2_structure_probe(
    model, tokenizer, dataset, task_name: str, out_dir: str, limit: int, max_length: int, device: str, seed: int
):
    """
    For each layer, take prompt-only last token hidden, fit simple logistic regression probe for:
      - negation present
      - prompt length bucket (0/1/2)
      - numeric option exists
    Output per-layer probe accuracy.
    """
    ensure_dir(out_dir)
    set_seed(seed)

    from sklearn.model_selection import train_test_split
    from sklearn.linear_model import LogisticRegression
    from sklearn.preprocessing import StandardScaler
    from sklearn.pipeline import Pipeline

    N = min(limit, len(dataset))
    prompts = []
    y_neg = []
    y_len = []
    y_num = []

    for i in range(N):
        batch = dataset[i]
        prompt_text, ans_text = parse_prompt_answer_from_labels(tokenizer, batch["input_ids"], batch["labels"])
        prompts.append(prompt_text)

        # negation
        y_neg.append(make_structure_labels(prompt_text)["negation"])

        # length bucket by token count
        q_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
        tlen = len(q_ids)
        if tlen <= 200:
            y_len.append(0)
        elif tlen <= 400:
            y_len.append(1)
        else:
            y_len.append(2)

        # numeric option heuristic: any line like "A. 12" / "B. 3.14"
        y_num.append(int(bool(re.search(r"\n[A-D]\.\s*\d", prompt_text))))

    y_neg = np.array(y_neg, dtype=int)
    y_len = np.array(y_len, dtype=int)
    y_num = np.array(y_num, dtype=int)

    # Encode prompts as prompt-only inputs (no answer appended)
    enc = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
        add_special_tokens=False,
    )
    input_ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)

    with torch.no_grad():
        hs_tuple = forward_hidden_states(model, input_ids, attn)  # (L+1), each [N,T,H]

    Lp1 = len(hs_tuple)
    H = hs_tuple[0].shape[-1]

    # get last token hidden for each prompt at each layer -> [N, L+1, H]
    reps = np.zeros((N, Lp1, H), dtype=np.float32)
    for b in range(N):
        seq_len = int(attn[b].sum().item())
        last = seq_len - 1
        for li in range(Lp1):
            reps[b, li, :] = hs_tuple[li][b, last, :].float().cpu().numpy()

    def probe_acc(X: np.ndarray, y: np.ndarray) -> float:
        if len(np.unique(y)) < 2:
            return 1.0
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=seed, stratify=y
        )
        clf = Pipeline([
            ("scaler", StandardScaler()),
            ("lr", LogisticRegression(max_iter=200, n_jobs=1, multi_class="auto"))
        ])
        clf.fit(X_train, y_train)
        return float(clf.score(X_test, y_test))

    rows = []
    for li in range(Lp1):
        X = reps[:, li, :]
        rows.append({"layer": li, "probe": "negation", "acc": probe_acc(X, y_neg), "n": N})
        rows.append({"layer": li, "probe": "len_bucket", "acc": probe_acc(X, y_len), "n": N})
        rows.append({"layer": li, "probe": "numeric_opt", "acc": probe_acc(X, y_num), "n": N})

    df = pd.DataFrame(rows)
    csv_path = os.path.join(out_dir, f"{task_name}_direction2_structure_probes.csv")
    df.to_csv(csv_path, index=False)

    plt.figure()
    for probe in ["negation", "len_bucket", "numeric_opt"]:
        sub = df[df["probe"] == probe].sort_values("layer")
        plt.plot(sub["layer"], sub["acc"], marker="o", label=probe)
    plt.title(f"Structure Probes vs Layer | task={task_name} | N={N}")
    plt.xlabel("Layer ID")
    plt.ylabel("Probe Accuracy")
    plt.grid(True, alpha=0.3)
    plt.legend()
    fig_path = os.path.join(out_dir, f"{task_name}_direction2_structure_probes.png")
    plt.tight_layout()
    plt.savefig(fig_path, dpi=200)
    plt.close()

    print(f"[direction2] saved: {csv_path}")
    print(f"[direction2] saved: {fig_path}")


# -------------------------
# Direction 3: perturb shallow layers and see if transition curves change
# -------------------------
class LayerPerturber:
    def __init__(self, model, layers: List[int], mode: str, noise_std: float, seed: int):
        self.model = model
        self.layers = set(layers)
        self.mode = mode
        self.noise_std = float(noise_std)
        self.rng = np.random.RandomState(seed)
        self.handles = []
        self._layer_modules = None  # will be a list-like of blocks

    def _find_layers(self):
        """
        Try common decoder-only layouts:
          - model.model.layers    (Llama/Qwen/Mistral etc in many HF implementations)
          - model.layers          (some models)
          - model.transformer.h   (GPT-like)
        """
        if hasattr(self.model, "model") and hasattr(self.model.model, "layers"):
            return self.model.model.layers
        if hasattr(self.model, "model") and hasattr(self.model.model, "h"):
            return self.model.model.h
        if hasattr(self.model, "layers"):
            return self.model.layers
        if hasattr(self.model, "transformer") and hasattr(self.model.transformer, "h"):
            return self.model.transformer.h
        raise ValueError("Cannot locate transformer blocks for this model (layers/h not found).")

    # def _apply(self, tensor_out: torch.Tensor):
    #     if self.mode == "noise":
    #         if self.noise_std <= 0:
    #             return tensor_out
    #         return tensor_out + torch.randn_like(tensor_out) * self.noise_std

    #     if self.mode == "shuffle":
    #         B = tensor_out.shape[0]
    #         if B <= 1:
    #             return tensor_out
    #         perm = torch.tensor(self.rng.permutation(B), device=tensor_out.device, dtype=torch.long)
    #         return tensor_out.index_select(0, perm)

    #     return tensor_out

    def _apply(self, tensor_out: torch.Tensor):
        if self.mode == "noise":
            if self.noise_std <= 0:
                return tensor_out
            # 计算当前 batch 激活值的标准差
            layer_std = tensor_out.std().item() 
            # 将 noise_std 理解为相对于该层信号强度的比例 (e.g., 0.1 代表 10% 的噪声)
            actual_noise = torch.randn_like(tensor_out) * (layer_std * self.noise_std)
            return tensor_out + actual_noise

    def _hook(self, idx: int):
        def fn(module, inp, out):
            if idx not in self.layers:
                return out

            # Case 1: out is a Tensor -> directly perturb
            if torch.is_tensor(out):
                return self._apply(out)

            # Case 2: out is a tuple/list and first element is hidden states
            if isinstance(out, (tuple, list)) and len(out) > 0 and torch.is_tensor(out[0]):
                out0 = self._apply(out[0])
                if isinstance(out, tuple):
                    return (out0,) + tuple(out[1:])
                else:
                    out = list(out)
                    out[0] = out0
                    return out

            # Otherwise, do nothing
            return out

        return fn

    def install(self):
        self._layer_modules = self._find_layers()
        for i, layer in enumerate(self._layer_modules):
            self.handles.append(layer.register_forward_hook(self._hook(i)))

    def remove(self):
        for h in self.handles:
            h.remove()
        self.handles = []
        self._layer_modules = None


@torch.no_grad()
def compute_transition_curves(
    model, tokenizer, dataset, task_name: str, limit: int, max_length: int, device: str
) -> pd.DataFrame:
    """
    For each sample, score letter options at each layer -> P(gold) and margin.
    """
    N = min(limit, len(dataset))
    K = task_num_options(task_name)
    letters = LETTER2 if K == 2 else LETTER4

    num_layers = get_num_layers(model)
    Lp1 = num_layers + 1

    p_sum = np.zeros(Lp1, dtype=np.float64)
    m_sum = np.zeros(Lp1, dtype=np.float64)

    for i in range(N):
        batch = dataset[i]
        prompt_text, ans_text = parse_prompt_answer_from_labels(tokenizer, batch["input_ids"], batch["labels"])
        gold = answer_letter_index(ans_text, K)

        opt_logprobs = score_answer_letter_per_layer(
            model, tokenizer, prompt_text, letters, max_length, device
        )  # [Lp1,K]

        probs = softmax_np(opt_logprobs, axis=1)   # [Lp1,K]
        p_gold = probs[:, gold]

        other = np.max(np.delete(opt_logprobs, gold, axis=1), axis=1) if K > 1 else opt_logprobs[:, gold]
        margin = opt_logprobs[:, gold] - other

        p_sum += p_gold
        m_sum += margin

        if (i + 1) % 50 == 0:
            print(f"[direction3] {task_name}: {i+1}/{N}")

    df = pd.DataFrame({
        "layer": np.arange(Lp1, dtype=int),
        "p_gold_mean": (p_sum / N).astype(np.float32),
        "margin_mean": (m_sum / N).astype(np.float32),
        "n": N
    })
    return df


def parse_layer_spec(spec: str) -> List[int]:
    spec = spec.strip()
    if "-" in speci:
        pass


def parse_layers(spec: str) -> List[int]:
    spec = spec.strip()
    if "-" in spec:
        a, b = spec.split("-")
        return list(range(int(a), int(b) + 1))
    return [int(x) for x in spec.split(",") if x.strip()]


def direction3_shallow_perturb(
    model, tokenizer, dataset, task_name: str, out_dir: str, limit: int,
    max_length: int, device: str, seed: int,
    shallow_layers: str, perturb_mode: str, noise_std: float
):
    """
    Support multiple shallow layer ranges and plot:
      - baseline
      - perturb@range1
      - perturb@range2
      - ...
    all in ONE figure.
    """
    ensure_dir(out_dir)
    set_seed(seed)

    # -------- parse multiple layer ranges --------
    layer_specs = parse_multi_layer_specs(shallow_layers)

    # -------- baseline (only once) --------
    df_base = compute_transition_curves(
        model, tokenizer, dataset, task_name, limit, max_length, device
    )
    df_base["condition"] = "baseline"

    all_dfs = [df_base]

    # -------- run perturbations --------
    for spec in layer_specs:
        layers = parse_layers(spec)

        pert = LayerPerturber(model, layers, perturb_mode, noise_std, seed)
        pert.install()
        print("[perturb] num blocks =", len(pert._find_layers()), "target layers =", sorted(list(pert.layers))[:40], "...")
        try:
            df_p = compute_transition_curves(
                model, tokenizer, dataset, task_name, limit, max_length, device
            )
        finally:
            pert.remove()

        if perturb_mode == "noise":
            cond = f"noise(std={noise_std})@{spec}"
        else:
            cond = f"shuffle@{spec}"

        df_p["condition"] = cond
        all_dfs.append(df_p)

    # -------- merge --------
    df = pd.concat(all_dfs, axis=0, ignore_index=True)
    csv_path = os.path.join(out_dir, f"{task_name}_direction3_transition_curves.csv")
    df.to_csv(csv_path, index=False)

    # ============================================================
    # Plot 1: P(gold)
    # ============================================================
    plt.figure(figsize=(10, 6))

    for idx, cond in enumerate(df["condition"].unique()):
        sub = df[df["condition"] == cond].sort_values("layer")

        if cond == "baseline":
            color = tab20(0)
            lw = 3.2
            ls = "-"
        else:
            color = tab20(2 * idx + 2)
            lw = 2.2
            ls = "--"

        plt.plot(
            sub["layer"],
            sub["p_gold_mean"],
            color=color,
            linewidth=lw,
            linestyle=ls,
            marker="o",
            alpha=0.95,
            label=cond,
        )

    plt.title("Qwen3-4B | Hellaswag", fontsize=22)
    plt.xlabel("Layer ID", fontsize=22)
    plt.ylabel("Probability of Gold", fontsize=22)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.grid(True, alpha=0.3)
    plt.legend(loc="upper left", fontsize=14)

    fig1 = os.path.join(out_dir, f"{task_name}_direction3_p_gold_multi.png")
    plt.tight_layout()
    plt.savefig(fig1, dpi=200)
    plt.close()

    # ============================================================
    # Plot 2: Decision Margin (with y=0 reference)
    # ============================================================
    plt.figure(figsize=(10, 6))

    for idx, cond in enumerate(df["condition"].unique()):
        sub = df[df["condition"] == cond].sort_values("layer")

        if cond == "baseline":
            color = tab20(0)
            lw = 3.2
            ls = "-"
        else:
            color = tab20(2 * idx + 2)
            lw = 2.2
            ls = "--"

        plt.plot(
            sub["layer"],
            sub["margin_mean"],
            color=color,
            linewidth=lw,
            linestyle=ls,
            marker="o",
            alpha=0.95,
            label=cond,
        )


    # y = 0 reference line (decision boundary)
    plt.axhline(
        y=0.0,
        color="gray",
        linestyle="--",
        linewidth=1.2,
        alpha=0.8,
        zorder=0
    )

    plt.title("Qwen3-4B | Hellaswag", fontsize=22)
    plt.xlabel("Layer ID", fontsize=22)
    plt.ylabel("Decision Margin", fontsize=22)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.grid(True, alpha=0.3)
    plt.legend(loc="upper left", fontsize=14)

    fig2 = os.path.join(out_dir, f"{task_name}_direction3_margin_multi.png")
    plt.tight_layout()
    plt.savefig(fig2, dpi=200)
    plt.close()

    print(f"[direction3] saved: {csv_path}")
    print(f"[direction3] saved: {fig1}")
    print(f"[direction3] saved: {fig2}")



# -------------------------
# main
# -------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--task", type=str, required=True,
                    help="task name passed into get_sft_dataset(), e.g., arc_challenge / arc_easy / Hellaswag")
    ap.add_argument("--split", type=str, default="test", choices=["train", "test", "validation"])
    ap.add_argument("--mode", type=str, required=True, choices=["direction1", "direction2", "direction3", "all"])
    ap.add_argument("--model_name_or_path", type=str, required=True)
    ap.add_argument("--out_dir", type=str, required=True)
    ap.add_argument("--limit", type=int, default=500)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--max_length", type=int, default=1024)
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--dtype", type=str, default="bf16", choices=["fp16", "bf16", "fp32"])

    # direction3
    ap.add_argument("--shallow_layers", type=str, default="0-7")
    ap.add_argument("--perturb_mode", type=str, default="noise", choices=["noise", "shuffle"])
    ap.add_argument("--noise_std", type=float, default=0.02)

    args = ap.parse_args()
    ensure_dir(args.out_dir)
    set_seed(args.seed)

    # tokenizer + model
    tok = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token if tok.eos_token is not None else tok.pad_token

    torch_dtype = torch.bfloat16 if args.dtype == "bf16" else (torch.float16 if args.dtype == "fp16" else torch.float32)
    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
    model.to(args.device)
    model.eval()

    # dataset (strict 500 samples)
    ds = get_sft_dataset(
        name=args.task,
        tokenizer=tok,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.limit,
        split=("test" if args.split == "test" else ("train" if args.split == "train" else "validation"))
    )

    # IMPORTANT: your get_sft_dataset uses parquet single-file and sets split="train" inside some datasets;
    # but num_samples is honored. We'll just rely on num_samples=limit to enforce 500.

    if args.mode in ["direction1", "all"]:
        direction1_option_disentangle(model, tok, ds, args.task, os.path.join(args.out_dir, "direction1"),
                                      args.limit, args.max_length, args.device, args.seed)
    if args.mode in ["direction2", "all"]:
        direction2_structure_probe(model, tok, ds, args.task, os.path.join(args.out_dir, "direction2"),
                                   args.limit, args.max_length, args.device, args.seed)
    if args.mode in ["direction3", "all"]:
        direction3_shallow_perturb(model, tok, ds, args.task, os.path.join(args.out_dir, "direction3"),
                                   args.limit, args.max_length, args.device, args.seed,
                                   args.shallow_layers, args.perturb_mode, args.noise_std)


if __name__ == "__main__":
    main()
