#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ablation experiment for pre-generative SAE features.

For each selected neuron:
  1) Find top-K (prefix, position) where it fires the strongest (token-level).
  2) Prefill to that token, then generate continuations under:
        (a) Baseline (no intervention)
        (b) Ablation (zero that feature from the firing step onward)
  3) Count SMARTS in generated molecules and report deltas.

Outputs (under --out_dir):
  - per_neuron_<id>.csv     : fragment counts (baseline, ablated, delta) aggregated over samples × contexts
  - per_neuron_examples.json: a few example generations (baseline vs ablated) per neuron
  - summary.csv             : one row per neuron with totals and top-affected fragments
"""

from __future__ import annotations

import argparse
import json
import math
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import jax.numpy as jnp
from rdkit import Chem

# --- LMKit imports you already use ---
from lmkit.sparse.sae import SAEKit
from lmkit.tools.compat import load_tokenizer
from lmkit.tools import data as data_tools

from lmkit.sparse.fragment_mapper import compile_smarts, find_fragments_and_tokens

# ======================
# Utility / data types
# ======================


@dataclass
class TopContext:
    tokens: np.ndarray  # full BOS-indexed token sequence (1D)
    positions: np.ndarray  # position ids (1D)
    t_fire: int  # index where the neuron fired (BOS-indexed)
    act: float  # activation value (for ranking)
    smiles_full: str  # decoded gold SMILES for reference
    prefix_tokens: np.ndarray  # tokens[:t_fire+1] (includes the firing token)
    prefix_len: int


def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)


# ======================
# SMARTS support
# ======================


def load_smarts_yaml(yaml_path: str, profile: str = "leadlike") -> Dict[str, str]:
    """
    Load your MolSAE-SMARTS YAML (the same helper you used earlier).
    Here we expect the normalized 2-line style for ring_seeds ('id:' + 'smiles:').
    """
    import yaml, re

    with open(yaml_path, "r", encoding="utf-8") as f:
        raw = f.read()

    # Normalize 'id: foo ; smiles: "bar"' -> two-line mapping
    semicolon_line = re.compile(
        r'^(\s*)-\s+id:\s*([^;\n]+?)\s*;\s*smiles:\s*"(.*?)"\s*$',
        flags=re.MULTILINE,
    )

    def repl(m: re.Match) -> str:
        indent = m.group(1)
        idval = m.group(2).strip()
        smiles = m.group(3)
        return f'{indent}- id: {idval}\n{indent}  smiles: "{smiles}"'

    fixed = semicolon_line.sub(repl, raw)
    y = yaml.safe_load(fixed)

    prof = y["profiles"][profile]
    include_prefixes = tuple(prof.get("include_prefixes", []))
    drop_ids = set(prof.get("drop", []))

    def want(sid: str) -> bool:
        return sid.startswith(include_prefixes) and sid not in drop_ids

    out = {}
    for item in y.get("curated", []):
        sid = item["id"]
        if want(sid):
            out[sid] = item["smarts"]
    for item in y.get("ring_seeds", []):
        sid = item["id"]
        if want(sid) and item.get("smiles"):
            out[sid] = item["smiles"]
    if not out:
        raise ValueError("No SMARTS selected by the profile.")
    return out


def count_smarts_in_smiles(smiles: str, queries) -> Dict[str, int]:
    """
    Count fragment occurrences via RDKit substructure matches.
    Invalid SMILES -> return empty counts.
    """
    m = Chem.MolFromSmiles(smiles)
    if m is None:
        return {}
    counts = {}
    for name, patt in queries.items():
        try:
            matches = m.GetSubstructMatches(patt, uniquify=True)
            counts[name] = len(matches)
        except Exception:
            counts[name] = 0
    return counts


# ======================
# Context mining
# ======================


def decode_trim(seq: np.ndarray, tokenizer) -> str:
    s = tokenizer.decode(seq[1:], skip_special_tokens=False)  # drop BOS
    s = s.split(tokenizer.eos_token)[0]
    s = s.replace(tokenizer.pad_token, "")
    return s


def top_k_contexts_for_neuron(
    ds,
    sae_kit: SAEKit,
    layer_id: int,
    neuron_id: int,
    tokenizer,
    *,
    top_k: int = 5,
    scan_batches: Optional[int] = None,
) -> List[TopContext]:
    """
    Scan the dataset, find the top-K (token, activation) for the given neuron.
    Returns TopContext objects with prefix tokens ending at the firing token.
    """
    import heapq

    heap: List[Tuple[float, TopContext]] = []  # (act, ctx)

    for bidx, batch in enumerate(ds):
        if scan_batches is not None and bidx >= scan_batches:
            break

        inputs = batch["inputs"]  # (B,T) jnp / np
        positions = batch["positions"]  # (B,T)
        mask_valid = sae_kit.mask_fn(
            inputs
        )  # (B,T) 1 for real tokens (BOS/EOS/PAD excluded)

        # Get activations (B,T,K) on device, then move to host as a mutable array
        acts = sae_kit.get_encoded(inputs, positions, layer_id)  # jnp array
        acts_np = np.array(acts)  # (B,T,K) mutable host copy
        valid = np.array(mask_valid, dtype=bool)  # (B,T) mutable host copy

        # Slice the requested neuron → (B,T) and suppress invalid positions
        feat = acts_np[:, :, neuron_id]  # (B,T)
        # Avoid in-place assignment (which fails on read-only views); use np.where
        feat = np.where(valid, feat, -1e30)

        # Top per-sequence candidate (reduce heap traffic)
        top_t = np.argmax(feat, axis=1)  # (B,)
        top_v = feat[np.arange(feat.shape[0]), top_t]  # (B,)

        for i in range(feat.shape[0]):
            v = top_v[i]
            if not np.isfinite(v) or v <= -1e20:
                continue

            t_fire = int(top_t[i])
            tok_seq = np.asarray(inputs[i])
            pos_seq = np.asarray(positions[i])

            # guard: must have at least one step to continue sampling
            if t_fire >= tok_seq.shape[0] - 1:
                continue

            prefix = tok_seq[: t_fire + 1]
            smiles_full = decode_trim(tok_seq, tokenizer)

            ctx = TopContext(
                tokens=tok_seq,
                positions=pos_seq,
                t_fire=t_fire,
                act=float(v),
                smiles_full=smiles_full,
                prefix_tokens=prefix,
                prefix_len=prefix.shape[0],
            )

            if len(heap) < top_k:
                heapq.heappush(heap, (ctx.act, ctx))
            else:
                if ctx.act > heap[0][0]:
                    heapq.heapreplace(heap, (ctx.act, ctx))

    # return top-K sorted by activation (desc)
    out = [c for _, c in heap]
    out.sort(key=lambda c: c.act, reverse=True)
    return out


# ======================
# Generation adapter
# ======================


def _decode_one(sample_tokens: np.ndarray, tokenizer) -> str:
    s = tokenizer.decode(sample_tokens, skip_special_tokens=False)
    s = s.split(tokenizer.eos_token)[0]
    s = s.replace(tokenizer.pad_token, "")
    # also strip BOS if present at pos 0
    if s and s[0] == tokenizer.bos_token:
        s = s[1:]
    return s


def _call_repo_generator(
    sae_kit: SAEKit,
    tokenizer,
    prefix_tokens: np.ndarray,
    *,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    num_samples: int,
    ablate_config: Optional[dict] = None,
) -> List[str]:
    """
    Tries several common entry points in LMKit to generate sequences.
    Wire your project-specific generator here if needed (see TODO).
    The ablate_config is passed through if the generator supports it.

    ablate_config format (convention used here):
      {
        "layer_id": int,
        "feature_id": int,
        "from_step": int,       # ablate at and after this decoding step (prefix length)
      }
    """
    # --- TODO: if your project exposes a specific generator, call it here and return SMILES strings. ---
    # Try SAEKit.generate(...)
    if hasattr(sae_kit, "generate"):
        return sae_kit.generate(
            prefix_tokens=prefix_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            num_samples=num_samples,
            ablate_config=ablate_config,  # your SAEKit.generate should use this to zero the feature
            decode_fn=lambda toks: _decode_one(toks, tokenizer),
        )

    # Try model.generate(...)
    if hasattr(sae_kit, "model") and hasattr(sae_kit.model, "generate"):
        return sae_kit.model.generate(
            prefix_tokens=prefix_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            num_samples=num_samples,
            ablate_config=ablate_config,
            decode_fn=lambda toks: _decode_one(toks, tokenizer),
        )

    # Try a global tools generator if you have one
    try:
        from lmkit.tools import generation as gen  # type: ignore

        return gen.generate(
            model=sae_kit.model,
            tokenizer=tokenizer,
            prefix_tokens=prefix_tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            num_samples=num_samples,
            ablate_config=ablate_config,
        )
    except Exception:
        pass

    raise RuntimeError(
        "Could not find a generation entry point. "
        "Please wire your generator in _call_repo_generator()."
    )


# ======================
# Core experiment
# ======================


def run_ablation_experiment(
    *,
    model_dir: str,
    sae_dir: str,
    ckpt_id: int,
    dataset_dir: str,
    smarts_yaml: str,
    profile: str,
    layer_id: int,
    neuron_ids: List[int],
    top_k_contexts: int,
    num_samples: int,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    out_dir: str,
    limit_mols: int = 100_000,  # upper bound scan to find contexts fast
    batch_size: int = 256,
    num_proc: int = 1,
):
    out_path = Path(out_dir)
    ensure_dir(out_path)

    # 1) Load tokenizer & model/SAE
    tokenizer = load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        trunc_length=None,
    )
    sae_kit = SAEKit.load(model_dir=model_dir, checkpoint_id=ckpt_id, sae_dir=sae_dir)

    # 2) Dataset
    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir,
        tokenizer=tokenizer,
        batch_size=batch_size,
        num_processes=num_proc,
        seed=2002,
        caching=True,
        limit=limit_mols,
    )

    # 3) SMARTS
    smarts_dict = load_smarts_yaml(smarts_yaml, profile=profile)
    queries = compile_smarts(smarts_dict)

    # 4) For each neuron: find top-K contexts
    summary_rows = []
    for nid in neuron_ids:
        print(
            f"\n[neuron {nid}] mining top-{top_k_contexts} contexts on layer {layer_id} …"
        )
        contexts = top_k_contexts_for_neuron(
            ds, sae_kit, layer_id, nid, tokenizer, top_k=top_k_contexts
        )
        if not contexts:
            print(f"  ! no valid contexts found for neuron {nid} (skipping)")
            continue

        # 5) Generate continuations: baseline vs ablation
        agg_baseline = defaultdict(int)  # fragment -> total count across all samples
        agg_ablated = defaultdict(int)
        example_log = []  # a few paired samples per neuron

        for cidx, ctx in enumerate(contexts):
            print(
                f"  - context #{cidx + 1}: t_fire={ctx.t_fire} act={ctx.act:.3g} prefix_len={ctx.prefix_len}"
            )

            # (a) Baseline
            base_smiles = _call_repo_generator(
                sae_kit,
                tokenizer,
                ctx.prefix_tokens,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                num_samples=num_samples,
                ablate_config=None,
            )

            # (b) Ablation from the firing step onward
            ablate_cfg = dict(
                layer_id=layer_id, feature_id=nid, from_step=ctx.prefix_len
            )
            abl_smiles = _call_repo_generator(
                sae_kit,
                tokenizer,
                ctx.prefix_tokens,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                num_samples=num_samples,
                ablate_config=ablate_cfg,
            )

            # Count SMARTS occurrences
            for s in base_smiles:
                counts = count_smarts_in_smiles(s, queries)
                for k, v in counts.items():
                    agg_baseline[k] += int(v)

            for s in abl_smiles:
                counts = count_smarts_in_smiles(s, queries)
                for k, v in counts.items():
                    agg_ablated[k] += int(v)

            # Keep a few paired examples for qualitative inspection
            for j in range(min(5, min(len(base_smiles), len(abl_smiles)))):
                example_log.append(
                    dict(
                        neuron_id=nid,
                        context_index=cidx,
                        act=ctx.act,
                        prefix=_decode_one(ctx.prefix_tokens, tokenizer),
                        baseline=base_smiles[j],
                        ablated=abl_smiles[j],
                    )
                )

        # 6) Aggregate & write per-neuron CSV
        all_frags = sorted(set(agg_baseline.keys()) | set(agg_ablated.keys()))
        per_neuron_csv = out_path / f"per_neuron_{nid}.csv"
        with per_neuron_csv.open("w") as f:
            f.write("fragment,baseline_count,ablated_count,delta\n")
            for frag in all_frags:
                b = agg_baseline.get(frag, 0)
                a = agg_ablated.get(frag, 0)
                d = a - b
                f.write(f"{frag},{b},{a},{d}\n")

        # 7) Save paired examples
        with (out_path / f"per_neuron_{nid}_examples.json").open("w") as f:
            json.dump(example_log, f, indent=2)

        # 8) Add to global summary (top deltas)
        deltas = {k: agg_ablated.get(k, 0) - agg_baseline.get(k, 0) for k in all_frags}
        if deltas:
            top_aff = sorted(deltas.items(), key=lambda kv: -abs(kv[1]))[:5]
        else:
            top_aff = []
        summary_rows.append(
            dict(
                neuron_id=nid,
                layer_id=layer_id,
                total_baseline=sum(agg_baseline.values()),
                total_ablated=sum(agg_ablated.values()),
                total_delta=sum(deltas.values()),
                top_affected_fragments="; ".join([f"{k}:{v:+d}" for k, v in top_aff]),
            )
        )

    # 9) Global summary
    if summary_rows:
        import pandas as pd

        df = pd.DataFrame(summary_rows)
        df.to_csv(Path(out_dir) / "summary.csv", index=False)
        print(
            f"\n✓ Wrote {len(summary_rows)} neuron summaries to {Path(out_dir) / 'summary.csv'}"
        )
    else:
        print("\n(no neuron results were produced)")


# ======================
# CLI
# ======================


def parse_int_list(s: str) -> List[int]:
    return [int(x) for x in s.replace(",", " ").split() if x.strip()]


def main():
    ap = argparse.ArgumentParser(
        description="Causal ablation of pre-generative SAE features (generation & SMARTS counts)."
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--sae_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--smarts_yaml", required=True)
    ap.add_argument("--profile", default="leadlike")
    ap.add_argument("--layer_id", type=int, default=5)

    ap.add_argument(
        "--neurons",
        type=str,
        required=True,
        help="Space/comma-separated neuron IDs, e.g. '1125, 201, 77'",
    )
    ap.add_argument("--top_k_contexts", type=int, default=5)

    ap.add_argument("--num_samples", type=int, default=64)
    ap.add_argument("--max_new_tokens", type=int, default=64)
    ap.add_argument("--temperature", type=float, default=0.8)
    ap.add_argument("--top_p", type=float, default=0.95)

    ap.add_argument("--out_dir", default="pre_ablation_runs")
    ap.add_argument("--limit_mols", type=int, default=100000)
    ap.add_argument("--batch_size", type=int, default=256)
    ap.add_argument("--num_proc", type=int, default=1)

    args = ap.parse_args()

    run_ablation_experiment(
        model_dir=args.model_dir,
        sae_dir=args.sae_dir,
        ckpt_id=args.ckpt_id,
        dataset_dir=args.dataset_dir,
        smarts_yaml=args.smarts_yaml,
        profile=args.profile,
        layer_id=args.layer_id,
        neuron_ids=parse_int_list(args.neurons),
        top_k_contexts=args.top_k_contexts,
        num_samples=args.num_samples,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_p=args.top_p,
        out_dir=args.out_dir,
        limit_mols=args.limit_mols,
        batch_size=args.batch_size,
        num_proc=args.num_proc,
    )


if __name__ == "__main__":
    main()
