import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# 原始数据 (4组对比，每组2个数值)
data = [
    [64.5, 35.5],  # PyraMotion vs CaMN
    [90.3, 9.7],  # PyraMotion vs EMAGE
    [30.3, 69.7],  # PyraMotion vs APVQ-VAE Rec
    [10.0, 90.0]  # PyraMotion vs GT
]

# 分组标签和类别定义
groups = ["PyraMotion vs CaMN", "PyraMotion vs EMAGE",
          "PyraMotion vs APVQ-VAE Rec", "PyraMotion vs GT"]
categories = ['PyraMotion', 'Comparison Method']  # 只保留两个分类

# 转换为DataFrame
df = pd.DataFrame(data, index=groups, columns=categories)

# 颜色配置
base_blue = "#1f77b4"  # 固定蓝色
orange_palette =  sns.light_palette("#FF4500",  # 深红橘色基色
                                  n_colors=len(groups),
                                  input="rgb",
                                  reverse=False)  # 生成4种渐变橘色

# 创建画布
plt.figure(figsize=(14, 8))
sns.set_theme(style="whitegrid")
ax = plt.gca()

# 绘制堆积柱状图
bottom = np.zeros(len(groups))  # 初始化底部位置
for idx, category in enumerate(categories):
    # 选择颜色方案
    colors = base_blue if idx == 0 else orange_palette

    # 绘制柱体
    bars = ax.bar(
        x=groups,
        height=df[category],
        bottom=bottom,
        color=colors,
        edgecolor="white",
        width=0.7
    )

    # 添加数据标注
    for bar, value in zip(bars, df[category]):
        ypos = bar.get_y() + bar.get_height() / 2
        ax.text(bar.get_x() + bar.get_width() / 2,
                ypos,
                f"{value}%",
                ha='center', va='center',
                color='black',
                fontsize=24,
                fontweight='bold')

    # 更新底部位置
    bottom += df[category].values

# 坐标轴设置
plt.ylim(0, 100)
plt.ylabel("Percentage (%)", fontsize=24)
plt.yticks(fontsize=20)
plt.xticks(fontsize=20, rotation=20)

# 自定义图例
legend_handles = [
    plt.Rectangle((0, 0), 1, 1, fc=base_blue, edgecolor="white"),
    plt.Rectangle((0, 0), 1, 1, fc=orange_palette[2], edgecolor="white")
]
plt.legend(legend_handles,
           categories,
           loc='upper right',
           fontsize=20,

           )

plt.tight_layout()
plt.show()