# experiments/valence/localize.py
from __future__ import annotations
import argparse, json, os, math
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd

from lmkit.impl import transformer
from lmkit.impl.hooks import HookType
from lmkit.impl.caching import TransformerCache, build_rope
from experiments.valence.core import (
    prepare_batches,
    extract_valence_events,
    build_direction_for_layer,
    head_postov_outputs_for_positions,
    ablate_heads_in_Wo,
    edit_Wo_along_wh,
    bond_token_ids,
)


# ---------- helpers ----------
def _explicit_positions(tokenizer, all_inputs):
    evs, _ = extract_valence_events(tokenizer, all_inputs)
    evs = [e for e in evs if e.event_type == "explicit" and e.pred_idx >= 0]
    pos = [(e.batch, e.pred_idx) for e in evs]
    return evs, pos


# ---------- (1) head ablation ----------
def run_head_ablation(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
):
    os.makedirs(out_dir, exist_ok=True)
    tokenizer, cfg, params, batches, all_inputs, _targets, _positions = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit, pos_list = _explicit_positions(tokenizer, all_inputs)
    ids = bond_token_ids(tokenizer)

    H = int(cfg["num_heads"])
    rows = []
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    for h in range(H):
        p2 = ablate_heads_in_Wo(
            params, [(layer_id, h)], cfg["hidden_size"], cfg["num_heads"]
        )
        ste_batches = []
        for b in batches:
            cache = TransformerCache.create(
                jnp.asarray(b.positions), cfg, dtype=jnp.bfloat16, dynamic=False
            )
            l, *_ = transformer.run(jnp.asarray(b.inputs), cache, p2, cfg)
            ste_batches.append(np.asarray(l))
        ste_logits = np.concatenate(ste_batches, axis=0)

        def gather(tok_id):
            base = [base_logits[b, t, tok_id] for (b, t) in pos_list]
            ste = [ste_logits[b, t, tok_id] for (b, t) in pos_list]
            return float(np.mean(np.array(base) - np.array(ste))) if pos_list else 0.0

        rows.append(
            dict(
                layer=layer_id,
                head=h,
                n=len(pos_list),
                drop_minus=gather(ids["-"]),
                drop_eq=gather(ids["="]),
                drop_hash=gather(ids["#"]),
            )
        )
    pd.DataFrame(rows).to_csv(
        os.path.join(out_dir, f"localize_head_ablation_L{layer_id}.csv"), index=False
    )
    print(f"[ablation] wrote → localize_head_ablation_L{layer_id}.csv")


# ---------- (2) head–direction alignment ----------
def run_head_alignment(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    align_max_events=4096,
    align_event_chunk=512,
    align_head_chunk=2,
):
    """
    Streaming head–direction alignment:
      - samples up to `align_max_events` explicit decisions
      - processes them in chunks of `align_event_chunk`
      - processes `align_head_chunk` heads at a time
    """
    os.makedirs(out_dir, exist_ok=True)
    tokenizer, cfg, params, batches, all_inputs, _targets, all_pos = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit, pos_list = _explicit_positions(tokenizer, all_inputs)
    if not explicit:
        print("[align] no explicit events")
        pd.DataFrame([]).to_csv(
            os.path.join(out_dir, f"localize_head_alignment_L{layer_id}.csv"),
            index=False,
        )
        return

    # Fit ŵ on this layer (uses RESID_PRE)
    resid_pre_batches = [b.resid_pre for b in batches]
    w_hat, info = build_direction_for_layer(resid_pre_batches, explicit, threshold=2)
    print(f"[align] dir N={info['N']} pos={info['pos']} neg={info['neg']}")

    # Sample events to a manageable size
    rng = np.random.default_rng(42)
    ev_idx = np.arange(len(pos_list))
    if len(ev_idx) > align_max_events:
        ev_idx = rng.choice(ev_idx, size=align_max_events, replace=False)
    sampled = [pos_list[i] for i in ev_idx]

    # Split sampled events into chunks
    event_chunks = [
        sampled[i : i + align_event_chunk]
        for i in range(0, len(sampled), align_event_chunk)
    ]

    # Pre-extract attn params
    attn = params["layers"][layer_id]["attn"]
    W_q, W_k, W_v, W_o = attn["W_q"], attn["W_k"], attn["W_v"], attn["W_o"]

    # Build per-batch tensors once (bf16)
    # Concatenate to get unified residual stream and positions for this layer
    resid = jnp.asarray(np.concatenate(resid_pre_batches, axis=0)).astype(jnp.bfloat16)
    pos = jnp.asarray(all_pos)

    H = int(cfg["num_heads"])
    hidden = int(cfg["hidden_size"])
    D = hidden // H

    rows = []
    # heads in small chunks
    head_chunks = [
        list(range(h, min(h + align_head_chunk, H)))
        for h in range(0, H, align_head_chunk)
    ]

    # Helper: compute post-OV for (b,t,h) without full-matrix RoPE
    def postov_single(b, t, h):
        # Valid prefix length
        kv_len = int(jnp.sum(pos[b] >= 0))
        kv_len = min(kv_len, t + 1)
        if kv_len <= 0:
            return None
        # Select needed residuals
        y_b = resid[b, :kv_len, :]  # (kv_len, H*D)
        y_t = resid[b, t, :]  # (H*D,)

        # Q/K/V slices for this head
        q = y_t @ W_q  # (H*D,)
        k = y_b @ W_k  # (kv_len, H*D)
        v = y_b @ W_v  # (kv_len, H*D)

        # Reshape to heads; select h
        q_h = q.reshape(H, D)[h]  # (D,)
        k_h = k.reshape(kv_len, -1).reshape(kv_len, H, D)[:, h, :]  # (kv_len, D)
        v_h = v.reshape(kv_len, -1).reshape(kv_len, H, D)[:, h, :]  # (kv_len, D)

        # RoPE on-the-fly for only these slices
        sin, cos = build_rope(pos[b : b + 1], D, cfg["rope_base"])  # (1,T,D)

        # rotate q_h at t, and k_h at [:kv_len]
        # broadcast-safe: expand batch/head dims
        def rope_vec(x_vec, sin_bt, cos_bt):
            x1, x2 = x_vec[: D // 2], x_vec[D // 2 :]
            sinv, cosv = sin_bt[0, t], cos_bt[0, t]
            return (x_vec * cosv) + (jnp.concatenate([-x2, x1]) * sinv)

        def rope_mat(x_mat, sin_bt, cos_bt):
            sinv, cosv = sin_bt[0, :kv_len], cos_bt[0, :kv_len]  # (kv_len, D)
            x1, x2 = x_mat[:, : D // 2], x_mat[:, D // 2 :]
            return (x_mat * cosv) + (jnp.concatenate([-x2, x1], axis=-1) * sinv)

        q_rot = rope_vec(q_h, sin, cos)  # (D,)
        k_rot = rope_mat(k_h, sin, cos)  # (kv_len,D)

        # attention + OV
        inv_sqrt_d = 1.0 / math.sqrt(D)
        logits = (k_rot @ q_rot) * inv_sqrt_d  # (kv_len,)
        w = jax.nn.softmax(logits, axis=-1)  # (kv_len,)
        pre = jnp.sum(w[:, None] * v_h, axis=0)  # (D,)
        Wo_block = W_o[h * D : (h + 1) * D, :]  # (D, hidden)
        post = pre @ Wo_block  # (hidden,)
        return np.asarray(post, dtype=np.float32)

    # Process by head-chunk and event-chunk to keep peak VRAM low
    for hc in head_chunks:
        for ev_chunk in event_chunks:
            # compute cos similarity for each head in hc using this event chunk
            for h in hc:
                vecs = []
                for b, t in ev_chunk:
                    out = postov_single(b, t, h)
                    if out is not None:
                        vecs.append(out)
                if not vecs:
                    rows.append(
                        dict(
                            layer=layer_id,
                            head=h,
                            n=0,
                            cos_mean=0.0,
                            cos_median=0.0,
                            frac_positive=0.0,
                        )
                    )
                    continue
                V = np.stack(vecs, 0)  # (n, hidden)
                w = (w_hat / (np.linalg.norm(w_hat) + 1e-8)).astype(np.float32)
                dot = V @ w
                norms = np.linalg.norm(V, axis=1) + 1e-8
                cos = dot / norms
                rows.append(
                    dict(
                        layer=layer_id,
                        head=h,
                        n=int(V.shape[0]),
                        cos_mean=float(np.mean(cos)),
                        cos_median=float(np.median(cos)),
                        frac_positive=float(np.mean(cos > 0)),
                    )
                )
    df = pd.DataFrame(rows)
    out_csv = os.path.join(out_dir, f"localize_head_alignment_L{layer_id}.csv")
    df.to_csv(out_csv, index=False)

    with open(os.path.join(out_dir, f"valence_direction_L{layer_id}.json"), "w") as f:
        json.dump(dict(layer=layer_id, w_hat=w_hat.tolist(), **info), f, indent=2)
    print(f"[align] wrote → {out_csv} and valence_direction_L{layer_id}.json")


# ---------- (3) single-head OV project-out / scale ----------
def run_head_project_out(
    model_dir,
    ckpt_id,
    dataset_dir,
    layer_id,
    heads,
    alphas,
    num_examples,
    batch_size,
    seq_length,
    out_dir,
    w_hat_json=None,
):
    os.makedirs(out_dir, exist_ok=True)
    tokenizer, cfg, params, batches, all_inputs, _targets, _positions = prepare_batches(
        model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
    )
    explicit, pos_list = _explicit_positions(tokenizer, all_inputs)
    if w_hat_json and os.path.exists(w_hat_json):
        with open(w_hat_json, "r") as f:
            w_hat = np.array(json.load(f)["w_hat"], np.float32)
    else:
        w_hat, _ = build_direction_for_layer(
            [b.resid_pre for b in batches], explicit, threshold=2
        )

    ids = bond_token_ids(tokenizer)
    base_logits = np.concatenate([b.logits for b in batches], axis=0)
    rows = []
    for h in heads:
        for alpha in alphas:
            p2 = edit_Wo_along_wh(params, layer_id, h, w_hat=w_hat, alpha=float(alpha))
            ste_batches = []
            for b in batches:
                cache = TransformerCache.create(
                    jnp.asarray(b.positions), cfg, dtype=jnp.bfloat16, dynamic=False
                )
                l, *_ = transformer.run(jnp.asarray(b.inputs), cache, p2, cfg)
                ste_batches.append(np.asarray(l))
            ste_logits = np.concatenate(ste_batches, axis=0)

            def dmean(tok_id):
                base = np.array(
                    [base_logits[b, t, tok_id] for (b, t) in pos_list], np.float32
                )
                ste = np.array(
                    [ste_logits[b, t, tok_id] for (b, t) in pos_list], np.float32
                )
                return float(np.mean(ste - base)) if base.size else 0.0

            rows.append(
                dict(
                    layer=layer_id,
                    head=h,
                    alpha=float(alpha),
                    n=len(pos_list),
                    dlogit_minus=dmean(ids["-"]),
                    dlogit_eq=dmean(ids["="]),
                    dlogit_hash=dmean(ids["#"]),
                )
            )
    pd.DataFrame(rows).to_csv(
        os.path.join(out_dir, f"localize_project_out_L{layer_id}.csv"), index=False
    )
    print(f"[project-out] wrote → localize_project_out_L{layer_id}.csv")


# ---------- CLI ----------
def main():
    ap = argparse.ArgumentParser(
        description="Valence localization: ablation, alignment, project-out"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--layer_id", type=int, default=3)
    ap.add_argument("--num_examples", type=int, default=4096)
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--seq_length", type=int, default=256)
    ap.add_argument("--out_dir", default="experiments/localize_valence_out")

    ap.add_argument("--run_ablation", action="store_true")
    ap.add_argument("--run_head_align", action="store_true")
    ap.add_argument("--run_project_out", action="store_true")
    ap.add_argument("--heads", type=int, nargs="+", default=None)
    ap.add_argument(
        "--alphas", type=float, nargs="+", default=[-1.0, -0.5, 0.5, 1.0, 2.0]
    )
    ap.add_argument("--align_max_events", type=int, default=4096, help="Max explicit events to use for alignment.")
    ap.add_argument("--align_event_chunk", type=int, default=512, help="Process events in chunks to control VRAM.")
    ap.add_argument("--align_head_chunk", type=int, default=2, help="Process this many heads at a time.")

    ap.add_argument("--w_hat_json", type=str, default=None)

    args = ap.parse_args()
    if args.run_ablation:
        run_head_ablation(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
        )
    if args.run_head_align:
        run_head_alignment(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
        )
    if args.run_head_align:
        run_head_alignment(
            args.model_dir, args.ckpt_id, args.dataset_dir, args.layer_id,
            args.num_examples, args.batch_size, args.seq_length, args.out_dir,
            align_max_events=getattr(args, "align_max_events", 4096),
            align_event_chunk=getattr(args, "align_event_chunk", 512),
            align_head_chunk=getattr(args, "align_head_chunk", 2),
        )

    if args.run_project_out:
        if args.heads is None:
            raise SystemExit("--run_project_out requires --heads")
        run_head_project_out(
            args.model_dir,
            args.ckpt_id,
            args.dataset_dir,
            args.layer_id,
            args.heads,
            args.alphas,
            args.num_examples,
            args.batch_size,
            args.seq_length,
            args.out_dir,
            args.w_hat_json,
        )


if __name__ == "__main__":
    main()
