import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib

def plot_from_binned_csv(path, save_path="reliability_final.png"):
    """
    输入: 包含 bin,count,mean_prob,target 四列的 CSV
    输出: reliability diagram (蓝色柱子 + 红色样本数折线)
    """

    # 读取CSV
    df = pd.read_csv(path)

    num_bins = len(df)
    bins = df['bin'].values
    target = df['target'].values
    counts = df['count'].values

    fig, ax1 = plt.subplots(figsize=(5, 4), dpi=200)
    ax1.grid(visible=True, axis="both", which="major", linestyle=":", color="grey")

    # 只绘制 count > 0 的 bin 柱子
    valid_mask = counts > 0
    bins_valid = bins[valid_mask]
    target_valid = target[valid_mask]
    counts_valid = counts[valid_mask]

    # 单色渐变：颜色深浅反映样本数
    cmap = matplotlib.cm.get_cmap("Blues")
    max_count = counts_valid.max() if len(counts_valid) > 0 else 1
    colors = [cmap(c / max_count) for c in counts_valid]

    # 把 bin 转换到 [0,1] 区间
    x_pos = (bins_valid + 0.5) / num_bins

    # 柱子：显示准确率
    ax1.bar(
        x_pos,
        target_valid,
        width=1.0/num_bins*0.9,
        alpha=0.9,
        color=colors,
        edgecolor="black",
        label="Accuracy (per bin)"
    )

    # 对角线
    ax1.plot([0, 1], [0, 1], color="black", alpha=0.4, linestyle="--")

    # 第二个 y 轴：红色折线（包含所有 bin，count=0 也画）
    ax2 = ax1.twinx()
    x_all = (bins + 0.5) / num_bins
    ax2.plot(
        x_all,
        counts,
        color="red",
        marker="o",
        linestyle="-",
        label="Sample Count"
    )
    ax2.set_ylabel("Sample Count", fontsize=12, color="red")
    ax2.tick_params(axis='y', labelcolor="red")

    # 坐标轴设置
    ax1.set_xlim([0, 1])
    ax1.set_ylim([0, 1])
    ax1.set_xticks(np.linspace(0.2, 1.0, 5))   # x 轴显示 0.2,0.4,...,1.0
    ax1.set_xlabel("Confidence", fontsize=12)
    ax1.set_ylabel("Accuracy", fontsize=12, color="black")

    # 外边框加粗
    for spine in ax1.spines.values():
        spine.set_linewidth(1.2)
    for spine in ax2.spines.values():
        spine.set_linewidth(1.2)

    # 图例上移
    ax1.legend(loc="upper left", bbox_to_anchor=(0, 1.15))
    ax2.legend(loc="upper right", bbox_to_anchor=(1, 1.15))

    # plt.text(0.4, 0.8, f'Calibration Error::{ece:.4f} \nauc:{auc:.4f} \nbrier:{brier:.4f}')
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Saved figure to {save_path}")


if __name__ == "__main__":
    # ⚠️ 修改为你自己的 bin.csv 文件路径
    csv_path = "bins.csv"  
    plot_from_binned_csv(csv_path, "reliability_final.png")
