import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# 设置样式
sns.set(style="whitegrid")
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10

# 数据定义
alphas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

dolly_data = {
    'alpha': alphas,
    '138M': [25.1286, 25.0443, 25.4995, 24.8588, 25.6343, 24.353, 24.9178, 25.0527, 29.9],
    '220M': [25.0021, 23.8388, 25.2381, 25.1117, 24.8672, 24.353, 25.1286, 25.3056, 25.9],
    '277M': [26.3, 24.8335, 24.5975, 24.3867, 25.1201, 25.2634, 24.8588, 24.8335, 24.6059],
}

mmlu_data = {
    'alpha': alphas,
    '138M': [11.6058, 11.8224, 11.3038, 12.7539, 17.0413, 18.6096, 20.7717, 21.3898, 22.1026],
    '220M': [15.3926, 12.5552, 11.6673, 10.5641, 12.7517, 11.8979, 14.6128, 14.9872, 15.8889],
    '277M': [16.2014, 12.8679, 11.3475, 10.8897, 11.7452, 11.6877, 12.3343, 11.9199, 14.7218],
}

# 转换为 DataFrame
dolly_df = pd.DataFrame(dolly_data)
mmlu_df = pd.DataFrame(mmlu_data)

# 输出目录（可修改）
output_dir = "./"

# 绘图函数
def save_single_barplot(df, model, dataset_name, color, ylim=None):
    plt.figure(figsize=(6, 4))
    sns.barplot(x='alpha', y=model, data=df, color=color)
    plt.xlabel('Alpha')
    plt.ylabel('Rouge-L')
    if ylim:
        plt.ylim(*ylim)
    # 添加柱顶数值
    for bar in plt.gca().patches:
        height = bar.get_height()
        plt.gca().text(bar.get_x() + bar.get_width()/2, height + 0.1, f'{height:.1f}',
                       ha='center', va='bottom', fontsize=9)
    # 去除标题
    plt.tight_layout()
    filename = f"{output_dir}{dataset_name.lower()}_{model}.png"
    plt.savefig(filename, dpi=300)
    plt.close()

# 自定义颜色（可调）
colors = ['#fc8d62', '#fc8d62', '#fc8d62']  # 每个模型一种颜色

# 绘制并保存每张图
for idx, model in enumerate(['138M', '220M', '277M']):
    save_single_barplot(dolly_df, model, "dolly", colors[idx], ylim=(27,27))
    save_single_barplot(mmlu_df, model, "mmlu", colors[idx], ylim=(31,31))
