import pandas as pd
import matplotlib.pyplot as plt


model_lis = ['QwQ-32B', 'Qwen3-8B', 'DeepSeek-R1-Distill-Llama-8B']
dataset_lis = ['aime', 'gpqa', 'math', 'minerva', 'olympiadbench']
warmup_lis = [30, 40, 50, 60, 70]
min_slope_lis = [3, 5, 7, 10, 15, 20]
threshold_lis = [0.01, 0.05, 0.1, 0.15, 0.2]

plt.style.use('tableau-colorblind10')
plt.rcParams.update({
    "font.family": "serif",
    "font.size": 16,         # 基础字号
    "axes.titlesize": 17,    # 子图标题
    "axes.labelsize": 16,    # 轴标签
    "xtick.labelsize": 15,
    "ytick.labelsize": 15,
    "legend.fontsize": 14,
})

LINEWIDTH = 2.4
MARKERSIZE = 6

df = pd.read_csv('/data/project/Reasoning/results/early_stop_cot_results_slope.csv')
df = df[df['dataset'] != 'amc']
df['cot_acc'] = df['cot_correct_answers'] / df['total_samples']
df['es_acc'] = df['es_correct_answers'] / df['total_samples']

for dataset in dataset_lis:
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
    for i, model in enumerate(model_lis):
        ax = axes[i]
        for threshold in threshold_lis:
            df_filtered = df[(df['model'] == model) & (df['dataset'] == dataset) & (df['threshold'] == threshold)]
            if df_filtered.empty:
                continue

            # 绘制曲线
            ax.plot(df_filtered['es_acc'], df_filtered['es_average_tokens'], marker='o', label=f'p: {threshold}',
                    linewidth=LINEWIDTH, markersize=MARKERSIZE)

        # 设置标题和标签
        model_name = model.split('-')[0]  # 获取模型名称的第一部分
        ax.set_title(f'Model: {model_name}, Dataset: {dataset}')
        ax.set_xlabel('Accuracy')
        ax.set_ylabel('Average Tokens')
        ax.legend(loc='upper left')
        ax.grid(True)
    plt.tight_layout()
    # 保存图像
    plt.savefig(f'/data/project/Reasoning/results/plots/{dataset}_average_tokens.png')
