# RQ6 synergy analysis (standalone) — DARE method
# Reads per-step merge CSVs like results_dare_32B.csv, computes donor→receiver synergy,
# and produces (1) a 9×9 heatmap and (2) top ± pairs bar chart, plus CSV/LaTeX tables.
#
# Expected CSV columns (case-insensitive subset used):
#   - 'model'  : merge sequence as hyphen-separated domain IDs, e.g., '1-8-3'
#   - 'problem': evaluation domain name in {'algebra','analysis','discrete','geometry',
#                                           'number_theory','biology','chemistry','physics','code'}
#   - 'CE Loss': cross-entropy loss for that (sequence, eval-domain)
#
# Notes:
# - No seaborn; single-plot charts; no explicit colors are set.
# - Figures are saved as PDF for paper.

import os, numpy as np, pandas as pd, matplotlib.pyplot as plt, glob
PALETTE = ["#008A45", "#468BCA", "#5F5F5E", "#7DD2F6", "#80C5A2", "#B384BA", "#D9C2DD", "#F27873", "#FFD373"]

plt.rcParams.update({"axes.linewidth": 1.7})

os.makedirs("figs", exist_ok=True)
os.makedirs("out", exist_ok=True)

id2dom = {'1':'algebra','2':'analysis','3':'discrete','4':'geometry','5':'number_theory',
          '6':'biology','7':'chemistry','8':'physics','9':'code'}
doms = ['algebra','analysis','discrete','geometry','number_theory','biology','chemistry','physics','code']

def _read_csv(path):
    df = pd.read_csv(path)
    # normalize cols
    prob_col = None
    for c in df.columns:
        if c.strip().lower() in ("problem","eval_domain","domain","eval"):
            prob_col = c; break
    if prob_col is None: raise ValueError(f"no 'problem' col in {path}")
    ce_col = None
    for c in df.columns:
        if c.strip().lower() in ("ce loss","ce_loss","celoss","loss"):
            ce_col = c; break
    if ce_col is None: raise ValueError(f"no CE Loss col in {path}")
    if 'model' not in df.columns: raise ValueError(f"no 'model' col in {path}")
    df = df.rename(columns={prob_col:"problem", ce_col:"CE Loss"})
    return df[['model','problem','CE Loss']].astype({'model':str,'problem':str})

def compute_synergy(df):
    df = df[df['problem'].isin(doms)].copy()
    agg = df.groupby(['model','problem'])['CE Loss'].mean().reset_index()
    key2ce = {(r['model'], r['problem']): r['CE Loss'] for _, r in agg.iterrows()}
    rows = []
    for _, r in agg.iterrows():
        seq = r['model']; toks = seq.split('-')
        if len(toks) < 2: continue
        prefix = "-".join(toks[:-1])
        donor = id2dom.get(toks[-1], None)
        recv  = r['problem']
        prev = key2ce.get((prefix, recv), None)
        if donor is None or prev is None: continue
        delta = float(prev - r['CE Loss'])
        rows.append((donor, recv, delta))
    pair_df = pd.DataFrame(rows, columns=["donor","receiver","delta"])
    synergy = pair_df.groupby(['donor','receiver'])['delta'].agg(['mean','count']).reset_index()
    M = synergy.pivot(index='donor', columns='receiver', values='mean').reindex(index=doms, columns=doms)
    return synergy, M, pair_df

def block_stats(M):
    math = ['algebra','analysis','discrete','geometry','number_theory']
    sci  = ['biology','chemistry','physics']
    def mean_block(rows, cols, exclude_diag=True):
        vals = []
        for r in rows:
            for c in cols:
                if exclude_diag and r==c: continue
                v = M.loc[r,c]
                if not pd.isna(v): vals.append(v)
        return float(np.mean(vals)) if vals else np.nan
    return {
        "math→math": mean_block(math, math, True),
        "science→science": mean_block(sci, sci, True),
        "math→science": mean_block(math, sci, False),
        "science→math": mean_block(sci, math, False),
        "code→math": mean_block(['code'], math, True),
        "code→science": mean_block(['code'], sci, True),
    }

def top_pairs_with_ci(pair_df, top_n=5, exclude_diag=True, B=1000, seed=0):
    rng = np.random.default_rng(seed)
    groups = {}
    for (d,e), sub in pair_df.groupby(['donor','receiver']):
        if exclude_diag and d==e: continue
        arr = sub['delta'].to_numpy()
        if len(arr)==0: continue
        groups[(d,e)] = arr
    means = {k: float(np.mean(v)) for k,v in groups.items()}
    items = sorted(means.items(), key=lambda kv: kv[1], reverse=True)
    top_pos = items[:top_n]
    top_neg = items[-top_n:]
    def ci(arr):
        n = len(arr); bs = [np.mean(arr[rng.integers(0,n,n)]) for _ in range(B)]
        lo, hi = np.percentile(bs, [2.5, 97.5])
        return float(np.mean(arr)), float(lo), float(hi), n
    pos_ci = [(k[0],k[1],*ci(groups[k])) for k,_ in top_pos]
    neg_ci = [(k[0],k[1],*ci(groups[k])) for k,_ in top_neg]
    return pos_ci, neg_ci

def plot_heatmap(M, title, out_pdf):
    vmax = np.nanmax(np.abs(M.values))
    vmin = -vmax
    plt.figure(figsize=(6.4,5.2))
    im = plt.imshow(M.values, vmin=vmin, vmax=vmax)
    plt.xticks(range(len(doms)), doms, rotation=45, ha='right')
    plt.yticks(range(len(doms)), doms)
    plt.colorbar(im, fraction=0.046, pad=0.04, label="ΔCE (help + / hurt −)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_pdf, dpi=300)
    plt.close()

def plot_top_pairs(pos_ci, neg_ci, title, out_pdf):
    labels = [f"{d}→{e}" for d,e,_,_,_,_ in pos_ci] + [f"{d}→{e}" for d,e,_,_,_,_ in neg_ci]
    means  = [m for _,_,m,_,_,_ in pos_ci] + [m for _,_,m,_,_,_ in neg_ci]
    los    = [lo for _,_,_,lo,_,_ in pos_ci] + [lo for _,_,_,lo,_,_ in neg_ci]
    his    = [hi for _,_,_,_,hi,_ in pos_ci] + [hi for _,_,_,_,hi,_ in neg_ci]
    x = np.arange(len(labels))
    errs = [np.array(means)-np.array(los), np.array(his)-np.array(means)]
    plt.figure(figsize=(max(6, 0.6*len(labels)), 3.6))
    plt.errorbar(x, means, yerr=errs, fmt='o', capsize=3)
    plt.axhline(0.0, linestyle='--', linewidth=1)
    plt.xticks(x, labels, rotation=45, ha='right', fontsize=9)
    plt.ylabel("Mean ΔCE (95% CI)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_pdf, dpi=300)
    plt.close()

def save_tables(N_tag, M, pos_ci, neg_ci):
    M.to_csv(f"out/rq6_synergy_matrix_{N_tag}_DARE.csv", float_format='%.6f')
    rows = []
    for sign, arr in (("positive", pos_ci), ("negative", neg_ci)):
        for d,e,m,lo,hi,n in arr:
            rows.append({"sign":sign,"donor":d,"receiver":e,"mean":m,"lo95":lo,"hi95":hi,"n":n})
    pd.DataFrame(rows).to_csv(f"out/rq6_top_pairs_{N_tag}_DARE.csv", index=False, float_format='%.6f')

# Run for all files
for path in glob.glob("results_dare_*.csv") + glob.glob("DARE/results_dare_*.csv"):
    if not os.path.exists(path): continue
    N_tag = os.path.basename(path).split("results_dare_")[1].split(".csv")[0]
    try:
        df = _read_csv(path)
        synergy, M, pair_df = compute_synergy(df)
        pos_ci, neg_ci = top_pairs_with_ci(pair_df, top_n=5, exclude_diag=True, B=1000, seed=42)
        plot_heatmap(M, f"Synergy heatmap (DARE) @ N={N_tag}", f"figs/rq6_synergy_heatmap_{N_tag}_DARE.png")
        plot_top_pairs(pos_ci, neg_ci, f"Top ± donor→receiver pairs @ N={N_tag} (DARE)",
                       f"figs/rq6_top_pairs_{N_tag}_DARE.png")
        save_tables(N_tag, M, pos_ci, neg_ci)
        print(f"[ok] {path} -> figs/ & out/")
    except Exception as e:
        print(f"[skip] {path}: {e}")
