# lmkit/experiments/pointer_heads.py
from __future__ import annotations
import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from lmkit.impl import transformer, config as config_lib
from lmkit.impl.hooks import (
    HookRequest,
    HookType,
    capture as capture_hooks,
    unpack_captured,
)
from lmkit.impl.caching import TransformerCache, build_rope
from lmkit.tools import compat, data as data_tools, train_utils
from ..smiles_events import extract_events, Events, RingEvent, ParenEvent


# ------------------------------
# small helpers
# ------------------------------
def _clone_params(params):
    # Shallow copy except at the small subtrees we mutate
    p = dict(params)
    p["layers"] = list(params["layers"])
    return p


def ablate_head_in_Wo(
    params, layer_id: int, head_idx: int, hidden_size: int, num_heads: int
):
    """
    Return a PARAM COPY with rows for the selected head zeroed in W_o.
    Works with either jax.numpy arrays or numpy arrays.
    """
    head_dim = hidden_size // num_heads
    row_start = head_idx * head_dim
    row_end = row_start + head_dim

    new_params = _clone_params(params)
    lyr = dict(new_params["layers"][layer_id])
    attn = dict(lyr["attn"])
    W_o = attn["W_o"]

    rows = row_end - row_start
    cols = W_o.shape[1]

    # JAX path (DeviceArray has `.at`)
    if hasattr(W_o, "at"):
        zeros = jnp.zeros((rows, cols), dtype=W_o.dtype)
        W_o2 = W_o.at[row_start:row_end, :].set(zeros)
    else:
        # NumPy path
        W_o2 = np.array(W_o, copy=True)
        # use dtype-preserving zeros (works for bfloat16 too if present)
        W_o2[row_start:row_end, :] = np.zeros((rows, cols), dtype=W_o2.dtype)

    attn["W_o"] = W_o2
    lyr["attn"] = attn
    new_params["layers"][layer_id] = lyr
    return new_params

def _json_default(o):
    # Minimal, dependency-free encoder for numpy/jax types
    import numpy as _np

    try:
        import jax.numpy as _jnp

        jnp_arr = _jnp.ndarray
    except Exception:

        class _Dummy:
            pass

        jnp_arr = _Dummy  # no-op if jax not present

    if isinstance(o, (_np.integer,)):
        return int(o)
    if isinstance(o, (_np.floating,)):
        return float(o)
    if isinstance(o, (_np.ndarray, jnp_arr)):
        return o.tolist()
    # 0-dim arrays / DeviceArray scalars often have .item()
    if hasattr(o, "item"):
        try:
            return o.item()
        except Exception:
            pass
    # Fallback—make it a string so we never crash
    return str(o)


def run_with_hooks(inputs, positions, params, config, hook_pairs):
    """Single forward pass with custom hooks."""
    hooks_to_return, _ = capture_hooks(*[HookRequest(l, k) for (l, k) in hook_pairs])
    cache = TransformerCache.create(
        positions, config, dtype=jnp.bfloat16, dynamic=False
    )
    logits, _cache, captured = transformer.run(
        inputs,
        cache,
        params,
        config,
        hooks_to_return=hooks_to_return,
        hooks_to_stream=frozenset(),
        editor=None,
    )
    return logits, unpack_captured(hooks_to_return, captured)


def _pad_time_np(a: np.ndarray, T_target: int, pad_value) -> np.ndarray:
    """Right-pad along time dimension to T_target."""
    if a.shape[1] == T_target:
        return a
    pad_shape = (a.shape[0], T_target - a.shape[1]) + a.shape[2:]
    pad_block = np.full(pad_shape, pad_value, dtype=a.dtype)
    return np.concatenate([a, pad_block], axis=1)


# ------------------------------
# attention computation (per-event, per-head)
# ------------------------------
def compute_pointer_mass_for_events(
    resid_pre_layer: jnp.ndarray,  # (B,T,HIDDEN)
    positions: jnp.ndarray,  # (B,T) -1 for pad
    params_layer_attn: Dict,  # dict with W_q/W_k/W_v/W_o
    config,  # FrozenDict
    events: List[Tuple[int, int, int]],  # list of (b, pred_idx, opener_idx)
) -> np.ndarray:
    """
    Return pointer mass per head: shape (num_heads,) averaged across provided events.
    We compute attention only for the event step i_pred, not the full B,T.
    Robust to GQA (num_kv_heads < num_heads): K is tiled across query heads.
    """
    hidden = config["hidden_size"]
    num_heads = int(config["num_heads"])
    num_kv_heads = int(config.get("num_kv_heads", num_heads))
    head_dim = hidden // num_heads

    B, T, _ = resid_pre_layer.shape
    # build RoPE
    sin, cos = build_rope(positions, head_dim, config["rope_base"])  # (B,T,head_dim)

    # project once (B,T, num_heads*head_dim)
    y = resid_pre_layer
    Q = y @ params_layer_attn["W_q"]  # (B,T, H*D)
    K = y @ params_layer_attn["W_k"]  # (B,T, H_kv*D)

    # reshape to heads and apply RoPE
    Q = jnp.reshape(Q, (B, T, num_heads, head_dim))
    K = jnp.reshape(K, (B, T, num_kv_heads, head_dim))
    from lmkit.impl.transformer import rope as rope_rot

    Qr = rope_rot(Q, sin, cos)  # (B,T,H,D)
    Kr = rope_rot(K, sin, cos)  # (B,T,H_kv,D)

    # If GQA, tile K across query heads
    if num_kv_heads != num_heads:
        assert num_heads % num_kv_heads == 0, (
            "num_heads must be a multiple of num_kv_heads"
        )
        repeat = num_heads // num_kv_heads
        Kr = jnp.repeat(Kr, repeats=repeat, axis=2)  # (B,T,H,D)

    inv_sqrt_d = 1.0 / math.sqrt(head_dim)

    per_head_scores = []
    for b, i_pred, j_open in events:
        if i_pred < 0 or j_open < 0 or b < 0 or b >= B:
            continue

        # valid prefix up to i_pred (inclusive)
        seq_valid = int(jnp.sum(positions[b] >= 0))
        kv_len = min(seq_valid, i_pred + 1)
        if kv_len <= 0 or j_open >= kv_len:
            continue

        q = Qr[b, i_pred]  # (H,D)
        k = Kr[b, :kv_len]  # (kv_len, H, D)

        # attn logits over prefix: (H, kv_len)
        logits = jnp.einsum("hd,thd->ht", q, k) * inv_sqrt_d

        # causal softmax over time for each head
        weights = jax.nn.softmax(logits, axis=-1)  # (H, kv_len)
        pointer = weights[:, j_open]  # (H,)
        per_head_scores.append(pointer)

    if not per_head_scores:
        return np.zeros((num_heads,), dtype=np.float32)

    arr = jnp.stack(per_head_scores, axis=0)  # (N_events, H)
    return np.array(jnp.mean(arr, axis=0))  # (H,)


# ------------------------------
# main experiment runner
# ------------------------------
@dataclass
class HeadKey:
    layer: int
    head: int


def run_pointer_experiments(
    model_dir: str,
    ckpt_id: int,
    dataset_dir: str,
    split: str,
    num_examples: int,
    batch_size: int,
    out_dir: str,
    max_events: int = 2000,
    topk_report: int = 5,
    do_generation: bool = True,
    gen_samples: int = 256,
    seq_length: int = 256,
):
    os.makedirs(out_dir, exist_ok=True)

    # --- load tokenizer & model (fixed length to avoid batch T drift)
    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train",
        trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            pad_id=tokenizer.pad_token_id,
        )
    )
    params, *_ = train_utils.load_checkpoint(
        f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl"
    )

    # --- small dataset slice
    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir,
        tokenizer=tokenizer,
        batch_size=batch_size,
        num_processes=8,
        seed=2002,
        target_column="smiles",
        caching=True,
        limit=num_examples,
    )

    # --- collect one pass of teacher-forced logits and RESID_PRE at all layers
    hook_pairs = [(i, HookType.RESID_PRE) for i in range(cfg["num_layers"])]

    batch_inputs, batch_targets, batch_positions = [], [], []
    resid_pre_per_batch: List[Dict[Tuple[int, HookType], np.ndarray]] = []
    logits_per_batch: List[np.ndarray] = []

    processed_examples = 0
    for batch in ds:
        inputs = jnp.asarray(batch["inputs"])
        targets = jnp.asarray(batch["targets"])
        positions = jnp.asarray(batch["positions"])
        logits, captured = run_with_hooks(inputs, positions, params, cfg, hook_pairs)

        batch_inputs.append(np.asarray(inputs))
        batch_targets.append(np.asarray(targets))
        batch_positions.append(np.asarray(positions))
        # Convert captured tensors to numpy now
        resid_pre_per_batch.append({k: np.asarray(v) for k, v in captured.items()})
        logits_per_batch.append(np.asarray(logits))

        processed_examples += inputs.shape[0]
        if processed_examples >= num_examples:
            break

    # --- Normalize time dimension across batches (defensive fallback)
    Ts = [arr.shape[1] for arr in batch_inputs]
    if len(set(Ts)) > 1:
        T_max = max(Ts)
        pad_id = tokenizer.pad_token_id
        for i in range(len(batch_inputs)):
            batch_inputs[i] = _pad_time_np(batch_inputs[i], T_max, pad_id)
            batch_targets[i] = _pad_time_np(batch_targets[i], T_max, 0)
            batch_positions[i] = _pad_time_np(batch_positions[i], T_max, -1)
            logits_per_batch[i] = _pad_time_np(logits_per_batch[i], T_max, 0.0)
            cap = {}
            for key, arr in resid_pre_per_batch[i].items():
                cap[key] = _pad_time_np(arr, T_max, 0.0)
            resid_pre_per_batch[i] = cap

    # --- concatenate along batch dimension
    all_inputs = np.concatenate(batch_inputs, axis=0)
    all_targets = np.concatenate(batch_targets, axis=0)
    all_positions = np.concatenate(batch_positions, axis=0)

    # --- event extraction (tokenized)
    events_all = extract_events(tokenizer, all_inputs)

    def take_first_n(ev_list, n):
        return ev_list[:n] if n is not None else ev_list

    ring_events = take_first_n(events_all.rings, max_events)
    paren_events = take_first_n(events_all.parens, max_events)

    # --- compute pointer mass per head (layer by layer)
    ring_pointer, paren_pointer = [], []
    for layer in range(cfg["num_layers"]):
        resid_batches = [
            cap[(layer, HookType.RESID_PRE)] for cap in resid_pre_per_batch
        ]
        resid_pre = jnp.asarray(np.concatenate(resid_batches, axis=0))  # (B,T,H)
        positions_all = jnp.asarray(all_positions)

        attn_params = params["layers"][layer]["attn"]

        ring_tuples = [
            (ev.batch, ev.pred_idx, ev.open_idx)
            for ev in ring_events
            if ev.pred_idx >= 0
        ]
        paren_tuples = [
            (ev.batch, ev.pred_idx, ev.open_idx)
            for ev in paren_events
            if ev.pred_idx >= 0
        ]

        ring_pm = compute_pointer_mass_for_events(
            resid_pre, positions_all, attn_params, cfg, ring_tuples
        )
        paren_pm = compute_pointer_mass_for_events(
            resid_pre, positions_all, attn_params, cfg, paren_tuples
        )

        ring_pointer.append(ring_pm)
        paren_pointer.append(paren_pm)

    ring_pointer = (
        np.stack(ring_pointer, axis=0)
        if ring_pointer
        else np.zeros((0, cfg["num_heads"]))
    )
    paren_pointer = (
        np.stack(paren_pointer, axis=0)
        if paren_pointer
        else np.zeros((0, cfg["num_heads"]))
    )

    # --- teacher-forced logits (baseline)
    logits_all = jnp.asarray(np.concatenate(logits_per_batch, axis=0))
    targets_all = jnp.asarray(all_targets)
    positions_all = jnp.asarray(all_positions)

    ring_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in ring_events if ev.pred_idx >= 0
    ]
    paren_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in paren_events if ev.pred_idx >= 0
    ]

    def top1_correct_at_positions(logits, targets, pos_list):
        if not pos_list:
            return 0.0
        b = jnp.array([p[0] for p in pos_list], dtype=jnp.int32)
        t = jnp.array([p[1] for p in pos_list], dtype=jnp.int32)
        gold = targets[b, t]
        pred = jnp.argmax(logits[b, t, :], axis=-1)
        return float(jnp.mean((pred == gold).astype(jnp.float32)))

    ring_acc_base = top1_correct_at_positions(
        logits_all, targets_all, ring_pred_positions
    )
    paren_acc_base = top1_correct_at_positions(
        logits_all, targets_all, paren_pred_positions
    )

    # --- head ranking (by pointer mass)
    L = int(cfg["num_layers"])
    H = int(cfg["num_heads"])
    entries = []
    for l in range(L):
        for h in range(H):
            pr = float(ring_pointer[l, h]) if ring_pointer.size else 0.0
            pp = float(paren_pointer[l, h]) if paren_pointer.size else 0.0
            entries.append(dict(layer=l, head=h, pointer_ring=pr, pointer_paren=pp))

    rank_df = pd.DataFrame(entries)
    if len(rank_df) == 0:
        out_csv = os.path.join(out_dir, "pointer_heads_summary.csv")
        pd.DataFrame(columns=["layer", "head", "pointer_ring", "pointer_paren"]).to_csv(
            out_csv, index=False
        )
        print("No ring/paren events found in the sample; wrote empty summary.")
        return dict(
            ring_pointer=None,
            paren_pointer=None,
            summary_csv=out_csv,
            generation_json=None,
        )

    rank_df["rank_ring"] = rank_df["pointer_ring"].rank(ascending=False, method="min")
    rank_df["rank_paren"] = rank_df["pointer_paren"].rank(ascending=False, method="min")

    # --- choose candidates (topk per task)
    top_ring = rank_df.sort_values("pointer_ring", ascending=False).head(topk_report)
    top_paren = rank_df.sort_values("pointer_paren", ascending=False).head(topk_report)

    # --- causal ablations (necessity): per-head W_o zeroing
    def run_ablation(layer: int, head: int):
        p_abl = ablate_head_in_Wo(
            params, layer, head, cfg["hidden_size"], cfg["num_heads"]
        )
        # Re-run only logits (no hooks)
        logits_abl = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(
                jnp.asarray(positions), cfg, dtype=jnp.bfloat16, dynamic=False
            )
            l, *_ = transformer.run(jnp.asarray(inputs), cache, p_abl, cfg)
            logits_abl.append(np.asarray(l))
        logits_abl = jnp.asarray(np.concatenate(logits_abl, axis=0))

        def delta_logit_at_positions(pos_list):
            if not pos_list:
                return 0.0
            b = jnp.array([p[0] for p in pos_list], dtype=jnp.int32)
            t = jnp.array([p[1] for p in pos_list], dtype=jnp.int32)
            gold = targets_all[b, t]
            base = logits_all[b, t, :][jnp.arange(b.shape[0]), gold]
            abl = logits_abl[b, t, :][jnp.arange(b.shape[0]), gold]
            return float(jnp.mean(base - abl))

        def acc_at_positions(lgt, pos_list):
            if not pos_list:
                return 0.0
            b = jnp.array([p[0] for p in pos_list], dtype=jnp.int32)
            t = jnp.array([p[1] for p in pos_list], dtype=jnp.int32)
            gold = targets_all[b, t]
            pred = jnp.argmax(lgt[b, t, :], axis=-1)
            return float(jnp.mean((pred == gold).astype(jnp.float32)))

        d_ring = delta_logit_at_positions(ring_pred_positions)
        d_par = delta_logit_at_positions(paren_pred_positions)

        ring_acc_abl = acc_at_positions(logits_abl, ring_pred_positions)
        paren_acc_abl = acc_at_positions(logits_abl, paren_pred_positions)

        # specificity: simple non-event random control (match count up to 5k)
        n_ctrl = min(len(ring_pred_positions) + len(paren_pred_positions), 5000)
        if n_ctrl > 0:
            B, T = logits_all.shape[:2]
            used = set(ring_pred_positions + paren_pred_positions)
            ctrl = []
            for bb in range(B):
                for tt in range(T):
                    if (bb, tt) not in used and positions_all[bb, tt] >= 0:
                        ctrl.append((bb, tt))
                        if len(ctrl) >= n_ctrl:
                            break
                if len(ctrl) >= n_ctrl:
                    break
            d_ctrl = delta_logit_at_positions(ctrl)
        else:
            d_ctrl = 0.0

        return dict(
            layer=layer,
            head=head,
            delta_logit_ring=d_ring,
            delta_logit_paren=d_par,
            ring_acc_base=ring_acc_base,
            ring_acc_abl=ring_acc_abl,
            paren_acc_base=paren_acc_base,
            paren_acc_abl=paren_acc_abl,
            delta_logit_controls=d_ctrl,
        )

    ablation_rows = []
    if len(top_ring) or len(top_paren):
        cand = pd.concat([top_ring, top_paren]).drop_duplicates(["layer", "head"])
        for _, r in cand.iterrows():
            # IMPORTANT: bracket indexing to avoid pandas Series.head() method collision
            res = run_ablation(int(r["layer"]), int(r["head"]))
            ablation_rows.append(res)

    ablate_df = (
        pd.DataFrame(ablation_rows)
        if ablation_rows
        else pd.DataFrame(
            columns=[
                "layer",
                "head",
                "delta_logit_ring",
                "delta_logit_paren",
                "ring_acc_base",
                "ring_acc_abl",
                "paren_acc_base",
                "paren_acc_abl",
                "delta_logit_controls",
            ]
        )
    )
    summary = rank_df.merge(ablate_df, on=["layer", "head"], how="left")
    out_csv = os.path.join(out_dir, "pointer_heads_summary.csv")
    summary.to_csv(out_csv, index=False)

    # --- Optional: small decoding sanity (validity %)
    gen_summary = {}
    gen_json = None
    if do_generation and len(ablation_rows) > 0:
        from lmkit.tools import stem

        eval_fn = train_utils.get_eval_fn(
            tokenizer=tokenizer,
            model_config=cfg,
            num_samples=gen_samples,
            batch_size=min(128, gen_samples),
            metrics_fn=stem.molstats,
            log_metrics=False,
        )
        key = jax.random.PRNGKey(0)
        base_stats = eval_fn(key, step=-1, params=params)
        gen_summary["baseline"] = base_stats

        for row in ablation_rows[: min(3, len(ablation_rows))]:
            l, h = int(row["layer"]), int(row["head"])
            p_abl = ablate_head_in_Wo(
                params, l, h, cfg["hidden_size"], cfg["num_heads"]
            )
            abl_stats = eval_fn(
                jax.random.PRNGKey(1 + l * 100 + h), step=-1, params=p_abl
            )
            gen_summary[f"L{l}H{h}"] = abl_stats

        gen_json = os.path.join(out_dir, "generation_summary.json")
        with open(gen_json, "w") as f:
            json.dump(gen_summary, f, indent=2, default=_json_default)

    # console summary (short)
    print("\nTop pointer heads (ring):")
    if len(rank_df):
        print(
            rank_df.sort_values("pointer_ring", ascending=False)
            .head(topk_report)[["layer", "head", "pointer_ring"]]
            .to_string(index=False)
        )
    print("\nTop pointer heads (paren):")
    if len(rank_df):
        print(
            rank_df.sort_values("pointer_paren", ascending=False)
            .head(topk_report)[["layer", "head", "pointer_paren"]]
            .to_string(index=False)
        )
    if len(ablation_rows):
        print("\nAblation deltas (subset):")
        print(pd.DataFrame(ablation_rows).to_string(index=False))

    return dict(
        ring_pointer=ring_pointer,
        paren_pointer=paren_pointer,
        summary_csv=out_csv,
        generation_json=gen_json,
    )


# ------------------------------
# CLI
# ------------------------------
def main():
    ap = argparse.ArgumentParser(
        description="Pointer heads: localization + causal ablations"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--split", default="train")
    ap.add_argument(
        "--num_examples", type=int, default=4096, help="number of sequences to analyze"
    )
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--out_dir", default="experiments/pointer_outputs")
    ap.add_argument(
        "--max_events",
        type=int,
        default=2000,
        help="cap events per type to keep it fast",
    )
    ap.add_argument("--topk_report", type=int, default=5)
    ap.add_argument("--no_generation", action="store_true", help="skip decoding sanity")
    ap.add_argument("--gen_samples", type=int, default=256)
    ap.add_argument(
        "--seq_length",
        type=int,
        default=256,
        help="Fixed tokenizer length (pad/truncate). Match your training length.",
    )
    args = ap.parse_args()

    run_pointer_experiments(
        model_dir=args.model_dir,
        ckpt_id=args.ckpt_id,
        dataset_dir=args.dataset_dir,
        split=args.split,
        num_examples=args.num_examples,
        batch_size=args.batch_size,
        out_dir=args.out_dir,
        max_events=args.max_events,
        topk_report=args.topk_report,
        do_generation=not args.no_generation,
        gen_samples=args.gen_samples,
        seq_length=args.seq_length,
    )


if __name__ == "__main__":
    main()
