"""Greedy forward selection of ensemble members based on val macro-F1.

Iteratively adds the bundle that gives the best ensemble val F1; stops
when adding any candidate would hurt. Reports test-set macro-F1 for the
chosen ensemble.
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path

import joblib
import numpy as np
from sklearn.metrics import cohen_kappa_score, f1_score, confusion_matrix

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.evaluate import (
    render_confusion_md, per_class_table,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import load_silver
from analysis.exploration.llm_validation.classifier.ensemble import predict_proba

RD = Path("<PROJECT_DIR>/results/exploration_analysis/llm_validation/classifier")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--bundles", nargs="+", type=Path,
                    help="Candidate bundles (default: all classifier_v*.joblib)")
    ap.add_argument("--silver", type=Path, default=RD / "silver_combined_v3.jsonl")
    ap.add_argument("--splits", type=Path,
                    default=RD / "classifier_v12_roberta_large_sqrt.splits.json")
    ap.add_argument("--out", type=Path, default=RD / "eval_ensemble_greedy.md")
    args = ap.parse_args()

    if not args.bundles:
        args.bundles = sorted(RD.glob("classifier_v*.joblib"))
        args.bundles = [b for b in args.bundles if "splits" not in b.stem]

    rows_all = load_silver(args.silver)
    by_id = {r["span_id"]: r for r in rows_all}
    splits = json.load(open(args.splits))
    val_rows = [by_id[i] for i in splits["val_ids"] if i in by_id]
    test_rows = [by_id[i] for i in splits["test_ids"] if i in by_id]

    label_idx = {l: i for i, l in enumerate(PRIMITIVES)}
    y_val = np.array([label_idx[r["llm_label"]] for r in val_rows])
    y_test = np.array([label_idx[r["llm_label"]] for r in test_rows])

    print(f"Val rows: {len(val_rows)}, Test rows: {len(test_rows)}")
    print(f"Candidate bundles: {len(args.bundles)}")

    # Pre-compute proba for each bundle on val and test
    val_probs = {}
    test_probs = {}
    for b in args.bundles:
        try:
            bundle = joblib.load(b)
            val_probs[b.stem] = predict_proba(bundle, val_rows)
            test_probs[b.stem] = predict_proba(bundle, test_rows)
            single_test_f1 = f1_score(
                y_test, test_probs[b.stem].argmax(axis=1),
                average="macro", zero_division=0,
            )
            print(f"  {b.stem}: test F1 = {single_test_f1:.4f}")
        except Exception as e:
            print(f"  SKIP {b.stem}: {e}")
            continue

    # Greedy forward selection on val
    selected: list[str] = []
    cumulative_val = None
    best_val = -1.0

    while True:
        best_candidate = None
        best_candidate_val = best_val
        for name in val_probs:
            if name in selected:
                continue
            if cumulative_val is None:
                trial = val_probs[name]
            else:
                trial = (cumulative_val * len(selected) + val_probs[name]) / (len(selected) + 1)
            f1 = f1_score(y_val, trial.argmax(axis=1), average="macro", zero_division=0)
            if f1 > best_candidate_val:
                best_candidate_val = f1
                best_candidate = name
                best_trial = trial
        if best_candidate is None:
            break
        selected.append(best_candidate)
        cumulative_val = best_trial
        best_val = best_candidate_val
        print(f"  added {best_candidate}: val_macro_f1 = {best_val:.4f} (size={len(selected)})")

    # Evaluate on test
    if not selected:
        print("No bundles selected.")
        return
    test_avg = sum(test_probs[n] for n in selected) / len(selected)
    test_pred = test_avg.argmax(axis=1)
    macro = f1_score(y_test, test_pred, average="macro", zero_division=0)
    weighted = f1_score(y_test, test_pred, average="weighted", zero_division=0)
    kappa = cohen_kappa_score(y_test, test_pred)
    cm = confusion_matrix(y_test, test_pred, labels=list(range(len(PRIMITIVES))))

    print()
    print(f"Greedy ensemble ({len(selected)} bundles):")
    for n in selected:
        print(f"  - {n}")
    print()
    print(f"  Val macro-F1 (search):  {best_val:.4f}")
    print(f"  Test macro-F1:          {macro:.4f}")
    print(f"  Test weighted-F1:       {weighted:.4f}")
    print(f"  Test kappa:             {kappa:.4f}")

    md = []
    md.append(f"# Greedy ensemble selection ({len(selected)} bundles)\n")
    md.append("Selected (in order added):")
    for n in selected:
        md.append(f"- `{n}`")
    md.append("")
    md.append(f"- Val macro-F1: {best_val:.4f}")
    md.append(f"- **Test macro-F1: {macro:.4f}**")
    md.append(f"- Test weighted-F1: {weighted:.4f}")
    md.append(f"- Test kappa: {kappa:.4f}")
    md.append("")
    md.append("## Per-class metrics on test\n")
    md.append(per_class_table(y_test, test_pred, PRIMITIVES))
    md.append("")
    md.append("## Confusion matrix on test\n")
    md.append(render_confusion_md(cm, PRIMITIVES))

    args.out.write_text("\n".join(md))
    print(f"\nWrote {args.out}")


if __name__ == "__main__":
    main()
