#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import re
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Iterable
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon, norm

# -------------------- 配置 --------------------
JSON_FOLDER = Path("./")  # 放 JSON 的目录

# 可读名称映射（用于输出表更友好）
NAME_MAP_MODEL = {
    "model1": "TrustPOMDP",
    "model2": "FCP",
    "model3": "MEP",
}

# -------------------- 基础工具 --------------------
MODEL_REGEX = re.compile(r"(model\d+)", flags=re.IGNORECASE)
LAYOUT_REGEX = re.compile(r"(layout\d+)", flags=re.IGNORECASE)

def parse_model_id(config_id: Optional[str]) -> Optional[str]:
    if not config_id or not isinstance(config_id, str):
        return None
    m = MODEL_REGEX.search(config_id)
    return m.group(1).lower() if m else None

def parse_layout_id(config_id: Optional[str]) -> Optional[str]:
    if not config_id or not isinstance(config_id, str):
        return None
    m = LAYOUT_REGEX.search(config_id)
    return m.group(1).lower() if m else None

def order_models(models: List[str]) -> List[str]:
    def key_fn(m: str) -> Tuple[int, str]:
        m = m or ""
        mnum = re.findall(r"model(\d+)", m)
        return (int(mnum[0]) if mnum else 1_000_000, m)
    return sorted(models, key=key_fn)

def _bh_fdr(pvals: List[float]) -> List[float]:
    """Benjamini–Hochberg FDR 校正。"""
    if not pvals:
        return []
    m = len(pvals)
    order = np.argsort(pvals)
    ranked = np.array(pvals)[order]
    q = np.minimum.accumulate((ranked * m / (np.arange(m) + 1))[::-1])[::-1]
    q = np.minimum(q, 1.0)
    out = np.empty_like(q)
    out[order] = q
    return out.tolist()

def _wilcoxon_effect_size(x: Iterable[float], y: Iterable[float], p: float) -> float:
    """Wilcoxon 近似效应量 r = z/sqrt(N)；方向由均值差决定。"""
    x, y = np.asarray(x), np.asarray(y)
    n = np.sum(~np.isnan(x) & ~np.isnan(y))
    if n <= 0 or p is None or p <= 0:
        return np.nan
    z = norm.isf(p / 2.0)   # 双尾 -> |z|
    if np.nanmean(x - y) < 0:
        z = -z
    return float(z / np.sqrt(n))

# -------------------- 数据抽取（最小集） --------------------
def safe_get_last_cumulative_reward(step_logs: List[Dict[str, Any]]) -> Optional[float]:
    if not step_logs:
        return None
    return step_logs[-1].get("cumulative_reward")

def mean_ignore_none(values: List[Optional[float]]) -> Optional[float]:
    nums = [v for v in values if v is not None]
    if not nums:
        return None
    return float(np.mean(nums))

def discover_question_keys(questionnaires: Dict[str, Dict[str, Any]]) -> List[str]:
    """自动收集 task1/2/3 下所有问卷 key（只收 q_ 开头）。"""
    keys = set()
    for k in ("task1", "task2", "task3"):
        q = questionnaires.get(k) or {}
        for kk in q.keys():
            if isinstance(kk, str) and kk.startswith("q_"):
                keys.add(kk)
    return sorted(keys)

def extract_from_file(path: Path) -> Dict[str, Any]:
    """从单个 JSON 中抽取最小必需信息：被试 id、三个 task 的 configId、均值 reward、问卷答案。"""
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)

    prolific_id = data.get("prolificId")

    a = data.get("assignment", {}) or {}
    cfg1 = (a.get("task1") or {}).get("configId")
    cfg2 = (a.get("task2") or {}).get("configId")
    cfg3 = (a.get("task3") or {}).get("configId")

    rounds = data.get("rounds", []) or []
    rounds_sorted = sorted(rounds, key=lambda r: (r.get("task", 1), r.get("round", 0)))

    # 12 局 reward：按 (task, round) 映射到 1..12 的位置
    idx_map = {(1,i+1):(i) for i in range(4)}
    idx_map.update({(2,i+1):(4+i) for i in range(4)})
    idx_map.update({(3,i+1):(8+i) for i in range(4)})

    per_round_reward = [None]*12
    for r in rounds_sorted:
        task = r.get("task"); rnd = r.get("round")
        if (task, rnd) in idx_map:
            i = idx_map[(task, rnd)]
            per_round_reward[i] = safe_get_last_cumulative_reward(r.get("stepLogs", []))

    # 按 task 聚合四局均值
    avg_t1 = mean_ignore_none(per_round_reward[0:4])
    avg_t2 = mean_ignore_none(per_round_reward[4:8])
    avg_t3 = mean_ignore_none(per_round_reward[8:12])

    questionnaires = data.get("questionnaires", {}) or {}
    q_keys = discover_question_keys(questionnaires)

    row = {
        "prolific_id": prolific_id,
        "configid_task1": cfg1,
        "configid_task2": cfg2,
        "configid_task3": cfg3,
        "avg_reward_task1": avg_t1,
        "avg_reward_task2": avg_t2,
        "avg_reward_task3": avg_t3,
    }

    # 每题的 task1/2/3 回答
    for qk in q_keys:
        row[f"task1_{qk}"] = (questionnaires.get("task1") or {}).get(qk)
        row[f"task2_{qk}"] = (questionnaires.get("task2") or {}).get(qk)
        row[f"task3_{qk}"] = (questionnaires.get("task3") or {}).get(qk)

    return row

def long_samples_from_df(df: pd.DataFrame) -> pd.DataFrame:
    """把每个参与者的 task1/2/3 均值展开为长表：一行=一个 task 样本。"""
    recs = []
    for _, r in df.iterrows():
        for t in (1,2,3):
            cfg = r.get(f"configid_task{t}")
            val = r.get(f"avg_reward_task{t}")
            if pd.notna(val) and isinstance(cfg, str):
                recs.append({
                    "prolific_id": str(r.get("prolific_id")),
                    "task": t,
                    "model": parse_model_id(cfg),
                    "layout": parse_layout_id(cfg),  # 虽然本分析不用，但保留字段无妨
                    "value": float(val),
                })
    return pd.DataFrame(recs)

# -------------------- 两两配对比较（不分 layout） --------------------
def _pretty_model(m: str) -> str:
    return NAME_MAP_MODEL.get(m, m)

def pairwise_reward_paired_all(samples: pd.DataFrame, out_csv: Path):
    """
    Reward：不分 layout。以被试为单位，先对同一被试×同一模型的值做均值，再对模型两两做配对 Wilcoxon。
    输出：model_a, model_b, n, mean_a, mean_b, mean_diff, stat, p, q, effect_r
    """
    df = samples.dropna(subset=["prolific_id", "model", "value"]).copy()
    df["prolific_id"] = df["prolific_id"].astype(str)
    df["model"] = df["model"].astype(str)

    piv = df.pivot_table(index="prolific_id", columns="model", values="value", aggfunc="mean")
    models = order_models([c for c in piv.columns if c is not None])

    rows = []
    from itertools import combinations
    for a, b in combinations(models, 2):
        xa = piv[a].values
        xb = piv[b].values
        mask = ~np.isnan(xa) & ~np.isnan(xb)
        xa, xb = xa[mask], xb[mask]
        n = int(mask.sum())
        if n < 5:
            continue
        try:
            st, p = wilcoxon(xa, xb, zero_method='wilcox', correction=False,
                             alternative="two-sided", mode="auto")
        except Exception:
            st, p = wilcoxon(xa, xb)
        r_eff = _wilcoxon_effect_size(xa, xb, p)
        rows.append({
            "model_a": a, "model_b": b,
            "model_a_name": _pretty_model(a), "model_b_name": _pretty_model(b),
            "n": n,
            "mean_a": float(np.nanmean(xa)), "mean_b": float(np.nanmean(xb)),
            "mean_diff": float(np.nanmean(xa - xb)),
            "stat": float(st), "p": float(p), "effect_r": float(r_eff),
        })

    out = pd.DataFrame(rows)
    if out.empty:
        Path(out_csv).write_text("", encoding="utf-8")
        print("\n==== Pairwise Reward (Overall, Paired) ====\n[无可用数据]")
        return

    out["q"] = _bh_fdr(out["p"].tolist())
    out = out.sort_values(["model_a", "model_b"]).reset_index(drop=True)
    out.to_csv(out_csv, index=False, encoding="utf-8-sig")

    print("\n==== Pairwise Reward (Overall, Paired Wilcoxon) ====")
    disp = out.copy()
    disp["model_a"] = disp["model_a_name"]; disp["model_b"] = disp["model_b_name"]
    print(disp[["model_a","model_b","n","mean_a","mean_b","mean_diff","stat","p","q","effect_r"]]
          .to_string(index=False))

def pairwise_questionnaire_paired_all(df: pd.DataFrame, out_csv: Path):
    """
    Questionnaire：不分 layout/题目。以被试为单位，先对同一被试×同一模型在所有 task/题目上的评分求均值，
    再做模型两两的配对 Wilcoxon。
    输出：model_a, model_b, n, mean_a, mean_b, mean_diff, stat, p, q, effect_r
    """
    # 1) 构建长表：prolific_id, model, value（收集 task1/2/3 下的所有 q_*）
    recs = []
    for _, r in df.iterrows():
        pid = r.get("prolific_id")
        for t in (1,2,3):
            cfg = r.get(f"configid_task{t}")
            model = parse_model_id(cfg) if isinstance(cfg, str) else None
            if not model:
                continue
            for col in r.index:
                if isinstance(col, str) and col.startswith(f"task{t}_q_"):
                    val = r[col]
                    if pd.notna(val):
                        try:
                            recs.append({"prolific_id": str(pid), "model": str(model), "value": float(val)})
                        except Exception:
                            continue

    if not recs:
        Path(out_csv).write_text("", encoding="utf-8")
        print("\n==== Pairwise Questionnaire (Overall, Paired) ====\n[无可用数据]")
        return

    qdf = pd.DataFrame(recs)
    qdf["prolific_id"] = qdf["prolific_id"].astype(str)
    qdf["model"] = qdf["model"].astype(str)

    # 2) 被试×模型 聚合（跨 task/题目）
    piv = qdf.pivot_table(index="prolific_id", columns="model", values="value", aggfunc="mean")
    models = order_models([c for c in piv.columns if c is not None])

    # 3) 两两配对 Wilcoxon
    rows = []
    from itertools import combinations
    for a, b in combinations(models, 2):
        xa = piv[a].values
        xb = piv[b].values
        mask = ~np.isnan(xa) & ~np.isnan(xb)
        xa, xb = xa[mask], xb[mask]
        n = int(mask.sum())
        if n < 5:
            continue
        try:
            st, p = wilcoxon(xa, xb, zero_method='wilcox', correction=False,
                             alternative="two-sided", mode="auto")
        except Exception:
            st, p = wilcoxon(xa, xb)
        r_eff = _wilcoxon_effect_size(xa, xb, p)
        rows.append({
            "model_a": a, "model_b": b,
            "model_a_name": _pretty_model(a), "model_b_name": _pretty_model(b),
            "n": n,
            "mean_a": float(np.nanmean(xa)), "mean_b": float(np.nanmean(xb)),
            "mean_diff": float(np.nanmean(xa - xb)),
            "stat": float(st), "p": float(p), "effect_r": float(r_eff),
        })

    out = pd.DataFrame(rows)
    if out.empty:
        Path(out_csv).write_text("", encoding="utf-8")
        print("\n==== Pairwise Questionnaire (Overall, Paired) ====\n[无可用数据]")
        return

    out["q"] = _bh_fdr(out["p"].tolist())
    out = out.sort_values(["model_a", "model_b"]).reset_index(drop=True)
    out.to_csv(out_csv, index=False, encoding="utf-8-sig")

    print("\n==== Pairwise Questionnaire (Overall, Paired Wilcoxon) ====")
    disp = out.copy()
    disp["model_a"] = disp["model_a_name"]; disp["model_b"] = disp["model_b_name"]
    print(disp[["model_a","model_b","n","mean_a","mean_b","mean_diff","stat","p","q","effect_r"]]
          .to_string(index=False))


def _pretty_question_label(qk: str) -> str:
    # q_trust_me -> Trust Me
    s = qk[2:] if qk.startswith("q_") else qk
    return re.sub(r"_+", " ", s).strip().title()

def pairwise_questionnaire_paired_by_question(df: pd.DataFrame, out_csv: Path):
    """
    问卷：按“每一道题目（q_*）”分别做三模型两两“配对”比较（以被试为单位）。
    步骤：
      1) 构建长表（prolific_id, model, question, value），合并 task1/2/3 的同名题目
      2) 对同一题目下，被试×模型聚合（跨 task 求均值）
      3) 对模型两两做配对 Wilcoxon；在“该题目内部”做 BH-FDR，报告效应量 r
    输出列：
      question, question_label, model_a, model_b, n, mean_a, mean_b, mean_diff, stat, p, q, effect_r
    """
    # 1) 构建长表
    recs = []
    for _, r in df.iterrows():
        pid = r.get("prolific_id")
        for t in (1, 2, 3):
            cfg = r.get(f"configid_task{t}")
            model = parse_model_id(cfg) if isinstance(cfg, str) else None
            if not model:
                continue
            for col in r.index:
                if isinstance(col, str) and col.startswith(f"task{t}_q_"):
                    qk = col[len(f"task{t}_"):]  # 还原 q_*
                    val = r[col]
                    if pd.notna(val):
                        try:
                            recs.append({
                                "prolific_id": str(pid),
                                "model": str(model),
                                "question": str(qk),
                                "value": float(val)
                            })
                        except Exception:
                            continue

    if not recs:
        Path(out_csv).write_text("", encoding="utf-8")
        print("\n==== Pairwise Questionnaire by Question (Paired) ====\n[无可用数据]")
        return

    qdf = pd.DataFrame(recs)
    qdf["prolific_id"] = qdf["prolific_id"].astype(str)
    qdf["model"] = qdf["model"].astype(str)
    qdf["question"] = qdf["question"].astype(str)

    from itertools import combinations
    rows = []

    # 2) 按题目逐一检验
    for qk, sub in qdf.groupby("question"):
        # 被试×模型 聚合（跨 task，对该题目求均值）
        piv = sub.pivot_table(index="prolific_id", columns="model", values="value", aggfunc="mean")
        models = order_models([c for c in piv.columns if c is not None])
        pair_rows = []
        for a, b in combinations(models, 2):
            xa = piv[a].values
            xb = piv[b].values
            mask = ~np.isnan(xa) & ~np.isnan(xb)
            xa, xb = xa[mask], xb[mask]
            n = int(mask.sum())
            if n < 5:  # 至少若干配对
                continue
            try:
                st, p = wilcoxon(xa, xb, zero_method='wilcox', correction=False,
                                 alternative="two-sided", mode="auto")
            except Exception:
                st, p = wilcoxon(xa, xb)
            r_eff = _wilcoxon_effect_size(xa, xb, p)
            pair_rows.append({
                "question": qk,
                "question_label": _pretty_question_label(qk),
                "model_a": a, "model_b": b,
                "model_a_name": NAME_MAP_MODEL.get(a, a),
                "model_b_name": NAME_MAP_MODEL.get(b, b),
                "n": n,
                "mean_a": float(np.nanmean(xa)), "mean_b": float(np.nanmean(xb)),
                "mean_diff": float(np.nanmean(xa - xb)),
                "stat": float(st), "p": float(p), "effect_r": float(r_eff),
            })

        # 该题目内部做 BH-FDR
        if pair_rows:
            pvals = [r["p"] for r in pair_rows]
            qvals = _bh_fdr(pvals)
            for r, q in zip(pair_rows, qvals):
                r["q"] = q
            rows.extend(pair_rows)

    out = pd.DataFrame(rows)
    if out.empty:
        Path(out_csv).write_text("", encoding="utf-8")
        print("\n==== Pairwise Questionnaire by Question (Paired) ====\n[无可用数据]")
        return

    # 排序：按题目名，再按模型对
    out = out.sort_values(["question_label", "model_a", "model_b"]).reset_index(drop=True)
    out.to_csv(out_csv, index=False, encoding="utf-8-sig")

    print("\n==== Pairwise Questionnaire by Question (Paired Wilcoxon) ====")
    disp = out.copy()
    disp["model_a"] = disp["model_a_name"]; disp["model_b"] = disp["model_b_name"]
    print(disp[["question_label","model_a","model_b","n","mean_a","mean_b","mean_diff","stat","p","q","effect_r"]]
          .to_string(index=False))



# -------------------- 主流程 --------------------
def main():
    folder = JSON_FOLDER.expanduser().resolve()
    files = sorted(folder.glob("*.json"))
    if not files:
        print(f"未找到 JSON：{folder}")
        return

    rows = []
    for fp in files:
        try:
            rows.append(extract_from_file(fp))
        except Exception as e:
            print(f"[解析失败] {fp.name}: {e}")
    df = pd.DataFrame(rows)

    # 导出最小结果表，便于追踪
    out_csv = folder / "results_minimal.csv"
    df.to_csv(out_csv, index=False, encoding="utf-8-sig")
    print(f"已生成：{out_csv}")

    # 构建 reward 的长表
    samples = long_samples_from_df(df)

    # 两两配对比较（不分 layout）
    pairwise_reward_paired_all(samples, folder / "pairwise_reward_overall_paired.csv")
    # ===== 两两比较（不分 layout）=====
    pairwise_reward_paired_all(samples, folder / "pairwise_reward_overall_paired.csv")
    pairwise_questionnaire_paired_by_question(df, folder / "pairwise_questionnaire_by_question_paired.csv")


if __name__ == "__main__":
    main()
