import argparse, os, glob, pickle, warnings
import numpy as np
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument("--dir", default=".", help="folder with *.pkl files")
args = parser.parse_args()
ROOT_DIR = args.dir

METHODS  = {"reefl":"ReeFL","scalefl":"ScaleFL","heterofl":"HeteroFL",
            "depthfl":"DepthFL","inclusivefl":"InclusiveFL","snowfl":"SNOWFL", 
            "snow_no_fl":"SNOWFL (Without SNIP AND OWEN)",
            "snow_ow_fl":"SNOWFL (Without SNIP)",
            "snow_os_fl":"SNOWFL (Without OWEN)"}
ACC_KEYS = ["centralized_test_exit_all_ensemble_acc",
            "centralized_test_exit_all_acc"]

results = {}
for p in glob.glob(os.path.join(ROOT_DIR, "*.pkl")):
    fname = os.path.basename(p).lower()
    print("→", fname)

    method = next((m for m in METHODS.keys() if m in fname), None)
    if method is None:
        print("   ↳ no method keyword, skipping"); continue

    with open(p, "rb") as f:
        data = pickle.load(f)

    mt = data["multi_tier"]
    scale_key = next(iter(mt))
    run_key   = next(iter(mt[scale_key]))
    metrics   = mt[scale_key][run_key]

    acc_key = next((k for k in ACC_KEYS if k in metrics), None)
    if acc_key is None:
        warnings.warn(f"{fname}: no ensemble-accuracy key, skipped"); continue

    rounds, acc = metrics["round"], metrics[acc_key]
    if len(rounds) != len(acc):
        mn = min(len(rounds), len(acc))
        warnings.warn(f"{fname}: trimming to {mn} points")
        rounds, acc = rounds[:mn], acc[:mn]

    results[method] = (np.asarray(rounds), np.asarray(acc, dtype=float))

# ----- plot (sorted by max accuracy) -----
if not results:
    raise RuntimeError("No curves collected – fix folder path or key names first.")

# Collect stats: (method, rounds, acc, max_acc)
stats = []
for method, (r, a) in results.items():
    # ignore NaNs when taking the max
    max_acc = np.nanmax(a)
    stats.append((method, r, a, max_acc))

# Sort by max accuracy, descending
stats.sort(key=lambda x: x[3], reverse=True)

# Optional: print the ranking
print("\nRanking by max accuracy:")
for rank, (method, _, __, m) in enumerate(stats, 1):
    print(f"{rank:>2}. {METHODS[method]} — max={m:.2f}")

# Plot in the sorted order; legend shows “Method (max%)”
for method, r, a, m in stats:
    plt.plot(r, a, label=f"{METHODS[method]} ({m:.2f}%)")

plt.xlabel("Round")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy vs. round (legend sorted by max)")
plt.grid(True)

# Legend already sorted because we plotted in sorted order
plt.legend(loc="best", title="↑ Max accuracy")
plt.tight_layout()
plt.show()
