import json
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for matplotlib
from matplotlib import pyplot as plt
from src.early_stop_cot import EarlyStopCoT
from pathlib import Path
import numpy as np


rootpath = '/data/project/Reasoning/results/'
model_lis = ['DeepSeek-R1-Distill-Llama-8B', 'QwQ-32B', 'Qwen3-8B']
dataset_lis = ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']

# Aesthetic tweaks
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 16,         # 基础字号
    "axes.titlesize": 17,    # 子图标题
    "axes.labelsize": 16,    # 轴标签
    "xtick.labelsize": 15,
    "ytick.labelsize": 15,
    "legend.fontsize": 14,
})
FIGSIZE = (9.2, 5.4)
LINEWIDTH = 2.4
ALPHA_BAND = 0.18
LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
MARKER_EVERY = None  # 可设为整型使曲线上点更易读，如 20

model_to_seqs = {m: [] for m in model_lis}
for model in model_lis:
    for dataset in dataset_lis:
        path = Path(rootpath) / model / dataset / f'{dataset}_step_results.jsonl'
        early_stop_cot = EarlyStopCoT(path)
        for data in early_stop_cot.data:
            seq = data.get('merged_step_answer_count')
            model_to_seqs[model].extend(seq)
model_to_maxlen = {m: max(len(s) for s in seqs) for m, seqs in model_to_seqs.items()}

def resample_to_length(arr, N):
    """Linear interpolate arr (1D) to length N over relative position 0..1."""
    x_orig = np.linspace(0, 1, len(arr))
    x_new = np.linspace(0, 1, N)
    return np.interp(x_new, x_orig, np.asarray(arr, dtype=float))

model_to_mean = {}
for model, seqs in model_to_seqs.items():
    N = model_to_maxlen[model]
    resampled = np.vstack([resample_to_length(s, N) for s in seqs])
    mean_curve = np.nanmean(resampled, axis=0)
    lo = 0
    hi = 0
    # lo = np.nanpercentile(resampled, 2.5, axis=0)
    # hi = np.nanpercentile(resampled, 97.5, axis=0)
    model_to_mean[model] = {"mean": mean_curve, "lo": lo, "hi": hi}

fig, ax = plt.subplots(figsize=FIGSIZE)

for idx, (model, mean_curve) in enumerate(model_to_mean.items()):
    x_rel = np.linspace(0, 1, len(mean_curve['mean']))
    ls = LINESTYLES[idx % len(LINESTYLES)]
    mean_curve, lo, hi = mean_curve["mean"], mean_curve["lo"], mean_curve["hi"]
    # ax.fill_between(x_rel, lo, hi, alpha=ALPHA_BAND, linewidth=0)
    # 均值曲线
    if MARKER_EVERY:
        ax.plot(x_rel, mean_curve, ls=ls, linewidth=LINEWIDTH, marker='o',
                markevery=MARKER_EVERY, label=model, antialiased=True)
    else:
        ax.plot(x_rel, mean_curve, ls=ls, linewidth=LINEWIDTH,
                label=model, antialiased=True)

ax.set_xlabel("Reasoning progress (measured as normalized run progress)")
ax.set_ylabel("Average Run Length")
# ax.set_title("Consecutive identical answers along reasoning steps")

# 百分比刻度
ax.set_xticks(np.linspace(0, 1, 5))
ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"])

# 网格与边距
ax.grid(True, alpha=0.3)
ax.margins(x=0.01, y=0.05)

leg = ax.legend(title="Model", frameon=True, loc="upper left", bbox_to_anchor=(0.02, 0.98))
# 左对齐
try:
    leg._legend_box.align = "left"
except Exception:
    pass

plt.tight_layout()
plt.savefig('mean_runlength_by_model.png', dpi=400, bbox_inches='tight')
