# lmkit/feature_bank/cli.py
import argparse
import os

from tqdm.auto import tqdm

from ..sparse.sae import SAEKit
from ..tools.data import load_and_tokenize
from .mining import (
    align_feature,
    collect_top_smiles_for_layer,
    pick_features,
    save_feature_bank,
)


def main():
    ap = argparse.ArgumentParser(
        "Feature-bank mining (per-feature, statistical enrichment)"
    )
    ap.add_argument(
        "--model_dir",
        required=True,
        help="LM directory (with tokenizer.json and checkpoints)",
    )
    ap.add_argument(
        "--checkpoint_id", required=True, help="LM checkpoint id (number or 'final')"
    )
    ap.add_argument(
        "--sae_dir",
        required=True,
        help="Directory with per-layer SAE checkpoints/configs",
    )
    ap.add_argument(
        "--dataset_dir",
        required=True,
        help="HuggingFace datasets 'load_from_disk' directory for background SMILES",
    )
    ap.add_argument(
        "--layer", type=int, required=True, help="Layer id for SAE features"
    )
    ap.add_argument(
        "--num_batches",
        type=int,
        default=250,
        help="Batches to scan for top activations (default batch size: 1024)",
    )
    ap.add_argument(
        "--top_sequences",
        type=int,
        default=1024,
        help="Top sequences to keep per feature in StatsCollector",
    )
    ap.add_argument(
        "--feature_metric",
        type=str,
        default="gini",
        choices=["mean", "max", "sparsity", "mean*max", "selectivity", "gini"],
        help="Metric for selecting features to align",
    )
    ap.add_argument(
        "--topk_features",
        type=int,
        default=256,
        help="How many features to align (after metric ranking)",
    )
    ap.add_argument(
        "--top_pos",
        type=int,
        default=3000,
        help="Positives per feature (top activating sequences)",
    )
    ap.add_argument(
        "--neg_per_pos",
        type=int,
        default=3,
        help="Matched negatives per positive (same Murcko scaffold)",
    )
    ap.add_argument(
        "--fdr_alpha", type=float, default=0.05, help="BH-FDR significance threshold"
    )
    ap.add_argument(
        "--setcover_target",
        type=float,
        default=0.8,
        help="Target positive coverage for motifs",
    )
    ap.add_argument(
        "--max_motifs",
        type=int,
        default=3,
        help="Max motifs to pick per feature via set cover",
    )
    ap.add_argument(
        "--out_dir", required=True, help="Where to save the feature bank JSON/CSV"
    )

    args = ap.parse_args()

    # --- Load LM + SAEs
    print("Loading model and SAE kit...")
    sae_kit = SAEKit.load(
        model_dir=args.model_dir, checkpoint_id=args.checkpoint_id, sae_dir=args.sae_dir
    )

    # --- Dataset
    print("Loading and tokenizing dataset...")
    tokenizer = sae_kit.tokenizer
    ds = load_and_tokenize(
        dataset_dir=args.dataset_dir,
        tokenizer=tokenizer,
        batch_size=1024,
        num_processes=4,
        caching=False,  # Disable caching for one-off scans
        limit=int(args.num_batches * 1024 * 1.1),  # Load slightly more than needed
    )

    # --- Scan layer for top sequences & collect background smiles
    print(f"Scanning layer {args.layer} to collect top activating SMILES...")
    col, background_smiles = collect_top_smiles_for_layer(
        ds,
        sae_kit,
        layer_id=args.layer,
        top_sequences=args.top_sequences,
        num_batches=args.num_batches,
    )

    # --- Pick features to align
    print(
        f"Selecting top {args.topk_features} features based on '{args.feature_metric}' metric..."
    )
    feature_ids = pick_features(
        col, metric=args.feature_metric, topk=args.topk_features
    )

    # --- Align, per feature
    print(f"Aligning {len(feature_ids)} features to motifs...")
    alignments = []
    for fid in tqdm(feature_ids, desc=f"Aligning L{args.layer} features"):
        a = align_feature(
            layer_id=args.layer,
            feature_id=fid,
            col=col,
            background_smiles=background_smiles,
            top_pos=args.top_pos,
            neg_per_pos=args.neg_per_pos,
            fdr_alpha=args.fdr_alpha,
            setcover_target=args.setcover_target,
            max_motifs=args.max_motifs,
        )
        if a.selected_motifs:  # Only keep features that align to something
            alignments.append(a)

    # --- Save bank
    print(
        f"Found alignments for {len(alignments)} features. Saving to {args.out_dir}..."
    )
    os.makedirs(args.out_dir, exist_ok=True)
    save_feature_bank(args.out_dir, args.layer, alignments)
    print("Done.")


if __name__ == "__main__":
    main()
