# experiments/valence/probes_and_causality.py
from __future__ import annotations
import argparse, json, math, os
import numpy as np
import jax.numpy as jnp
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split

from experiments.smiles_events import _tok_str
from lmkit.impl.hooks import HookType
from experiments.valence.core import (
    run_with_hooks, pad_time_np, extract_valence_events,
    build_direction_for_layer, bond_token_ids
)
from lmkit.impl import config as config_lib
from lmkit.impl import transformer
from lmkit.impl.caching import TransformerCache
from lmkit.tools import compat, train_utils
from lmkit.tools import data as data_tools


# ------------------------ Probes (valence budget) ------------------------
def run_valence_budget(model_dir, ckpt_id, dataset_dir, num_examples, batch_size, out_dir, seq_length=256, max_events=60000):
    os.makedirs(out_dir, exist_ok=True)

    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train", trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(bos_id=tokenizer.bos_token_id, eos_id=tokenizer.eos_token_id, pad_id=tokenizer.pad_token_id)
    )
    params, *_ = train_utils.load_checkpoint(f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl")

    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir, tokenizer=tokenizer, batch_size=batch_size,
        num_processes=8, seed=2002, target_column="smiles", caching=True, limit=num_examples
    )

    hook_pairs = [(i, HookType.RESID_PRE) for i in range(cfg["num_layers"])]
    batch_inputs, batch_targets, batch_positions = [], [], []
    resid_pre_per_batch = []

    processed = 0
    for batch in ds:
        inputs = jnp.asarray(batch["inputs"])
        targets = jnp.asarray(batch["targets"])
        positions = jnp.asarray(batch["positions"])
        _logits, captured = run_with_hooks(inputs, positions, params, cfg, hook_pairs)

        batch_inputs.append(np.asarray(inputs))
        batch_targets.append(np.asarray(targets))
        batch_positions.append(np.asarray(positions))
        resid_pre_per_batch.append({k: np.asarray(v, np.float32) for k, v in captured.items()})

        processed += inputs.shape[0]
        if processed >= num_examples: break

    Ts = [arr.shape[1] for arr in batch_inputs]
    if len(set(Ts)) > 1:
        T_max = max(Ts); pad_id = tokenizer.pad_token_id
        for i in range(len(batch_inputs)):
            batch_inputs[i]    = pad_time_np(batch_inputs[i],    T_max, pad_id)
            batch_targets[i]   = pad_time_np(batch_targets[i],   T_max, 0)
            batch_positions[i] = pad_time_np(batch_positions[i], T_max, -1)
            resid_pre_per_batch[i] = {k: pad_time_np(v, T_max, 0.0) for k, v in resid_pre_per_batch[i].items()}

    all_inputs = np.concatenate(batch_inputs, axis=0)
    events_all, _dbg = extract_valence_events(tokenizer, all_inputs)
    if len(events_all) > max_events: events_all = events_all[:max_events]

    y = np.array([int(max(0, min(4, math.floor(ev.remaining_before)))) for ev in events_all], np.int32)
    pred_positions = [(ev.batch, ev.pred_idx) for ev in events_all]

    rows = [dict(event=ev.event_type, y=int(max(0, min(4, math.floor(ev.remaining_before))))) for ev in events_all]
    pd.DataFrame(rows).value_counts(["event", "y"]).reset_index(name="count").to_csv(
        os.path.join(out_dir, "valence_events_summary.csv"), index=False
    )

    L = int(cfg["num_layers"]); probe_rows = []
    for layer in range(L):
        resid_batches = [cap[(layer, HookType.RESID_PRE)] for cap in resid_pre_per_batch]
        resid = np.concatenate(resid_batches, axis=0).astype(np.float32, copy=False)
        B,T,H = resid.shape
        b_idx = np.array([b for (b,t) in pred_positions], np.int32)
        t_idx = np.array([t for (b,t) in pred_positions], np.int32)
        ok = (t_idx >= 0) & (t_idx < T)
        b_idx, t_idx = b_idx[ok], t_idx[ok]
        y_layer = y[ok]
        X = resid[b_idx, t_idx, :].astype(np.float32, copy=False)
        if X.shape[0] < 200:
            probe_rows.append(dict(layer=layer, n=int(X.shape[0]), acc=0.0, f1_macro=0.0)); continue
        Xtr, Xte, ytr, yte = train_test_split(X, y_layer, test_size=0.3, random_state=42, stratify=y_layer)
        clf = LogisticRegression(max_iter=1000, solver="lbfgs", multi_class="multinomial")
        clf.fit(Xtr, ytr)
        yhat = clf.predict(Xte)
        probe_rows.append(dict(layer=layer, n=int(X.shape[0]), acc=float(accuracy_score(yte,yhat)), f1_macro=float(f1_score(yte,yhat,average="macro"))))
    pd.DataFrame(probe_rows).to_csv(os.path.join(out_dir, "valence_probe_by_layer.csv"), index=False)

    # small debug
    debug_examples = []
    for k, ev in enumerate(events_all[:20]):
        seq_tokens = [_tok_str(tokenizer, int(t)) for t in all_inputs[ev.batch]]
        debug_examples.append(dict(
            event=ev.event_type, pred_idx=int(ev.pred_idx), token_at_pred=ev.token_at_pred,
            remaining=float(ev.remaining_before), allowed=int(ev.allowed),
            consumed=float(ev.consumed_before), context=ev.context,
            prefix="".join(seq_tokens[: ev.pred_idx + 1])[:256],
            next_tok=seq_tokens[ev.pred_idx + 1] if ev.pred_idx + 1 < len(seq_tokens) else "<EOS>",
        ))
    with open(os.path.join(out_dir, "examples_debug.json"), "w") as f:
        json.dump(debug_examples, f, indent=2)
    print("[budget] wrote: valence_events_summary.csv, valence_probe_by_layer.csv, examples_debug.json")


# ------------------------ Causality (Δlogits vs α) ------------------------
def measure_delta_logits(model_dir, ckpt_id, dataset_dir, layer_id, alphas, num_examples=4096, batch_size=512, seq_length=256, out_dir="experiments/valence_causality_out"):
    os.makedirs(out_dir, exist_ok=True)

    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train", trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(bos_id=tokenizer.bos_token_id, eos_id=tokenizer.eos_token_id, pad_id=tokenizer.pad_token_id)
    )
    params, *_ = train_utils.load_checkpoint(f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl")

    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir, tokenizer=tokenizer, batch_size=batch_size, num_processes=8,
        seed=2002, target_column="smiles", caching=True, limit=num_examples,
    )

    hook_pairs = [(layer_id, HookType.RESID_PRE)]
    batch_inputs, batch_positions, resid_pre_batches, logits_ref_batches = [], [], [], []
    processed = 0
    for batch in ds:
        inputs = jnp.asarray(batch["inputs"])
        positions = jnp.asarray(batch["positions"])
        logits_ref, captured = run_with_hooks(inputs, positions, params, cfg, hook_pairs)
        batch_inputs.append(np.asarray(inputs))
        batch_positions.append(np.asarray(positions))
        resid_pre_batches.append(np.asarray(captured[(layer_id, HookType.RESID_PRE)], np.float32))
        logits_ref_batches.append(np.asarray(logits_ref))
        processed += inputs.shape[0]
        if processed >= num_examples: break

    Ts = [arr.shape[1] for arr in batch_inputs]
    if len(set(Ts)) > 1:
        T_max = max(Ts); pad_id = tokenizer.pad_token_id
        for i in range(len(batch_inputs)):
            batch_inputs[i]    = pad_time_np(batch_inputs[i],    T_max, pad_id)
            batch_positions[i] = pad_time_np(batch_positions[i], T_max, -1)
            resid_pre_batches[i] = pad_time_np(resid_pre_batches[i], T_max, 0.0)
            logits_ref_batches[i] = pad_time_np(logits_ref_batches[i], T_max, 0.0)

    all_inputs  = np.concatenate(batch_inputs, axis=0)
    all_pos     = np.concatenate(batch_positions, axis=0)
    logits_ref  = np.concatenate(logits_ref_batches, axis=0)

    events_all, _ = extract_valence_events(tokenizer, all_inputs)
    explicit_events = [ev for ev in events_all if ev.event_type == "explicit"]
    if len(explicit_events) < 100:
        print(f"[warn] few explicit events: {len(explicit_events)}")

    w_hat, info = build_direction_for_layer(resid_pre_batches, explicit_events, threshold=2)
    print(f"[dir] layer={layer_id} N={info['N']} pos={info['pos']} neg={info['neg']}")

    ids = bond_token_ids(tokenizer)
    results = []
    for alpha in alphas:
        # lightweight additive editor via hooks (= inject alpha*w_hat at RESID_PRE)
        from experiments.valence.core import AddEditor  # lazy import
        editor = AddEditor(layer=layer_id, kind=HookType.RESID_PRE, w=w_hat, alpha=float(alpha))
        logits_ste_batches = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(jnp.asarray(positions), cfg, dtype=jnp.bfloat16, dynamic=False)
            l, *_ = transformer.run(jnp.asarray(inputs), cache, params, cfg, hooks_to_return=(), hooks_to_stream=frozenset(), editor=editor)
            logits_ste_batches.append(np.asarray(l))
        logits_ste = np.concatenate(logits_ste_batches, axis=0)

        diffs = {"-": [], "=": [], "#": []}
        for ev in explicit_events:
            b, t = ev.batch, ev.pred_idx
            if t < 0 or all_pos[b, t] < 0: continue
            base = logits_ref[b, t, :]; ste = logits_ste[b, t, :]
            for sym, tid in ids.items():
                diffs[sym].append(float(ste[tid] - base[tid]))
        results.append(dict(
            alpha=float(alpha),
            dlogit_single = float(np.mean(diffs["-"])) if diffs["-"] else 0.0,
            dlogit_double  = float(np.mean(diffs["="])) if diffs["="] else 0.0,
            dlogit_triple = float(np.mean(diffs["#"])) if diffs["#"] else 0.0,
            n_positions = int(sum(len(v) for v in diffs.values()) // 3),
        ))

    out_path = os.path.join(out_dir, f"valence_causality_L{layer_id}.json")
    with open(out_path, "w") as f:
        json.dump(dict(layer=layer_id, results=results), f, indent=2)
    print(f"[causality] wrote → {out_path}")


# ------------------------ CLI ------------------------
def main():
    ap = argparse.ArgumentParser(description="Valence: (A) probes, (B) causality")
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--num_examples", type=int, default=4096)
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--seq_length", type=int, default=256)
    ap.add_argument("--out_dir", default="experiments/valence_out")

    ap.add_argument("--run_budget", action="store_true")
    ap.add_argument("--run_causality", action="store_true")
    ap.add_argument("--layer_id", type=int, default=3)
    ap.add_argument("--alphas", type=float, nargs="+", default=[-2,-1,-0.5,0,0.5,1,2])

    args = ap.parse_args()
    if args.run_budget:
        run_valence_budget(args.model_dir, args.ckpt_id, args.dataset_dir,
                           args.num_examples, args.batch_size, os.path.join(args.out_dir, "budget"),
                           seq_length=args.seq_length)
    if args.run_causality:
        os.makedirs(os.path.join(args.out_dir, "causality"), exist_ok=True)
        measure_delta_logits(args.model_dir, args.ckpt_id, args.dataset_dir,
                             args.layer_id, args.alphas,
                             num_examples=args.num_examples, batch_size=args.batch_size,
                             seq_length=args.seq_length, out_dir=os.path.join(args.out_dir, "causality"))

if __name__ == "__main__":
    main()
