# import pandas as pd

# # 读取数据，跳过第一行
# df_original = pd.read_csv('document/basic_understanding/all_scores.csv', skiprows=1)
# df_balanced = pd.read_csv('document/basic_understanding/all_scores_balanced.csv', skiprows=1)

# # 确保数据按相同顺序对齐
# df_original = df_original.sort_values(['Difficulty', 'Model'])
# df_balanced = df_balanced.sort_values(['Difficulty', 'Model'])

# # 计算相对变化
# relative_change = pd.DataFrame({
#     'Difficulty': df_original['Difficulty'],
#     'Model': df_original['Model'],
#     'Relative Change': (df_balanced['Objective Score'] - df_original['Objective Score']) / df_original['Objective Score'],
#     'Alignment Rate Change': (df_balanced['Alignment Rate'] - df_original['Alignment Rate']) / df_original['Alignment Rate']
# })

# # 保存结果
# output_path = 'document/basic_understanding/relative_change.csv'
# relative_change.to_csv(output_path, index=False)

# # 打印结果验证
# print(relative_change.head())


import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

plt.rcParams['font.family'] = 'avenir'

df = pd.read_csv('document/basic_understanding/relative_change.csv')

# plt.style.use('ggplot')
colors = ['#c9def4', '#f5ccd4', '#b8a4c9']
plt.figure(figsize=(12, 6))

models = df['Model'].unique()
difficulties = ['easy', 'medium', 'hard']

x = np.arange(len(models))
width = 0.25

for idx, difficulty in enumerate(difficulties):
    data = df[df['Difficulty'] == difficulty]
    plt.bar(x + idx*width, data['Relative Change'], 
            width, label=difficulty, color=colors[idx], alpha=0.7)


# plt.xlabel('Models', fontsize=12)
# plt.ylabel('Relative Change', fontsize=15)
# plt.title('Relative Performance Change by Model and Difficulty', fontsize=14, pad=20)
plt.xticks(x + width, models, rotation=30, ha='center', fontsize=18)
plt.yticks(fontsize=18)
plt.grid(True, linestyle='--', alpha=0.3)
plt.legend(fontsize=18)

# 将x轴移到顶部
ax = plt.gca()
ax.xaxis.set_label_position('top')
ax.xaxis.tick_top()

plt.tight_layout()
plt.savefig('figure/position_bias.pdf', 
            bbox_inches='tight', dpi=500)
plt.show()