"""Train classifier_v2 on silver_train + silver_topup.

Re-uses v1's val/test trace assignments (NOT span_ids — trace IDs).
- New spans from v1-train traces → train.
- New spans from v1-val traces → val.
- New spans from v1-test traces → test.
- Spans from previously-unseen traces → train.

This keeps val/test apples-to-apples vs v1 while augmenting train with
the rare-class top-up data.
"""
from __future__ import annotations

import argparse
import json
import time
from pathlib import Path

import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.features import (
    fit_feature_pipeline,
    FeaturePipeline,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver, report_split, trace_group_key,
)


def labels_to_y(rows, label_order):
    idx = {l: i for i, l in enumerate(label_order)}
    return np.array([idx[r["llm_label"]] for r in rows], dtype=np.int32)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--silver-base", required=True, type=Path,
                    help="Original silver_train.jsonl")
    ap.add_argument("--silver-topup", required=True, type=Path,
                    help="silver_topup.jsonl from rare-class minting")
    ap.add_argument("--v1-splits", required=True, type=Path,
                    help="classifier_v1.splits.json (provides trace→split map)")
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument(
        "--grid", default="default", choices=["default", "fast"],
    )
    args = ap.parse_args()

    # ---- Load v1 splits and reconstruct trace assignments ----
    v1_silver = load_silver(args.silver_base)
    v1_by_id = {r["span_id"]: r for r in v1_silver}
    v1_splits = json.load(open(args.v1_splits))
    train_traces, val_traces, test_traces = set(), set(), set()
    for sid in v1_splits["train_ids"]:
        if sid in v1_by_id:
            train_traces.add(trace_group_key(v1_by_id[sid]))
    for sid in v1_splits["val_ids"]:
        if sid in v1_by_id:
            val_traces.add(trace_group_key(v1_by_id[sid]))
    for sid in v1_splits["test_ids"]:
        if sid in v1_by_id:
            test_traces.add(trace_group_key(v1_by_id[sid]))
    print(f"v1 trace counts: train={len(train_traces)}, val={len(val_traces)}, test={len(test_traces)}")

    # ---- Combine silver datasets ----
    topup = load_silver(args.silver_topup)
    print(f"Loaded {len(v1_silver)} base + {len(topup)} topup labels")
    all_rows = v1_silver + topup

    train, val, test = [], [], []
    new_traces = 0
    for r in all_rows:
        tk = trace_group_key(r)
        if tk in val_traces:
            val.append(r)
        elif tk in test_traces:
            test.append(r)
        elif tk in train_traces:
            train.append(r)
        else:
            new_traces += 1
            train.append(r)
    print(f"Assigned {new_traces} spans from previously-unseen traces -> train")

    report_split("train", train)
    report_split("val",   val)
    report_split("test",  test)

    label_order = list(PRIMITIVES)
    grid = [
        ((1, 2), 0.1), ((1, 2), 1.0), ((1, 2), 10.0),
        ((1, 3), 0.1), ((1, 3), 1.0), ((1, 3), 10.0),
    ]
    if args.grid == "fast":
        grid = [((1, 3), 1.0)]

    print()
    print(f"=== Grid search ({len(grid)} combos) ===")
    best = None
    for ngram, C in grid:
        t0 = time.time()
        pipe, X_train = fit_feature_pipeline(train, ngram_range=ngram)
        y_train = labels_to_y(train, label_order)
        X_val = pipe.transform(val)
        y_val = labels_to_y(val, label_order)
        clf = LogisticRegression(
            class_weight="balanced", max_iter=2000, C=C, n_jobs=-1, solver="lbfgs",
        )
        clf.fit(X_train, y_train)
        macro = f1_score(y_val, clf.predict(X_val), average="macro", zero_division=0)
        print(f"  ngram={ngram} C={C:>5}: val_macro_f1={macro:.4f}  ({time.time()-t0:.1f}s)")
        if best is None or macro > best["macro"]:
            best = {"pipe": pipe, "clf": clf, "macro": macro, "ngram": ngram, "C": C}

    # Test eval
    X_test = best["pipe"].transform(test)
    y_test = labels_to_y(test, label_order)
    test_macro = f1_score(y_test, best["clf"].predict(X_test), average="macro", zero_division=0)
    print()
    print(f"Best: ngram={best['ngram']} C={best['C']} val={best['macro']:.4f} test={test_macro:.4f}")

    bundle = {
        "feature_pipeline": best["pipe"],
        "classifier": best["clf"],
        "label_order": label_order,
        "hyperparams": {"ngram_range": best["ngram"], "C": best["C"]},
        "pipeline_type": "tfidf_v2",
    }
    args.out.parent.mkdir(parents=True, exist_ok=True)
    joblib.dump(bundle, args.out)
    print(f"Saved bundle -> {args.out}")

    splits_path = args.out.with_suffix(".splits.json")
    splits = {
        "train_ids": [r["span_id"] for r in train],
        "val_ids":   [r["span_id"] for r in val],
        "test_ids":  [r["span_id"] for r in test],
    }
    with open(splits_path, "w") as f:
        json.dump(splits, f)
    print(f"Saved splits -> {splits_path}")


if __name__ == "__main__":
    main()
