# experiments/valence/decisions_and_robustness.py
from __future__ import annotations
import argparse, json, os
import numpy as np
import pandas as pd
import jax.numpy as jnp

from lmkit.impl.hooks import HookType
from experiments.valence.core import (
    prepare_batches,
    extract_valence_events,
    build_direction_for_layer,
    AddEditor,
    ProjectOutEditor,
    ComposeEditor,
    bond_token_ids,
)


# ------------------------ Shared helpers ------------------------
def _collect_explicit(tokenizer, all_inputs):
    evs, _ = extract_valence_events(tokenizer, all_inputs)
    return [e for e in evs if e.event_type == "explicit" and e.pred_idx >= 0]


def _margin3(vec: np.ndarray, gold_id: int, ids):
    three = np.array([vec[ids["-"]], vec[ids["="]], vec[ids["#"]]], np.float32)
    if gold_id in (ids["-"], ids["="], ids["#"]):
        gi = 0 if gold_id == ids["-"] else (1 if gold_id == ids["="] else 2)
        others = np.delete(three, gi)
        return float(three[gi] - np.max(others))
    order = three.argsort()[::-1]
    return float(three[order[0]] - three[order[1]])


def _pred3(vec: np.ndarray, ids) -> int:
    three = np.array([vec[ids["-"]], vec[ids["="]], vec[ids["#"]]], np.float32)
    return int(np.argmax(three))  # 0:'-', 1:'=', 2:'#'


def _unsat_dir(before: int, after: int) -> int:
    return int(np.sign(after - before))


# ------------------------ (A) Decision metrics & context slices ------------------------
def run_decision_metrics(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    alphas,
    ctrl_seed=13,
):
    tokenizer, cfg, params, batches, all_inputs, all_targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit = _collect_explicit(tokenizer, all_inputs)
    if not explicit:
        raise RuntimeError("No explicit bond events found.")
    resid_pre_batches = [b.resid_pre for b in batches]
    w_hat, info = build_direction_for_layer(resid_pre_batches, explicit, threshold=2)
    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    ev_pos = {(e.batch, e.pred_idx) for e in explicit}

    results = []
    for alpha in alphas:
        ste_batches = []
        for b in batches:
            pos_mask = np.zeros_like(b.positions, bool)
            row_lo, row_hi = b.row_start, b.row_start + b.inputs.shape[0]
            for e in explicit:
                if (
                    row_lo <= e.batch < row_hi
                    and b.positions[e.batch - row_lo, e.pred_idx] >= 0
                ):
                    pos_mask[e.batch - row_lo, e.pred_idx] = True
            ed = AddEditor(
                layer=layer_id,
                kind=HookType.RESID_PRE,
                w=w_hat,
                alpha=float(alpha),
                pos_mask=pos_mask,
            )
            from experiments.valence.core import run_with_hooks

            l, _ = run_with_hooks(
                jnp.asarray(b.inputs),
                jnp.asarray(b.positions),
                params,
                cfg,
                hook_pairs=(),
                editor=ed,
            )
            ste_batches.append(np.asarray(l))
        ste_logits = np.concatenate(ste_batches, axis=0)

        d_margins, switches, unsat_shifts = [], [], []
        for e in explicit:
            r, t = e.batch, e.pred_idx
            if t < 0 or all_pos[r, t] < 0:
                continue
            gold = int(all_targets[r, t])
            m0 = _margin3(base_logits[r, t, :], gold, ids)
            m1 = _margin3(ste_logits[r, t, :], gold, ids)
            d_margins.append(m1 - m0)
            p0 = _pred3(base_logits[r, t, :], ids)
            p1 = _pred3(ste_logits[r, t, :], ids)
            switches.append(int(p0 != p1))
            unsat_shifts.append(_unsat_dir(p0, p1))

        # Random control (same N, non-decision positions)
        rng = np.random.default_rng(ctrl_seed)
        B, T = all_pos.shape
        cand = [
            (b, t)
            for b in range(B)
            for t in range(T)
            if all_pos[b, t] >= 0 and (b, t) not in ev_pos
        ]
        k = min(len(cand), len(d_margins))
        idx = (
            rng.choice(len(cand), size=k, replace=False)
            if k and len(cand) > k
            else np.arange(k)
        )
        ctrl = [cand[i] for i in idx]
        ctrl_dm, ctrl_sw = [], []
        for r, t in ctrl:
            three0 = np.array(
                [
                    base_logits[r, t, ids["-"]],
                    base_logits[r, t, ids["="]],
                    base_logits[r, t, ids["#"]],
                ],
                np.float32,
            )
            three1 = np.array(
                [
                    ste_logits[r, t, ids["-"]],
                    ste_logits[r, t, ids["="]],
                    ste_logits[r, t, ids["#"]],
                ],
                np.float32,
            )
            m0 = float(three0.max() - np.partition(three0, -2)[-2])
            m1 = float(three1.max() - np.partition(three1, -2)[-2])
            ctrl_dm.append(m1 - m0)
            ctrl_sw.append(int(np.argmax(three0) != np.argmax(three1)))

        results.append(
            dict(
                layer=layer_id,
                alpha=float(alpha),
                n_events=len(d_margins),
                dmargin3_event=float(np.mean(d_margins)) if d_margins else 0.0,
                switch_rate_event=float(np.mean(switches)) if switches else 0.0,
                unsat_shift_event=float(np.mean(unsat_shifts)) if unsat_shifts else 0.0,
                n_ctrl=len(ctrl_dm),
                dmargin3_ctrl=float(np.mean(ctrl_dm)) if ctrl_dm else 0.0,
                switch_rate_ctrl=float(np.mean(ctrl_sw)) if ctrl_sw else 0.0,
                dir_N=int(info["N"]),
                dir_pos=int(info["pos"]),
                dir_neg=int(info["neg"]),
            )
        )
    os.makedirs(out_dir, exist_ok=True)
    pd.DataFrame(results).to_csv(
        os.path.join(out_dir, f"decision_metrics_L{layer_id}.csv"), index=False
    )
    print(f"[decision] wrote → decision_metrics_L{layer_id}.csv")


def run_context_slices(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    alphas,
):
    tokenizer, cfg, params, batches, all_inputs, all_targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit = _collect_explicit(tokenizer, all_inputs)
    resid_pre_batches = [b.resid_pre for b in batches]
    w_hat, _ = build_direction_for_layer(resid_pre_batches, explicit, threshold=2)
    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)

    def key(e):
        ch = e.context.get("charge", 0)
        ch = ch if ch in (-1, 0, 1) else (2 if ch > 1 else -2)
        return (
            e.context.get("elem", "?"),
            int(ch),
            bool(e.context.get("aromatic", False)),
        )

    groups = {}
    for e in explicit:
        groups.setdefault(key(e), []).append(e)

    rows = []
    for alpha in alphas:
        ste_batches = []
        for b in batches:
            pos_mask = np.zeros_like(b.positions, bool)
            row_lo, row_hi = b.row_start, b.row_start + b.inputs.shape[0]
            for e in explicit:
                if (
                    row_lo <= e.batch < row_hi
                    and b.positions[e.batch - row_lo, e.pred_idx] >= 0
                ):
                    pos_mask[e.batch - row_lo, e.pred_idx] = True
            from experiments.valence.core import run_with_hooks

            ed = AddEditor(
                layer=layer_id,
                kind=HookType.RESID_PRE,
                w=w_hat,
                alpha=float(alpha),
                pos_mask=pos_mask,
            )
            l, _ = run_with_hooks(
                jnp.asarray(b.inputs),
                jnp.asarray(b.positions),
                params,
                cfg,
                hook_pairs=(),
                editor=ed,
            )
            ste_batches.append(np.asarray(l))
        ste_logits = np.concatenate(ste_batches, axis=0)

        for gk, evs in groups.items():
            d_marg, sw, sh = [], [], []
            for e in evs:
                r, t = e.batch, e.pred_idx
                if t < 0 or all_pos[r, t] < 0:
                    continue
                gold = int(all_targets[r, t])
                m0 = _margin3(base_logits[r, t, :], gold, ids)
                m1 = _margin3(ste_logits[r, t, :], gold, ids)
                d_marg.append(m1 - m0)
                p0 = _pred3(base_logits[r, t, :], ids)
                p1 = _pred3(ste_logits[r, t, :], ids)
                sw.append(int(p0 != p1))
                sh.append(_unsat_dir(p0, p1))
            n = len(d_marg)
            rows.append(
                dict(
                    layer=layer_id,
                    alpha=float(alpha),
                    elem=gk[0],
                    charge=int(gk[1]),
                    aromatic=bool(gk[2]),
                    n=n,
                    dmargin3=float(np.mean(d_marg)) if n else 0.0,
                    switch_rate=float(np.mean(sw)) if n else 0.0,
                    unsat_shift=float(np.mean(sh)) if n else 0.0,
                )
            )
    os.makedirs(out_dir, exist_ok=True)
    pd.DataFrame(rows).to_csv(
        os.path.join(out_dir, f"context_slices_L{layer_id}.csv"), index=False
    )
    print(f"[context] wrote → context_slices_L{layer_id}.csv")


# ------------------------ (B) Hardening: bootstrap, null, lesion ------------------------
def run_bootstrap(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    alphas,
    num_bootstrap=2000,
    direction_path=None,
):
    tokenizer, cfg, params, batches, all_inputs, _targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit = _collect_explicit(tokenizer, all_inputs)
    resid_pre_batches = [b.resid_pre for b in batches]
    if direction_path and os.path.exists(direction_path):
        obj = json.load(open(direction_path))
        w_hat = np.array(obj["w_hat"], np.float32)
    else:
        w_hat, _ = build_direction_for_layer(resid_pre_batches, explicit, threshold=2)

    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    out_path = os.path.join(out_dir, f"bootstrap_effects_L{layer_id}.csv")
    rng = np.random.default_rng(2023)
    os.makedirs(out_dir, exist_ok=True)
    with open(out_path, "w") as f:
        f.write("layer,alpha,metric,n,mean,lo95,hi95\n")
        for alpha in alphas:
            # one run to collect per-event deltas
            from experiments.valence.core import run_with_hooks

            diffs = {"-": [], "=": [], "#": [], "observed": []}
            ste_batches = []
            for b in batches:
                pos_mask = np.zeros_like(b.positions, bool)
                row_lo, row_hi = b.row_start, b.row_start + b.inputs.shape[0]
                for e in explicit:
                    if (
                        row_lo <= e.batch < row_hi
                        and b.positions[e.batch - row_lo, e.pred_idx] >= 0
                    ):
                        pos_mask[e.batch - row_lo, e.pred_idx] = True
                ed = AddEditor(
                    layer=layer_id,
                    kind=HookType.RESID_PRE,
                    w=w_hat,
                    alpha=float(alpha),
                    pos_mask=pos_mask,
                )
                l, _ = run_with_hooks(
                    jnp.asarray(b.inputs),
                    jnp.asarray(b.positions),
                    params,
                    cfg,
                    hook_pairs=(),
                    editor=ed,
                )
                ste_batches.append(np.asarray(l))
            ste_logits = np.concatenate(ste_batches, axis=0)

            for e in explicit:
                b, t = e.batch, e.pred_idx
                if t < 0 or all_pos[b, t] < 0:
                    continue
                base = base_logits[b, t, :]
                ste = ste_logits[b, t, :]
                for k, tid in ids.items():
                    diffs[k].append(float(ste[tid] - base[tid]))
                if e.token_at_pred in ids:
                    tid = ids[e.token_at_pred]
                    diffs["observed"].append(float(ste[tid] - base[tid]))
            # bootstrap
            for key, arr in diffs.items():
                arr = np.array(arr, np.float64)
                n = arr.size
                if n == 0:
                    f.write(f"{layer_id},{alpha},{key},0,0,0,0\n")
                    continue
                idxs = rng.integers(0, n, size=(num_bootstrap, n))
                boots = arr[idxs].mean(axis=1)
                mean = float(arr.mean())
                lo, hi = np.percentile(boots, [2.5, 97.5])
                f.write(
                    f"{layer_id},{alpha},{key},{n},{mean},{float(lo)},{float(hi)}\n"
                )
    print(f"[bootstrap] wrote → {out_path}")


def run_null_tests(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    alphas,
    k_random=64,
    direction_path=None,
    orth_to_dir=True,
):
    tokenizer, cfg, params, batches, all_inputs, _targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit = _collect_explicit(tokenizer, all_inputs)
    resid_pre_batches = [b.resid_pre for b in batches]
    w_hat = None
    if orth_to_dir:
        if direction_path and os.path.exists(direction_path):
            obj = json.load(open(direction_path))
            w_hat = np.array(obj["w_hat"], np.float32)
        else:
            w_hat, _ = build_direction_for_layer(
                resid_pre_batches, explicit, threshold=2
            )

    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    out_path = os.path.join(out_dir, f"null_injection_L{layer_id}.csv")
    rng = np.random.default_rng(1234)
    os.makedirs(out_dir, exist_ok=True)
    with open(out_path, "w") as f:
        f.write("layer,alpha,mode,metric,k,mean,lo95,hi95\n")
        H = resid_pre_batches[0].shape[-1]
        for alpha in alphas:
            stats = {m: [] for m in ["-", "=", "#", "observed"]}
            for _ in range(k_random):
                v = rng.normal(size=(H,)).astype(np.float32)
                if orth_to_dir and w_hat is not None:
                    v = v - float(v @ w_hat) * w_hat
                v /= np.linalg.norm(v) + 1e-8
                # one pass with random direction
                from experiments.valence.core import run_with_hooks

                ste_batches = []
                for b in batches:
                    pos_mask = np.zeros_like(b.positions, bool)
                    row_lo, row_hi = b.row_start, b.row_start + b.inputs.shape[0]
                    for e in explicit:
                        if (
                            row_lo <= e.batch < row_hi
                            and b.positions[e.batch - row_lo, e.pred_idx] >= 0
                        ):
                            pos_mask[e.batch - row_lo, e.pred_idx] = True
                    ed = AddEditor(
                        layer=layer_id,
                        kind=HookType.RESID_PRE,
                        w=v,
                        alpha=float(alpha),
                        pos_mask=pos_mask,
                    )
                    l, _ = run_with_hooks(
                        jnp.asarray(b.inputs),
                        jnp.asarray(b.positions),
                        params,
                        cfg,
                        hook_pairs=(),
                        editor=ed,
                    )
                    ste_batches.append(np.asarray(l))
                ste_logits = np.concatenate(ste_batches, axis=0)
                # stats
                for e in explicit:
                    r, t = e.batch, e.pred_idx
                    if t < 0 or all_pos[r, t] < 0:
                        continue
                    base = base_logits[r, t, :]
                    ste = ste_logits[r, t, :]
                    for k, tid in ids.items():
                        stats[k].append(float(ste[tid] - base[tid]))
                    if e.token_at_pred in ids:
                        stats["observed"].append(
                            float(
                                ste[ids[e.token_at_pred]] - base[ids[e.token_at_pred]]
                            )
                        )
            mode = "orth_rand" if orth_to_dir else "rand"
            for m, vals in stats.items():
                arr = np.array(vals, np.float64)
                mean = float(arr.mean()) if arr.size else 0.0
                lo, hi = np.percentile(arr, [2.5, 97.5]) if arr.size else (0.0, 0.0)
                f.write(
                    f"{layer_id},{alpha},{mode},{m},{k_random},{mean},{float(lo)},{float(hi)}\n"
                )
    print(f"[null] wrote → {out_path}")


def run_projectout_lesion(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    direction_path=None,
):
    tokenizer, cfg, params, batches, all_inputs, _targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit = _collect_explicit(tokenizer, all_inputs)
    resid_pre_batches = [b.resid_pre for b in batches]
    if direction_path and os.path.exists(direction_path):
        obj = json.load(open(direction_path))
        w_hat = np.array(obj["w_hat"], np.float32)
    else:
        w_hat, _ = build_direction_for_layer(resid_pre_batches, explicit, threshold=2)

    from experiments.valence.core import run_with_hooks

    editor = ProjectOutEditor(layer=layer_id, kind=HookType.RESID_PRE, w_hat=w_hat)
    ste_batches = []
    for b in batches:
        l, _ = run_with_hooks(
            jnp.asarray(b.inputs),
            jnp.asarray(b.positions),
            params,
            cfg,
            hook_pairs=(),
            editor=editor,
        )
        ste_batches.append(np.asarray(l))
    ste_logits = np.concatenate(ste_batches, axis=0)

    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    diffs = {"-": [], "=": [], "#": [], "observed": []}
    for e in explicit:
        r, t = e.batch, e.pred_idx
        if t < 0 or all_pos[r, t] < 0:
            continue
        base = base_logits[r, t, :]
        ste = ste_logits[r, t, :]
        for k, tid in ids.items():
            diffs[k].append(float(ste[tid] - base[tid]))
        if e.token_at_pred in ids:
            diffs["observed"].append(
                float(ste[ids[e.token_at_pred]] - base[ids[e.token_at_pred]])
            )

    out_path = os.path.join(out_dir, f"projectout_collapse_L{layer_id}.csv")
    os.makedirs(out_dir, exist_ok=True)
    with open(out_path, "w") as f:
        f.write("layer,metric,n,mean\n")
        for k, arr in diffs.items():
            arr = np.array(arr, np.float32)
            f.write(
                f"{layer_id},{k},{arr.size},{float(arr.mean()) if arr.size else 0.0}\n"
            )
    print(f"[lesion] wrote → {out_path}")


# ------------------------ CLI ------------------------
def main():
    ap = argparse.ArgumentParser(description="Decision metrics + robustness suite")
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--layer_id", type=int, default=3)
    ap.add_argument("--num_examples", type=int, default=4096)
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--seq_length", type=int, default=256)
    ap.add_argument("--out_dir", default="experiments/decision_robustness_out")
    ap.add_argument(
        "--alphas", type=float, nargs="+", default=[-2, -1, -0.5, 0, 0.5, 1, 2]
    )
    ap.add_argument("--num_bootstrap", type=int, default=2000)
    ap.add_argument("--k_random", type=int, default=64)
    ap.add_argument("--direction_path", type=str, default=None)

    ap.add_argument("--run_decision_metrics", action="store_true")
    ap.add_argument("--run_context_slices", action="store_true")
    ap.add_argument("--run_bootstrap", action="store_true")
    ap.add_argument("--run_null_tests", action="store_true")
    ap.add_argument("--run_projectout_lesion", action="store_true")

    args = ap.parse_args()

    if args.run_decision_metrics:
        run_decision_metrics(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            args.alphas,
        )
    if args.run_context_slices:
        run_context_slices(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            args.alphas,
        )
    if args.run_bootstrap:
        run_bootstrap(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            args.alphas,
            num_bootstrap=args.num_bootstrap,
            direction_path=args.direction_path,
        )
    if args.run_null_tests:
        run_null_tests(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            args.alphas,
            k_random=args.k_random,
            direction_path=args.direction_path,
            orth_to_dir=True,
        )
    if args.run_projectout_lesion:
        run_projectout_lesion(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            direction_path=args.direction_path,
        )


if __name__ == "__main__":
    main()
