import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

# === 可调参数 ===
DETAIL_PATH = Path("tune_log/llama2/buddy_8/alpaca/training_details.json")
METRICS = ["loss", "policy", "entropy", "CE", "budget", "reward"]

SMOOTH_WINDOW = 20  # SMA 窗口；>1 时启用
EMA_BETA = 0.95  # EMA 平滑系数，越接近 1 越平滑

BIN_SIZE = None  # 如 1000 或 10000；None 表示不分箱
START_FROM_ONE = False  # 横轴从 1 开始

SAVE_DIR = Path("plots/grpo")
SAVE_DIR.mkdir(parents=True, exist_ok=True)


def ema(y, beta=0.95):
    """指数滑动平均，NaN 友好：遇到 NaN 跳过更新。"""
    y = np.asarray(y, dtype=float)

    # from scipy.ndimage import gaussian_filter1d
    # out = gaussian_filter1d(y, sigma=beta)
    # return out

    out = np.empty_like(y)
    out[:] = np.nan
    prev = np.nan
    alpha = 1.0 - beta
    for i, v in enumerate(y):
        if np.isfinite(v):
            if np.isfinite(prev):
                prev = beta * prev + alpha * v
            else:
                prev = v  # 以首个有效值为起点
            out[i] = prev
        else:
            out[i] = prev  # 也可选择保持 NaN；这里延用前值以保持线连续
    return out


def maybe_bin(steps, series, bin_size):
    """可选分箱：把 step 落到最近的左闭右开 bin 上，对每个 bin 取均值。"""
    if bin_size is None or bin_size <= 1:
        return steps, series

    steps = np.asarray(steps, dtype=int)
    series = np.asarray(series, dtype=float)
    bin_left = (steps // bin_size) * bin_size

    # 聚合
    uniq = np.unique(bin_left)
    b_steps, b_vals = [], []
    for b in uniq:
        sel = (bin_left == b)
        vals = series[sel]
        # NaN 友好：只对有效值取均值
        if np.any(np.isfinite(vals)):
            b_steps.append(b)
            b_vals.append(np.nanmean(vals))
        else:
            # 整个 bin 都是 NaN，就跳过
            pass

    return np.array(b_steps), np.array(b_vals, dtype=float)


def main():
    # 读取数据
    with open(DETAIL_PATH, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list) or len(data) == 0:
        raise ValueError("JSON 内容不是非空列表，请检查文件。")

    num_steps = len(data)
    steps = np.arange(1, num_steps + 1) if START_FROM_ONE else np.arange(num_steps)

    for m in METRICS:
        # 取该指标序列，缺失用 NaN 占位
        series = []
        for rec in data:
            if isinstance(rec, dict) and (m in rec):
                if m in ["loss", "policy"]:
                    series.append(-rec[m])
                else:
                    series.append(rec[m])
            else:
                series.append(np.nan)
        series = np.array(series, dtype=float)

        # 分箱（先分箱再平滑，一般更稳）
        s_steps, s_series = maybe_bin(steps, series, BIN_SIZE)

        # 平滑
        smoothed = ema(s_series, beta=EMA_BETA)
        smooth_note = f"EMA β={EMA_BETA}"

        # 画图：原始(浅) + 平滑(实线)
        plt.figure(figsize=(8.5, 4.4))
        # 原始曲线：用分箱前的原始数据（更完整），避免点太多可改为 alpha 更低
        plt.plot(steps, series, linewidth=0.8, alpha=0.05, label="raw")

        # 分箱后的平滑主线
        plt.plot(s_steps, smoothed, linewidth=1.8, label=f"smoothed ({smooth_note})")

        # 标题里写明是否分箱
        bin_note = f", bin={BIN_SIZE}" if BIN_SIZE else ""
        plt.title(f"{m} curve ({smooth_note}{bin_note})", fontsize=22)
        plt.xlabel("step", fontsize=18)
        plt.ylabel(m, fontsize=18)
        plt.grid(True, linestyle="--", alpha=0.5)
        plt.legend(loc="best", fontsize=18)
        plt.tick_params(axis='both', labelsize=16)
        plt.tight_layout()

        out_path = SAVE_DIR / f"{m}.png"
        plt.savefig(SAVE_DIR / f"{m}.png", dpi=300)
        plt.savefig(SAVE_DIR / f"{m}.pdf")
        plt.show()
        print(f"Saved: {out_path.resolve()}")


if __name__ == "__main__":
    main()
