import pandas as pd
import matplotlib.pyplot as plt
import re
from tqdm import tqdm
import os
import numpy as np
dataset = "suzuki"
os.makedirs(f"visual_outputs/{dataset}", exist_ok=True)

# ===== 读取 CSV 1：产率均值 =====
df_yields = pd.read_csv(f"experiment_log_{dataset}_yields.csv")

combo = df_yields[df_yields["key"] == "组合类别"]["value"].reset_index(drop=True)
mean_yield = df_yields[df_yields["key"] == "产率均值"]["value"].reset_index(drop=True).astype(float)
max_yield = df_yields[df_yields["key"] == "最大值"]["value"].reset_index(drop=True).astype(float)
min_yield = df_yields[df_yields["key"] == "最小值"]["value"].reset_index(drop=True).astype(float)

result = pd.DataFrame({
    "组合类别": combo,
    "产率均值": mean_yield,
    "最大值": max_yield,
    "最小值": min_yield
})
result["序号"] = result.index + 1

# ===== 读取 CSV 2：访问次数 =====
# 读取 CSV
df_visits = pd.read_csv(f"experiment_log_{dataset}_visits.csv")

# 1. 拆分出 iteration 编号
df_visits["iter_num"] = (
    df_visits["section"]
    .str.extract(r"iteration-(\d+)", expand=False)
    .astype("Int64")
)

# 3. 一次性提取 Group 名 与 访问次数
df_visits[["group", "visits"]] = df_visits["key"].str.extract(
    r"Group '([^']+)'.*", expand=False
).to_frame(name="group").join(
    df_visits["value"].str.extract(r"(\d+)", expand=False).astype(int).rename("visits")
)

# 4. 生成最终干净表
visit_result = df_visits[["iter_num", "group", "visits"]].rename(
    columns={"iter_num": "迭代轮次", "group": "组别", "visits": "访问次数"}
)


for it, g in tqdm(visit_result.groupby("迭代轮次")):
    # 当前轮次的数据
    tmp_visit = g.drop(columns=["迭代轮次"])   # 合并时不再需要这一列
    
    # 与主表 result 合并
    final_df = (
        result
        .merge(tmp_visit.rename(columns={"组别": "组合类别"}),
               on="组合类别", how="left")
    )
    final_df["访问次数"] = final_df["访问次数"].fillna(0).astype(int)
    # 1. 先按“最大值”降序排列
    final_df = final_df.sort_values("最大值", ascending=False).reset_index(drop=True)

    # 2. 重新生成连续的序号（可选，让柱子从 1 开始编号）
    final_df["序号"] = range(1, len(final_df) + 1)

    # ===== 绘图 =====
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # 左轴：柱状图
    x = final_df["序号"].values
    y_max = final_df["最大值"].values
    y_min = final_df["最小值"].values

    # 1. 画柱子（高度用最大值）
    bars = ax1.bar(x, y_max, alpha=0.7, color="skyblue", label="Max")

    # 2. 用 errorbar 画最小值（向下延伸）
    #    yerr 需要 2×N 的数组，第一行是 0（不向上延），第二行是 y_max - y_min（向下延）
    yerr = np.vstack([y_max - y_min, np.zeros_like(y_min)])
    ax1.errorbar(x, y_max, yerr=yerr,
                fmt='none',            # 不画数据点
                ecolor='darkred',      # 误差线颜色
                elinewidth=2,
                capsize=3)

    # 3. 标注数值
    for xi, vmax, vmin in zip(x, y_max, y_min):
        ax1.text(xi, vmax + 0.5, f"{vmax:.2f}", ha='center', va='bottom', fontsize=8, color='blue')
        ax1.text(xi, vmin - 1.0, f"{vmin:.2f}", ha='center', va='top', fontsize=8, color='darkred')

    # 右轴：折线图（访问次数）
    ax2 = ax1.twinx()
    ax2.plot(final_df["序号"], final_df["访问次数"], marker="o", color="red", linewidth=2)
    ax2.set_ylabel("Visits", color="red")
    ax2.tick_params(axis="y", labelcolor="red")

    # 美化
    plt.title("Cluster Yield vs Visits")
    plt.tight_layout()
    plt.savefig(f"visual_outputs/{dataset}/output_{it}.png", dpi=300)
    plt.close()
    # exit()