"""Train option-B classifier: BGE embeddings + linear head.

Uses the saved splits.json from the TF-IDF run when available so the
two models are evaluated on identical splits (apples-to-apples).
"""
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.embed_features import (
    fit_embed_pipeline,
    encode_rows,
    DEFAULT_ENCODER,
    _get_encoder,
)
from analysis.exploration.llm_validation.classifier.features import (
    hand_crafted_row,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver,
    report_split,
    split_by_trace,
)


def labels_to_y(rows: list[dict], label_order: list[str]) -> np.ndarray:
    idx = {l: i for i, l in enumerate(label_order)}
    return np.array([idx[r["llm_label"]] for r in rows], dtype=np.int32)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silver", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument(
        "--reuse-splits", type=Path, default=None,
        help="Path to classifier_v1.splits.json — reuses train/val/test for fair comparison",
    )
    ap.add_argument("--encoder", default=DEFAULT_ENCODER)
    args = ap.parse_args()

    rows = load_silver(args.silver)
    print(f"Loaded {len(rows)} silver rows.")

    if args.reuse_splits and args.reuse_splits.exists():
        splits = json.load(open(args.reuse_splits))
        by_id = {r["span_id"]: r for r in rows}
        train = [by_id[i] for i in splits["train_ids"] if i in by_id]
        val   = [by_id[i] for i in splits["val_ids"]   if i in by_id]
        test  = [by_id[i] for i in splits["test_ids"]  if i in by_id]
        print(f"Reusing splits from {args.reuse_splits}")
    else:
        train, val, test = split_by_trace(rows, seed=args.seed)
    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)

    label_order = list(PRIMITIVES)

    # Encode all splits in one shot to amortize encoder load (~6s) and CUDA warm-up.
    print()
    print(f"Encoding with {args.encoder}...")
    encoder = _get_encoder(args.encoder)
    t0 = time.time()
    emb_train = encode_rows(train, encoder=encoder, batch_size=128)
    print(f"  train: {emb_train.shape} ({time.time()-t0:.1f}s)")
    t0 = time.time()
    emb_val = encode_rows(val, encoder=encoder, batch_size=128, show_progress=False)
    print(f"  val:   {emb_val.shape} ({time.time()-t0:.1f}s)")
    t0 = time.time()
    emb_test = encode_rows(test, encoder=encoder, batch_size=128, show_progress=False)
    print(f"  test:  {emb_test.shape} ({time.time()-t0:.1f}s)")

    # Fit hand-crafted scaler on train, transform val/test.
    from sklearn.preprocessing import StandardScaler
    hand_train = np.array([hand_crafted_row(r) for r in train], dtype=np.float32)
    hand_val = np.array([hand_crafted_row(r) for r in val], dtype=np.float32)
    hand_test = np.array([hand_crafted_row(r) for r in test], dtype=np.float32)
    scaler = StandardScaler()
    hand_train_s = scaler.fit_transform(hand_train)
    hand_val_s = scaler.transform(hand_val)
    hand_test_s = scaler.transform(hand_test)

    X_train = np.hstack([emb_train, hand_train_s]).astype(np.float32)
    X_val   = np.hstack([emb_val,   hand_val_s]).astype(np.float32)
    X_test  = np.hstack([emb_test,  hand_test_s]).astype(np.float32)
    y_train = labels_to_y(train, label_order)
    y_val   = labels_to_y(val,   label_order)
    y_test  = labels_to_y(test,  label_order)

    print()
    print("=== Grid search on val ===")
    best = None
    for C in [0.1, 1.0, 10.0, 100.0]:
        t0 = time.time()
        clf = LogisticRegression(
            class_weight="balanced", max_iter=2000, C=C, n_jobs=-1, solver="lbfgs",
        )
        clf.fit(X_train, y_train)
        macro = f1_score(y_val, clf.predict(X_val), average="macro", zero_division=0)
        print(f"  C={C:>6}: val_macro_f1={macro:.4f}  ({time.time()-t0:.1f}s)")
        if best is None or macro > best["macro"]:
            best = {"clf": clf, "C": C, "macro": macro}

    test_macro = f1_score(y_test, best["clf"].predict(X_test), average="macro", zero_division=0)
    print()
    print(f"Best: C={best['C']} val_macro_f1={best['macro']:.4f}  test_macro_f1={test_macro:.4f}")

    from analysis.exploration.llm_validation.classifier.embed_features import (
        EmbedFeaturePipeline,
    )
    pipe = EmbedFeaturePipeline(
        encoder_name=args.encoder,
        scaler=scaler,
        embedding_dim=emb_train.shape[1],
    )
    bundle = {
        "feature_pipeline": pipe,
        "classifier": best["clf"],
        "label_order": label_order,
        "hyperparams": {"C": best["C"]},
        "split_seed": args.seed,
        "pipeline_type": "embed",
    }
    args.out.parent.mkdir(parents=True, exist_ok=True)
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")

    # Save splits if not reusing.
    splits_path = args.out.with_suffix(".splits.json")
    if not args.reuse_splits:
        with open(splits_path, "w") as f:
            json.dump({
                "seed": args.seed,
                "train_ids": [r["span_id"] for r in train],
                "val_ids":   [r["span_id"] for r in val],
                "test_ids":  [r["span_id"] for r in test],
            }, f)
        print(f"Saved splits -> {splits_path}")
    else:
        # Symlink/copy so evaluate.py can find it next to the bundle.
        import shutil
        shutil.copy(args.reuse_splits, splits_path)
        print(f"Copied splits -> {splits_path}")


if __name__ == "__main__":
    main()
