#!/usr/bin/env python3
# bench3.py
# Extended to support both LGD CSV and CausalGym (HuggingFace "aryaman/causalgym" or local JSON).
# - dataset_type: lgd | causalgym
# - CausalGym loader:
#     * Uses huggingface datasets if available (aryaman/causalgym) or local JSON files.
#     * Converts each minimal pair into two samples (base and src) with gold and alt continuation strings.
#     * Adds Ze for CausalGym using Preposition family near the verb: NONE/OF/IN/WITH_OR_BY/OTHER -> {0,1,2,3,4}.
# - Interventions: gbi (PGD/FGSM, L∞/L2), inlp, alterrep, null, hdmi
# - Metrics: ΔTask Acc, Completeness (TV), Selectivity (if Ze available), Reliability (H harmonic mean)
#
# Example (CausalGym, Pythia):
#   python bench3.py --dataset_type causalgym --model_name EleutherAI/pythia-70m \
#       --cg_tasks agr_sv_num_subj-relc agr_sv_num_obj-relc --intervention gbi \
#       --gbi_attack pgd --gbi_norm linf --epsilon 0.112 --pgd_steps 40
#
# Example (LGD):
#   python bench3.py --dataset_type lgd --data_csv lgd_equiv_sva.csv \
#       --model_name EleutherAI/pythia-70m --intervention alterrep --inlp_rank 16 --alterrep_alpha 0.1

import os
import re
import math
import json
import random
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM, logging as hf_logging

hf_logging.set_verbosity_error()


# ----------------------------
# Utilities
# ----------------------------

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_device(arg_device: str) -> torch.device:
    if arg_device == "auto":
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")
    return torch.device(arg_device)


def softmax_np(x: np.ndarray) -> np.ndarray:
    x = x - np.max(x)
    e = np.exp(x)
    s = e.sum()
    if s <= 0:
        return np.ones_like(x) / len(x)
    return e / s


def tv_distance(p: np.ndarray, q: np.ndarray) -> float:
    return 0.5 * np.abs(p - q).sum()


def harmonic_mean(a: float, b: float, eps: float = 1e-9) -> float:
    if a <= 0.0 or b <= 0.0:
        return 0.0
    return 2.0 * a * b / (a + b + eps)


# ----------------------------
# Dataset-agnostic sample
# ----------------------------

@dataclass
class GenericSample:
    prompt_text: str           # full prompt string before the continuation
    gold_cont: str             # gold continuation string (may already start with a space)
    alt_cont: str              # alternative continuation string
    zc: int                    # class index of the target causal property (e.g., sg/pl -> 0/1)
    cf_zc: int                 # counterfactual class index for this sample (paired type)
    ze: Optional[int] = None   # auxiliary property label if available; CausalGym uses Preposition family {0=NONE,1=OF,2=IN,3=WITH_OR_BY,4=OTHER}
    task: Optional[str] = None # task name (for CausalGym) for reference


# ----------------------------
# LGD CSV loader (legacy)
# ----------------------------

def _normalize_label(val) -> Optional[int]:
    if val is None:
        return None
    if isinstance(val, float) and (np.isnan(val) or pd.isna(val)):
        return None
    s = str(val).strip().lower()
    if s in {"sg", "singular", "s", "single"}:
        return 0
    if s in {"pl", "plural", "p", "multi"}:
        return 1
    return None


def _strip_repeat_header_rows(df: pd.DataFrame) -> pd.DataFrame:
    mask = df["text"].astype(str).str.strip().str.lower().eq("text")
    return df.loc[~mask].reset_index(drop=True)


def _clean_lgd_dataset(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    expected_cols = {"text", "verb_sg", "verb_pl", "Zc", "Ze"}
    missing = expected_cols - set(df.columns)
    if missing:
        raise ValueError(f"Dataset missing columns: {missing}. Found: {list(df.columns)}")
    df = _strip_repeat_header_rows(df)
    df["Zc_norm"] = df["Zc"].apply(_normalize_label)
    df["Ze_norm"] = df["Ze"].apply(_normalize_label)
    df["verb_sg"] = df["verb_sg"].astype(str)
    df["verb_pl"] = df["verb_pl"].astype(str)
    def _has_text(s): return isinstance(s, str) and len(s.strip()) > 0
    ok = df["Zc_norm"].notna() & df["verb_sg"].apply(_has_text) & df["verb_pl"].apply(_has_text) & df["text"].apply(_has_text)
    df = df.loc[ok].reset_index(drop=True)
    return df


def make_lgd_samples(df: pd.DataFrame) -> List[GenericSample]:
    samples: List[GenericSample] = []
    for _, row in df.iterrows():
        text = str(row["text"])
        if "[VERB]" in text:
            prompt = text.split("[VERB]")[0].rstrip()
        elif "VERB" in text:
            prompt = text.split("VERB")[0].rstrip()
        else:
            prompt = text.rstrip()
        zc = int(row["Zc_norm"])
        # gold vs alt continuation strings. Ensure a leading space for both to be safe
        verb_sg = " " + str(row["verb_sg"]).strip()
        verb_pl = " " + str(row["verb_pl"]).strip()
        gold = verb_sg if zc == 0 else verb_pl
        alt  = verb_pl if zc == 0 else verb_sg
        # Ze as 3-class: Ø=0 if missing, Sg=1, Pl=2
        ze_val = row["Ze_norm"]
        ze3 = 0 if pd.isna(ze_val) else 1 + int(ze_val)
        samples.append(GenericSample(
            prompt_text=prompt,
            gold_cont=gold,
            alt_cont=alt,
            zc=zc,
            cf_zc=1 - zc,
            ze=int(ze3),
            task="lgd_sva",
        ))
    return samples


def split_three_way(n: int, interv_frac: float = 0.4, valprobe_frac: float = 0.4, seed: int = 123) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    if interv_frac + valprobe_frac >= 1.0:
        raise ValueError("interv_frac + valprobe_frac must be < 1.0")
    idxs = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idxs)
    n_interv = int(n * interv_frac)
    n_valprob = int(n * valprobe_frac)
    interv_idx = idxs[:n_interv]
    valprobe_idx = idxs[n_interv:n_interv + n_valprob]
    test_idx = idxs[n_interv + n_valprob:]
    return interv_idx, valprobe_idx, test_idx


def make_independent_subset(indices: np.ndarray, samples: List[GenericSample], seed: int = 123) -> np.ndarray:
    rng = np.random.default_rng(seed)
    zc = np.array([samples[i].zc for i in indices], dtype=np.int64)
    ze = np.array([samples[i].ze if samples[i].ze is not None else -1 for i in indices], dtype=np.int64)
    # Ze classes are non-negative integers; we only consider non-negative labels here
    ze_classes = [z for z in np.unique(ze) if z >= 0]
    if len(ze_classes) == 0:
        return indices
    p_zc0 = (zc == 0).mean()
    p_zc1 = 1.0 - p_zc0
    out_idxs: List[int] = []
    for ze_val in ze_classes:
        mask = (ze == ze_val)
        grp_idx = indices[mask]
        if len(grp_idx) == 0:
            continue
        zc_grp = zc[mask]
        n_grp = len(grp_idx)
        d0 = int(round(n_grp * p_zc0))
        d1 = n_grp - d0
        idx0 = grp_idx[zc_grp == 0]
        idx1 = grp_idx[zc_grp == 1]
        take0 = min(d0, len(idx0))
        take1 = min(d1, len(idx1))
        # sample and extend
        if take0 > 0:
            sel0 = rng.choice(idx0, size=take0, replace=False)
            out_idxs.extend(sel0.tolist())
        if take1 > 0:
            sel1 = rng.choice(idx1, size=take1, replace=False)
            out_idxs.extend(sel1.tolist())
    return np.array(out_idxs, dtype=np.int64)


# ----------------------------
# CausalGym loader (HuggingFace or local JSON)
# ----------------------------

def _safe_join(spans):
    if isinstance(spans, list):
        return "".join(spans)
    return str(spans)

# Preposition family Ze
# 0=NONE, 1=OF, 2=IN, 3=WITH_OR_BY, 4=OTHER
PREP_OF = {"of"}
PREP_IN = {"in", "inside", "within"}
PREP_WITH_BY = {"with", "without", "by"}
# Common prepositions categorized as OTHER (fall back if not OF/IN/WITH/BY)
PREP_OTHER = {
    "on", "for", "to", "at", "from", "into", "onto", "upon", "over", "under",
    "between", "among", "through", "across", "against", "towards", "toward",
    "about", "around", "before", "after", "during", "since", "until", "till",
    "past", "behind", "beneath", "beside", "besides", "beyond", "outside",
    "near", "off", "per", "via", "like", "unlike", "except", "along", "amid",
    "amongst", "inside", "within", "without", "onto", "out"
}
PREP_ANY = PREP_OF | PREP_IN | PREP_WITH_BY | PREP_OTHER

def _tokenize_simple(text: str) -> List[str]:
    return re.findall(r"\w+|[^\w\s]", str(text).lower())

def ze_prep_family(prompt: str, window: int = 12) -> int:
    """
    Preposition family Ze label based on a window near the verb:
      0=NONE, 1=OF, 2=IN, 3=WITH_OR_BY, 4=OTHER.
    Strategy: scan the last 'window' tokens; return the family of the most recent preposition.
    """
    toks = _tokenize_simple(prompt)
    w = toks[-window:] if len(toks) > window else toks
    for t in reversed(w):
        if t in PREP_OF:
            return 1
        if t in PREP_IN:
            return 2
        if t in PREP_WITH_BY:
            return 3
        if t in PREP_ANY:
            return 4
    return 0  # NONE


def _load_causalgym_hf() -> Optional[dict]:
    try:
        from datasets import load_dataset  # type: ignore
    except Exception:
        return None
    try:
        ds = load_dataset("aryaman/causalgym")
        return {"train": ds.get("train"), "dev": ds.get("dev"), "test": ds.get("test")}
    except Exception:
        return None


def _load_causalgym_local_json(dir_path: str) -> dict:
    splits = {}
    for split in ["train", "dev", "test"]:
        fp = os.path.join(dir_path, f"{split}.json")
        if not os.path.isfile(fp):
            raise FileNotFoundError(f"Missing {fp}. Provide a directory with train.json/dev.json/test.json.")
        with open(fp, "r") as f:
            splits[split] = json.load(f)
    return splits


def _filter_tasks(records, wanted_tasks: Optional[List[str]]) -> list:
    if not wanted_tasks:
        return list(records)
    wl = set(wanted_tasks)
    out = []
    for r in records:
        t = r.get("task", "")
        if t in wl:
            out.append(r)
    return out

def _collect_type_map(records: list) -> Dict[str, int]:
    types = set()
    for r in records:
        types.add(str(r["base_type"]))
        types.add(str(r["src_type"]))
    types = sorted(list(types))
    return {t: i for i, t in enumerate(types)}


def make_causalgym_samples(splits: dict, tasks: Optional[List[str]]) -> Tuple[List[GenericSample], List[GenericSample], List[GenericSample]]:
    # Convert HF/JSON objects to plain lists of dict
    def as_list(obj):
        if obj is None:
            return []
        return list(obj)

    # Read splits defensively (HF sometimes has "validation"/"val" instead of "dev")
    train_raw = as_list(getattr(splits, "get", dict().get)("train") or getattr(splits, "get", dict().get)("training"))
    dev_raw   = as_list(getattr(splits, "get", dict().get)("dev")   or getattr(splits, "get", dict().get)("validation") or getattr(splits, "get", dict().get)("val"))
    test_raw  = as_list(getattr(splits, "get", dict().get)("test")  or getattr(splits, "get", dict().get)("testing"))

    # Fallback: if no dev/validation split, carve out 10% of train as validation
    if len(dev_raw) == 0 and len(train_raw) > 0:
        rng = random.Random(0)
        idx = list(range(len(train_raw)))
        rng.shuffle(idx)
        k = max(1, int(0.50 * len(idx)))
        val_idx = set(idx[:k])
        new_dev = [train_raw[i] for i in val_idx]
        new_train = [train_raw[i] for i in idx[k:]]
        dev_raw, train_raw = new_dev, new_train
        print(f"[make_causalgym_samples] No dev/validation split found; using {len(dev_raw)} examples ({len(dev_raw)/(len(dev_raw)+len(train_raw)):.1%}) from train as validation.")

    if tasks:
        train_raw = _filter_tasks(train_raw, tasks)
        dev_raw   = _filter_tasks(dev_raw, tasks)
        test_raw  = _filter_tasks(test_raw, tasks)

    # map types to indices using all splits so re-splitting won’t drop types unseen in original train
    tmap = _collect_type_map(train_raw)

    def samples_from_records(recs, split_name: str) -> List[GenericSample]:
        out: List[GenericSample] = []
        skipped_unknown_type = 0
        for r in recs:
            base_t_str = str(r["base_type"])
            src_t_str  = str(r["src_type"])
            # Skip records introducing types unseen in train to avoid out-of-range labels
            if base_t_str not in tmap or src_t_str not in tmap:
                skipped_unknown_type += 1
                continue
            base_text = _safe_join(r["base"])
            src_text  = _safe_join(r["src"])
            base_label = str(r["base_label"])
            src_label  = str(r["src_label"])
            base_type  = tmap[base_t_str]
            src_type   = tmap[src_t_str]
            task_name  = r.get("task", "unknown")

            # Compute Ze using preposition family heuristic (always recompute to ensure consistency)
            ze_base = ze_prep_family(base_text, window=12)
            ze_src  = ze_prep_family(src_text,  window=12)

            # base sample
            out.append(GenericSample(
                prompt_text=base_text,
                gold_cont=base_label,
                alt_cont=src_label,
                zc=base_type,
                cf_zc=src_type,
                ze=int(ze_base),
                task=task_name
            ))
            # src sample
            out.append(GenericSample(
                prompt_text=src_text,
                gold_cont=src_label,
                alt_cont=base_label,
                zc=src_type,
                cf_zc=base_type,
                ze=int(ze_src),
                task=task_name
            ))
        if skipped_unknown_type > 0:
            print(f"[make_causalgym_samples] {split_name}: skipped {skipped_unknown_type}/{len(recs)} records with types unseen in train.")
        return out

    train_samples = samples_from_records(train_raw, "train")
    dev_samples   = samples_from_records(dev_raw, "dev")
    test_samples  = samples_from_records(test_raw, "test")
    #breakpoint()
    return train_samples, dev_samples, test_samples


# ----------------------------
# Probes
# ----------------------------

class Probe(nn.Module):
    def __init__(self, in_dim: int, n_classes: int, hidden: Optional[int] = None, dropout: float = 0.0):
        super().__init__()
        if hidden and hidden > 0:
            self.net = nn.Sequential(
                nn.LayerNorm(in_dim),
                nn.Linear(in_dim, hidden),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden, n_classes),
            )
        else:
            self.net = nn.Sequential(nn.Linear(in_dim, n_classes))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() > 2:
            reduce_dims = tuple(range(1, x.dim() - 1))
            x = x.mean(dim=reduce_dims)
        return self.net(x.contiguous())


def _to_float_tensor(x, device):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    x = x.to(device)
    if not x.is_floating_point():
        x = x.float()
    return x


@torch.no_grad()
def eval_probe_acc(model: nn.Module, X: np.ndarray, y: np.ndarray, device: torch.device) -> float:
    if len(y) == 0:
        return float("nan")
    logits = model(_to_float_tensor(X, device))
    return (logits.argmax(dim=-1).cpu().numpy() == y).mean()


def train_probe(
    H,
    y,
    epochs: int = 100,
    lr: float = 1e-2,
    wb: float = 1e-6,
    device: torch.device = torch.device("cuda"),
    hidden: Optional[int] = None,
    batch_size: int = 256,
    verbose: bool = False,
) -> Probe:
    x = _to_float_tensor(H, device)
    y = torch.as_tensor(y, dtype=torch.long, device=device)

    if x.dim() > 2:
        x = x.mean(dim=tuple(range(1, x.dim() - 1)))

    n_features = int(x.shape[-1])
    # labels must be 0..C-1; allow gaps by using max(y)+1
    if torch.any(y < 0):
        raise ValueError("Negative class ids passed to train_probe. Ensure to mask/remap labels >= 0.")
    n_classes = int(y.max().item()) + 1

    ds = torch.utils.data.TensorDataset(x, y)
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False)

    model = Probe(n_features, n_classes, hidden=hidden).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wb)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        for xb, yb in dl:
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()
        if verbose and ((epoch + 1) % max(1, epochs // 10) == 0 or epoch == epochs - 1):
            print(f"[probe] epoch {epoch+1}/{epochs}")

    return model


def train_validation_probe_with_selection(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    hidden_grid=(0, 64, 256, 512),
    device: torch.device = torch.device("cuda"),
    epochs: int = 75,
    lr: float = 1e-2,
    wd: float = 1e-6,
    batch_size: int = 256,
) -> Probe:
    best_acc = -1.0
    best_model = None
    for hidden in hidden_grid:
        m = train_probe(X_train, y_train, epochs=epochs, lr=lr, wb=wd, device=device, hidden=hidden, batch_size=batch_size, verbose=False)
        acc = eval_probe_acc(m, X_val, y_val, device=device)
        if acc > best_acc:
            best_acc = acc
            best_model = m
    return best_model


# ----------------------------
# LM helpers
# ----------------------------

@torch.no_grad()
def encode(tokenizer: AutoTokenizer, text: str, device: torch.device) -> torch.Tensor:
    out = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    return out["input_ids"].to(device)


def get_prompt_hidden_at_layer(
    model: AutoModelForCausalLM,
    input_ids: torch.Tensor,
    layer_idx: int = -1,
) -> torch.Tensor:
    outputs = model(input_ids=input_ids, output_hidden_states=True, return_dict=True, use_cache=False)
    h_states = outputs.hidden_states if hasattr(outputs, "hidden_states") else None
    if h_states is None:
        # fallback: last hidden state only
        h = outputs.last_hidden_state[:, -1, :]
        return h
    L = len(h_states)
    idx = layer_idx if layer_idx != -1 else (L - 1)
    if idx < 0:
        idx = L + idx
    if idx < 0 or idx >= L:
        raise ValueError(f"Invalid layer_idx {layer_idx} for hidden_states with length {L}")
    chosen = h_states[idx]  # [B, T, D]
    h = chosen[:, -1, :]    # last token of prompt
    return h


@torch.no_grad()
def sequence_logprob(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt_text: str,
    candidate: str,
    device: torch.device,
) -> Dict[str, torch.Tensor]:
    """
    Compute AR log-prob for candidate continuation given prompt.
    Handles whitespace: if candidate already starts with whitespace, no extra space is inserted.
    Returns:
      {
        "total_lp": scalar tensor,
        "first_token_id": tensor scalar,
        "first_lp": scalar tensor (baseline),
        "prefix_len": int,
        "full_len": int,
      }
    """
    prefix_ids = encode(tokenizer, prompt_text, device)
    # Join prompt and candidate respecting whitespace
    if len(candidate) > 0 and not candidate[0].isspace() and len(prompt_text) > 0 and not prompt_text[-1].isspace():
        full_text = prompt_text + " " + candidate
    else:
        full_text = prompt_text + candidate
    full_ids = encode(tokenizer, full_text, device)
    Lp = prefix_ids.shape[-1]
    Lf = full_ids.shape[-1]

    # If there's no prefix or no added tokens, bail out with sentinel
    if Lf <= Lp or Lp == 0:
        return {
            "total_lp": torch.tensor(-1e9, device=device),
            "first_token_id": torch.tensor(-1, device=device),
            "first_lp": torch.tensor(-1e9, device=device),
            "prefix_len": int(Lp),
            "full_len": int(Lf),
        }

    out = model(input_ids=full_ids, use_cache=False, return_dict=True)
    logits = out.logits  # [1, Lf, V]
    logp = F.log_softmax(logits, dim=-1)

    # Predict tokens at positions Lp..Lf-1 using logits at positions Lp-1..Lf-2
    token_ids = full_ids[:, Lp:]                            # [1, cand_len]
    pred_positions = torch.arange(Lp - 1, Lf - 1, device=device)  # [cand_len]
    token_logps = logp[0, pred_positions, token_ids.squeeze(0)]   # [cand_len]
    total_lp = token_logps.sum()

    first_token_id = token_ids[0, 0]
    first_lp = token_logps[0]

    return {
        "total_lp": total_lp,
        "first_token_id": first_token_id,
        "first_lp": first_lp,
        "prefix_len": int(Lp),
        "full_len": int(Lf),
    }


def lm_head_from_model(model: AutoModelForCausalLM) -> nn.Module:
    head = model.get_output_embeddings()
    if head is None and hasattr(model, "lm_head"):
        head = model.lm_head
    if head is None:
        raise RuntimeError("Could not find LM head on the model.")
    return head


def compute_first_step_logprob_from_h(
    model: AutoModelForCausalLM,
    h_vec: torch.Tensor,          # [1, D]
    token_id: torch.Tensor,       # scalar tensor
) -> torch.Tensor:
    head = lm_head_from_model(model)
    head_dtype = next(head.parameters()).dtype
    logits = head(h_vec.to(head_dtype))  # [1, V]
    logp = F.log_softmax(logits, dim=-1)
    return logp[0, token_id]


# ----------------------------
# Interventions: attacks and methods
# ----------------------------

def fgsm_linf(h: torch.Tensor, probe: nn.Module, target_label: int, eps: float) -> torch.Tensor:
    y = torch.tensor([target_label], dtype=torch.long, device=h.device)
    h_var = h.detach().clone().requires_grad_(True)
    logits = probe(h_var)
    loss = F.cross_entropy(logits, y)
    # targeted: descend CE(target) to increase p(target)
    g = torch.autograd.grad(loss, h_var, retain_graph=False, create_graph=False)[0]
    adv = h_var - eps * g.sign()
    return adv.detach()


def pgd_linf(h: torch.Tensor, probe: nn.Module, target_label: int, eps: float, steps: int = 40, step_size: Optional[float] = None) -> torch.Tensor:
    if step_size is None:
        step_size = 2.5 * eps / max(1, steps)
    base = h.detach()
    adv = (base + torch.empty_like(base).uniform_(-eps, eps)).detach()
    for _ in range(max(1, steps)):
        adv = adv.clone().requires_grad_(True)
        y = torch.tensor([target_label], dtype=torch.long, device=h.device)
        logits = probe(adv)
        loss = F.cross_entropy(logits, y)
        g = torch.autograd.grad(loss, adv, retain_graph=False, create_graph=False)[0]
        with torch.no_grad():
            adv = adv - step_size * g.sign()  # targeted step
            delta = (adv - base).clamp(-eps, eps)
            adv = (base + delta).detach()
    return adv


def fgsm_l2(h: torch.Tensor, probe: nn.Module, target_label: int, eps: float) -> torch.Tensor:
    y = torch.tensor([target_label], dtype=torch.long, device=h.device)
    h_var = h.detach().clone().requires_grad_(True)
    logits = probe(h_var)
    loss = F.cross_entropy(logits, y)
    g = torch.autograd.grad(loss, h_var, retain_graph=False, create_graph=False)[0]
    g = g / (g.norm(dim=-1, keepdim=True) + 1e-12)
    adv = h_var - eps * g  # targeted
    return adv.detach()


def pgd_l2(h: torch.Tensor, probe: nn.Module, target_label: int, eps: float, steps: int = 40, step_size: Optional[float] = None) -> torch.Tensor:
    if step_size is None:
        step_size = 0.25 * eps / max(1, steps)
    base = h.detach()
    adv = base.clone().detach()
    for _ in range(max(1, steps)):
        adv = adv.clone().requires_grad_(True)
        y = torch.tensor([target_label], dtype=torch.long, device=h.device)
        logits = probe(adv)
        loss = F.cross_entropy(logits, y)
        g = torch.autograd.grad(loss, adv, retain_graph=False, create_graph=False)[0]
        g = g / (g.norm(dim=-1, keepdim=True) + 1e-12)
        with torch.no_grad():
            adv = adv - step_size * g  # targeted
            delta = adv - base
            n = delta.norm(dim=-1, keepdim=True).clamp(min=1e-12)
            factor = torch.minimum(torch.ones_like(n), torch.tensor(eps, device=n.device) / n)
            adv = (base + delta * factor).detach()
    return adv


def gbi_intervention(
    h: torch.Tensor,
    probe_c: nn.Module,
    target_label: int,
    epsilon: float = 0.112,
    pgd_steps: int = 40,
    attack: str = "pgd",      # "fgsm" | "pgd"
    norm: str = "linf",       # "linf" | "l2"
    step_size: Optional[float] = None,
) -> torch.Tensor:
    if norm == "linf":
        if attack == "fgsm":
            return fgsm_linf(h, probe_c, target_label, eps=epsilon)
        else:
            return pgd_linf(h, probe_c, target_label, eps=epsilon, steps=pgd_steps, step_size=step_size)
    else:
        if attack == "fgsm":
            return fgsm_l2(h, probe_c, target_label, eps=epsilon)
        else:
            return pgd_l2(h, probe_c, target_label, eps=epsilon, steps=pgd_steps, step_size=step_size)


def nullspace_erasure(h: torch.Tensor, probe_c: Probe, strength: float = 1.0) -> torch.Tensor:
    last_linear = None
    # safer: reverse explicit children list
    for m in list(probe_c.net)[::-1]:
        if isinstance(m, nn.Linear):
            last_linear = m
            break
    if last_linear is None:
        return h
    W = last_linear.weight  # [C, D]
    if W.shape[0] < 2:
        v = W[0]
    else:
        v = W[1] - W[0]
    v = v / (v.norm() + 1e-12)
    proj = (h @ v.unsqueeze(-1)).squeeze(-1)
    h_new = h - strength * proj.unsqueeze(-1) * v.unsqueeze(0)
    return h_new.detach()


def hdmi_intervention(
    model: nn.Module,
    h: torch.Tensor,
    target_token_id: int,
    source_token_id: Optional[int] = None,
    alpha: float = 10.0,
    inner_steps: int = 3,
    use_margin: bool = False,
    normalize_grad: bool = False,
    grad_clip_norm: float = 0.0,
) -> torch.Tensor:
    head = lm_head_from_model(model)
    head_dtype = next(head.parameters()).dtype
    if target_token_id is None or target_token_id < 0:
        return h
    h_work = h.detach().to(dtype=head_dtype)
    eps = 1e-9
    for _ in range(max(1, int(inner_steps))):
        with torch.enable_grad():
            h_var = h_work.clone().detach().requires_grad_(True)
            z = head(h_var)
            if use_margin and (source_token_id is not None) and (source_token_id >= 0):
                L = z[0, int(target_token_id)] - z[0, int(source_token_id)]
            else:
                L = z[0, int(target_token_id)]
            grad_h = torch.autograd.grad(L, h_var, retain_graph=False, create_graph=False)[0]
        if grad_clip_norm and grad_clip_norm > 0.0:
            n = grad_h.norm(p=2, dim=-1, keepdim=True)
            grad_h = torch.where(n > grad_clip_norm, grad_h * (grad_clip_norm / (n + eps)), grad_h)
        if normalize_grad:
            n = grad_h.norm(p=2, dim=-1, keepdim=True)
            grad_h = grad_h / (n + eps)
        with torch.no_grad():
            h_work = (h_var + alpha * grad_h).detach()
    return h_work.to(dtype=h.dtype)


# ----------------------------
# INLP and AlterRep
# ----------------------------

class INLPProjector:
    def __init__(self, rank: int = 8, lr: float = 1e-2, wd: float = 1e-6, epochs: int = 50, batch_size: int = 256, device: torch.device = torch.device("cpu")):
        self.rank = rank
        self.lr = lr
        self.wd = wd
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = device
        self.vs: List[torch.Tensor] = []
        self.dim: Optional[int] = None

    def _train_linear(self, X: torch.Tensor, y: torch.Tensor) -> nn.Linear:
        D = X.shape[-1]
        clf = nn.Linear(D, int(y.max().item()) + 1, bias=True).to(self.device)
        opt = torch.optim.AdamW(clf.parameters(), lr=self.lr, weight_decay=self.wd)
        loss_fn = nn.CrossEntropyLoss()
        ds = torch.utils.data.TensorDataset(X, y)
        dl = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=True)
        for _ in range(self.epochs):
            clf.train()
            for xb, yb in dl:
                opt.zero_grad(set_to_none=True)
                logits = clf(xb)
                loss = loss_fn(logits, yb)
                loss.backward()
                opt.step()
        return clf

    def fit(self, X: np.ndarray, y: np.ndarray):
        X_t = torch.from_numpy(X).float().to(self.device)
        y_t = torch.from_numpy(y).long().to(self.device)
        self.dim = X_t.shape[-1]
        X_work = X_t.clone()
        self.vs = []
        # Erase along successive discriminative directions for class separation.
        for _ in range(self.rank):
            clf = self._train_linear(X_work, y_t)
            W = clf.weight.detach()  # [C, D]
            if W.shape[0] == 1:
                v = W[0]
            else:
                # Use top pairwise difference magnitude direction
                diffs = []
                for i in range(W.shape[0]):
                    for j in range(i + 1, W.shape[0]):
                        diffs.append((i, j, (W[j] - W[i])))
                # pick the largest norm diff
                v = max(diffs, key=lambda x: x[2].norm().item())[2]
            v = v / (v.norm() + 1e-12)
            self.vs.append(v.detach().clone())
            # project out v from X_work
            proj = (X_work @ v.unsqueeze(-1)).squeeze(-1).unsqueeze(-1) * v.unsqueeze(0)
            X_work = X_work - proj

    @torch.no_grad()
    def project_vector(self, h: torch.Tensor, upto_rank: Optional[int] = None) -> torch.Tensor:
        if upto_rank is None:
            upto_rank = len(self.vs)
        out = h.clone()
        for i in range(min(upto_rank, len(self.vs))):
            v = self.vs[i].to(out.device).to(out.dtype)
            coeff = (out @ v.unsqueeze(-1)).squeeze(-1).unsqueeze(-1)
            out = out - coeff * v.unsqueeze(0)
        return out.detach()

    @torch.no_grad()
    def rowspace_push_alterrep(self, h: torch.Tensor, target_label: int, alpha: float = 0.1, upto_rank: Optional[int] = None) -> torch.Tensor:
        if upto_rank is None:
            upto_rank = len(self.vs)
        base = self.project_vector(h, upto_rank=upto_rank)
        add = torch.zeros_like(base)
        for i in range(min(upto_rank, len(self.vs))):
            v = self.vs[i].to(h.device).to(h.dtype)
            m = float((h @ v.unsqueeze(-1)).item())
            # push to reduce current projection's sign magnitude
            add = add - math.copysign(1.0, m if m != 0.0 else 1.0) * v.unsqueeze(0)
        out = base + alpha * add
        return out.detach()


# ----------------------------
# Feature building
# ----------------------------

@torch.no_grad()
def build_features(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    samples: List[GenericSample],
    device: torch.device,
    layer_idx: int = -1,
) -> np.ndarray:
    feats = []
    for s in samples:
        ids = encode(tokenizer, s.prompt_text, device)
        h = get_prompt_hidden_at_layer(model, ids, layer_idx=layer_idx)  # [1, D]
        feats.append(h[0].float().cpu().numpy())
    X = np.stack(feats, axis=0) if len(feats) > 0 else np.zeros((0, model.config.hidden_size), dtype=np.float32)
    return X


def choose_candidate_by_logprob(lp_gold: float, lp_alt: float) -> int:
    return 0 if lp_gold >= lp_alt else 1  # 0 means chose gold, 1 means chose alt


# ----------------------------
# Evaluation
# ----------------------------

@dataclass
class Metrics:
    baseline_task_acc: float
    after_task_acc: float
    delta_task_acc: float
    completeness: float
    selectivity: float
    reliability: float
    n: int


def evaluate(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    samples: List[GenericSample],
    probe_c_interv: Probe,           # interventional probe for Zc (used by GBI)
    vZc: Probe,                      # validation probe for Zc (disjoint)
    vZe: Optional[Probe] = None,     # optional validation probe for Ze (if available)
    intervention: str = "none",      # none | gbi | inlp | alterrep | null | hdmi
    epsilon: float = 0.112,
    pgd_steps: int = 40,
    gbi_attack: str = "pgd",
    gbi_norm: str = "linf",
    gbi_step_size: Optional[float] = None,
    erase_strength: float = 1.0,
    inlp_projector: Optional[INLPProjector] = None,
    inlp_rank_apply: Optional[int] = None,
    alterrep_alpha: float = 0.1,
    hdmi_alpha: float = 25.0,
    hdmi_inner_steps: int = 1,
    hdmi_use_margin: bool = False,
    hdmi_normalize_grad: bool = False,
    hdmi_grad_clip_norm: float = 0.0,
    device: torch.device = torch.device("cpu"),
    layer_idx: int = -1,
    verbose: bool = False,
) -> Metrics:
    model.eval()
    probe_c_interv.eval()
    vZc.eval()
    if vZe is not None:
        vZe.eval()

    n = len(samples)
    correct_base = 0
    correct_after = 0
    comp_vals: List[float] = []
    sel_vals: List[float] = []

    with torch.no_grad():
        for idx, s in enumerate(samples):
            # Baseline logprobs
            base_gold = sequence_logprob(model, tokenizer, s.prompt_text, s.gold_cont, device)
            base_alt  = sequence_logprob(model, tokenizer, s.prompt_text, s.alt_cont, device)
            lp_gold0  = float(base_gold["total_lp"].item())
            lp_alt0   = float(base_alt["total_lp"].item())
            pred0     = choose_candidate_by_logprob(lp_gold0, lp_alt0)
            if pred0 == 0:  # chose gold
                correct_base += 1

            # Hidden for prompt
            prefix_ids = encode(tokenizer, s.prompt_text, device)
            h0 = get_prompt_hidden_at_layer(model, prefix_ids, layer_idx=layer_idx)  # [1, D]

            # Selectivity before (if vZe)
            if vZe is not None:
                p_ze_before = F.softmax(vZe(h0), dim=-1).cpu().numpy()[0]
            else:
                p_ze_before = None

            # Apply intervention
            if intervention == "gbi":
                target = int(s.cf_zc)
                # enable grads just for the intervention
                with torch.enable_grad():
                    h1 = gbi_intervention(
                        h0, probe_c_interv, target_label=target,
                        epsilon=epsilon, pgd_steps=pgd_steps,
                        attack=gbi_attack, norm=gbi_norm, step_size=gbi_step_size
                    )
            elif intervention == "null":
                h1 = nullspace_erasure(h0, probe_c_interv, strength=erase_strength)
            elif intervention == "inlp":
                if inlp_projector is None:
                    h1 = h0
                else:
                    h1 = inlp_projector.project_vector(h0, upto_rank=inlp_rank_apply)
            elif intervention == "alterrep":
                if inlp_projector is None:
                    h1 = h0
                else:
                    target = int(s.cf_zc)
                    h1 = inlp_projector.rowspace_push_alterrep(h0, target_label=target, alpha=alterrep_alpha, upto_rank=inlp_rank_apply)
            elif intervention == "hdmi":
                token_gold = int(base_gold["first_token_id"].item())
                token_alt  = int(base_alt["first_token_id"].item())
                # push toward counterfactual (alt)
                target_id  = token_alt
                source_id  = token_gold
                if target_id < 0:
                    h1 = h0
                else:
                    h1 = hdmi_intervention(
                        model=model,
                        h=h0,
                        target_token_id=target_id,
                        source_token_id=(source_id if source_id >= 0 else None),
                        alpha=hdmi_alpha,
                        inner_steps=hdmi_inner_steps,
                        use_margin=hdmi_use_margin,
                        normalize_grad=hdmi_normalize_grad,
                        grad_clip_norm=hdmi_grad_clip_norm,
                    )
            else:
                h1 = h0

            # Replace first-step term for both gold and alt
            token_gold = base_gold["first_token_id"]
            token_alt  = base_alt["first_token_id"]
            if token_gold.item() < 0 or token_alt.item() < 0:
                lp_gold1 = lp_gold0
                lp_alt1  = lp_alt0
            else:
                lp1_gold_first = compute_first_step_logprob_from_h(model, h1, token_gold).item()
                lp1_alt_first  = compute_first_step_logprob_from_h(model, h1, token_alt).item()
                lp_gold1 = lp_gold0 - float(base_gold["first_lp"].item()) + lp1_gold_first
                lp_alt1  = lp_alt0  - float(base_alt["first_lp"].item())  + lp1_alt_first

            pred1 = choose_candidate_by_logprob(lp_gold1, lp_alt1)
            if pred1 == 0:
                correct_after += 1

            # Completeness via vZc
            logits_c_after = vZc(h1)
            p_c_after = F.softmax(logits_c_after, dim=-1).cpu().numpy()[0]  # distribution over classes
            if intervention in {"gbi", "alterrep", "hdmi"}:
                # counterfactual goal one-hot
                k = p_c_after.shape[0]
                goal = np.zeros(k, dtype=np.float32)
                goal[int(np.clip(int(s.cf_zc), 0, k - 1))] = 1.0
                c_val = 1.0 - tv_distance(p_c_after, goal)
            elif intervention in {"inlp", "null"}:
                k = p_c_after.shape[0]
                uniform = np.ones(k) / k
                c_val = 1.0 - (tv_distance(p_c_after, uniform) / (1.0 - 1.0 / k))
            else:
                c_val = 0.0
            comp_vals.append(c_val)

            # Selectivity via vZe (if available)
            if vZe is not None and p_ze_before is not None:
                logits_ze_after = vZe(h1)
                p_ze_after = F.softmax(logits_ze_after, dim=-1).cpu().numpy()[0]
                m = max(1.0 - float(np.min(p_ze_before)), float(np.max(p_ze_before)))
                if m <= 0:
                    s_val = 0.0
                else:
                    s_val = 1.0 - (tv_distance(p_ze_after, p_ze_before) / m)
            else:
                s_val = 1.0  # default when no Ze probe available
            sel_vals.append(s_val)

            if verbose and (idx % max(1, n // 10) == 0):
                print(f"[eval] {idx+1}/{n}")

    baseline_acc = correct_base / n if n > 0 else 0.0
    after_acc = correct_after / n if n > 0 else 0.0
    delta_acc = after_acc - baseline_acc
    completeness = float(np.mean(comp_vals)) if comp_vals else 0.0
    selectivity = float(np.mean(sel_vals)) if sel_vals else 1.0
    reliability = harmonic_mean(completeness, selectivity)

    return Metrics(
        baseline_task_acc=baseline_acc,
        after_task_acc=after_acc,
        delta_task_acc=delta_acc,
        completeness=completeness,
        selectivity=selectivity,
        reliability=reliability,
        n=n,
    )


# ----------------------------
# Main runner
# ----------------------------

def run(args):
    device = get_device(args.device)
    print(f"Using device: {device}")
    set_seed(args.seed)

    # Load model + tokenizer
    print(f"Loading model: {args.model_name}")
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else tokenizer.unk_token
    model = AutoModelForCausalLM.from_pretrained(args.model_name)
    model.to(device).eval()
    model.config.output_hidden_states = True
    print("Model loaded.")

    # Build dataset according to dataset_type
    if args.dataset_type == "lgd":
        # LGD CSV
        df = _clean_lgd_dataset(args.data_csv)
        if args.max_samples is not None and args.max_samples > 0:
            df = df.iloc[:args.max_samples].reset_index(drop=True)
        samples_all = make_lgd_samples(df)
        print(f"LGD: {len(samples_all)} samples.")
        # Save cleaned if requested
        if args.cleaned_csv_out:
            out_df = df[["text", "verb_sg", "verb_pl", "Zc", "Ze", "Zc_norm", "Ze_norm"]].copy()
            out_df.to_csv(args.cleaned_csv_out, index=False)
            print(f"Wrote cleaned dataset to {args.cleaned_csv_out}")
        # 3-way split: interventional (train), validation-probe (val), test
        N = len(samples_all)
        interv_idx, valprobe_idx, test_idx = split_three_way(
            N, interv_frac=args.interv_frac, valprobe_frac=args.valprobe_frac, seed=args.seed
        )
        # Optional decorrelation for validation probes
        indep_idxs = make_independent_subset(valprobe_idx, samples_all, seed=args.seed + 2)
        train_samples = [samples_all[i] for i in interv_idx]
        val_samples   = [samples_all[i] for i in indep_idxs]
        test_samples  = [samples_all[i] for i in test_idx]
        ze_available  = True
    else:
        # CausalGym
        if args.cg_local_json_dir:
            splits = _load_causalgym_local_json(args.cg_local_json_dir)
        else:
            ds = _load_causalgym_hf()
            if ds is None:
                raise RuntimeError("Failed to load HuggingFace dataset aryaman/causalgym and no --cg_local_json_dir provided.")
            splits = {k: v for k, v in ds.items()}
        train_samples, val_samples, test_samples = make_causalgym_samples(splits, tasks=args.cg_tasks)
        if args.max_samples is not None and args.max_samples > 0:
            train_samples = train_samples[:args.max_samples]
            val_samples   = val_samples[:max(1, args.max_samples // 4)]
            test_samples  = test_samples[:max(1, args.max_samples // 4)]
        print(f"CausalGym: train={len(train_samples)} dev={len(val_samples)} test={len(test_samples)}")
        # Ze now computed for CausalGym via Preposition family heuristic
        ze_available = any(s.ze is not None for s in (train_samples + val_samples))
        if ze_available:
            print("Ze (Preposition family: 0=NONE,1=OF,2=IN,3=WITH_OR_BY,4=OTHER) computed for CausalGym.")
        else:
            print("Warning: Ze not found/computed; selectivity will default to 1.0.")

   # breakpoint()
    # Guard against empty splits
    if len(train_samples) == 0:
        raise RuntimeError("No training samples found after preprocessing/filtering. Adjust dataset, --cg_tasks, or --max_samples.")
    if len(val_samples) == 0:
        raise RuntimeError("No dev/validation samples found after preprocessing/filtering.")
    if len(test_samples) == 0:
        raise RuntimeError("No test samples found after preprocessing/filtering.")

    # Extract features for each split
    print("Extracting features...")
    X_train = build_features(model, tokenizer, train_samples, device, layer_idx=args.layer_idx)
    y_train_zc = np.array([s.zc for s in train_samples], dtype=np.int64)
    X_val = build_features(model, tokenizer, val_samples, device, layer_idx=args.layer_idx)
    y_val_zc = np.array([s.zc for s in val_samples], dtype=np.int64)

    # Holdout inside the interventional split (80/20) to evaluate the interventional probe
    rng = np.random.default_rng(args.seed + 1)
    perm = rng.permutation(len(train_samples))
    cut = int(0.8 * len(perm))
    tr_i, te_i = perm[:cut], perm[cut:]
    X_train_c, y_train_c = X_train[tr_i], y_train_zc[tr_i]
    X_test_c,  y_test_c  = X_train[te_i], y_train_zc[te_i]

    print("Training interventional Zc probe (for GBI/INLP)...")
    probe_c_interv = train_probe(
        X_train_c, y_train_c,
        epochs=args.probe_epochs,
        lr=args.probe_lr,
        wb=args.probe_wd,
        device=device,
        hidden=args.probe_hidden,     # MLP if >0
        batch_size=args.probe_batch_size,
        verbose=args.probe_verbose,
    )
    acc_zc_interv = eval_probe_acc(probe_c_interv, X_test_c, y_test_c, device=device)
    print(f"Interventional Zc probe acc (holdout within interv split): {acc_zc_interv:.3f} (N={len(y_test_c)})")
    # INLP projector (fit on train)
    inlp_projector = None
    if "inlp" in args.intervention or "alterrep" in args.intervention or args.precompute_inlp:
        print(f"Fitting INLP projector (rank={args.inlp_rank}) on train split...")
        inlp_projector = INLPProjector(rank=args.inlp_rank, lr=args.inlp_lr, wd=args.inlp_wd,
                                       epochs=args.inlp_epochs, batch_size=args.inlp_batch_size, device=device)
        inlp_projector.fit(X_train, y_train_zc)

    # Validation probe vZc (select best hidden size): train -> dev
    print("Training validation probe vZc (select best hidden size) on validation-probe split only...")
    rng = np.random.default_rng(args.seed + 3)
    perm2 = rng.permutation(len(val_samples))
    cut2 = int(0.8 * len(perm2))
    tr_v_idx, va_v_idx = perm2[:cut2], perm2[cut2:]

    X_v_train, y_v_zc_train = X_val[tr_v_idx], y_val_zc[tr_v_idx]
    X_v_val,   y_v_zc_val   = X_val[va_v_idx], y_val_zc[va_v_idx]

    vZc = train_validation_probe_with_selection(
        X_v_train, y_v_zc_train,
        X_v_val,   y_v_zc_val,
        hidden_grid=args.valprobe_hidden_grid,
        device=device,
        epochs=args.valprobe_epochs,
        lr=args.valprobe_lr,
        wd=args.valprobe_wd,
        batch_size=args.valprobe_batch_size,
    )
    acc_vZc = eval_probe_acc(vZc, X_v_val, y_v_zc_val, device=device)
    print(f"vZc validation acc: {acc_vZc:.3f} (N={len(y_v_zc_val)})")

    # Optional vZe
    if ze_available:
        y_train_ze = np.array([s.ze for s in train_samples], dtype=np.int64)
        y_val_ze   = np.array([s.ze for s in val_samples], dtype=np.int64)

        train_mask = y_train_ze >= 0
        val_mask   = y_val_ze >= 0

        has_train = train_mask.sum() > 10 and len(np.unique(y_train_ze[train_mask])) > 1
        has_val   = val_mask.sum() > 10 and len(np.unique(y_val_ze[val_mask])) > 1

        if has_train and has_val:
            print("Training validation probe vZe (train -> dev)...")
            vZe = train_validation_probe_with_selection(
                X_train[train_mask], y_train_ze[train_mask],
                X_val[val_mask],     y_val_ze[val_mask],
                hidden_grid=args.valprobe_hidden_grid,
                device=device,
                epochs=args.valprobe_epochs,
                lr=args.valprobe_lr,
                wd=args.valprobe_wd,
                batch_size=args.valprobe_batch_size,
            )
            acc_vZe = eval_probe_acc(vZe, X_val[val_mask], y_val_ze[val_mask], device=device)
            print(f"vZe dev acc: {acc_vZe:.3f} (N={int(val_mask.sum())})")
        else:
            vZe = None
            print("Skipping vZe (insufficient or imbalanced Ze labels).")
    else:
        vZe = None
        print("No Ze available. Selectivity will default to 1.0 and Reliability==Completeness.")

    # Evaluate on test split
    print(f"Evaluating on test split with intervention={args.intervention} ...")
    metrics = evaluate(
        model=model,
        tokenizer=tokenizer,
        samples=test_samples,
        probe_c_interv=probe_c_interv,
        vZc=vZc,
        vZe=vZe,
        intervention=args.intervention,
        epsilon=args.epsilon,
        pgd_steps=args.pgd_steps,
        gbi_attack=args.gbi_attack,
        gbi_norm=args.gbi_norm,
        gbi_step_size=args.gbi_step_size,
        erase_strength=args.erase_strength,
        inlp_projector=inlp_projector,
        inlp_rank_apply=args.inlp_rank_apply if args.inlp_rank_apply is not None else args.inlp_rank,
        alterrep_alpha=args.alterrep_alpha,
        hdmi_alpha=args.hdmi_alpha,
        hdmi_inner_steps=args.hdmi_inner_steps,
        hdmi_use_margin=args.hdmi_use_margin,
        hdmi_normalize_grad=args.hdmi_normalize_grad,
        hdmi_grad_clip_norm=args.hdmi_grad_clip_norm,
        device=device,
        layer_idx=args.layer_idx,
        verbose=args.eval_verbose,
    )

    print("\nResults (Test Split)")
    print(json.dumps({
        "N": metrics.n,
        "baseline_task_acc": round(metrics.baseline_task_acc, 4),
        "after_task_acc": round(metrics.after_task_acc, 4),
        "delta_task_acc": round(metrics.delta_task_acc, 4),
        "completeness": round(metrics.completeness, 4),
        "selectivity": round(metrics.selectivity, 4),
        "reliability": round(metrics.reliability, 4),
        "intervention": args.intervention,
        "epsilon": args.epsilon,
        "pgd_steps": args.pgd_steps,
        "gbi_attack": args.gbi_attack,
        "gbi_norm": args.gbi_norm,
        "gbi_step_size": args.gbi_step_size,
        "erase_strength": args.erase_strength,
        "inlp_rank_apply": (args.inlp_rank_apply if args.inlp_rank_apply is not None else args.inlp_rank),
        "alterrep_alpha": args.alterrep_alpha,
        "hdmi_alpha": args.hdmi_alpha,
        "hdmi_inner_steps": args.hdmi_inner_steps,
        "hdmi_use_margin": args.hdmi_use_margin,
        "hdmi_normalize_grad": args.hdmi_normalize_grad,
        "hdmi_grad_clip_norm": args.hdmi_grad_clip_norm,
        "layer_idx": args.layer_idx,
        "model": args.model_name,
        "dataset_type": args.dataset_type,
        "cg_tasks": args.cg_tasks,
    }, indent=2))


def build_argparser():
    p = argparse.ArgumentParser(description="Benchmark probing interventions on causal LMs (LGD or CausalGym) with TV-based metrics.")

    # Dataset selection
    p.add_argument("--dataset_type", type=str, default="lgd", choices=["lgd", "causalgym"], help="Which dataset to use.")
    # LGD CSV
    p.add_argument("--data_csv", type=str, default="lgd_equiv_sva.csv", help="LGD CSV (for dataset_type=lgd).")
    p.add_argument("--cleaned_csv_out", type=str, default="", help="Optional: write cleaned LGD CSV.")
    p.add_argument("--interv_frac", type=float, default=0.4, help="LGD: fraction for interventional-probe split.")
    p.add_argument("--valprobe_frac", type=float, default=0.4, help="LGD: fraction for validation-probe split.")
    # CausalGym
    p.add_argument("--cg_local_json_dir", type=str, default="", help="If provided, read CausalGym splits from this dir (train.json/dev.json/test.json). Otherwise uses HF dataset.")
    p.add_argument("--cg_tasks", type=str, nargs="+", default=None, help="Optional: restrict to these CausalGym task names (exact match). E.g., agr_sv_num_subj-relc agr_sv_num_obj-relc")
    # Re-splitting control (applies to CausalGym; LGD already uses 3-way split)
    p.add_argument("--resplit_3way", action="store_true",
                help="For CausalGym, ignore dataset-provided splits and re-split all samples into interventional/validation-probe/test using --interv_frac and --valprobe_frac.")
    p.add_argument("--decorrelate_val", action="store_true",
                help="After 3-way re-split (CausalGym), decorrelate validation-probe set so Zc ⟂ Ze (if Ze available).")
    # Model
    p.add_argument("--model_name", type=str, default="EleutherAI/pythia-70m", help="HF model id for AutoModelForCausalLM.")
    p.add_argument("--device", type=str, default="auto", help="cpu | cuda | mps | auto")
    p.add_argument("--layer_idx", type=int, default=-1, help="Hidden layer index to extract: -1 for final.")
    p.add_argument("--max_samples", type=int, default=None, help="Limit rows/samples for quicker runs.")
    p.add_argument("--seed", type=int, default=1337)

    # Interventional Zc probe training
    p.add_argument("--probe_epochs", type=int, default=150)
    p.add_argument("--probe_lr", type=float, default=1e-2)
    p.add_argument("--probe_wd", type=float, default=1e-6)
    p.add_argument("--probe_hidden", type=int, default=256, help="0 for linear; >0 for MLP hidden size.")
    p.add_argument("--probe_batch_size", type=int, default=256)
    p.add_argument("--probe_verbose", action="store_true")

    # Validation probes (vZc, and vZe if available)
    p.add_argument("--valprobe_epochs", type=int, default=150)
    p.add_argument("--valprobe_lr", type=float, default=1e-2)
    p.add_argument("--valprobe_wd", type=float, default=1e-6)
    p.add_argument("--valprobe_batch_size", type=int, default=256)
    p.add_argument("--valprobe_hidden_grid", type=int, nargs="+", default=[0, 64, 256, 512], help="Hidden sizes to try; best by dev acc.")

    # Interventions
    p.add_argument("--intervention", type=str, default="none", choices=["none", "gbi", "inlp", "alterrep", "null", "hdmi"])

    # GBI settings
    p.add_argument("--epsilon", type=float, default=0.112, help="Radius for GBI.")
    p.add_argument("--pgd_steps", type=int, default=40)
    p.add_argument("--gbi_attack", type=str, default="pgd", choices=["fgsm", "pgd"])
    p.add_argument("--gbi_norm", type=str, default="linf", choices=["linf", "l2"])
    p.add_argument("--gbi_step_size", type=float, default=None, help="Optional step size override (None = heuristic).")

    # Null (legacy)
    p.add_argument("--erase_strength", type=float, default=1.0)

    # INLP settings
    p.add_argument("--precompute_inlp", action="store_true", help="Fit INLP projector even if not used.")
    p.add_argument("--inlp_rank", type=int, default=8)
    p.add_argument("--inlp_rank_apply", type=int, default=None)
    p.add_argument("--inlp_epochs", type=int, default=50)
    p.add_argument("--inlp_lr", type=float, default=1e-2)
    p.add_argument("--inlp_wd", type=float, default=1e-6)
    p.add_argument("--inlp_batch_size", type=int, default=256)

    # AlterRep settings
    p.add_argument("--alterrep_alpha", type=float, default=0.1)

    # hdmi
    p.add_argument("--hdmi_alpha", type=float, default=25.0)
    p.add_argument("--hdmi_inner_steps", type=int, default=1)
    p.add_argument("--hdmi_use_margin", default=False, action="store_true")
    p.add_argument("--hdmi_normalize_grad", default=False, action="store_true")
    p.add_argument("--hdmi_grad_clip_norm", type=float, default=0.0)

    # Eval
    p.add_argument("--eval_verbose", action="store_true")
    p.set_defaults(hdmi_use_margin=True, hdmi_normalize_grad=False)

    return p


if __name__ == "__main__":
    args = build_argparser().parse_args()
    run(args)
