import json
import argparse
import yaml
from collections import defaultdict
import os

def _get_json_path(cfg):
    model_short = cfg["model_name"].split("/")[-1]    # "Qwen2.5-Math-7B"
    json_out_path = os.path.join(
        "outputs",
        "eval",
        model_short,
        "evaluation_results.json"
    )
    return json_out_path

def analyze(json_path):
    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    K = data.get("k", "K")
    aggs = data["aggregates"]
    last_step = {}
    for r in aggs:
        key = (r["dataset"], r["algo_type"])
        last_step[key] = max(r["step"], last_step.get(key, -1))

    rows = []
    for r in aggs:
        key = (r["dataset"], r["algo_type"])
        if r["step"] == last_step[key]:
            rows.append(r)

    rows = sorted(rows, key=lambda x: (x["dataset"], x["algo_type"]))

    print(f"\n📊 Final Checkpoint Summary (Pass@1, Pass@{K}): mean ± std")
    print("-"*90)

    for r in rows:
        ds   = r["dataset"]
        algo = r["algo_type"]
        st   = r["step"]
        p1_m, p1_s = r["pass_at_1_mean"], r["pass_at_1_std"]
        pk_m, pk_s = r["pass_at_k_mean"], r["pass_at_k_std"]

        print(f"[{ds}] {algo} @ step {st}: "
              f"P@1 {p1_m:.3f}±{p1_s:.3f} | "
              f"P@{K} {pk_m:.3f}±{pk_s:.3f}"
              )

    algo_groups = defaultdict(list)
    for r in rows:
        algo_groups[r["algo_type"]].append(r)

    print("\n🔎 Algorithm-wise Averages Across Datasets")
    print("-"*90)

    for algo, items in sorted(algo_groups.items()):
        avg_p1 = sum(x["pass_at_1_mean"] for x in items) / len(items)
        avg_pk = sum(x["pass_at_k_mean"] for x in items) / len(items)

        st = items[0]["step"]

        print(f"[AVG] {algo} @ step {st}: "
              f"P@1 {avg_p1:.3f} | "
              f"P@{K} {avg_pk:.3f}"
              )

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    args = parser.parse_args()
    with open(args.config, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)
    json_out_path = _get_json_path(cfg)
    analyze(json_out_path)

if __name__ == "__main__":
    main()
