"""Stage 3: train TF-IDF + LogReg classifier on silver labels.

Searches a small grid over (ngram_range, C); picks the model with best
val macro-F1; saves the bundle (vectorizer + scaler + model + label_order)
to a single joblib file.
"""
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.features import (
    fit_feature_pipeline,
    FeaturePipeline,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver,
    report_split,
    split_by_trace,
)


def labels_to_y(rows: list[dict], label_order: list[str]) -> np.ndarray:
    idx = {l: i for i, l in enumerate(label_order)}
    return np.array([idx[r["llm_label"]] for r in rows], dtype=np.int32)


def train_one(
    train_rows: list[dict], val_rows: list[dict],
    ngram_range: tuple[int, int], C: float,
    label_order: list[str],
) -> tuple[FeaturePipeline, LogisticRegression, float]:
    pipe, X_train = fit_feature_pipeline(
        train_rows, ngram_range=ngram_range,
    )
    y_train = labels_to_y(train_rows, label_order)
    X_val = pipe.transform(val_rows)
    y_val = labels_to_y(val_rows, label_order)

    clf = LogisticRegression(
        class_weight="balanced",
        max_iter=2000,
        C=C,
        n_jobs=-1,
        solver="lbfgs",
    )
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_val)
    macro_f1 = f1_score(y_val, y_pred, average="macro", zero_division=0)
    return pipe, clf, macro_f1


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silver", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument(
        "--grid", default="default",
        choices=["default", "fast"],
        help="default: (1,2)+(1,3) x C{0.1,1,10}. fast: (1,3) x C{1.0}",
    )
    args = ap.parse_args()

    rows = load_silver(args.silver)
    print(f"Loaded {len(rows)} silver rows.")

    train, val, test = split_by_trace(rows, seed=args.seed)
    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)

    label_order = list(PRIMITIVES)

    grid = [
        ((1, 2), 0.1), ((1, 2), 1.0), ((1, 2), 10.0),
        ((1, 3), 0.1), ((1, 3), 1.0), ((1, 3), 10.0),
    ]
    if args.grid == "fast":
        grid = [((1, 3), 1.0)]

    best = None
    print()
    print(f"=== Grid search ({len(grid)} combos) ===")
    for ngram, C in grid:
        t0 = time.time()
        pipe, clf, macro_f1 = train_one(train, val, ngram, C, label_order)
        dt = time.time() - t0
        print(f"  ngram={ngram} C={C:>5}: val_macro_f1={macro_f1:.4f}  ({dt:.1f}s)")
        if best is None or macro_f1 > best["macro_f1"]:
            best = {
                "pipe": pipe, "clf": clf, "macro_f1": macro_f1,
                "ngram": ngram, "C": C,
            }

    print()
    print(f"Best: ngram={best['ngram']} C={best['C']} val_macro_f1={best['macro_f1']:.4f}")

    args.out.parent.mkdir(parents=True, exist_ok=True)
    bundle = {
        "feature_pipeline": best["pipe"],
        "classifier": best["clf"],
        "label_order": label_order,
        "hyperparams": {"ngram_range": best["ngram"], "C": best["C"]},
        "split_seed": args.seed,
    }
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")

    # Persist split assignments for evaluate.py to reuse the same test set.
    splits_path = args.out.with_suffix(".splits.json")
    splits = {
        "seed": args.seed,
        "train_ids": [r["span_id"] for r in train],
        "val_ids":   [r["span_id"] for r in val],
        "test_ids":  [r["span_id"] for r in test],
    }
    with open(splits_path, "w") as f:
        json.dump(splits, f)
    print(f"Saved splits -> {splits_path}")


if __name__ == "__main__":
    main()
