
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
读取 Overcooked 用户实验 JSON（3 tasks, 3 models, 多 layouts, 新问卷），输出 results.csv，并绘图：
1) reward：按 model 跨 layout 的总体均值（mean±SEM）
2) reward：按 layout×model 的均值（mean±SEM）
3) questionnaire：按 question×model 的均值（mean±SEM）
4) per-participant：12 个 round（reward / persona / model / layout）

新增：把 12 个 round 的 personaFidelity 也写入 CSV（persona_fidelity_round1..12）
"""

import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt


# ======== ICLR/NeurIPS 风格：统一样式 ========
def set_paper_style():
    plt.rcParams.update({
        "figure.dpi": 160,
        "savefig.dpi": 300,
        "savefig.transparent": True,
        "axes.spines.top": False,
        "axes.spines.right": False,
        "axes.linewidth": 1.0,
        "axes.grid": True,
        "grid.linestyle": "--",
        "grid.linewidth": 0.6,
        "grid.alpha": 0.35,
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "font.size": 11,
        "mathtext.default": "regular",
        "figure.autolayout": False,  # 我们手动 tight_layout()
    })

# 色弱友好调色板（Paul Tol）
_TOL_COLORS = [
    "#4477AA",  # blue
    "#66CCEE",  # cyan
    "#228833",  # green
    "#CCBB44",  # yellow
    "#EE6677",  # red
    "#AA3377",  # purple
    "#BBBBBB",  # grey
]
PALETTE = _TOL_COLORS  # 替换原先的 tab10


# ======== Questionnaire: key 映射 / 排除 / 排序 ========  # 【新增】
# JSON 原始字段 -> 目标展示名
QUESTION_LABEL_MAP = {  # 你可按实际 JSON 字段名微调左侧 key
    "q_trust":               "I Trust Agent",
    "q_agenttrust":        "Agent Trusts Me",
    "q_adaptivity":             "Agent's Adaptability",
    "q_understandability":      "I Understand Agent",
    "q_understandability2":     "Agent Understands Me",
    "q_satisfaction":           "Satisfaction",
    "q_willingness":"Cooperation Willingness",
    # "q_attention":            "Attention",  # 不展示
}

# 需要从图中移除的题目 key
EXCLUDED_QUESTIONS = {"q_attention"}  # 【新增】

# 展示顺序（按目标展示名）
QUESTION_DISPLAY_ORDER = [
    "I Trust Agent",
    "Agent Trusts Me",
    "Agent's Adaptability",
    "I Understand Agent",       # 原 Understandability
    "Agent Understands Me",     # 原 Understandability2
    "Satisfaction",
    "Cooperation Willingness",
]



def _finalize(ax, tight=True):
    ax.tick_params(length=3.5, width=0.9)
    ax.margins(x=0.02)
    if tight:
        ax.figure.tight_layout()

def _save_fig(fig, out_path: Path):
    out_path = Path(out_path)
    fig.savefig(out_path.with_suffix(".png"), dpi=300, transparent=True)
    fig.savefig(out_path.with_suffix(".pdf"), transparent=True)  # 矢量版（投稿/论文）


# —— 小清新固定色系（同一 model 在所有图中保持一致）
MODEL_COLOR_MAP = {
    "model1": "#8ECFC9",
    "model2": "#FFBE7A",
    "model3": "#BEB8DC",
}
def color_for_model(model: str) -> str:
    return MODEL_COLOR_MAP.get(model, "#BDBDBD")  # 默认灰
EDGE_COLOR = "#3A3A3A"
BAR_ALPHA = 0.88


# ======== 路径 ========
JSON_FOLDER = Path("./")  # 改成你的目录

# ======== 可读名称映射（按需改）=======
NAME_MAP_MODEL = {
    "model1": "TrustPOMDP",
    "model2": "FCP",
    "model3": "MEP",
}
# layout 直接用 layout1/2/3…；也可在这儿映射成人类可读
NAME_MAP_LAYOUT = {}

PALETTE = plt.get_cmap("tab10").colors

MODEL_REGEX = re.compile(r"(model\d+)", flags=re.IGNORECASE)
LAYOUT_REGEX = re.compile(r"(layout\d+)", flags=re.IGNORECASE)

def parse_model_id(config_id: Optional[str]) -> Optional[str]:
    if not config_id or not isinstance(config_id, str):
        return None
    m = MODEL_REGEX.search(config_id)
    return m.group(1).lower() if m else None

def parse_layout_id(config_id: Optional[str]) -> Optional[str]:
    if not config_id or not isinstance(config_id, str):
        return None
    m = LAYOUT_REGEX.search(config_id)
    return m.group(1).lower() if m else None

def safe_get_last_cumulative_reward(step_logs: List[Dict[str, Any]]) -> Optional[float]:
    if not step_logs:
        return None
    return step_logs[-1].get("cumulative_reward")

def mean_ignore_none(values: List[Optional[float]]) -> Optional[float]:
    nums = [v for v in values if v is not None]
    if not nums:
        return None
    return float(np.mean(nums))

def _series_sem(std: pd.Series, n: pd.Series) -> pd.Series:
    sem = std / np.sqrt(n.clip(lower=1))
    sem = sem.fillna(0.0)
    sem[n <= 1] = 0.0
    return sem

def discover_question_keys(questionnaires: Dict[str, Dict[str, Any]]) -> List[str]:
    """自动收集 task1/2/3 下所有问卷 key（只收 q_ 开头）。"""
    keys = set()
    for k in ("task1", "task2", "task3"):
        q = questionnaires.get(k) or {}
        for kk in q.keys():
            if isinstance(kk, str) and kk.startswith("q_"):
                keys.add(kk)
    return sorted(keys)

def pretty_label_from_key(k: str) -> str:
    # q_understandability2 -> Understandability2
    s = k[2:] if k.startswith("q_") else k
    return re.sub(r"_+", " ", s).strip().title()

def extract_from_file(path: Path) -> Dict[str, Any]:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)

    prolific_id = data.get("prolificId")
    age = data.get("age")
    gender = data.get("gender")

    a = data.get("assignment", {}) or {}
    cfg1 = (a.get("task1") or {}).get("configId")
    cfg2 = (a.get("task2") or {}).get("configId")
    cfg3 = (a.get("task3") or {}).get("configId")

    rounds = data.get("rounds", []) or []
    # 更稳妥：按 (task, round) 排
    rounds_sorted = sorted(rounds, key=lambda r: (r.get("task", 1), r.get("round", 0)))

    # 逐局 reward / persona / personaFidelity（12 局）
    idx_map = {(1,i+1):(i) for i in range(4)}
    idx_map.update({(2,i+1):(4+i) for i in range(4)})
    idx_map.update({(3,i+1):(8+i) for i in range(4)})

    per_round_reward   = [None]*12
    per_round_persona  = [None]*12
    per_round_fid      = [None]*12  # <<< 新增：personaFidelity

    for r in rounds_sorted:
        task = r.get("task"); rnd = r.get("round")
        if (task, rnd) in idx_map:
            i = idx_map[(task, rnd)]
            per_round_reward[i]  = safe_get_last_cumulative_reward(r.get("stepLogs", []))
            per_round_persona[i] = r.get("persona")
            # 取整数/浮点都可，保持原值
            fid = r.get("personaFidelity")
            if fid is not None:
                try:
                    per_round_fid[i] = float(fid)
                except Exception:
                    per_round_fid[i] = fid  # 如果不是数值，原样存

    # 按 task 聚合四局均值
    avg_t1 = mean_ignore_none(per_round_reward[0:4])
    avg_t2 = mean_ignore_none(per_round_reward[4:8])
    avg_t3 = mean_ignore_none(per_round_reward[8:12])

    questionnaires = data.get("questionnaires", {}) or {}
    # 自动收集题目
    q_keys = discover_question_keys(questionnaires)

    row = {
        "prolific_id": prolific_id,
        "age": age,
        "gender": gender,
        "configid_task1": cfg1,
        "configid_task2": cfg2,
        "configid_task3": cfg3,
        "avg_reward_task1": avg_t1,
        "avg_reward_task2": avg_t2,
        "avg_reward_task3": avg_t3,
    }

    # 每题的 task1/2/3 回答并到行里
    for qk in q_keys:
        row[f"task1_{qk}"] = (questionnaires.get("task1") or {}).get(qk)
        row[f"task2_{qk}"] = (questionnaires.get("task2") or {}).get(qk)
        row[f"task3_{qk}"] = (questionnaires.get("task3") or {}).get(qk)

    # 12 局逐局 reward/persona/personaFidelity
    for i, v in enumerate(per_round_reward, start=1):
        row[f"cumulative_reward_round{i}"] = v
    for i, v in enumerate(per_round_persona, start=1):
        row[f"persona_round{i}"] = v
    for i, v in enumerate(per_round_fid, start=1):
        row[f"persona_fidelity_round{i}"] = v   # <<< 新增到 CSV

    return row

def long_samples_from_df(df: pd.DataFrame) -> pd.DataFrame:
    """把每个参与者的 task1/2/3 均值展开为长表：一行=一个 task 样本。"""
    recs = []
    for _, r in df.iterrows():
        for t in (1,2,3):
            cfg = r.get(f"configid_task{t}")
            val = r.get(f"avg_reward_task{t}")
            if pd.notna(val) and isinstance(cfg, str):
                recs.append({
                    "prolific_id": r.get("prolific_id"),
                    "task": t,
                    "model": parse_model_id(cfg),
                    "layout": parse_layout_id(cfg),
                    "value": float(val),
                })
    return pd.DataFrame(recs)

def order_models(models: List[str]) -> List[str]:
    def key_fn(m: str) -> Tuple[int, str]:
        m = m or ""
        mnum = re.findall(r"model(\d+)", m)
        return (int(mnum[0]) if mnum else 1_000_000, m)
    return sorted(models, key=key_fn)

def order_layouts(layouts: List[str]) -> List[str]:
    def key_fn(l: str) -> Tuple[int, str]:
        l = l or ""
        lnum = re.findall(r"layout(\d+)", l)
        return (int(lnum[0]) if lnum else 1_000_000, l)
    return sorted(layouts, key=key_fn)

# ======== 统计：Reward（1）按 model 跨 layout 总体 ========
def build_reward_by_model_overall(samples: pd.DataFrame) -> pd.DataFrame:
    if samples.empty: return pd.DataFrame(columns=["model","mean","sem","n"])
    grp = samples.groupby("model")["value"]
    mean, n = grp.mean(), grp.count()
    std = grp.std(ddof=1)
    sem = _series_sem(std, n)
    out = pd.DataFrame({"model": mean.index, "mean": mean.values, "sem": sem.values, "n": n.values})
    out["model"] = out["model"].astype(str)
    out = out.set_index("model").loc[order_models(list(out["model"]))].reset_index()
    return out

def build_reward_by_model_by_layout(samples: pd.DataFrame) -> pd.DataFrame:
    if samples.empty: 
        return pd.DataFrame(columns=["layout","model","mean","sem","n"])
    grp = samples.groupby(["layout","model"])["value"]
    mean, n = grp.mean(), grp.count()
    std = grp.std(ddof=1)
    sem = _series_sem(std, n)
    out = pd.DataFrame({"mean": mean, "sem": sem, "n": n}).reset_index()
    out["layout"] = out["layout"].astype(str)
    out["model"]  = out["model"].astype(str)

    # ===== 在这里加 layout 映射 =====
    layout_name_map = {
        "layout1": "Divided Room",
        "layout2": "Resource Asymmetry",
        "layout3": "Divided Room-easy",
        "layout4": "Resource Asymmetry-easy",
    }
    out["layout"] = out["layout"].map(layout_name_map).fillna(out["layout"])

    # 排序
    out["l_order"] = out["layout"].map({l:i for i,l in enumerate(order_layouts(out["layout"].unique().tolist()))})
    out["m_order"] = out["model" ].map({m:i for i,m in enumerate(order_models(out["model"].unique().tolist()))})
    out = out.sort_values(["l_order","m_order"]).drop(columns=["l_order","m_order"])
    return out


# ======== 统计：Question（3）按 question×model ========
def build_questions_by_model(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[str], Dict[str,str]]:  # 【修改：整段替换】
    """
    返回：stats_df, question_keys_in_order, label_map
    stats_df 列：['question','model','mean','sem','n']
    其中 question 为 JSON 原始 key（如 q_trust_me），label_map 映射到目标展示名
    """
    recs = []
    for _, r in df.iterrows():
        for t in (1,2,3):
            cfg = r.get(f"configid_task{t}")
            model = parse_model_id(cfg) if isinstance(cfg, str) else None
            if not model:
                continue
            for col in r.index:
                if isinstance(col, str) and col.startswith(f"task{t}_q_"):
                    qkey = col[len(f"task{t}_"):]  # 还原 q_*
                    if qkey in EXCLUDED_QUESTIONS:  # 过滤 attention
                        continue
                    val = r[col]
                    if pd.notna(val):
                        try:
                            recs.append({
                                "model": str(model),
                                "question": qkey,
                                "value": float(val)
                            })
                        except Exception:
                            continue

    if not recs:
        return pd.DataFrame(columns=["question","model","mean","sem","n"]), [], {}

    qdf = pd.DataFrame(recs)
    grp = qdf.groupby(["question","model"])["value"]
    mean, n = grp.mean(), grp.count()
    std = grp.std(ddof=1)
    sem = _series_sem(std, n)
    stats = pd.DataFrame({"mean": mean, "sem": sem, "n": n}).reset_index()

    # 仅保留在映射表中的问题，并建立展示标签
    stats = stats[stats["question"].isin(QUESTION_LABEL_MAP.keys())].copy()
    stats["label"] = stats["question"].map(QUESTION_LABEL_MAP)

    # 按展示顺序排序
    order_map = {name: i for i, name in enumerate(QUESTION_DISPLAY_ORDER)}
    stats["q_order"] = stats["label"].map(order_map)
    stats["m_order"] = stats["model"].map({m:i for i,m in enumerate(order_models(stats["model"].astype(str).unique().tolist()))})
    stats = stats.sort_values(["q_order","m_order"]).drop(columns=["q_order","m_order"])

    # 输出的 qkeys 使用原始 JSON key 的排序（与展示顺序一致）
    qkeys_in_order = []
    for disp in QUESTION_DISPLAY_ORDER:
        keys_this = stats.loc[stats["label"] == disp, "question"].unique().tolist()
        if keys_this:
            qkeys_in_order.append(keys_this[0])

    # label_map: 原始 key -> 展示名
    label_map = {k: QUESTION_LABEL_MAP[k] for k in qkeys_in_order}
    return stats.drop(columns=["label"]), qkeys_in_order, label_map


# ======== 绘图工具 ========
def _label_models(ms: List[str]) -> List[str]:
    return [NAME_MAP_MODEL.get(m, m) for m in ms]

def _label_layout(l: str) -> str:
    return NAME_MAP_LAYOUT.get(l, l)

def plot_reward_by_model_overall(stats_df: pd.DataFrame, out: Path):
    if stats_df.empty:
        print("[跳过绘图] reward_by_model_overall 数据为空")
        return

    set_paper_style()
    models = stats_df["model"].astype(str).tolist()
    labels = _label_models(models)

    x = np.arange(len(stats_df))
    means = stats_df["mean"].values
    sems  = stats_df["sem"].values

    fig, ax = plt.subplots(figsize=(max(5.5, 0.9*len(labels)+4), 4.6))
    bar_width = 0.62

    colors = [color_for_model(m) for m in models]
    ax.bar(
        x, means, bar_width,
        color=colors, edgecolor=EDGE_COLOR, linewidth=1.0, alpha=BAR_ALPHA
    )
    ax.errorbar(
        x, means, yerr=sems, fmt="none",
        ecolor=EDGE_COLOR, elinewidth=1.2, capsize=4, capthick=1.0, zorder=3
    )

    ax.set_xticks(x); ax.set_xticklabels(labels)
    ax.set_ylabel("Avg reward (per task mean)")
    ax.set_title("Reward by Model (across all layouts)")
    _finalize(ax)
    _save_fig(fig, out)
    plt.close(fig)
    print(f"图已保存：{out.with_suffix('.png')} / {out.with_suffix('.pdf')}")


def plot_reward_by_model_by_layout(stats_df: pd.DataFrame, out: Path):
    if stats_df.empty:
        print("[跳过绘图] reward_by_model_by_layout 数据为空")
        return

    set_paper_style()
    layouts = order_layouts(stats_df["layout"].astype(str).unique().tolist())
    models  = order_models(stats_df["model"].astype(str).unique().tolist())

    x = np.arange(len(layouts))
    width = 0.22
    fig, ax = plt.subplots(figsize=(max(7.2, 0.8*len(layouts)+4), 4.8))

    for i, m in enumerate(models):
        sub = stats_df[stats_df["model"] == m]
        means, sems, ns = [], [], []
        for l in layouts:
            row = sub[sub["layout"] == l]
            means.append(float(row["mean"].iloc[0]) if len(row) else np.nan)
            sems.append(float(row["sem"].iloc[0]) if len(row) else 0.0)
            ns.append(int(row["n"].iloc[0]) if len(row) and "n" in row else 0)

        pos = x + (i - (len(models)-1)/2) * width
        color = color_for_model(m)

        ax.bar(
            pos, means, width,
            label=NAME_MAP_MODEL.get(m, m),
            facecolor=color, edgecolor=EDGE_COLOR, linewidth=1.0, alpha=BAR_ALPHA
        )
        ax.errorbar(
            pos, means, yerr=sems, fmt="none",
            ecolor=EDGE_COLOR, elinewidth=1.0, capsize=4, capthick=1.0, zorder=3
        )

        # 如需在柱顶显示 n=，可保留/合并你之前的标注代码，这里略。

    ax.set_xticks(x)
    ax.set_xticklabels([_label_layout(l) for l in layouts])
    ax.set_ylabel("Avg reward (per task mean)")
    ax.set_xlabel("Layout")
    ax.set_title("Reward by Model within each Layout")
    ax.legend(frameon=False, ncol=min(3, len(models)), handlelength=1.2, columnspacing=0.9)
    _finalize(ax)
    _save_fig(fig, out)
    plt.close(fig)
    print(f"图已保存：{out.with_suffix('.png')} / {out.with_suffix('.pdf')}")


# ======== 每参与者 12 个 round（含 persona / model / layout）========
def plot_round_persona_rewards_per_participant(df: pd.DataFrame, outdir: Path):
    """
    为每个参与者绘制 12 个 round 的 reward、persona、model、layout。
    - x轴：Round1~12
    - y轴：cumulative reward
    - bar颜色：3 个 model (TrustPOMDP/FCP/MEP)
    - bar纹理：layout（layout1/2/3/4…）
    - bar顶部：persona
    """
    outdir.mkdir(parents=True, exist_ok=True)

    # 颜色映射（3 模型）
    model_colors = MODEL_COLOR_MAP
    
    # hatch 映射（可按需扩展）
    hatch_map = {
        "layout1": "//",
        "layout2": "\\\\",
        "layout3": "xx",
        "layout4": "..",
        "unknown": "",
    }

    rounds = list(range(1, 12+1))

    from matplotlib.patches import Patch
    for _, row in df.iterrows():
        pid  = row.get("prolific_id")
        cfg1 = row.get("configid_task1")
        cfg2 = row.get("configid_task2")
        cfg3 = row.get("configid_task3")

        model1, layout1 = parse_model_id(cfg1), parse_layout_id(cfg1)
        model2, layout2 = parse_model_id(cfg2), parse_layout_id(cfg2)
        model3, layout3 = parse_model_id(cfg3), parse_layout_id(cfg3)

        rewards, personas, models, layouts = [], [], [], []
        for r in rounds:
            reward  = row.get(f"cumulative_reward_round{r}")
            persona = row.get(f"persona_round{r}")
            if r <= 4:
                m, l = model1, layout1
            elif r <= 8:
                m, l = model2, layout2
            else:
                m, l = model3, layout3

            rewards.append(np.nan if reward is None else reward)
            personas.append(persona if persona is not None else "")
            models.append(m or "unknown")
            layouts.append(l or "unknown")

        # 绘图
        fig, ax = plt.subplots(figsize=(11, 6))
        x = np.arange(1, 13)
        max_y = np.nanmax(rewards) if np.any(~np.isnan(rewards)) else 1.0
        width = 0.6

        for i, r in enumerate(rounds):
            m = models[i]
            l = layouts[i]
            val = rewards[i]
            color = model_colors.get(m, "gray")
            hatch = hatch_map.get(l, "")

            ax.bar(r, val, width, color=color, alpha=0.7,
                   edgecolor=color, linewidth=1, hatch=hatch)
            # persona 标注
            if isinstance(personas[i], str) and len(personas[i].strip()) > 0 and pd.notna(val):
                ax.text(r, val + 0.05 * max_y, personas[i],
                        ha="center", va="bottom", fontsize=9, rotation=45)

        ax.set_xticks(rounds)
        ax.set_xticklabels([f"R{r}" for r in rounds])
        ax.set_ylabel("Cumulative Reward")
        ax.set_title(f"Participant {pid}: Round-wise Reward / Persona / Model / Layout")

        # 构造图例（模型颜色 + 布局纹理）
        legend_model = [Patch(facecolor=model_colors.get(m, "gray"),
                              edgecolor=model_colors.get(m, "gray"),
                              alpha=0.7,
                              label=NAME_MAP_MODEL.get(m, m))
                        for m in order_models(list(set([models[i] for i in [0,4,8] if i < len(models)])))]
        # 所有出现过的布局
        layouts_seen = sorted(set(layouts), key=lambda l: (l=="unknown", l))
        legend_layout = [Patch(facecolor="white", edgecolor="black",
                               hatch=hatch_map.get(l, ""), label=_label_layout(l))
                         for l in layouts_seen if l != "unknown"]

        ax.legend(handles=legend_model + legend_layout, title="Legend", loc="best")
        ax.grid(axis="y", linestyle="--", alpha=0.35)

        fig.tight_layout()
        outfile = outdir / f"persona_model_layout_reward_{pid}.png"
        fig.savefig(outfile, dpi=200)
        plt.close(fig)
        print(f"图已保存：{outfile}")


def plot_questions_by_model_per_question(stats_df: pd.DataFrame, qkeys: List[str], label_map: Dict[str,str], out: Path):
    if stats_df.empty:
        print("[跳过绘图] question_by_model 数据为空")
        return

    set_paper_style()
    models = order_models(stats_df["model"].astype(str).unique().tolist())
    x = np.arange(len(qkeys))
    width = 0.22

    fig, ax = plt.subplots(figsize=(max(8.5, 0.55*len(qkeys)+4), 4.8))

    for i, m in enumerate(models):
        sub = stats_df[stats_df["model"] == m]
        means, sems = [], []
        for q in qkeys:
            row = sub[sub["question"] == q]
            means.append(float(row["mean"].iloc[0]) if len(row) else np.nan)
            sems.append(float(row["sem"].iloc[0]) if len(row) else 0.0)

        pos = x + (i - (len(models)-1)/2) * width
        color = color_for_model(m)

        ax.bar(
            pos, means, width,
            label=NAME_MAP_MODEL.get(m, m),
            facecolor=color, edgecolor=EDGE_COLOR, linewidth=1.0, alpha=BAR_ALPHA
        )
        ax.errorbar(pos, means, yerr=sems, fmt="none",
                    ecolor=EDGE_COLOR, elinewidth=1.0, capsize=4, capthick=1.0)

    ax.set_xticks(x)
    ax.set_xticklabels([label_map[q] for q in qkeys], rotation=15, ha="right")
    ax.set_ylabel("Rating (1–7)")
    ax.set_xlabel("Question")
    ax.set_title("Questionnaire Scores by Model")
    ax.legend(frameon=False, ncol=min(3, len(models)))
    _finalize(ax)
    _save_fig(fig, out)
    plt.close(fig)
    print(f"图已保存：{out.with_suffix('.png')} / {out.with_suffix('.pdf')}")




def plot_overall_and_questions_side_by_side(
    reward_stats_df: pd.DataFrame,
    q_stats_df: pd.DataFrame,
    qkeys: List[str],
    label_map: Dict[str, str],
    out: Path
):  # 【新增】
    """
    左右拼图：
    - 左：reward_by_model_overall
    - 右：question_by_model_per_question
    """
    if reward_stats_df.empty or q_stats_df.empty or not qkeys:
        print("[跳过绘图] 合成图数据不足")
        return

    set_paper_style()
    models_reward = reward_stats_df["model"].astype(str).tolist()
    labels_reward = _label_models(models_reward)

    fig, (axL, axR) = plt.subplots(
        1, 2,
        figsize=(20, 6),   # 宽20，高6
        sharey=False,
        gridspec_kw={"width_ratios": [1, 3]}
    )
    # --- 左：overall reward ---
    xL = np.arange(len(reward_stats_df))
    means = reward_stats_df["mean"].values
    sems  = reward_stats_df["sem"].values
    colorsL = [color_for_model(m) for m in models_reward]
    axL.bar(xL, means, 0.62, color=colorsL, edgecolor=EDGE_COLOR, linewidth=1.0, alpha=BAR_ALPHA)
    axL.errorbar(xL, means, yerr=sems, fmt="none", ecolor=EDGE_COLOR, elinewidth=1.2, capsize=4, capthick=1.0, zorder=3)
    axL.set_xticks(xL); axL.set_xticklabels(labels_reward, fontsize=18)
    axL.set_ylabel("Avg reward", fontsize=24)
    axL.set_title("Overall Team Reward", fontsize=24)
    axL.tick_params(axis='y', labelsize=20)   # ← 新增：左图 y 轴刻度字号

    # --- 右：questionnaire per question ---
    models_q = order_models(q_stats_df["model"].astype(str).unique().tolist())
    xR = np.arange(len(qkeys))
    width = 0.22
    for i, m in enumerate(models_q):
        sub = q_stats_df[q_stats_df["model"] == m]
        means_q, sems_q = [], []
        for q in qkeys:
            row = sub[sub["question"] == q]
            means_q.append(float(row["mean"].iloc[0]) if len(row) else np.nan)
            sems_q.append(float(row["sem"].iloc[0]) if len(row) else 0.0)
        pos = xR + (i - (len(models_q)-1)/2) * width
        color = color_for_model(m)
        axR.bar(pos, means_q, width, label=NAME_MAP_MODEL.get(m, m),
                facecolor=color, edgecolor=EDGE_COLOR, linewidth=1.0, alpha=BAR_ALPHA)
        axR.errorbar(pos, means_q, yerr=sems_q, fmt="none",
                     ecolor=EDGE_COLOR, elinewidth=1.0, capsize=4, capthick=1.0)

    axR.set_xticks(xR)
    axR.set_xticklabels([label_map[q] for q in qkeys], rotation=15, ha="right", fontsize=20)

    axR.set_ylabel("Rating (1–7)", fontsize=24)
    axR.set_title("Humans' Subjective Perception", fontsize=24)
    axR.legend(frameon=False, ncol=min(3, len(models_q)), fontsize=20)
    axR.tick_params(axis='y', labelsize=20)   # ← 新增：右图 y 轴刻度字号

    axL.text(0.0, 1.02, "(a)", transform=axL.transAxes, fontsize=20, fontweight="bold", ha="left", va="bottom")
    axR.text(0.0, 1.02, "(b)", transform=axR.transAxes, fontsize=20, fontweight="bold", ha="left", va="bottom")



    fig.tight_layout()
    _save_fig(fig, out)
    plt.close(fig)
    print(f"图已保存：{out.with_suffix('.png')} / {out.with_suffix('.pdf')}")






# ======== 主流程 ========
def main():
    folder = JSON_FOLDER.expanduser().resolve()
    files = sorted(folder.glob("*.json"))
    if not files:
        print(f"未找到 JSON：{folder}")
        return

    rows = []
    for fp in files:
        try:
            rows.append(extract_from_file(fp))
        except Exception as e:
            print(f"[解析失败] {fp.name}: {e}")
    df = pd.DataFrame(rows)

    # 导出 CSV（便于你核查）
    out_csv = folder / "results.csv"
    df.to_csv(out_csv, index=False, encoding="utf-8-sig")
    print(f"已生成：{out_csv}")

    # ===== Reward 统计 =====
    samples = long_samples_from_df(df)  # 每行=一个参与者×task 的均值样本，含 model/layout
    reward_overall = build_reward_by_model_overall(samples)
    reward_by_ml   = build_reward_by_model_by_layout(samples)

    if not reward_overall.empty:
        print("\n==== Reward by Model (overall across layouts) ====")
        print(reward_overall.to_string(index=False))
    if not reward_by_ml.empty:
        print("\n==== Reward by Model within each Layout ====")
        piv = reward_by_ml.pivot_table(index="layout", columns="model", values="mean")
        print(piv.to_string())

    plot_reward_by_model_overall(reward_overall, folder / "reward_by_model_overall.png")
    plot_reward_by_model_by_layout(reward_by_ml, folder / "reward_by_model_by_layout.png")

    # ===== Questionnaire 统计 =====
    q_stats, qkeys, label_map = build_questions_by_model(df)
    if not q_stats.empty:
        print("\n==== Questionnaire mean by Question × Model ====")
        pivq = q_stats.pivot_table(index="question", columns="model", values="mean")
        print(pivq.to_string())

    # 使用过滤与排序后的 qkeys/label_map 绘图  # 【修改：确保使用新的顺序/标签】
    plot_questions_by_model_per_question(q_stats, qkeys, label_map, folder / "question_by_model_per_question.png")

    # 新增：左右合成图（overall reward + questionnaire per question）  # 【新增】
    plot_overall_and_questions_side_by_side(
        reward_overall, q_stats, qkeys, label_map, folder / "reward_and_question_combined"
    )

    
    # ===== 每参与者 12 round 图 =====
    plot_round_persona_rewards_per_participant(df, folder / "persona_round_plots")



    # ==== 统计 gender 计数 & age 均值/方差 ====
    # 1) 规范化 gender（可按你数据再扩展映射）
    _gender_map = {
        "male": "male", "m": "male", "man": "male",
        "female": "female", "f": "female", "woman": "female",
        "non-binary": "non-binary", "nonbinary": "non-binary", "nb": "non-binary",
        "other": "other", "prefer not to say": "prefer not to say", "na": "unknown", "none": "unknown", "": "unknown"
    }
    gender_norm = (
        df.get("gender")
        .astype(str)
        .str.strip()
        .str.lower()
        .map(_gender_map)
        .fillna("unknown")
    )
    gender_counts = gender_norm.value_counts(dropna=False).rename_axis("gender").reset_index(name="count")

    # 2) 年龄转数值
    age_num = pd.to_numeric(df.get("age"), errors="coerce")

    # 均值与方差（默认样本方差 ddof=1；如需总体方差用 ddof=0）
    age_mean = float(age_num.mean(skipna=True)) if not age_num.dropna().empty else float("nan")
    age_var  = float(age_num.var(ddof=1))       if not age_num.dropna().empty else float("nan")

    print("\n=== Gender counts ===")
    print(gender_counts.to_string(index=False))
    print("\n=== Age stats ===")
    age_sd = float(age_num.std(ddof=1))  # 样本标准差
    print(f"mean = {age_mean:.3f}, SD (sample) = {age_sd:.3f}")

    #（可选）导出
    # out_demo = folder / "demographics_summary.csv"
    # gender_counts.to_csv(out_demo, index=False, encoding="utf-8-sig")
    # print(f"已导出性别计数：{out_demo}")

    age_counts = (
        age_num.dropna()
            .astype(int)                 # 如果 age 可能是小数且你不想取整，请去掉这行
            .value_counts(dropna=False)
            .sort_index()
            .rename_axis("age")
            .reset_index(name="count")
    )

    print("\n=== Age counts ===")
    print(age_counts.to_string(index=False))



if __name__ == "__main__":
    main()
