import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# 设置字体为 DejaVu Sans
plt.rcParams.update({'font.size': 14, 'font.family': 'DejaVu Sans'})

# 读取 JSON 文件内容
file_path = "data.json"

# 读取并解析 JSON 文件
with open(file_path, 'r') as file:
    data = json.load(file)

# 准备数据用于可视化
scores = []
for aspect, difficulties in data.items():
    for difficulty, models in difficulties.items():
        for model, values in models.items():
            scores.append({
                "Aspect": aspect,
                "Difficulty": difficulty,
                "Model": model,
                "Objective Score": values.get("Objective Score", 0)
            })

# 转换为 DataFrame
df = pd.DataFrame(scores)

# 将 Aspect 中的下划线替换为空格
df['Aspect'] = df['Aspect'].str.replace('_', ' ')

# 透视表格式转换，每个 difficulty 成为列
pivot_df = df.pivot_table(index=["Aspect", "Model"], columns="Difficulty", values="Objective Score").reset_index()

# 计算难度间的得分差异
pivot_df['Medium-Easy'] = pivot_df['medium'] - pivot_df['easy']
pivot_df['Hard-Medium'] = pivot_df['hard'] - pivot_df['medium']
pivot_df['Hard-Easy'] = pivot_df['hard'] - pivot_df['easy']

def clean_model_name(name):
    return name # 保留完整的模型名称

pivot_df["Short Model"] = pivot_df["Model"].apply(clean_model_name)

# 将数据转换为长格式并包括缩写模型名称
melted_df = pivot_df.melt(id_vars=["Aspect", "Short Model"], value_vars=["Medium-Easy", "Hard-Medium", "Hard-Easy"],
                          var_name="Difficulty Change", value_name="Score Difference")

# 删除重复值，确保每个组合唯一
melted_df_unique = melted_df.drop_duplicates(subset=["Aspect", "Short Model", "Difficulty Change"])

# 创建完整的索引组合，用于对齐数据
complete_index = pd.MultiIndex.from_product(
    [pivot_df["Aspect"].unique(), pivot_df["Short Model"].unique(), ["Medium-Easy", "Hard-Medium", "Hard-Easy"]],
    names=["Aspect", "Short Model", "Difficulty Change"]
)

# 重新索引数据集以包含所有组合，并填充缺失值为 0
melted_df_filled = melted_df_unique.set_index(["Aspect", "Short Model", "Difficulty Change"]).reindex(complete_index, fill_value=0).reset_index()

# 设置Seaborn样式
sns.set_style("whitegrid")
plt.rcParams.update({'font.size': 14, 'font.family': 'DejaVu Sans'})

# 定义所需的顺序
desired_order = ['basic', 'spatial', 'semantic', 'reasoning', 'atmospheric']

# 对 unique_aspects 进行排序
unique_aspects = sorted(melted_df_filled["Aspect"].unique(),
                        key=lambda x: [i for i, aspect in enumerate(desired_order) if aspect in x.lower()][0])

# 从 JSON 文件中提取模型顺序
model_order = list(data['Atmospheric.']['easy'].keys())

# 创建图表
fig, axes = plt.subplots(3, 5, figsize=(40, 30), sharey=True)
plt.rcParams.update({'font.size': 14})  # 设置字体为 DejaVu Sans

colors = sns.color_palette("husl", 3)
difficulty_changes = ["Medium-Easy", "Hard-Medium", "Hard-Easy"]

for i, aspect in enumerate(unique_aspects):
    row, col = i // 5, i % 5
    ax = axes[row, col]

    # 按 JSON 文件中的顺序排序模型
    x_labels = [model for model in model_order if model in melted_df_filled[melted_df_filled["Aspect"] == aspect]["Short Model"].unique()]
    x_pos = list(range(len(x_labels)))

    bar_width = 0.33
    for col_idx, difficulty in enumerate(difficulty_changes):
        subset = melted_df_filled[(melted_df_filled["Aspect"] == aspect) & (melted_df_filled["Difficulty Change"] == difficulty)]
        subset = subset.set_index("Short Model").loc[x_labels].reset_index()  # 按模型顺序重新排序
        adjusted_x_pos = [x + col_idx * bar_width for x in x_pos]
        ax.bar(adjusted_x_pos, subset["Score Difference"], width=bar_width, color=colors[col_idx], label=difficulty)

    ax.set_xticks([x + bar_width for x in x_pos])
    ax.set_xticklabels(x_labels, rotation=45, ha='right', fontsize=32)
    ax.set_title(f"{aspect}", fontsize=40)

    if col == 0:
        ax.set_ylabel("Score Difference", fontsize=28)

    ax.tick_params(axis='y', labelsize=22)
    ax.grid(True, axis='y', linestyle='--', alpha=0.7)

    # 添加子图边框和坐标轴黑边
    for spine in ax.spines.values():
        spine.set_edgecolor('black')
    ax.tick_params(axis='x', colors='black')
    ax.tick_params(axis='y', colors='black')

# 删除多余子图
for j in range(len(unique_aspects), 15):
    fig.delaxes(axes.flatten()[j])

# 添加图例到图表顶部
fig.legend(difficulty_changes, loc='upper center', ncol=3, fontsize=40, bbox_to_anchor=(0.5, 0.95))

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.subplots_adjust(top=0.88, bottom=0.07, hspace=0.2, wspace=0.1)
plt.savefig('variance.pdf',dpi=500)
# plt.savefig('variance.png')
# plt.show()