import json
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for matplotlib
from matplotlib import pyplot as plt
from src.early_stop_cot import EarlyStopCoT
from pathlib import Path
import numpy as np


rootpath = '/data/project/Reasoning/results_0714/'
model_lis = ['DeepSeek-R1-Distill-Llama-8B', 'QwQ-32B', 'Qwen3-8B']
dataset_lis = ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']

# Aesthetic tweaks
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 16,         # 基础字号
    "axes.titlesize": 17,    # 子图标题
    "axes.labelsize": 16,    # 轴标签
    "xtick.labelsize": 15,
    "ytick.labelsize": 15,
    "legend.fontsize": 14,
})
FIGSIZE = (9.2, 5.4)
LINEWIDTH = 2.4
ALPHA_BAND = 0.18
LINESTYLES = ['-', '--', '-.', ':', (0, (3, 1, 1, 1))]
MARKER_EVERY = None  # 可设为整型使曲线上点更易读，如 20
NUM_BINS = 10

def relative_positions_np(lists, queries, bins=10):
    out = []
    lengths = [len(L) for L in lists]
    total_len = sum(lengths)

    for L, q in zip(lists, queries):
        n = len(L)
        if n == 0:
            out.extend([])
            continue
        arr = np.asarray(L, dtype=object)  # 混合类型也可
        inds = np.flatnonzero(arr == q) + 1  # +1 to convert to 1-based index
        out.extend((inds / n).tolist())

    hist, _ = np.histogram(out, bins=bins, range=(0, 1))
    hist = hist / (total_len / bins)  # 归一化为密度
    return hist

model_to_finals = {m: [] for m in model_lis}
# model_to_gts = {m: [] for m in model_lis}
for model in model_lis:
    L_lis = []
    final_query_lis = []
    gt_lis = []
    for dataset in dataset_lis:
        path = Path(rootpath) / model / dataset / f'{dataset}_step_results.jsonl'
        early_stop_cot = EarlyStopCoT(path)
        for data in early_stop_cot.data:
            seq = data.get('step_answer')
            L_lis.extend(seq)
            final_query_lis.extend(data.get('generated_answer'))
            # gt_lis.extend([data.get('answer')] * len(seq))
    finals = relative_positions_np(L_lis, final_query_lis, bins=NUM_BINS)
    # gts = relative_positions_np(L_lis, gt_lis)
    model_to_finals[model].extend(finals)
    # model_to_gts[model].extend(gts)

fig, ax = plt.subplots(figsize=FIGSIZE)

num_models = len(model_to_finals)
num_bins = NUM_BINS
bin_edges = np.linspace(0, 1, num_bins + 1)
centers = (bin_edges[:-1] + bin_edges[1:]) / 2

# 每个bin的总组宽度，占bin宽度的85%，组内再平均分给每个模型
group_width = (1 / num_bins) * 0.85
bar_width = group_width / num_models

# 画多个模型的并列柱子
for i, (model, hist_vals) in enumerate(model_to_finals.items()):
    # hist_vals 应为长度 = NUM_BINS 的数组/列表
    offset = (i - (num_models - 1) / 2) * bar_width
    ax.bar(centers + offset, hist_vals, width=bar_width,
           edgecolor='black', alpha=0.9, label=model)

# 参考线：与均匀分布相当的密度（=1）
ax.axhline(1.0, linestyle=':', linewidth=1.3, color='gray', alpha=0.8)

# 轴与网格
ax.set_xlabel("Reasoning progress (measured as normalized step progress)")
ax.set_ylabel(r"$P(X_t = X_T)$")
ax.set_xticks(np.linspace(0, 1, 5))
ax.set_xticklabels(["0%", "25%", "50%", "75%", "100%"])
ax.grid(True, alpha=0.3)
ax.margins(x=0.01, y=0.05)

leg = ax.legend(title="Model", frameon=True, loc="upper left", bbox_to_anchor=(0.02, 0.98))
try:
    leg._legend_box.align = "left"
except Exception:
    pass
plt.tight_layout()
plt.savefig('final_answer_position_distribution.png', dpi=400, bbox_inches='tight')

