# lmkit/atlas/causal_steering.py
# Causal SMARTS steering with SAE features
# - Select top 10 SMARTS by delta = frac(rep) - frac(baseline) (descending)
# - Rank SAE features by fraction of firing on representatives
# - Steer by single-feature latent bump; measure delta in SMARTS frequency vs baseline
# - Save full results and leaderboards to --out_dir

from __future__ import annotations

import argparse
import os
from typing import Dict, List, Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from flax.core import freeze
from rdkit import Chem
from scipy.stats import fisher_exact
from tqdm.auto import tqdm

from lmkit.impl import config as config_lib
from lmkit.impl import transformer
from lmkit.impl.hooks import ActivationEditor, Edit
from lmkit.impl.sampler import generate
from lmkit.sparse import sae as sae_mod
from lmkit.sparse import utils as sae_utils
from lmkit.sparse.sae import decode_latent, normalize

# ---- lmkit imports (match your working notebook path) ----
from lmkit.tools import compat, train_utils
from lmkit.tools.stem import clean_smiles

# SMARTS lib (atlas → feature_bank fallback)
try:
    from lmkit.atlas.smarts_lib import get_smarts_library
except Exception:
    try:
        from lmkit.atlas.smarts_lib import get_smarts_library
    except Exception:
        get_smarts_library = None


# ------------------------- Load model + SAEs ------------------------- #


def load_model_and_saes(model_dir: str, checkpoint_id: str | int, sae_dir: str):
    tokenizer = compat.load_tokenizer(
        mode="train",
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        trunc_length=256,
    )

    model_config_raw = config_lib.load_from_dir(model_dir)
    lm_config = model_config_raw.copy(
        dict(
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            pad_id=tokenizer.pad_token_id,
        )
    )
    lm_config = freeze(lm_config)

    lm_params, *_ = train_utils.load_checkpoint(
        f"{model_dir}/checkpoints/checkpoint_{checkpoint_id}.pkl"
    )

    sae_configs = [
        sae_mod.SAEConfig.from_file(f"{sae_dir}/sae_{i}_config.json").replace(
            act_fn=sae_mod.relu, loss_fn=None
        )
        for i in range(model_config_raw["num_layers"])
    ]
    sae_params = [
        train_utils.load_checkpoint(f"{sae_dir}/sae_{i}/checkpoint_final.pkl")[0]
        for i in range(model_config_raw["num_layers"])
    ]

    train_config = sae_mod.TrainConfig(
        sae_configs=sae_configs,
        lm_params=lm_params,
        lm_config=lm_config,
        tokenizer=tokenizer,
        optimizers=None,
        checkpoint_root=None,
    )
    run_fn, mask_fn, hooks = sae_utils.get_hook_tools(train_config)

    return (
        tokenizer,
        lm_config,
        lm_params,
        sae_configs,
        sae_params,
        run_fn,
        mask_fn,
        hooks,
    )


# ------------------------- SMARTS helpers ------------------------- #


def _suggest_keys(catalog: dict, needle: str, limit=10):
    n = needle.lower()
    hits = [k for k in catalog.keys() if n in k.lower()]
    hits.sort(key=lambda k: (not k.lower().startswith(n), k))
    return hits[:limit]


def compile_smarts_map(
    ids_or_smarts: List[str], custom_map: Optional[Dict[str, str]] = None
):
    catalog = get_smarts_library() if get_smarts_library is not None else {}
    custom_map = custom_map or {}
    out: Dict[str, Tuple[str, Chem.Mol]] = {}
    for key in ids_or_smarts:
        if key in custom_map:
            s = custom_map[key]
            q = Chem.MolFromSmarts(s)
            if q is None:
                raise ValueError(f"Invalid custom SMARTS for '{key}': {s}")
            out[key] = (s, q)
            continue
        if key in catalog:
            s = catalog[key]
            q = Chem.MolFromSmarts(s)
            if q is None:
                raise ValueError(f"Catalog entry for '{key}' invalid SMARTS: {s}")
            out[key] = (s, q)
            continue
        if ":" not in key and catalog:
            sugg = _suggest_keys(catalog, key, limit=1)
            if sugg:
                k2 = sugg[0]
                s = catalog[k2]
                q = Chem.MolFromSmarts(s)
                if q:
                    out[k2] = (s, q)
                    print(f"[SMARTS] Interpreting '{key}' as '{k2}'")
                    continue
        q = Chem.MolFromSmarts(key)
        if q is None:
            hint = ""
            if catalog:
                sug = _suggest_keys(catalog, key, limit=5)
                if sug:
                    hint = f"\nDid you mean: {', '.join(sug)}"
            raise ValueError(f"Invalid SMARTS or unknown key '{key}'.{hint}")
        out[key] = (key, q)
    return out


# ------------------------- Representatives from dataset ------------------------- #


def load_smiles_from_dataset(dataset_dir: str, limit: int) -> List[str]:
    from datasets import load_from_disk

    ds = load_from_disk(dataset_dir)
    n = min(limit, len(ds)) if limit is not None else len(ds)
    out = []
    for i in range(n):
        smi = ds[i].get("smiles")
        if not smi:
            continue
        cs = clean_smiles(smi)
        if cs and Chem.MolFromSmiles(cs) is not None:
            out.append(cs)
    return out


def collect_reps_for_patterns(
    smarts_map: Dict[str, Tuple[str, Chem.Mol]],
    smiles_pool: List[str],
    per_pattern: int,
) -> Tuple[Dict[str, List[str]], List[str]]:
    reps_by_pat: Dict[str, List[str]] = {k: [] for k in smarts_map}
    seen = set()
    for s in smiles_pool:
        try:
            m = Chem.MolFromSmiles(s)
            if m is None:
                continue
        except Exception:
            continue
        for key, (_, q) in smarts_map.items():
            if len(reps_by_pat[key]) >= per_pattern:
                continue
            try:
                if m.HasSubstructMatch(q):
                    reps_by_pat[key].append(s)
                    seen.add(s)
            except Exception:
                pass
    union = list(seen)
    return reps_by_pat, union


# ------------------------- Token packing ------------------------- #


def pack_batch(tokenizer, smiles_list: List[str]):
    encs = tokenizer.encode_batch_fast(smiles_list)
    seqs = [e.ids for e in encs]
    maxlen = max(len(x) for x in seqs)
    pad_id, bos_id, eos_id = (
        tokenizer.pad_token_id,
        tokenizer.bos_token_id,
        tokenizer.eos_token_id,
    )
    ids = np.full((len(seqs), maxlen), pad_id, dtype=np.int32)
    for i, row in enumerate(seqs):
        ids[i, : len(row)] = row
    pos_base = np.arange(maxlen, dtype=np.int32)
    pos_tiled = np.tile(pos_base, (ids.shape[0], 1))
    valid = (ids != pad_id) & (ids != bos_id) & (ids != eos_id)
    positions = np.where(valid, pos_tiled, -1).astype(np.int32)
    return jnp.asarray(ids), jnp.asarray(positions)


# ------------------------- Fractions & stds on reps ------------------------- #


def fractions_and_stds_on_reps(
    rep_smiles: List[str],
    tokenizer,
    run_fn,
    mask_fn,
    hooks,
    lm_params,
    lm_config,
    sae_configs,
    sae_params,
    batch_size: int = 64,
    act_threshold: float = 1e-6,
):
    num_layers = int(lm_config["num_layers"])
    n = len(rep_smiles)
    rep_counts = {L: None for L in range(num_layers)}
    sum_act = {L: None for L in range(num_layers)}
    sumsq_act = {L: None for L in range(num_layers)}
    num_tok = {L: 0 for L in range(num_layers)}

    for start in tqdm(range(0, n, batch_size), desc="Reps (fractions/std)"):
        batch = rep_smiles[start : start + batch_size]
        ids, positions = pack_batch(tokenizer, batch)
        residuals = sae_mod.utils.run_and_capture(
            run_fn, ids, positions, lm_params, lm_config, hooks
        )
        valid_mask = (np.asarray(positions) >= 0)[..., None].astype(np.float32)

        for L, cfg in enumerate(sae_configs):
            sae_inp = residuals[(cfg.layer_id, cfg.placement)]
            rec, lat, act = sae_mod.run(
                sae_inp, sae_params[L], cfg, return_latents=True, return_act=True
            )
            acts_np = np.asarray(act) * valid_mask  # (B,T,K)

            seq_max = acts_np.max(axis=1)  # (B,K)
            fired = (seq_max > act_threshold).astype(np.int32)
            if rep_counts[L] is None:
                rep_counts[L] = np.zeros(acts_np.shape[-1], dtype=np.int64)
            rep_counts[L] += fired.sum(axis=0)

            if sum_act[L] is None:
                sum_act[L] = acts_np.sum(axis=(0, 1))
                sumsq_act[L] = (acts_np**2).sum(axis=(0, 1))
            else:
                sum_act[L] += acts_np.sum(axis=(0, 1))
                sumsq_act[L] += (acts_np**2).sum(axis=(0, 1))
            num_tok[L] += int(valid_mask.sum())

    rep_frac = {
        L: rep_counts[L].astype(np.float64) / max(1, n) for L in range(num_layers)
    }
    stds = {
        L: np.sqrt(
            np.maximum(
                sumsq_act[L] / max(1, num_tok[L])
                - (sum_act[L] / max(1, num_tok[L])) ** 2,
                0.0,
            )
        )
        for L in range(num_layers)
    }
    return rep_frac, stds


# ------------------------- Generation & counting ------------------------- #


def generate_valid(
    tokenizer,
    lm_params,
    lm_config,
    n_target: int,
    max_attempts: int,
    temp: float,
    top_p: float,
    max_new_tokens: int,
    editor: Optional[ActivationEditor] = None,
    seed: int = 2025,
    batch: int = 512,
    dedup: bool = False,
) -> List[str]:
    out, attempts, step = [], 0, 0
    run_fn_local = (
        transformer.run
        if editor is None
        else (lambda *a, **k: transformer.run(*a, editor=editor, **k))
    )
    while len(out) < n_target and attempts < max_attempts:
        b = min(batch, n_target - len(out))
        bos = jnp.array([[tokenizer.bos_token_id]] * b, dtype=jnp.int32)
        seqs = generate(
            max_new_tokens=max_new_tokens,
            tokenizer=tokenizer,
            params=lm_params,
            config=lm_config,
            random_key=jax.random.PRNGKey(seed + step),
            tokenized_inputs=bos,
            temp=temp,
            top_p=top_p,
            return_text=True,
            run_fn=run_fn_local,
            verbose=False,
        )
        attempts += b
        step += 1
        for s in seqs:
            cs = clean_smiles(s)
            if cs is not None and Chem.MolFromSmiles(cs) is not None:
                out.append(cs)
        if dedup:
            out = list(dict.fromkeys(out))
    return out[:n_target] if len(out) >= n_target else out


def count_hits(smiles_list: List[str], q: Chem.Mol) -> int:
    c = 0
    for s in smiles_list:
        try:
            m = Chem.MolFromSmiles(s)
            if m is not None and m.HasSubstructMatch(q):
                c += 1
        except Exception:
            pass
    return c


# ------------------------- Steering editor ------------------------- #


def make_editor_single_add(
    layer_id: int, add_vec_1d: np.ndarray, sae_configs, sae_params
) -> ActivationEditor:
    cfg = sae_configs[layer_id]
    prm = sae_params[layer_id]
    add_vec_1d = jnp.asarray(add_vec_1d, dtype=jnp.float32)

    def _cb(x):
        if cfg.rescale_inputs:
            x_norm, x_mean, x_std = normalize(x)
        else:
            x_norm, x_mean, x_std = x, 0.0, 1.0
        x_enc = x_norm - prm["b_dec"] if cfg.pre_enc_bias else x_norm
        z_cur = x_enc @ prm["W_enc"]  # (B,T,K)
        z_add = z_cur + add_vec_1d
        x_new = decode_latent(z_add, params=prm, config=cfg, x_mean=x_mean, x_std=x_std)
        return x_new.astype(x.dtype)

    return ActivationEditor(
        edits=(Edit(layer=layer_id, kind=cfg.placement, op="call", callback=_cb),)
    )


# ------------------------- Pattern selection (Δ ranking) ------------------------- #


def select_top_patterns_by_delta(
    smarts_map: Dict[str, Tuple[str, Chem.Mol]],
    reps_by_pat: Dict[str, List[str]],
    baseline: List[str],
    top_k: int = 10,
) -> List[str]:
    """
    Rank motifs by delta = frac(reps) - frac(baseline), descending.
    Tie-break by higher frac(reps). This mirrors the notebook behavior.
    """
    rows = []
    n_base = len(baseline)
    for key, (_, q) in smarts_map.items():
        reps = reps_by_pat.get(key, [])
        if not reps:
            continue
        c_rep = count_hits(
            reps, q
        )  # usually == len(reps) because of how reps are chosen
        c_base = count_hits(baseline, q)
        f_rep = c_rep / max(1, len(reps))
        f_base = c_base / max(1, n_base)
        delta = f_rep - f_base
        rows.append(
            dict(
                smarts_id=key,
                rep_size=len(reps),
                f_rep=f_rep,
                f_base=f_base,
                delta=delta,
            )
        )
    df = pd.DataFrame(rows)
    if df.empty:
        return []
    df = df.sort_values(["delta", "f_rep"], ascending=[False, False]).reset_index(
        drop=True
    )
    return df.head(top_k)["smarts_id"].tolist()


# ------------------------- Main experiment ------------------------- #


def run_experiment(
    model_dir: str,
    checkpoint_id: str | int,
    sae_dir: str,
    out_dir: str,
    dataset_dir: Optional[str],
    reps_scan_limit: int,
    reps_per_pattern: int,
    top_patterns: int,
    top_features_per_layer: int,
    frac_min: float,
    target_valid: int,
    max_attempts: int,
    temp: float,
    top_p: float,
    max_new_tokens: int,
    dedup_generations: bool,
    steer_by_std_mult: float,
    steer_absolute: Optional[float],
    both_directions: bool,
    layers: Optional[List[int]],
    rep_batch: int,
    act_threshold: float,
    baseline_seed: int,
) -> None:
    os.makedirs(out_dir, exist_ok=True)
    tokenizer, lm_config, lm_params, sae_configs, sae_params, run_fn, mask_fn, hooks = (
        load_model_and_saes(model_dir, checkpoint_id, sae_dir)
    )
    num_layers = int(lm_config["num_layers"])

    # 1) SMARTS catalog (all motifs candidate)
    catalog = get_smarts_library() if get_smarts_library is not None else {}
    candidate_keys = list(catalog.keys())
    smarts_map = compile_smarts_map(candidate_keys)
    print(f"[info] Loaded {len(smarts_map)} SMARTS motifs from catalog.")

    # 2) Build representatives per motif
    if dataset_dir:
        print(f"[info] Loading up to {reps_scan_limit:,} SMILES from dataset …")
        pool = load_smiles_from_dataset(dataset_dir, reps_scan_limit)
    else:
        print("[info] No dataset provided; sampling a pool via baseline generation …")
        pool = generate_valid(
            tokenizer=tokenizer,
            lm_params=lm_params,
            lm_config=lm_config,
            n_target=min(20000, reps_scan_limit),
            max_attempts=max_attempts,
            temp=temp,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            editor=None,
            seed=baseline_seed,
            batch=512,
            dedup=True,
        )
    print(f"[info] Pool size for representatives: {len(pool):,}")

    reps_by_pat, union_reps = collect_reps_for_patterns(
        smarts_map, pool, reps_per_pattern
    )
    union_reps = list(dict.fromkeys(union_reps))
    print(
        f"[info] Built representatives for {sum(1 for k in reps_by_pat if reps_by_pat[k])} motifs; union size={len(union_reps):,}"
    )

    # 3) Small baseline just for motif selection (not final eval)
    print("[info] Generating small baseline (for motif selection) …")
    baseline_small = generate_valid(
        tokenizer=tokenizer,
        lm_params=lm_params,
        lm_config=lm_config,
        n_target=min(8000, target_valid),
        max_attempts=max_attempts,
        temp=temp,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        editor=None,
        seed=baseline_seed + 1,
        batch=512,
        dedup=True,
    )

    # ---- selection by DELTA only (descending), like in the notebook ----
    top_motifs = select_top_patterns_by_delta(
        smarts_map, reps_by_pat, baseline_small, top_k=top_patterns
    )
    if not top_motifs:
        raise RuntimeError(
            "No SMARTS patterns found to be significant in selection stage."
        )
    print("[info] Selected patterns (by delta):", top_motifs)

    # 4) Representatives for feature ranking = union of selected motifs
    rep_smiles = []
    for key in top_motifs:
        rep_smiles.extend(reps_by_pat.get(key, []))
    rep_smiles = list(dict.fromkeys(rep_smiles))
    print(f"[info] Representative set for feature ranking: {len(rep_smiles):,}")

    # 5) Fractions and std on reps
    rep_frac, stds = fractions_and_stds_on_reps(
        rep_smiles,
        tokenizer,
        run_fn,
        mask_fn,
        hooks,
        lm_params,
        lm_config,
        sae_configs,
        sae_params,
        batch_size=rep_batch,
        act_threshold=act_threshold,
    )

    # 6) Final baseline (evaluation)
    print("[info] Generating final baseline …")
    baseline = generate_valid(
        tokenizer=tokenizer,
        lm_params=lm_params,
        lm_config=lm_config,
        n_target=target_valid,
        max_attempts=max_attempts,
        temp=temp,
        top_p=top_p,
        max_new_tokens=max_new_tokens,
        editor=None,
        seed=baseline_seed + 2,
        batch=512,
        dedup=dedup_generations,
    )
    n_base = len(baseline)
    print(f"[info] Final baseline valid: {n_base}/{target_valid}")

    baseline_hits = {}
    baseline_frac = {}
    for key in top_motifs:
        _, q = smarts_map[key]
        c = count_hits(baseline, q)
        baseline_hits[key] = c
        baseline_frac[key] = c / max(1, n_base)

    # 7) Brute-force steering
    rows = []
    layers_to_test = list(range(num_layers)) if not layers else layers
    print(f"[info] Testing layers: {layers_to_test}")

    for L in layers_to_test:
        K = sae_configs[L].latent_size
        frac = rep_frac[L]
        stdv = stds[L]

        cand = np.where(frac >= frac_min)[0]
        order = cand[np.argsort(-frac[cand])]
        order = order[:top_features_per_layer]
        if len(order) == 0:
            print(f"[warn] Layer {L}: no features pass frac_min={frac_min}")
            continue

        print(f"[info] Layer {L}: testing {len(order)} of {K} features")

        for fid in tqdm(order, desc=f"Layer {L}"):
            if steer_absolute is not None:
                add_val = float(steer_absolute)
            else:
                est_std = stdv[fid] if np.isfinite(stdv[fid]) else 0.0
                add_val = float(steer_by_std_mult * (est_std if est_std > 0 else 0.05))

            dirs = [("pos", +add_val)]
            if both_directions:
                dirs.append(("neg", -add_val))

            for dn, delta in dirs:
                add_vec = np.zeros(K, dtype=np.float32)
                add_vec[int(fid)] = delta
                editor = make_editor_single_add(L, add_vec, sae_configs, sae_params)

                steered = generate_valid(
                    tokenizer=tokenizer,
                    lm_params=lm_params,
                    lm_config=lm_config,
                    n_target=target_valid,
                    max_attempts=max_attempts,
                    temp=temp,
                    top_p=top_p,
                    max_new_tokens=max_new_tokens,
                    editor=editor,
                    seed=202600 + int(fid),
                    batch=512,
                    dedup=dedup_generations,
                )
                n_steer = len(steered)

                for key in top_motifs:
                    smarts_str, q = smarts_map[key]
                    c_steer = count_hits(steered, q)
                    f_steer = c_steer / max(1, n_steer)
                    c_base = baseline_hits[key]
                    f_base = baseline_frac[key]
                    delta_f = f_steer - f_base
                    table = [[c_steer, n_steer - c_steer], [c_base, n_base - c_base]]
                    try:
                        odds, p = fisher_exact(table, alternative="greater")
                    except Exception:
                        odds, p = (np.nan, np.nan)

                    rows.append(
                        {
                            "smarts_id": key,
                            "smarts": smarts_str,
                            "layer": L,
                            "feature_id": int(fid),
                            "direction": dn,
                            "rep_frac": float(frac[fid]),
                            "steer_value": float(delta),
                            "baseline_hits": int(c_base),
                            "baseline_n": int(n_base),
                            "baseline_frac": float(f_base),
                            "steered_hits": int(c_steer),
                            "steered_n": int(n_steer),
                            "steered_frac": float(f_steer),
                            "delta_frac": float(delta_f),
                            "odds_ratio": float(odds) if np.isfinite(odds) else np.nan,
                            "p_value_fisher_greater": float(p),
                        }
                    )

    results_df = (
        pd.DataFrame(rows)
        .sort_values(
            ["smarts_id", "layer", "delta_frac"], ascending=[True, True, False]
        )
        .reset_index(drop=True)
    )

    # Save outputs
    results_csv = os.path.join(out_dir, "causal_steering_results.csv")
    results_df.to_csv(results_csv, index=False)

    leaderboard = (
        results_df.sort_values(
            ["smarts_id", "layer", "delta_frac"], ascending=[True, True, False]
        )
        .groupby(["smarts_id", "layer"])
        .head(10)
        .reset_index(drop=True)
    )
    leaderboard_csv = os.path.join(out_dir, "causal_steering_top10.csv")
    leaderboard.to_csv(leaderboard_csv, index=False)

    sel_diag = pd.DataFrame(
        dict(smarts_id=top_motifs, baseline_frac=[baseline_frac[k] for k in top_motifs])
    )
    sel_diag_csv = os.path.join(out_dir, "selected_patterns.csv")
    sel_diag.to_csv(sel_diag_csv, index=False)

    print(f"[done] Saved:\n  {results_csv}\n  {leaderboard_csv}\n  {sel_diag_csv}")


# ------------------------- CLI ------------------------- #


def parse_args():
    ap = argparse.ArgumentParser(
        "Causal SMARTS steering with SAE features (delta-ranked patterns)"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--checkpoint_id", required=True)
    ap.add_argument("--sae_dir", required=True)
    ap.add_argument("--out_dir", required=True)

    ap.add_argument(
        "--dataset_dir",
        default=None,
        help="Optional HF dataset (load_from_disk) with 'smiles' column for reps",
    )
    ap.add_argument(
        "--reps_scan_limit",
        type=int,
        default=200_000,
        help="Max molecules scanned to pick representatives",
    )
    ap.add_argument(
        "--reps_per_pattern",
        type=int,
        default=128,
        help="Max representatives per SMARTS during selection",
    )
    ap.add_argument(
        "--top_patterns",
        type=int,
        default=10,
        help="How many SMARTS to keep for steering",
    )

    ap.add_argument("--top_features_per_layer", type=int, default=32)
    ap.add_argument("--frac_min", type=float, default=0.02)

    ap.add_argument("--target_valid", type=int, default=5000)
    ap.add_argument("--max_attempts", type=int, default=30_000)
    ap.add_argument("--temp", type=float, default=0.6)
    ap.add_argument("--top_p", type=float, default=0.9)
    ap.add_argument("--max_new_tokens", type=int, default=256)
    ap.add_argument("--dedup_generations", action="store_true")

    ap.add_argument("--steer_by_std_mult", type=float, default=4.0)
    ap.add_argument("--steer_absolute", type=float, default=None)
    ap.add_argument("--both_directions", action="store_true")

    ap.add_argument(
        "--layers",
        type=str,
        default=None,
        help="Comma-separated layer ids, e.g., '2,3,4'; default: all layers",
    )

    ap.add_argument("--rep_batch", type=int, default=64)
    ap.add_argument("--act_threshold", type=float, default=1e-6)
    ap.add_argument("--baseline_seed", type=int, default=2025)

    return ap.parse_args()


def main():
    args = parse_args()
    layers = None
    if args.layers:
        layers = [int(x) for x in args.layers.split(",") if x.strip() != ""]
    run_experiment(
        model_dir=args.model_dir,
        checkpoint_id=args.checkpoint_id,
        sae_dir=args.sae_dir,
        out_dir=args.out_dir,
        dataset_dir=args.dataset_dir,
        reps_scan_limit=args.reps_scan_limit,
        reps_per_pattern=args.reps_per_pattern,
        top_patterns=args.top_patterns,
        top_features_per_layer=args.top_features_per_layer,
        frac_min=args.frac_min,
        target_valid=args.target_valid,
        max_attempts=args.max_attempts,
        temp=args.temp,
        top_p=args.top_p,
        max_new_tokens=args.max_new_tokens,
        dedup_generations=args.dedup_generations,
        steer_by_std_mult=args.steer_by_std_mult,
        steer_absolute=args.steer_absolute,
        both_directions=args.both_directions,
        layers=layers,
        rep_batch=args.rep_batch,
        act_threshold=args.act_threshold,
        baseline_seed=args.baseline_seed,
    )


if __name__ == "__main__":
    main()
