import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['font.family'] = 'Avenir'

data = {'Model': ['GPT-4o', 'GPT-4o mini', 'Gemini-Flash', 'Gemini-Pro', 'GLM-4v-Plus', 'Qwen2-VL', 'Claude-3.5', 'Claude-3', 'Intern_VL-2.5', 'Gemma-3', 'Llama-3.2'], 'Atmospheric Understanding Easy': [5, 8, 6, 1, 9, 7, 4, 10, 2, 3, 11], 'Atmospheric Understanding Medium': [4, 8, 6, 5, 7, 6, 1, 10, 2, 3, 9], 'Atmospheric Understanding Hard': [3, 10, 7, 6, 8, 5, 2, 11, 1, 4, 9], 'Basic Understanding Easy': [5, 8, 4, 4, 7, 6, 1, 9, 3, 2, 10], 'Basic Understanding Medium': [3, 8, 10, 5, 7, 4, 2, 9, 1, 6, 11], 'Basic Understanding Hard': [3, 8, 6, 4, 6, 1, 1, 7, 2, 5, 9], 'Reasoning Capacity Easy': [4, 9, 6, 2, 8, 7, 5, 10, 1, 3, 11], 'Reasoning Capacity Medium': [3, 9, 7, 2, 8, 1, 4, 11, 5, 6, 10], 'Reasoning Capacity Hard': [4, 7, 10, 6, 9, 2, 5, 11, 1, 3, 8], 'Semantic Understanding Easy': [1, 8, 5, 4, 9, 7, 2, 10, 6, 3, 11], 'Semantic Understanding Medium': [2, 8, 6, 5, 7, 3, 2, 8, 1, 4, 9], 'Semantic Understanding Hard': [5, 8, 4, 2, 7, 4, 3, 10, 6, 1, 9], 'Spatial Understanding Easy': [7, 9, 6, 5, 8, 4, 1, 10, 3, 2, 11], 'Spatial Understanding Medium': [6, 11, 8, 2, 7, 5, 3, 9, 4, 1, 10], 'Spatial Understanding Hard': [7, 10, 6, 3, 8, 5, 4, 9, 2, 1, 11]}

df = pd.DataFrame(data)

new_model_order = [
    'GPT-4o', 'GPT-4o mini', 'Gemini-Flash', 'Gemini-Pro', 'GLM-4v-Plus', 'Qwen2-VL', 'Claude-3.5', 'Claude-3', 'Intern_VL-2.5', 'Gemma-3', 'Llama-3.2'
]

df = df.set_index('Model').loc[new_model_order].reset_index()

aspects = {
    'Basic.': ['Basic Understanding Easy', 'Basic Understanding Medium', 'Basic Understanding Hard'],
    'Spatial.': ['Spatial Understanding Easy', 'Spatial Understanding Medium', 'Spatial Understanding Hard'],
    'Reasoning.': ['Reasoning Capacity Easy', 'Reasoning Capacity Medium', 'Reasoning Capacity Hard'],
    'Semantic.': ['Semantic Understanding Easy', 'Semantic Understanding Medium', 'Semantic Understanding Hard'],
    'Atmospheric.': ['Atmospheric Understanding Easy', 'Atmospheric Understanding Medium', 'Atmospheric Understanding Hard']
}

cubehelix_cmap = sns.color_palette("Blues_r", as_cmap=True)

# 调整图表尺寸，减少高度
fig, axs = plt.subplots(1, 5, figsize=(12, 3.5))  # 减小高度从5降到3.5

for i, (aspect, columns) in enumerate(aspects.items()):
    ax = axs[i]
    # 调整热图单元格大小和标注字体大小
    sns.heatmap(df[columns], annot=True, cmap=cubehelix_cmap, 
                yticklabels=(df['Model'] if i == 0 else False), 
                ax=ax, cbar=(i == 4), 
                annot_kws={"size": 11},  # 减小标注字体
                square=False,
                linewidths=0.5)  # 添加单元格边框线
                
    ax.set_title(aspect, fontsize=13)  # 减小标题字体
    ax.set_xticklabels(['E', 'M', 'H'], rotation=0, fontsize=10)  # 减小x轴标签字体
    
    if i == 0:
        for label in ax.get_yticklabels():
            label.set_ha('right')
            label.set_rotation(315)
            label.set_fontstyle('italic')
            label.set_fontsize(11)  # 减小y轴标签字体
    else:
        ax.set_yticks([])

# 调整布局，减少边距
plt.subplots_adjust(left=0.12, right=0.98, top=0.88, bottom=0.12, wspace=0.05)
plt.tight_layout(rect=[0, 0, 1, 0.92])
plt.savefig('figure/heatmap_compressed.pdf', dpi=500, bbox_inches='tight')
# plt.show()