# -*- coding: utf-8 -*-
# RQ10 (DARE, quality-aware vs random) — compact 2-panel figure + CSV exports
# Reads: /mnt/data/results_dare_*.csv
# Writes: figs/rq10_compact.pdf (or .png), out/*.csv
# ------------------------------------------------------------
import os, re, glob, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

# ---------- helpers ----------
DOMAP = {
    1:"algebra", 2:"analysis", 3:"discrete", 4:"geometry", 5:"number_theory",
    6:"biology", 7:"chemistry", 8:"physics", 9:"code"
}
DOMS = [DOMAP[i] for i in range(1,10)]

def parse_model_to_tuple(s):
    """Parse 'model' string to a sorted tuple of ints, e.g. '1-3-7' -> (1,3,7)."""
    if pd.isna(s): return tuple()
    if isinstance(s, (list, tuple)): 
        nums = [int(x) for x in s]
    else:
        # strip brackets and split by non-digits
        toks = re.findall(r"\d+", str(s))
        nums = [int(x) for x in toks]
    return tuple(sorted(nums))

def detect_ce_column(df):
    """Pick CE column: priority average -> domain-mean -> cd_loss/ce/loss."""
    cols = [c.lower() for c in df.columns]
    # 1) explicit 'average'
    if "average" in cols:
        return df.columns[cols.index("average")]
    # 2) domain mean
    dom_hit = [c for c in df.columns if c.lower() in DOMS]
    if len(dom_hit) >= 5:  # if majority present, compute macro mean on the fly
        return None  # signal to compute mean of domain columns
    # 3) cd_loss / ce / loss
    for key in ["cd_loss","ce","loss","cd loss","cdloss"]:
        if key in cols:
            return df.columns[cols.index(key)]
    # fallback: last numeric column
    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    if len(num_cols)==0:
        raise ValueError("No numeric CE-like column found.")
    return num_cols[-1]

def load_all_dare_results(pattern=""):
    """Return dict N(float)-> DataFrame with columns: model_set(tuple), k(int), CE(float)"""
    files = sorted(glob.glob(pattern))
    out = {}
    for fp in files:
        # parse N from filename like '..._7B.csv' or '..._0.5B.csv' or '..._72B.csv'
        m = re.search(r"_(\d+(?:\.\d+)?)B\.csv$", fp)
        if not m:
            # also try integer like '_72B.csv' without dot captured above
            m = re.search(r"_(\d+)B\.csv$", fp)
        if not m: 
            print(f"[warn] skip (cannot parse N): {fp}")
            continue
        N = float(m.group(1))
        df = pd.read_csv(fp)
        # require 'model'
        if "model" not in df.columns:
            # try lowercase
            cand = [c for c in df.columns if c.lower()=="model"]
            if cand: df.rename(columns={cand[0]:"model"}, inplace=True)
        if "model" not in df.columns:
            print(f"[warn] skip (no 'model' col): {fp}"); continue
        # parse CE
        ce_col = detect_ce_column(df)
        if ce_col is None:
            # macro mean over domain columns
            dom_hit = [c for c in df.columns if c.lower() in DOMS]
            df["_CE_"] = df[dom_hit].astype(float).mean(axis=1)
            ce_col = "_CE_"
        CE = df[ce_col].astype(float).to_numpy()
        # parse model sets and k
        M = df["model"].apply(parse_model_to_tuple)
        K = M.apply(len).astype(int).to_numpy()
        tmp = pd.DataFrame({"model_set":M, "k":K, "CE":CE})
        # keep best CE per unique set if duplicates exist
        tmp = tmp.groupby(["model_set","k"], as_index=False)["CE"].mean()
        out[N] = tmp
        print(f"[info] loaded {fp}: N={N}, rows={len(tmp)} (unique sets)")
    if not out:
        raise RuntimeError("No results_dare_*B.csv loaded. Check path/pattern.")
    return out

def build_index(df):
    """Return dict from set(tuple)-> CE for quick lookup."""
    return {ms:ce for ms, ce in zip(df["model_set"], df["CE"])}

def early_strength_from_sets(df, sets_sizes=(2,3)):
    """
    Compute expert early strength via set-difference:
    For each set S with |S| in {2,3}, for each d in S:
        contribution c(d,S) = CE(S\{d}) - CE(S)
    Aggregate mean over all such S available.
    """
    idx = build_index(df)
    contrib = {i:[] for i in range(1,10)}
    for ms, k, ce in df[["model_set","k","CE"]].itertuples(index=False):
        if k not in sets_sizes: 
            continue
        S = set(ms)
        for d in S:
            prev = tuple(sorted(S - {d}))
            if prev in idx:
                contrib[d].append(idx[prev] - ce)  # positive = improvement
    strength = {d:(np.mean(v) if len(v)>0 else 0.0) for d,v in contrib.items()}
    counts = {d:len(v) for d,v in contrib.items()}
    return strength, counts

def greedy_order_from_strength(strength):
    """Return list of expert IDs sorted by strength descending."""
    return sorted(strength.keys(), key=lambda d:(-strength[d], d))

def get_ce_for_set(df, S):
    """Lookup CE for set S (tuple sorted). Return np.nan if not exist."""
    S = tuple(sorted(S))
    hits = df[df["model_set"]==S]["CE"].to_numpy()
    return float(hits[0]) if len(hits)>0 else float("nan")

def greedy_curve(df, order):
    """Return CE(k) for k=1..9 along the given order by direct lookup; may be nan if missing."""
    c = []
    for k in range(1, len(order)+1):
        S = tuple(sorted(order[:k]))
        c.append(get_ce_for_set(df, S))
    return np.array(c, float)

def random_baseline_by_avg_sets(df):
    """Return E[CE | |S|=k] by averaging all sets of a given size k."""
    out = []
    for k in range(1, 10):
        sub = df[df["k"]==k]["CE"].to_numpy()
        out.append(float(np.mean(sub)) if len(sub)>0 else np.nan)
    return np.array(out, float)

# ---------- main compute ----------
res_all = load_all_dare_results("DARE/results_dare_*B.csv")
Ns_sorted = sorted(res_all.keys())

# compute per-N early strength, greedy order, greedy & random curves
rows_strength = []
rows_curves = []
heat_vals_abs, heat_vals_rel = [], []  # for Panel (b): rows=N, cols k in [3,6,9]
Ks_show = [3,6,9]

for N in Ns_sorted:
    dfN = res_all[N]
    strength, counts = early_strength_from_sets(dfN, sets_sizes=(2,3))
    order = greedy_order_from_strength(strength)
    g_curve = greedy_curve(dfN, order)   # CE(k) for k=1..9 (may contain NaN if set missing)
    r_curve = random_baseline_by_avg_sets(dfN)

    # export tables
    for d in range(1,10):
        rows_strength.append({
            "N": N, "expert_id": d, "domain": DOMAP[d],
            "early_strength(CE gain)": strength[d], "num_pairs": counts[d]
        })
    for k in range(1,10):
        rows_curves.append({
            "N":N, "k":k, "CE_greedy": g_curve[k-1], "CE_random_mean": r_curve[k-1]
        })
    # collect heat entries
    abs_improve = []
    rel_improve = []
    for k in Ks_show:
        if np.isfinite(g_curve[k-1]) and np.isfinite(r_curve[k-1]):
            abs_imp = r_curve[k-1] - g_curve[k-1]
            rel_imp = abs_imp / max(1e-12, r_curve[k-1]) * 100.0
        else:
            abs_imp, rel_imp = np.nan, np.nan
        abs_improve.append(abs_imp); rel_improve.append(rel_imp)
    heat_vals_abs.append(abs_improve); heat_vals_rel.append(rel_improve)
breakpoint()
# save csvs
pd.DataFrame(rows_strength).to_csv("out/rq10_early_strength_by_expert.csv", index=False)
pd.DataFrame(rows_curves).to_csv("out/rq10_greedy_vs_random_curves.csv", index=False)

# ---------- Panel (a): early strength bar (pick one N to display; prefer 7B, else max N) ----------
def pick_display_N(preferred=[7.0, 32.0, 72.0], available=Ns_sorted):
    for p in preferred:
        if p in available: return p
    return available[-1]
N_disp = pick_display_N()

df_str = pd.read_csv("out/rq10_early_strength_by_expert.csv")
strN = df_str[df_str["N"]==N_disp].copy()
strN.sort_values("early_strength(CE gain)", ascending=False, inplace=True)

# ---------- Panel (b): heatmap (relative improvement %) ----------
H = np.array(heat_vals_rel, float)  # shape (len(Ns), len(Ks_show))
N_labels = [f"{int(n) if n.is_integer() else n}B" for n in Ns_sorted]
K_labels = [str(k) for k in Ks_show]

# ---------- make a single 2-panel figure ----------
plt.figure(figsize=(11,4.5))

# (a) early strength bar
ax1 = plt.subplot(1,2,1)
bars = ax1.bar(range(len(strN)), strN["early_strength(CE gain)"], tick_label=strN["domain"])
ax1.set_title(f"(a) Early expert strength @ {int(N_disp) if N_disp.is_integer() else N_disp}B (ΔCE from |S|=2,3)")
ax1.set_xlabel("Expert (domain)")
ax1.set_ylabel("Avg marginal CE gain ↑")
plt.xticks(rotation=35, ha="right")


# (b) heatmap of relative improvement (%)
ax2 = plt.subplot(1,2,2)
mask = ~np.isfinite(H)
vmax = np.nanmax(np.abs(H))
if not np.isfinite(vmax) or vmax<1e-6: vmax = 1.0
breakpoint()
im = ax2.imshow(np.where(mask, 0.0, H), cmap="RdYlGn", vmin=-vmax, vmax=vmax, aspect="auto")
ax2.set_xticks(range(len(Ks_show))); ax2.set_xticklabels(K_labels)
ax2.set_yticks(range(len(Ns_sorted))); ax2.set_yticklabels(N_labels)
ax2.set_xlabel("k (experts)"); ax2.set_ylabel("Model size N")
ax2.set_title("(b) Greedy vs Random: relative CE improvement (%)")
# annotate cells
for i in range(H.shape[0]):
    for j in range(H.shape[1]):
        val = H[i,j]
        txt = "" if (not np.isfinite(val)) else f"{val:.1f}"
        ax2.text(j, i, txt, ha="center", va="center", fontsize=8, color="black")
cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
cbar.set_label("Rel. improvement (%)")

plt.tight_layout()
plt.savefig("figs/rq10_compact.pdf", bbox_inches="tight")
plt.savefig("figs/rq10_compact.png", dpi=200, bbox_inches="tight")
plt.close()

print("[done] Wrote figs/rq10_compact.{pdf,png}")
print("[done] Tables: out/rq10_early_strength_by_expert.csv, out/rq10_greedy_vs_random_curves.csv")
