# experiments/valence/core.py
from __future__ import annotations

import math, os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
from sklearn.linear_model import LogisticRegression

from lmkit.impl import config as config_lib
from lmkit.impl import transformer
from lmkit.impl.caching import TransformerCache, build_rope
from lmkit.impl.hooks import HookRequest, HookType, unpack_captured
from lmkit.impl.hooks import capture as capture_hooks
from lmkit.tools import compat, train_utils
from lmkit.tools import data as data_tools

# Helper tokens (comes from your existing codebase)
from ..smiles_events import (
    _read_ring_id,
    _tok_str,
)  # assumes experiments/smiles_events.py exists


# ---------------------------------------------------------------------
# Common utilities
# ---------------------------------------------------------------------
def pad_time_np(a: np.ndarray, T_target: int, pad_value) -> np.ndarray:
    if a.shape[1] == T_target:
        return a
    pad_shape = (a.shape[0], T_target - a.shape[1]) + a.shape[2:]
    pad_block = np.full(pad_shape, pad_value, dtype=a.dtype)
    return np.concatenate([a, pad_block], axis=1)


def run_with_hooks(inputs, positions, params, config, hook_pairs, editor=None):
    """Small wrapper to run the model with hook capture + optional editor."""
    hooks_to_return, _ = capture_hooks(*[HookRequest(l, k) for (l, k) in hook_pairs])
    cache = TransformerCache.create(
        positions, config, dtype=jnp.bfloat16, dynamic=False
    )
    logits, _cache, captured = transformer.run(
        inputs,
        cache,
        params,
        config,
        hooks_to_return=hooks_to_return,
        hooks_to_stream=frozenset(),
        editor=editor,
    )
    return logits, unpack_captured(hooks_to_return, captured)


def capture_only(inputs, positions, params, config, hook_pairs, editor=None):
    """Like run_with_hooks, but returns (logits, captured) with a short name."""
    return run_with_hooks(inputs, positions, params, config, hook_pairs, editor=editor)


# ---------------------------------------------------------------------
# Event extraction (80/20 SMILES valence ledger)
# ---------------------------------------------------------------------
_ALLOWED_VAL = {
    "B": 3,
    "C": 4,
    "N": 3,
    "O": 2,
    "F": 1,
    "P": 3,
    "S": 2,
    "Cl": 1,
    "Br": 1,
    "I": 1,
    "Si": 4,
    "Se": 2,
    "c": 3,
    "n": 3,
    "o": 2,
    "s": 2,
    "p": 3,
}
_BOND_ORDER = {"-": 1, "=": 2, "#": 3, ":": 1}
_TWO_LETTER = {"Cl", "Br", "Si", "Se"}
_ONE_LETTER = {"B", "C", "N", "O", "F", "P", "S", "I"}
_ARO = {"c", "n", "o", "s", "p"}


@dataclass
class ValenceEvent:
    batch: int
    pred_idx: int
    atom_idx: int
    event_type: str  # "explicit" | "implicit" | "ring_close"
    allowed: int
    consumed_before: float
    remaining_before: float
    token_at_pred: str  # token that will be predicted at pred_idx+1
    context: Dict[str, object]  # {'elem','aromatic','charge','expH',...}


def _allowed_valence(elem: str, aromatic: bool, charge: int) -> int:
    base = _ALLOWED_VAL.get(elem, 4)
    if elem in ("N", "n", "P", "p", "S", "s") and charge > 0:  # crude +1 when cationic
        base = min(base + 1, 5)
    if elem in ("B",) and charge < 0:
        base = min(base + 1, 4)
    return int(base)


def _parse_bracket_block(tokens: List[str], i: int):
    # Accept whole "[...]" token or sequences '[',... ,']'
    if tokens[i].startswith("[") and tokens[i].endswith("]"):
        block, end_i = tokens[i], i
    else:
        if tokens[i] != "[":
            return dict(kind="non_bracket", elem=None), i
        parts, j = [tokens[i]], i + 1
        while j < len(tokens) and tokens[j] != "]":
            parts.append(tokens[j])
            j += 1
        if j < len(tokens):
            parts.append(tokens[j])
        block, end_i = "".join(parts), (j if j < len(tokens) else i)

    inner = block[1:-1]
    elem, aromatic, charge, expH = None, False, 0, 0

    k = 0
    while k < len(inner) and inner[k].isdigit():
        k += 1  # skip isotopes

    if k + 1 < len(inner) and inner[k : k + 2] in _TWO_LETTER:
        elem, k = inner[k : k + 2], k + 2
    elif k < len(inner):
        ch = inner[k]
        if ch in _ONE_LETTER:
            elem, k = ch, k + 1
        elif ch in {"b", "c", "n", "o", "p", "s"}:
            elem, aromatic, k = ch, True, k + 1

    rest = inner[k:]
    # H count
    idx = rest.find("H")
    if idx != -1:
        hnum = 1
        if idx + 1 < len(rest) and rest[idx + 1].isdigit():
            j = idx + 1
            while j < len(rest) and rest[j].isdigit():
                j += 1
            try:
                hnum = int(rest[idx + 1 : j])
            except:
                hnum = 1
        expH = hnum

    # charge (+, ++, +2, -, --, -2)
    import re

    plus = rest.count("+")
    minus = rest.count("-")
    m_plus = re.search(r"\+(\d+)", rest)
    m_minus = re.search(r"-(\d+)", rest)
    charge += int(m_plus.group(1)) if m_plus else plus
    charge -= int(m_minus.group(1)) if m_minus else minus

    return dict(
        kind="bracket", elem=elem, aromatic=aromatic, charge=charge, expH=expH
    ), end_i


def _classify_atom_token(ts: str) -> Optional[Dict]:
    if ts.startswith("[") and ts.endswith("]"):
        info, _ = _parse_bracket_block([ts], 0)
        return info
    if ts in _TWO_LETTER or ts in _ONE_LETTER:
        return dict(kind="plain", elem=ts, aromatic=False, charge=0, expH=0)
    if ts in _ARO:
        return dict(kind="plain", elem=ts, aromatic=True, charge=0, expH=0)
    return None


def extract_valence_events_for_ids(tokenizer, ids_row: np.ndarray):
    toks = [_tok_str(tokenizer, int(t)) for t in ids_row]
    atoms: List[Dict] = []
    cur_atom: Optional[int] = None
    branch_stack: List[int] = []
    rings: Dict[str, Tuple[Optional[int], Optional[int]]] = {}
    pending_bond: Optional[int] = None
    events: List[ValenceEvent] = []

    def _new_atom(elem, aromatic, charge, expH):
        allowed = _allowed_valence(elem, aromatic, charge)
        atoms.append(
            dict(
                elem=elem,
                aromatic=aromatic,
                charge=charge,
                expH=int(expH),
                allowed=int(allowed),
                consumed=0.0,
            )
        )
        return len(atoms) - 1

    def _connect(a: int, b: int, order: float):
        if a is None or b is None:
            return
        if a < 0 or b < 0 or a >= len(atoms) or b >= len(atoms):
            return
        atoms[a]["consumed"] += order
        atoms[b]["consumed"] += order

    i = 0
    while i < len(toks):
        ts = toks[i]

        if ts == "(":
            branch_stack.append(cur_atom)
            i += 1
            continue
        if ts == ")":
            if branch_stack:
                cur_atom = branch_stack.pop()
            i += 1
            continue

        if ts in _BOND_ORDER:  # explicit
            if cur_atom is not None:
                a = cur_atom
                allowed, consumed = atoms[a]["allowed"], atoms[a]["consumed"]
                rem = allowed - consumed - atoms[a]["expH"]
                pred_idx = max(0, i - 1)
                events.append(
                    ValenceEvent(
                        batch=-1,
                        pred_idx=pred_idx,
                        atom_idx=a,
                        event_type="explicit",
                        allowed=allowed,
                        consumed_before=consumed,
                        remaining_before=rem,
                        token_at_pred=toks[i],
                        context=dict(
                            elem=atoms[a]["elem"],
                            aromatic=atoms[a]["aromatic"],
                            charge=atoms[a]["charge"],
                            expH=atoms[a]["expH"],
                        ),
                    )
                )
            pending_bond = _BOND_ORDER[ts]
            i += 1
            continue

        rid, end_i = _read_ring_id(toks, i)
        if rid is not None:
            if rid in rings and cur_atom is not None:  # closer
                a = cur_atom
                allowed, consumed = atoms[a]["allowed"], atoms[a]["consumed"]
                rem = allowed - consumed - atoms[a]["expH"]
                pred_idx = max(0, end_i - 1)
                events.append(
                    ValenceEvent(
                        batch=-1,
                        pred_idx=pred_idx,
                        atom_idx=a,
                        event_type="ring_close",
                        allowed=allowed,
                        consumed_before=consumed,
                        remaining_before=rem,
                        token_at_pred=toks[end_i],
                        context=dict(
                            elem=atoms[a]["elem"],
                            aromatic=atoms[a]["aromatic"],
                            charge=atoms[a]["charge"],
                            expH=atoms[a]["expH"],
                            ring_id=rid,
                        ),
                    )
                )
                opener_idx, opener_bond = rings.pop(rid)
                order = (
                    pending_bond
                    if pending_bond is not None
                    else (opener_bond if opener_bond is not None else 1)
                )
                _connect(a, opener_idx, float(order))
                pending_bond = None
            else:
                if cur_atom is not None:
                    rings[rid] = (cur_atom, pending_bond)
                pending_bond = None
            i = end_i + 1
            continue

        if ts == "[" or (ts.startswith("[") and ts.endswith("]")):
            info, end_j = _parse_bracket_block(toks, i)
            if info.get("elem") is not None:
                a_new = _new_atom(
                    info["elem"],
                    info.get("aromatic", False),
                    info.get("charge", 0),
                    info.get("expH", 0),
                )
                if cur_atom is not None:
                    order = pending_bond if pending_bond is not None else 1
                    pred_idx = max(0, i - 1)
                    allowed, consumed = (
                        atoms[cur_atom]["allowed"],
                        atoms[cur_atom]["consumed"],
                    )
                    rem = allowed - consumed - atoms[cur_atom]["expH"]
                    events.append(
                        ValenceEvent(
                            batch=-1,
                            pred_idx=pred_idx,
                            atom_idx=cur_atom,
                            event_type="implicit",
                            allowed=allowed,
                            consumed_before=consumed,
                            remaining_before=rem,
                            token_at_pred=toks[i],
                            context=dict(
                                elem=atoms[cur_atom]["elem"],
                                aromatic=atoms[cur_atom]["aromatic"],
                                charge=atoms[cur_atom]["charge"],
                                expH=atoms[cur_atom]["expH"],
                            ),
                        )
                    )
                    _connect(cur_atom, a_new, float(order))
                cur_atom = a_new
                pending_bond = None
            i = end_j + 1
            continue

        atom_info = _classify_atom_token(ts)
        if atom_info and atom_info.get("elem"):
            a_new = _new_atom(
                atom_info["elem"],
                atom_info.get("aromatic", False),
                atom_info.get("charge", 0),
                atom_info.get("expH", 0),
            )
            if cur_atom is not None:
                order = pending_bond if pending_bond is not None else 1
                pred_idx = max(0, i - 1)
                allowed, consumed = (
                    atoms[cur_atom]["allowed"],
                    atoms[cur_atom]["consumed"],
                )
                rem = allowed - consumed - atoms[cur_atom]["expH"]
                events.append(
                    ValenceEvent(
                        batch=-1,
                        pred_idx=pred_idx,
                        atom_idx=cur_atom,
                        event_type="implicit",
                        allowed=allowed,
                        consumed_before=consumed,
                        remaining_before=rem,
                        token_at_pred=toks[i],
                        context=dict(
                            elem=atoms[cur_atom]["elem"],
                            aromatic=atoms[cur_atom]["aromatic"],
                            charge=atoms[cur_atom]["charge"],
                            expH=atoms[cur_atom]["expH"],
                        ),
                    )
                )
                _connect(cur_atom, a_new, float(order))
            cur_atom = a_new
            pending_bond = None
            i += 1
            continue

        i += 1

    return events, dict(num_atoms=len(atoms), atoms=atoms)


def extract_valence_events(tokenizer, batch_ids: np.ndarray):
    all_events: List[ValenceEvent] = []
    debugs: List[Dict] = []
    for b in range(batch_ids.shape[0]):
        evs, dbg = extract_valence_events_for_ids(tokenizer, batch_ids[b])
        for ev in evs:
            ev.batch = b
        all_events.extend(evs)
        debugs.append(dbg)
    return all_events, debugs


# ---------------------------------------------------------------------
# Editors (JAX-safe, dtype-safe)
# ---------------------------------------------------------------------
class AddEditor:
    """Add alpha * w at (layer, kind) optionally masked to specific token positions."""

    __slots__ = ("layer", "kind", "w", "alpha", "pos_mask")

    def __init__(
        self,
        layer: int,
        kind: HookType,
        w: np.ndarray,
        alpha: float,
        pos_mask: Optional[np.ndarray] = None,
    ):
        self.layer = int(layer)
        self.kind = kind
        self.w = jnp.asarray(w, dtype=jnp.float32)
        self.alpha = float(alpha)
        self.pos_mask = (
            None if pos_mask is None else jnp.asarray(pos_mask, dtype=jnp.bool_)
        )

    def __hash__(self):
        return id(self)

    def apply(self, *, layer: int, kind: HookType, x, token_mask=None):
        if layer != self.layer or kind.value != self.kind.value:
            return x
        upd = (self.alpha * self.w).astype(x.dtype)[None, None, :]
        if self.pos_mask is None and token_mask is None:
            return x + upd
        mask = None
        if self.pos_mask is not None:
            mask = self.pos_mask
        if token_mask is not None:
            mask = (mask & (token_mask > 0)) if mask is not None else (token_mask > 0)
        mask_f = mask.astype(x.dtype)[..., None]
        return x + upd * mask_f


class ProjectOutEditor:
    """Project-out along unit w_hat at (layer, kind): x <- x - (x·w) w"""

    __slots__ = ("layer", "kind", "w")

    def __init__(self, layer: int, kind: HookType, w_hat: np.ndarray):
        w = np.asarray(w_hat, np.float32)
        w = w / (np.linalg.norm(w) + 1e-8)
        self.layer = int(layer)
        self.kind = kind
        self.w = jnp.asarray(w, jnp.float32)

    def __hash__(self):
        return id(self)

    def apply(self, *, layer: int, kind: HookType, x, token_mask=None):
        if layer != self.layer or kind.value != self.kind.value:
            return x
        w = self.w.astype(x.dtype)
        coeff = jnp.einsum("bth,h->bt", x, w).astype(x.dtype)
        if token_mask is not None:
            coeff = coeff * jnp.asarray(token_mask, dtype=x.dtype)
        return x - coeff[..., None] * w


class ComposeEditor:
    """Apply a sequence of editors in order."""

    __slots__ = ("editors",)

    def __init__(self, *editors):
        self.editors = tuple(editors)

    def __hash__(self):
        return id(self)

    def apply(self, *, layer: int, kind: HookType, x, token_mask=None):
        for e in self.editors:
            x = e.apply(layer=layer, kind=kind, x=x, token_mask=token_mask)
        return x


# ---------------------------------------------------------------------
# Direction fitting (binary LR on RESID_PRE at explicit decisions)
# ---------------------------------------------------------------------
def build_direction_for_layer(resid_batches: List[np.ndarray], events, threshold=2):
    resid = np.concatenate(resid_batches, axis=0).astype(np.float32, copy=False)
    b = np.array([ev.batch for ev in events], np.int32)
    t = np.array([ev.pred_idx for ev in events], np.int32)
    ok = (t >= 0) & (t < resid.shape[1]) & (b >= 0) & (b < resid.shape[0])
    b, t = b[ok], t[ok]
    y = np.array(
        [1 if getattr(ev, "remaining_before", 0) >= threshold else 0 for ev in events],
        np.int32,
    )[ok]
    pos = int(y.sum())
    neg = int((1 - y).sum())
    if pos == 0 or neg == 0:
        raise ValueError("Need both classes to fit valence direction.")
    X = resid[b, t, :]
    clf = LogisticRegression(max_iter=1000, solver="lbfgs", class_weight="balanced")
    clf.fit(X, y)
    w = clf.coef_.reshape(-1).astype(np.float32)
    w /= np.linalg.norm(w) + 1e-8
    return w, dict(N=int(X.shape[0]), pos=pos, neg=neg)


# ---------------------------------------------------------------------
# Batch loaders & helpers (shared)
# ---------------------------------------------------------------------
@dataclass
class BatchPack:
    inputs: np.ndarray
    targets: np.ndarray
    positions: np.ndarray
    resid_pre: Optional[np.ndarray]
    logits: np.ndarray
    row_start: int


def prepare_batches(
    model_dir, ckpt_id, dataset_dir, layer_id, num_examples, batch_size, seq_length
):
    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train",
        trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            pad_id=tokenizer.pad_token_id,
        )
    )
    params, *_ = train_utils.load_checkpoint(
        f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl"
    )

    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir,
        tokenizer=tokenizer,
        batch_size=batch_size,
        num_processes=8,
        seed=2002,
        target_column="smiles",
        caching=True,
        limit=num_examples,
    )

    hook_pairs = [(layer_id, HookType.RESID_PRE)]
    batches: List[BatchPack] = []
    processed, row_cursor = 0, 0
    for batch in ds:
        inputs = jnp.asarray(batch["inputs"])
        targets = np.asarray(batch["targets"])
        positions = jnp.asarray(batch["positions"])
        logits_ref, captured = capture_only(
            inputs, positions, params, cfg, hook_pairs, editor=None
        )
        resid = np.asarray(captured[(layer_id, HookType.RESID_PRE)], dtype=np.float32)
        b_np = np.asarray(inputs)
        p_np = np.asarray(positions)
        batches.append(
            BatchPack(
                inputs=b_np,
                targets=targets,
                positions=p_np,
                resid_pre=resid,
                logits=np.asarray(logits_ref),
                row_start=row_cursor,
            )
        )
        processed += int(inputs.shape[0])
        row_cursor += int(inputs.shape[0])
        if processed >= num_examples:
            break

    # normalize T
    Ts = [b.inputs.shape[1] for b in batches]
    if len(set(Ts)) > 1:
        T = max(Ts)
        pad_id = tokenizer.pad_token_id
        for b in batches:
            b.inputs = pad_time_np(b.inputs, T, pad_id)
            b.targets = pad_time_np(b.targets, T, 0)
            b.positions = pad_time_np(b.positions, T, -1)
            b.resid_pre = pad_time_np(b.resid_pre, T, 0.0)
            b.logits = pad_time_np(b.logits, T, 0.0)

    all_inputs = np.concatenate([b.inputs for b in batches], axis=0)
    all_targets = np.concatenate([b.targets for b in batches], axis=0)
    all_pos = np.concatenate([b.positions for b in batches], axis=0)
    return tokenizer, cfg, params, batches, all_inputs, all_targets, all_pos


# ---------------------------------------------------------------------
# Attention/OV helpers for localization
# ---------------------------------------------------------------------
def rope_apply(x, sin, cos):
    """Broadcast-safe RoPE application to (B,T,H,D) or (B,T,D)."""
    if x.ndim == 4:  # (B,T,H,D)
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return (x * cos) + (jnp.concatenate([-x2, x1], axis=-1) * sin)


def head_postov_outputs_for_positions(
    resid_pre: jnp.ndarray,
    positions: jnp.ndarray,
    attn_params: Dict,
    cfg,
    pos_list: List[Tuple[int, int]],
    head_idx: int,
) -> List[np.ndarray]:
    hidden = int(cfg["hidden_size"])
    H = int(cfg["num_heads"])
    Hkv = int(cfg.get("num_kv_heads", H))
    D = hidden // H
    B, T, _ = resid_pre.shape
    sin, cos = build_rope(positions, D, cfg["rope_base"])

    y = resid_pre
    Q = y @ attn_params["W_q"]
    K = y @ attn_params["W_k"]
    V = y @ attn_params["W_v"]
    Q = jnp.reshape(Q, (B, T, H, D))
    K = jnp.reshape(K, (B, T, Hkv, D))
    V = jnp.reshape(V, (B, T, Hkv, D))
    Qr = rope_apply(Q, sin, cos)
    Kr = rope_apply(K, sin, cos)
    if Hkv != H:
        rep = H // Hkv
        Kr = jnp.repeat(Kr, repeats=rep, axis=2)
        Vr = jnp.repeat(V, repeats=rep, axis=2)
    else:
        Vr = V
    Vr = jnp.reshape(Vr, (B, T, H, D))

    Wo = attn_params["W_o"]
    rs, re = head_idx * D, (head_idx + 1) * D
    Wo_block = Wo[rs:re, :]  # (D, hidden)
    inv_sqrt_d = 1.0 / math.sqrt(D)

    out = []
    for b, t in pos_list:
        if not (0 <= b < B) or t < 0:
            continue
        seq_valid = int(jnp.sum(positions[b] >= 0))
        kv_len = min(seq_valid, t + 1)
        if kv_len <= 0:
            continue
        q = Qr[b, t, head_idx]  # (D,)
        k = Kr[b, :kv_len, head_idx]  # (kv_len,D)
        v = Vr[b, :kv_len, head_idx]  # (kv_len,D)
        logits = jnp.einsum("d,td->t", q, k) * inv_sqrt_d
        w = jax.nn.softmax(logits, axis=-1)
        pre = jnp.sum(w[:, None] * v, axis=0)  # (D,)
        post = pre @ Wo_block  # (hidden,)
        out.append(np.asarray(post, np.float32))
    return out


def ablate_heads_in_Wo(
    params, heads: List[Tuple[int, int]], hidden_size: int, num_heads: int
):
    """Return copied params where selected heads (layer,head) are zeroed in W_o."""
    head_dim = hidden_size // num_heads
    new_params = dict(params)
    new_layers = []
    for layer_id, lyr in enumerate(params["layers"]):
        lyr_d = dict(lyr)
        if any(l == layer_id for (l, _) in heads):
            attn = dict(lyr_d["attn"])
            W_o = np.array(attn["W_o"])
            for l, h in heads:
                if l != layer_id:
                    continue
                rs, re = h * head_dim, (h + 1) * head_dim
                W_o[rs:re, :] = 0.0
            attn["W_o"] = jnp.asarray(W_o, dtype=attn["W_o"].dtype)
            lyr_d["attn"] = attn
        new_layers.append(lyr_d)
    new_params["layers"] = new_layers
    return new_params


# --- in experiments/valence/core.py ---


def edit_Wo_along_wh(params, layer: int, head: int, w_hat: np.ndarray, alpha: float):
    """
    Scale a single head's W_o block along w_hat by (1 + alpha).
      alpha = -1.0 → project-out (remove that component)
      alpha > 0   → amplify component

    This implementation is shape-based and does NOT reload any config files.
    """
    attn = params["layers"][layer]["attn"]
    W_o = np.array(attn["W_o"])  # (H*D, hidden)
    hidden = W_o.shape[1]

    # Infer head_dim and num_heads from sibling matrices
    # W_q: (hidden, H * D)
    W_q = np.array(attn["W_q"])
    H_times_D = W_q.shape[1]
    # head_dim must divide H_times_D; hidden must be D * ??? but we don't need num_heads explicitly.
    # Compute head_dim by dividing H_times_D by the number of columns in W_o per head block:
    # Actually head_dim = (H*D) / H. We infer H from W_o by solving H * D = W_q.shape[1]
    # and D = (W_o.shape[0] / H). Let H be such that (H * (W_o.shape[0] / H)) == H_times_D -> always true.
    # A robust way: head_dim candidates are divisors of W_o.shape[0]. We can pick D = H_times_D // H
    # Instead keep it simple: D must be H_times_D / H, and W_o.shape[0] = H * D => H = W_o.shape[0] * 1.0 / D
    # Solve directly:
    # Try to find integer H that divides W_o.shape[0] and H_times_D % H == 0.
    total_rows = W_o.shape[0]
    H = None
    for cand_H in range(1, 1 + total_rows):
        if total_rows % cand_H == 0 and H_times_D % cand_H == 0:
            # D inferred from W_q
            D_from_Wq = H_times_D // cand_H
            # D from W_o rows
            D_from_Wo = total_rows // cand_H
            if D_from_Wq == D_from_Wo:
                H = cand_H
                D = D_from_Wo
                break
    if H is None:
        raise ValueError(
            f"Could not infer num_heads/head_dim from shapes: W_o={W_o.shape}, W_q={W_q.shape}"
        )

    # Now edit the chosen head block
    rs, re = head * D, (head + 1) * D
    Wb = W_o[rs:re, :]  # (D, hidden)
    u = Wb @ w_hat.astype(Wb.dtype)  # (D,)   — component along w_hat
    Wb_new = Wb + alpha * (u[:, None] * w_hat[None, :])
    W_o[rs:re, :] = Wb_new

    # Stitch back
    attn["W_o"] = jnp.asarray(W_o, dtype=attn["W_o"].dtype)
    lyr = dict(params["layers"][layer])
    lyr["attn"] = attn
    new_params = dict(params)
    new_params["layers"] = list(params["layers"])
    new_params["layers"][layer] = lyr
    return new_params


def bond_token_ids(tokenizer):
    return {
        "-": tokenizer.token_to_id("-"),
        "=": tokenizer.token_to_id("="),
        "#": tokenizer.token_to_id("#"),
    }
