"""Average predictions from multiple trained classifiers (soft voting).

Loads K bundles, runs predict_proba on a target span set, averages the
class probability matrices, takes argmax, and reports macro-F1 against
ground-truth labels.

Useful for combining the top 2-4 models from the leaderboard.
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path

import numpy as np


def predict_proba(bundle, rows):
    pipe = bundle["feature_pipeline"]
    clf = bundle["classifier"]
    X = pipe.transform(rows)
    if hasattr(clf, "predict_proba"):
        return clf.predict_proba(X)
    # Decision function fallback (LogReg)
    z = clf.decision_function(X)
    e = np.exp(z - z.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)


def main():
    import joblib
    from sklearn.metrics import classification_report, cohen_kappa_score, f1_score, confusion_matrix
    from analysis.exploration.llm_validation._client import PRIMITIVES
    from analysis.exploration.llm_validation.classifier.evaluate import (
        render_confusion_md, per_class_table,
    )
    from analysis.exploration.llm_validation.classifier.prepare_dataset import load_silver

    ap = argparse.ArgumentParser()
    ap.add_argument("--bundles", nargs="+", required=True, type=Path)
    ap.add_argument("--mode", required=True, choices=["in_dist", "calibration", "math"])
    ap.add_argument(
        "--silver", type=Path,
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/classifier/silver_combined_v3.jsonl"),
    )
    ap.add_argument(
        "--splits", type=Path,
        help="splits.json from one of the bundles (defines test_ids)",
    )
    ap.add_argument(
        "--spans", type=Path,
        help="for calibration mode",
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/sampled_spans.jsonl"),
    )
    ap.add_argument(
        "--r1-judgments", type=Path,
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/llm_judgments_sc_v4.jsonl"),
    )
    ap.add_argument(
        "--v3-judgments", type=Path,
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/llm_judgments_v3_sc_v4.jsonl"),
    )
    ap.add_argument(
        "--math-spans", type=Path,
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/classifier/silver_math_spans.jsonl"),
    )
    ap.add_argument(
        "--math-judgments", type=Path,
        default=Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/classifier/silver_math.jsonl"),
    )
    ap.add_argument("--out", type=Path)
    args = ap.parse_args()

    label_order = list(PRIMITIVES)
    label_idx = {l: i for i, l in enumerate(label_order)}

    # Build (rows, y_true) for the chosen mode.
    if args.mode == "in_dist":
        if not args.splits:
            args.splits = args.bundles[0].with_suffix(".splits.json")
        rows_all = load_silver(args.silver)
        by_id = {r["span_id"]: r for r in rows_all}
        splits = json.load(open(args.splits))
        rows = [by_id[i] for i in splits["test_ids"] if i in by_id]
        y_true = np.array([label_idx[r["llm_label"]] for r in rows])
    elif args.mode == "calibration":
        spans_d = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.spans)}
        r1 = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.r1_judgments)}
        rows = []
        y_true_lbl = []
        for sid, sp in spans_d.items():
            if sid in r1 and r1[sid]["llm_label"] in PRIMITIVES:
                if "n_episodes_in_trace" not in sp:
                    sp["n_episodes_in_trace"] = 1
                rows.append(sp)
                y_true_lbl.append(r1[sid]["llm_label"])
        y_true = np.array([label_idx[l] for l in y_true_lbl])
    else:  # math
        spans_d = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.math_spans)}
        j_d = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.math_judgments)}
        rows = []
        y_true_lbl = []
        for sid, sp in spans_d.items():
            j = j_d.get(sid)
            if j and j["llm_label"] in PRIMITIVES:
                rows.append(sp)
                y_true_lbl.append(j["llm_label"])
        y_true = np.array([label_idx[l] for l in y_true_lbl])

    print(f"Mode: {args.mode}, eval rows: {len(rows)}")

    # Sum probabilities across bundles
    summed = None
    for b in args.bundles:
        bundle = joblib.load(b)
        proba = predict_proba(bundle, rows)
        if summed is None:
            summed = proba
        else:
            summed = summed + proba
        print(f"  loaded {b.stem}, proba shape {proba.shape}")
    summed /= len(args.bundles)
    y_pred = summed.argmax(axis=1)

    macro = f1_score(y_true, y_pred, average="macro", zero_division=0)
    weighted = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(label_order))))

    print()
    print(f"Ensemble of {len(args.bundles)} bundles ({args.mode}):")
    print(f"  Macro-F1     : {macro:.4f}")
    print(f"  Weighted-F1  : {weighted:.4f}")
    print(f"  Kappa        : {kappa:.4f}")

    if args.out:
        md = []
        md.append(f"# Ensemble {args.mode} eval ({len(args.bundles)} bundles)\n")
        md.append("Bundles:")
        for b in args.bundles:
            md.append(f"- {b.name}")
        md.append("")
        md.append(f"- Macro-F1: **{macro:.4f}**")
        md.append(f"- Weighted-F1: {weighted:.4f}")
        md.append(f"- Kappa: {kappa:.4f}")
        md.append("")
        md.append("## Per-class metrics\n")
        md.append(per_class_table(y_true, y_pred, label_order))
        md.append("")
        md.append("## Confusion matrix\n")
        md.append(render_confusion_md(cm, label_order))
        args.out.write_text("\n".join(md))
        print(f"Wrote {args.out}")


if __name__ == "__main__":
    main()
