import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# 设置全局样式
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("tab10")
plt.rcParams.update({'font.size': 12})

# 读取数据
dalle_data = pd.read_csv('/Users/wad3/Downloads/Research/AutoBench-V-private/document/results/dall_e_3_score_summary.csv')
sd_data = pd.read_csv('/Users/wad3/Downloads/Research/AutoBench-V-private/document/results/stable_diffusion_score_summary.csv')
flux_data = pd.read_csv('/Users/wad3/Downloads/Research/AutoBench-V-private/document/results/flux_score_summary.csv')

# 添加模型名称列
dalle_data['model'] = 'DALL-E 3'
sd_data['model'] = 'Stable Diffusion 3.5'
flux_data['model'] = 'Flux'

# 合并数据
combined_data = pd.concat([dalle_data, sd_data, flux_data])

# 用于显示的模型和难度名称映射
model_display = {
    'DALL-E 3': 'DALL-E 3',
    'Stable Diffusion 3.5': 'SD 3.5',
    'Flux': 'Flux'
}

difficulty_display = {
    'easy': '简单',
    'medium': '中等',
    'hard': '困难'
}

# 创建带有三个子图的可视化
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True)
fig.suptitle('图像生成模型在不同难度下的平均评分', fontsize=20, y=1.05)

# 难度顺序
difficulties = ['easy', 'medium', 'hard']

# 对每个难度级别创建子图
for i, difficulty in enumerate(difficulties):
    ax = axes[i]
    
    # 筛选当前难度的数据
    diff_data = combined_data[combined_data['difficulty'] == difficulty]
    
    # 颜色映射
    colors = {
        'DALL-E 3': '#1f77b4',  # 蓝色
        'Stable Diffusion 3.5': '#ff7f0e',  # 橙色
        'Flux': '#2ca02c'  # 绿色
    }
    
    # 绘制条形图 - 不包含误差线
    bars = ax.bar(diff_data['model'], diff_data['mean'], 
                 color=[colors[model] for model in diff_data['model']],
                 width=0.6, edgecolor='black', linewidth=1)
    
    # 添加分数标签
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                f'{height:.3f}', ha='center', va='bottom', fontsize=14, fontweight='bold')
    
    # 设置图表标题和标签
    ax.set_title(f'{difficulty_display[difficulty]}难度', fontsize=16)
    ax.set_xlabel('模型', fontsize=14)
    if i == 0:
        ax.set_ylabel('平均评分', fontsize=14)
    
    # 设置x轴刻度标签
    ax.set_xticks(range(len(model_display)))
    ax.set_xticklabels([model_display[model] for model in diff_data['model']], fontsize=12)
    
    # 设置y轴范围，以便更好地显示差异
    ax.set_ylim(0.75, 0.95)  # 缩小范围以强调差异
    
    # 添加网格线
    ax.grid(True, axis='y', linestyle='--', alpha=0.7)

# 添加说明信息
plt.figtext(0.5, -0.05, 
           "注：柱形高度表示平均评分。分数范围为0-1，分数越高表示生成图像与提示的匹配度越高。",
           ha='center', fontsize=12, style='italic')

# 添加图例（使用自定义颜色）
legend_elements = [
    plt.Rectangle((0,0), 1, 1, facecolor=colors['DALL-E 3'], edgecolor='black', label='DALL-E 3'),
    plt.Rectangle((0,0), 1, 1, facecolor=colors['Stable Diffusion 3.5'], edgecolor='black', label='Stable Diffusion 3.5'),
    plt.Rectangle((0,0), 1, 1, facecolor=colors['Flux'], edgecolor='black', label='Flux')
]

fig.legend(handles=legend_elements, loc='upper center', 
           bbox_to_anchor=(0.5, 0), ncol=3, fontsize=12)

# 调整布局
plt.tight_layout()
plt.subplots_adjust(bottom=0.15, top=0.85)  # 为标题和注释留出空间

# 保存图表
output_path = '/Users/wad3/Downloads/Research/AutoBench-V-private/document/results/model_average_scores_simple.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')

# 显示图表
plt.show()

print(f"图表已保存至: {output_path}")