import os, json, argparse, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe

plt.rcParams.update({
    # "font.family": "serif",
    # "mathtext.fontset": "cm",
    "font.size": 8,  
    "axes.unicode_minus": False,
    "axes.labelsize": 8,
    "axes.titlesize": 9,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "legend.title_fontsize": 7,
})


DATASET_ORDER = ["reward_bench", "reward_bench_v2", "judgebench"]


DATASET_NAME_MAP = {
    "reward_bench": "RewardBench",
    "reward_bench_v2": "RewardBench 2",
    "judgebench": "JudgeBench"
}

def get_display_name(dataset_key):
    return DATASET_NAME_MAP.get(dataset_key, dataset_key)

# color and label settings
CASE_LABELS = [
    "Only reasoning correct",
    "Only instruct correct",
    "Both correct",
    "Both wrong",
]
CASE_COLORS = {
    "Only reasoning correct": "#2ca02c",
    "Only instruct correct": "#d62728", 
    "Both correct": "#4d4d4d",         
    "Both wrong": "#bdbdbd",       
}   


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def check_format_and_correctness(judge_text, flip):
    s = (judge_text or "").lower()
    if "[[a]]" in s or "[a]" in s or "\\boxed{assistant 1}" in s:
        choice, valid = "A", True
    elif "[[b]]" in s or "[b]" in s or "\\boxed{assistant 2}" in s:
        choice, valid = "B", True
    else:
        choice, valid = None, False
    if choice == "A":
        correct = (flip == 0)
    elif choice == "B":
        correct = (flip == 1)
    else:
        correct = None
    return valid, correct, choice

def safe_extract_scalar(value, default=0):
    if isinstance(value, (list, tuple)):
        return value[0] if len(value) > 0 else default
    return value

def format_domain_label(domain):
    domain_lower = str(domain).lower()
    if domain_lower == "precise if":
        return "Precise IF"
    return domain.capitalize()

def build_valid_table(gen_data, dataset_name):
    rows = []
    for ex in gen_data:
        _id = safe_extract_scalar(ex.get("id"))
        
        if dataset_name.lower() == "reward_bench":
            raw_domain = ex.get("category", "overall")
            
            if str(raw_domain).lower() == "unknown":
                subset = ex.get("subset", "")
                subset_lower = str(subset).lower()
                
                if subset_lower == "donotanswer":
                    raw_domain = "safety"
                elif subset_lower == "mt-bench-med":
                    raw_domain = "chat"
        else:
            raw_domain = (ex.get("domain") or 
                          ex.get("category") or 
                          ex.get("subset") or
                          ex.get("source") or 
                          "overall")
        
        domain = safe_extract_scalar(raw_domain, "overall")
        
        flip = safe_extract_scalar(ex.get("flip", 0), default=0)
        
        judges = ex.get("judge", [])
        if not judges:
            continue
        
        judge_text = safe_extract_scalar(judges[0] if judges else "")
        
        valid, correct, _ = check_format_and_correctness(judge_text, flip)
        rows.append({"id": _id, "domain": domain, "valid": valid, "correct": correct})
    
    return pd.DataFrame(rows)

def classify(ci, cr):
    if ci == 1 and cr == 1: return "Both correct"
    if ci == 0 and cr == 1: return "Only reasoning correct"
    if ci == 1 and cr == 0: return "Only instruct correct"
    if ci == 0 and cr == 0: return "Both wrong"
    return None

def infer_model_size(model_name):
    m = re.search(r'(\d+\.?\d*)\s*[bB]\b', model_name)
    return (m.group(1) + 'B') if m else ""

def extract_model_size_for_sorting(model_name):
    match = re.search(r'(\d+\.?\d*)\s*[bB]\b', model_name.lower())
    if match:
        return float(match.group(1))
    return float('inf')

def sort_models_by_size(models):
    return sorted(models, key=extract_model_size_for_sorting)

def safe_to_scalar(x):
    if isinstance(x, (list, tuple)):
        return x[0] if len(x) > 0 else None
    if isinstance(x, np.ndarray):
        return x.item() if x.size == 1 else x[0]
    return x

def clean_dataframe_columns(df, columns_to_clean):
    df_cleaned = df.copy()
    for col in columns_to_clean:
        if col in df_cleaned.columns:
            df_cleaned[col] = df_cleaned[col].apply(safe_to_scalar)
    return df_cleaned

def compute_agreement(resdir, model, dataset):
    p_i = os.path.join(resdir, f"{model}_instruct", dataset, "generations.json")
    p_r = os.path.join(resdir, f"{model}_reasoning", dataset, "generations.json")
    
    if not (os.path.exists(p_i) and os.path.exists(p_r)):
        print(f"⚠️  Missing files for {model} on {dataset}")
        return pd.DataFrame(columns=["model","dataset","domain",*CASE_LABELS,"total"])

    try:
        di, dr = load_json(p_i), load_json(p_r)
        # 🆕 传入dataset参数
        dfi, dfr = build_valid_table(di, dataset), build_valid_table(dr, dataset)
        
        # 清理可能包含列表的列
        dfi = clean_dataframe_columns(dfi, ["id", "domain", "valid", "correct"])
        dfr = clean_dataframe_columns(dfr, ["id", "domain", "valid", "correct"])
        
        df = pd.merge(dfi, dfr, on=["id","domain"], suffixes=("_i","_r"))
        df = df[(df["valid_i"]==True) & (df["valid_r"]==True)]

        df["case"] = [classify(ci, cr) for ci, cr in zip(df["correct_i"], df["correct_r"])]
        df = df.dropna(subset=["case"])

        if df.empty:
            print(f"⚠️  No valid data for {model} on {dataset}")
            return pd.DataFrame(columns=["model","dataset","domain",*CASE_LABELS,"total"])

        # domain-level
        dom = df.groupby(["domain","case"]).size().unstack(fill_value=0)
        dom = dom.reindex(columns=CASE_LABELS, fill_value=0)
        dom["total"] = dom.sum(axis=1)
        dom = dom.reset_index()

        # overall
        ov = df.groupby("case").size().reindex(CASE_LABELS, fill_value=0)
        ov_row = pd.DataFrame([{"domain":"overall", **ov.to_dict(), "total": ov.sum()}])

        out = pd.concat([dom, ov_row], ignore_index=True)
        out["model"] = model
        out["dataset"] = dataset
        
        return out[["model","dataset","domain",*CASE_LABELS,"total"]]
    
    except Exception as e:
        print(f"❌ Error processing {model} on {dataset}: {e}")
        import traceback
        traceback.print_exc()
        return pd.DataFrame(columns=["model","dataset","domain",*CASE_LABELS,"total"])

# ====== Scatter plots ======
def compute_domain_stats(data, dataset_name, mode=None, token_limit=1000):
    rows = []
    for ex in data:
        if dataset_name.lower() == "reward_bench":
            raw_domain = ex.get("category", "overall")
            
            if str(raw_domain).lower() == "unknown":
                subset = ex.get("subset", "")
                subset_lower = str(subset).lower()
                
                if subset_lower == "donotanswer":
                    raw_domain = "safety"
                elif subset_lower == "mt-bench-med":
                    raw_domain = "chat"
        else:
            raw_domain = ex.get("domain") or ex.get("category") or ex.get("subset") or ex.get("source") or "overall"
        
        domain = raw_domain
            
        flip_raw = ex.get("flip", 0)
        flip = flip_raw[0] if isinstance(flip_raw, (list, tuple)) and len(flip_raw) > 0 else flip_raw
        
        judges = ex.get("judge", [])
        toks = ex.get("num_tokens", [])
        
        if not judges:
            continue
        
        jt = judges[0]
        t = toks[0] if len(toks) > 0 else 0
        valid, correct, _ = check_format_and_correctness(jt, flip)

        if mode == "instruct" and t > token_limit:
            valid = False
            correct = None

        rows.append({"domain": domain, "valid": valid, "correct": correct, "tokens": t})
        
    if not rows:
        return {}, {}, {}, {}, {}
    
    df = pd.DataFrame(rows)
    stats = {}
    for dom in df["domain"].unique():
        sub = df[df["domain"] == dom]
        total = len(sub)
        vsub  = sub[sub["valid"] == True]
        vcnt  = len(vsub)
        stats[dom] = {
            "accuracy": vsub["correct"].mean() if vcnt>0 else 0.0,
            "tokens":   vsub["tokens"].mean()  if vcnt>0 else 0.0,
            "valid_count": vcnt, "total_count": total
        }
    return (
        {d: stats[d]["accuracy"] for d in stats},
        {d: stats[d]["tokens"]   for d in stats},
        {d: stats[d]["valid_count"] for d in stats},
        {d: stats[d]["total_count"] for d in stats},
        stats
    )

def collect_one_scatter(resdir, model_name, dataset, instruct_suffix="_instruct", reasoning_suffix="_reasoning"):
    """return tidy df: [model,dataset,domain,delta_acc,cost_ratio,n_valid]"""
    p_i = os.path.join(resdir, f"{model_name}{instruct_suffix}", dataset, "generations.json")
    p_r = os.path.join(resdir, f"{model_name}{reasoning_suffix}", dataset, "generations.json")
    if not (os.path.exists(p_i) and os.path.exists(p_r)):
        print(f"⚠️  Missing: {model_name} × {dataset}")
        return pd.DataFrame(columns=["model","dataset","domain","delta_acc","cost_ratio","n_valid"])
    di = load_json(p_i); dr = load_json(p_r)
    
    acc_i, tok_i, v_i, _, _ = compute_domain_stats(di, dataset, mode="instruct")
    acc_r, tok_r, v_r, _, _ = compute_domain_stats(dr, dataset)

    domains = sorted(set(acc_i.keys()) & set(acc_r.keys()))
    rows = []
    for dom in domains:
        ai, ar = acc_i[dom], acc_r[dom]
        ti, tr = tok_i[dom], tok_r[dom]
        cost_ratio = np.nan if (ti is None or ti==0 or np.isnan(ti)) else (tr/ti)
        n_valid = min(v_i.get(dom,0), v_r.get(dom,0))
        rows.append({
            "model": model_name, "dataset": dataset, "domain": dom,
            "delta_acc": ar - ai, "cost_ratio": cost_ratio, "n_valid": n_valid
        })
    return pd.DataFrame(rows)

def _annotate_all(ax, sub, fontsize=7):
    try:
        from adjustText import adjust_text
        
        texts = []
        for k, row in sub.reset_index(drop=True).iterrows():
            x, y = float(row["cost_ratio"]), float(row["delta_acc"])
            label = format_domain_label(row['domain'])
            
            text = ax.text(x, y, label, fontsize=fontsize,
                          ha="left", va="bottom",
                          path_effects=[pe.withStroke(linewidth=1.5, foreground="white", alpha=0.85)])
            texts.append(text)
        
        adjust_text(texts, 
                   arrowprops=dict(arrowstyle='->', color='gray', lw=0.5, alpha=0.5),
                   ax=ax)
    
    except ImportError:
        print("⚠️  adjustText not installed, using simple positioning")
        for k, row in sub.reset_index(drop=True).iterrows():
            x, y = float(row["cost_ratio"]), float(row["delta_acc"])
            dx = 0.06 if (k % 2 == 0) else -0.06
            dy = 0.004 if (k % 3 == 0) else -0.004
            ax.text(x+dx, y+dy, format_domain_label(row['domain']), fontsize=fontsize,
                    ha="left", va="bottom",
                    path_effects=[pe.withStroke(linewidth=1.5, foreground="white", alpha=0.85)])

def plot_scatter_facets_by_dataset(df_all, out_png, panel_w=2.5, panel_h=2.0, fontsize=8, sharey=True):
    df = df_all.dropna(subset=["cost_ratio","delta_acc"]).copy()
    
    datasets = [ds for ds in DATASET_ORDER if ds in df["dataset"].unique()]
    
    models = sort_models_by_size(df["model"].unique().tolist())
    
    n = len(datasets)
    cols = n 
    rows = 1
    fig, axes = plt.subplots(rows, cols, figsize=(cols*panel_w, rows*panel_h), sharey=sharey)
    if n == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    fixed_colors = ['#d62728', '#2ca02c', '#1f77b4']
    model_color = {m: fixed_colors[i % len(fixed_colors)] for i, m in enumerate(models)}

    for idx, ds in enumerate(datasets):
        ax = axes[idx]
        sub = df[df["dataset"]==ds].copy()
        for m, g in sub.groupby("model"):
            sizes = 30 + 2*np.sqrt(np.maximum(g.get("n_valid",1),1))
            ax.scatter(g["cost_ratio"], g["delta_acc"],
                       s=sizes, color=model_color[m], alpha=0.9,
                       edgecolors="white", linewidths=0.5, label=m)
        _annotate_all(ax, sub, fontsize=6)
        ax.axhline(0, ls="--", c="gray", lw=0.8, alpha=0.5)
        ax.axvline(1, ls="--", c="gray", lw=0.8, alpha=0.5)
        ax.grid(alpha=0.2)
        ax.set_title(get_display_name(ds), fontsize=9, fontweight="bold")
        ax.set_xlabel("Cost ratio", fontsize=8)
        ax.tick_params(axis='x', length=0, labelsize=7)
        ax.tick_params(axis='y', direction='in', labelsize=7)
        yticks = ax.get_yticks()
        ax.set_yticklabels([f'{int(y*100)}' for y in yticks])
        if idx == 0:
            ax.set_ylabel("ΔAccuracy (%)", fontsize=8)

    handles, labels = [], []
    for m in models:
        from matplotlib.patches import Patch
        handles.append(Patch(color=model_color[m]))
        labels.append(m)
    
    fig.legend(handles, labels,
            loc="lower center", bbox_to_anchor=(0.5, -0.08),
            ncol=len(labels), frameon=False, fontsize=7)

    fig.tight_layout()
    plt.savefig(out_png, dpi=600, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved: {out_png}")

# ====== Stacked barplots ======
def plot_grouped_stacked(df_all, datasets, models, out_png, panel_w=2.5, panel_h=2.4, fontsize=8):
    datasets = [ds for ds in DATASET_ORDER if ds in datasets]
    
    models = sort_models_by_size(models)
    
    n = len(datasets)
    fig, axes = plt.subplots(1, n, figsize=(panel_w*n, panel_h), sharey=True)
    if n == 1:
        axes = [axes]

    for ax, ds in zip(axes, datasets):
        sub = df_all[df_all["dataset"] == ds]
        if sub.empty:
            ax.text(0.5, 0.5, f"No data for {ds}", ha="center", va="center", transform=ax.transAxes, fontsize=7)
            ax.axis("off")
            continue

        domains = sorted([d for d in sub["domain"].unique().tolist() if d != "overall"]) + ["overall"]
        x = np.arange(len(domains))
        width = 0.26  

        for mi, model in enumerate(models):
            model_data = sub[sub["model"] == model].copy()
            
            print(f"\n=== Processing {ds} - {model} ===")
            for col in model_data.columns:
                has_list = model_data[col].apply(lambda x: isinstance(x, (list, tuple))).any()
                if has_list:
                    print(f"⚠️  Column '{col}' contains lists/tuples, converting to scalars...")
                    model_data[col] = model_data[col].apply(safe_to_scalar)
            
            if model_data["domain"].duplicated().any():
                print(f"⚠️  Found duplicate domains for {model}, aggregating by mean...")
                duplicates = model_data[model_data["domain"].duplicated(keep=False)]
                print(f"   Duplicate rows:\n{duplicates[['domain'] + CASE_LABELS]}")
                
                numeric_cols = CASE_LABELS + ["total"]
                model_data_grouped = model_data.groupby("domain")[numeric_cols].mean().reset_index()
                g = model_data_grouped.set_index("domain").reindex(domains)
            else:
                g = model_data.set_index("domain").reindex(domains)
            
            g = g.fillna(0)
            g.loc[g["total"] == 0, "total"] = np.nan
            props = g[CASE_LABELS].div(g["total"], axis=0).fillna(0)
            
            for label in CASE_LABELS:
                props[label] = props[label].apply(safe_to_scalar)
            
            offset = (mi - (len(models) - 1) / 2) * width
            x_m = x + offset
            bottom = np.zeros(len(domains))

            for label in CASE_LABELS:
                values = props[label].values
                if values.ndim > 1:
                    values = values.flatten()
                
                values_pct = values * 100
                
                ax.bar(x_m, values_pct, width, bottom=bottom,
                       color=CASE_COLORS[label], edgecolor="white", linewidth=0.3,
                       label=label if mi==0 else "_nolegend_")
                bottom += values_pct

            size_str = infer_model_size(model)
            for xi, top in zip(x_m, bottom):
                ax.text(xi, top + 2, size_str, ha="center", va="bottom",
                        fontsize=3.5, color="#333")

        ax.set_xticks(x)
        ax.set_xticklabels([format_domain_label(d) for d in domains], fontsize=5.5)
        ax.tick_params(axis='x', length=0, labelsize=5.5)
        ax.tick_params(axis='y', direction='in', labelsize=7)
        ax.set_ylim(0, 110) 
        ax.grid(axis="y", alpha=0.2)

        if ax is axes[0]:
            ax.set_ylabel("Proportion (%)", fontsize=8)


    handles = [plt.Line2D([0],[0], color=CASE_COLORS[c], lw=6) for c in CASE_LABELS]
    labels = CASE_LABELS
    fig.legend(handles, labels, ncol=4, loc="lower center",
               bbox_to_anchor=(0.5, -0.08), frameon=False,
               fontsize=7)

    fig.tight_layout()
    plt.savefig(out_png, dpi=600, bbox_inches="tight")
    plt.close()
    print(f"✅ Saved: {out_png}")



def main(args):
    models = [s.strip() for s in args.models.split(",") if s.strip()]
    datasets = [s.strip() for s in args.datasets.split(",") if s.strip()]
    os.makedirs(args.outdir, exist_ok=True)

    dfs = []
    for m in models:
        for ds in datasets:
            df = compute_agreement(args.resdir, m, ds)
            if not df.empty:
                dfs.append(df)
    
    if dfs:
        df_all = pd.concat(dfs, ignore_index=True)
        
        print("\n=== Final data cleaning ===")
        for col in df_all.columns:
            has_list = df_all[col].apply(lambda x: isinstance(x, (list, tuple))).any()
            if has_list:
                print(f"⚠️  Cleaning column '{col}' in final dataframe...")
                df_all[col] = df_all[col].apply(safe_to_scalar)
        
        out_png = os.path.join(args.outdir, "agreement.png")
        plot_grouped_stacked(df_all, datasets, models, out_png, args.panel_w, args.panel_h, args.fontsize)
    else:
        print("⚠️ No data found for stacked bar chart.")

    scatter_frames = []
    for m in models:
        for ds in datasets:
            df_ = collect_one_scatter(args.resdir, m, ds, args.instruct_suffix, args.reasoning_suffix)
            if not df_.empty:
                scatter_frames.append(df_)
    
    if scatter_frames:
        df_scatter = pd.concat(scatter_frames, ignore_index=True)
        out_scatter = os.path.join(args.outdir, "delta_cost.png")
        plot_scatter_facets_by_dataset(df_scatter, out_scatter, panel_w=args.panel_w_scatter, 
                                      panel_h=args.panel_h_scatter, fontsize=args.fontsize_scatter, 
                                      sharey=not args.no_sharey)
    else:
        print("⚠️ No data found for scatter plot.")

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--resdir", type=str, default="./result")
    ap.add_argument("--outdir", type=str, default="./figs")
    ap.add_argument("--models", type=str, required=True)
    ap.add_argument("--datasets", type=str, required=True)
    ap.add_argument("--instruct_suffix", type=str, default="_instruct")
    ap.add_argument("--reasoning_suffix", type=str, default="_reasoning")
    ap.add_argument("--panel_w", type=float, default=2.5)
    ap.add_argument("--panel_h", type=float, default=2.2)
    ap.add_argument("--panel_w_scatter", type=float, default=2.5)
    ap.add_argument("--panel_h_scatter", type=float, default=2.3)
    ap.add_argument("--fontsize", type=int, default=8)
    ap.add_argument("--fontsize_scatter", type=int, default=8)
    ap.add_argument("--no_sharey", action="store_true")
    args = ap.parse_args()
    main(args)