import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# ==============================
# 1. 配置样式 (符合论文规范)
# ==============================
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.weight'] = 'bold'
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titleweight'] = 'bold'

# ==============================
# 2. 读取与合并数据
# ==============================
# 读取综合得分 (来自之前生成的分析结果)
perf_file = "综合得分分析结果.xlsx"
df_perf = pd.read_excel(perf_file, sheet_name="Final_Score")
df_perf.rename(columns={df_perf.columns[0]: "Tasks"}, inplace=True)

# 读取执行时间与成本 (来自用户指定的分析结果)
cost_time_file = "Complexity_Cost_Time_Analysis.xlsx"
df_time_raw = pd.read_excel(cost_time_file, sheet_name="Raw_Time_With_Complexity")
df_cost_raw = pd.read_excel(cost_time_file, sheet_name="Raw_Cost_With_Complexity")

# 转换为长格式以便合并
def to_long(df, val_name):
    id_vars = ["Tasks", "Complexity"] if "Complexity" in df.columns else ["Tasks"]
    model_cols = [c for c in df.columns if c not in id_vars]
    return df.melt(id_vars=id_vars, value_vars=model_cols, var_name="Model", value_name=val_name)

perf_long = df_perf.melt(id_vars="Tasks", value_vars=[c for c in df_perf.columns if c not in ["Tasks", "Complexity"]], 
                         var_name="Model", value_name="Score")
time_long = to_long(df_time_raw, "Time")
cost_long = to_long(df_cost_raw, "Cost")

# 合并所有数据
merged_df = perf_long.merge(time_long, on=["Tasks", "Model"], how="inner")
merged_df = merged_df.merge(cost_long[["Tasks", "Model", "Cost"]], on=["Tasks", "Model"], how="inner")

# 过滤无效数据 (如有)
merged_df = merged_df.dropna()

# 确定我们的模型名称用于突出显示
ours_name = [m for m in merged_df["Model"].unique() if "Ours" in m or "R3" in m][0]

# ==============================
# 3. 绘图 (2个子图：Score vs Time & Score vs Cost)
# ==============================
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# 自定义调色板
# Ours: 红色, 其他: 官方色彩序列
models = merged_df["Model"].unique()
palette = {}
standard_colors = sns.color_palette("colorblind", len(models))
idx = 0
for m in models:
    if m == ours_name:
        palette[m] = "#D62728"
    else:
        palette[m] = standard_colors[idx]
        idx += 1

def setup_plot(ax, x_col, x_label, title):
    # 绘制背景散点
    sns.scatterplot(data=merged_df, x=x_col, y="Score", hue="Model", style="Model", 
                    palette=palette, s=120, alpha=0.35, ax=ax, edgecolor='none')
    
    # 绘制模型均值中心 (大五角星)
    means = merged_df.groupby("Model")[[x_col, "Score"]].mean().reset_index()
    # 为了图例整洁，我们这里手动绘制以控制图例显示
    for m in models:
        row = means[means["Model"] == m].iloc[0]
        ax.scatter(row[x_col], row["Score"], color=palette[m], marker='*', 
                   s=800, edgecolor='black', linewidth=1.5, zorder=5)

    # 设置对数坐标并美化刻度
    ax.set_xscale('log')
    from matplotlib.ticker import LogLocator, FuncFormatter
    ax.xaxis.set_major_locator(LogLocator(base=10.0, numticks=10))
    def log_formatter(val, pos=None):
        if val <= 0: return "0"
        return f"$10^{{{int(np.floor(np.log10(val)))}}}$"
    ax.xaxis.set_major_formatter(FuncFormatter(log_formatter))

    # 字体与标签设置 (加大加粗)
    font_label_2d = {'family': 'Times New Roman', 'weight': 'bold', 'size': 20}
    font_tick_2d = {'family': 'Times New Roman', 'weight': 'bold', 'size': 18}
    font_title_2d = {'family': 'Times New Roman', 'weight': 'bold', 'size': 22}

    ax.set_title(title, fontdict=font_title_2d, pad=20)
    ax.set_xlabel(x_label, fontdict=font_label_2d, labelpad=12)
    ax.set_ylabel("Comprehensive Score", fontdict=font_label_2d, labelpad=12)
    
    ax.grid(True, which="major", linestyle='--', alpha=0.6, linewidth=1.0)
    ax.set_ylim(0, 1.05)
    
    # 应用刻度字体
    for label in (ax.get_xticklabels() + ax.get_yticklabels()):
        label.set_fontname('Times New Roman')
        label.set_weight('bold')
        label.set_fontsize(18)
    
    # 移除子图单独的图例，改为在外部统一生成
    if ax.get_legend():
        ax.get_legend().remove()

# 子图1: 分数 vs 时间
setup_plot(ax1, "Time", "Execution Time (min)", "(a) Score vs. Execution Time")

# 子图2: 分数 vs 成本
setup_plot(ax2, "Cost", "Execution Cost (tokens)", "(b) Score vs. Execution Cost")

# ==============================
# 3.1 生成共用图例 (Top Center)
# ==============================
# 从第一个子图获取 handles 和 labels
handles, labels = ax1.get_legend_handles_labels()

# 整理图例顺序：Ours 放在最后
hl_dict = {l: h for l, h in zip(labels, handles)}
sorted_model_names = [m for m in models if m != ours_name] + [ours_name]

final_handles = []
final_labels = []
# 仅添加存在的模型，过滤掉Seaborn可能产生的标题等杂项
for m in sorted_model_names:
    if m in hl_dict:
        final_handles.append(hl_dict[m])
        final_labels.append(m)

# 在图形顶部正中央添加共用图例 (增大 markerscale，调整位置 closer)
# bbox_to_anchor=(0.5, 0.95) 让图例离下方子图更近
leg = fig.legend(handles=final_handles, labels=final_labels, 
           loc='lower center', 
           bbox_to_anchor=(0.5, 0.92), # 下移图例位置 (closer to plots)
           ncol=len(final_labels),    
           prop={'family': 'Times New Roman', 'weight': 'bold', 'size': 18}, # 字体加大
           frameon=False, columnspacing=1.5,
           markerscale=2.5) # 图标扩大 2.5 倍

plt.tight_layout()
# 调整 top margin 以为图例留出空间 (减少留白，让图例更贴近)
plt.subplots_adjust(top=0.82)

# 保存 2D 图

# 保存 2D 图
output_2d = "Performance_Cost_Tradeoff_Scatter.png"
plt.savefig(output_2d, dpi=600, bbox_inches='tight')
print(f"✅ 2D 效能-成本权衡散点图已生成: {output_2d}")

# ==============================
# 4. 绘制 3D 散点图 (符合 ICML 标准)
# ==============================
# 定义排序后的模型列表 (确保 'Ours' 放在最后，在绘图中处于最顶层)
models_sorted = [m for m in models if m != ours_name] + [ours_name]

# 数据预处理
merged_df["Log_Time"] = np.log10(merged_df["Time"] + 1e-2)
merged_df["Log_Cost"] = np.log10(merged_df["Cost"] + 1)
agg_df_3d = merged_df.groupby("Model")[["Log_Time", "Log_Cost", "Score"]].mean().reset_index()

fig_3d = plt.figure(figsize=(10, 10)) # 使图表更方正
ax_3d = fig_3d.add_subplot(111, projection='3d')

# A. 背景散点 (略微调高透明度)
for m in models:
    sub = merged_df[merged_df["Model"] == m]
    ax_3d.scatter(sub["Log_Cost"], sub["Log_Time"], sub["Score"],
                  c=palette[m], marker='o', s=45, alpha=0.35, # 调整透明度和大小
                  edgecolor='none', depthshade=False)

# B. 中心重心点与垂线
legend_elements_3d = []
for m in models_sorted:
    row = agg_df_3d[agg_df_3d["Model"] == m].iloc[0]
    sc = ax_3d.scatter(row["Log_Cost"], row["Log_Time"], row["Score"],
                       c=palette[m], marker='o', s=450, alpha=1.0, 
                       edgecolor='white', linewidth=1.5, label=m, depthshade=False)
    legend_elements_3d.append(sc)
    
    ax_3d.plot([row["Log_Cost"], row["Log_Cost"]], [row["Log_Time"], row["Log_Time"]], [0, row["Score"]],
               color=palette[m], linestyle='--', linewidth=2.0, alpha=0.7)

# C. 3D 坐标轴美化 (移除背景板颜色)
ax_3d.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax_3d.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
ax_3d.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))

# 加强网格线清晰度
grid_style_3d = {'color': (0.5, 0.5, 0.5, 0.3), 'linewidth': 0.8}
ax_3d.xaxis._axinfo["grid"].update(grid_style_3d)
ax_3d.yaxis._axinfo["grid"].update(grid_style_3d)
ax_3d.zaxis._axinfo["grid"].update(grid_style_3d)

font_label_3d = {'family': 'Times New Roman', 'weight': 'bold', 'size': 16}
font_tick_3d = {'family': 'Times New Roman', 'weight': 'bold', 'size': 14}

ax_3d.set_xlabel("Tokens Cost", fontdict=font_label_3d, labelpad=15)
ax_3d.set_ylabel("Execution Time(min)", fontdict=font_label_3d, labelpad=15)
ax_3d.set_zlabel("Comprehensive Score", fontdict=font_label_3d, labelpad=10)
ax_3d.set_zlim(0, 1.05)

# 优化 Log 刻度显示为 $10^n$
def set_log_ticks_3d(axis_min, axis_max, ax_setter, ax_tick_setter):
    ticks = np.arange(np.floor(axis_min), np.ceil(axis_max) + 1)
    ax_setter(ticks)
    # 使用更符合数学规范的指数形式
    labels = [f"$10^{{{int(i)}}}$" for i in ticks]
    ax_tick_setter(labels, **font_tick_3d)

# 对调视角和刻度，使得“好”的方向一目了然
# 成本和时间越小越好，通常将 10^small 放在离原点近的地方
set_log_ticks_3d(merged_df["Log_Cost"].min(), merged_df["Log_Cost"].max(), ax_3d.set_xticks, ax_3d.set_xticklabels)
set_log_ticks_3d(merged_df["Log_Time"].min(), merged_df["Log_Time"].max(), ax_3d.set_yticks, ax_3d.set_yticklabels)

z_ticks_3d = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
ax_3d.set_zticks(z_ticks_3d)
ax_3d.set_zticklabels([f"{z:.1f}" for z in z_ticks_3d], **font_tick_3d)

# 调整视角 (让中心点更突出，视角更平滑)
ax_3d.view_init(elev=25, azim=130)

# D. 3D 图例美化 (对齐并移出绘图区域或置于顶部)
# 使用更加紧凑的排列方式
ax_3d.legend(handles=legend_elements_3d, labels=models_sorted, 
             loc='upper center', 
             bbox_to_anchor=(0.5, 1.0), # 向上偏移以免遮挡
             ncol=len(models_sorted)//2 if len(models_sorted)>3 else len(models_sorted), 
             columnspacing=2, # 调整列间距
             labelspacing=1.0,  # 调整行间距
             handletextpad=1.5, # 调整图标与文字间距
             prop={'family': 'Times New Roman', 'weight': 'bold', 'size': 14}, 
             frameon=False)

plt.tight_layout()
output_3d = "Performance_Cost_Tradeoff_3D.png"
# 注意：3D图像保存时可能需要手动调整 margins
plt.savefig(output_3d, dpi=600, bbox_inches='tight', pad_inches=0.1)
print(f"✅ 3D 效能-成本权衡散点图已优化: {output_3d}")

# plt.show()
