import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import os
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def evaluate_calibrated_dataset(pkl_path, save_dir="./", use_calibrated=True):
    """
    Evaluate calibrated dataset (ECE, Brier Score, Reliability Diagram).
    - Binning strictly based on hd_label (原始 softmax conf from EM step).
    - 显示 bin 的真实区间，而不是仅仅编号。
    """

    # ===== Load dataset =====
    with open(pkl_path, "rb") as f:
        dataset = pickle.load(f)
    df = pd.DataFrame(dataset)

    if not {"correctness", "hd_label"}.issubset(df.columns):
        raise ValueError("Dataset must contain 'correctness' and 'hd_label' columns.")

    # ground truth: 确保是 0/1
    df["correctness"] = df["correctness"].astype(int)   # True/False → 1/0
    y_true = df["correctness"].values
    # 原始 softmax 置信度
    y_pred = df["hd_label"].astype(float).values

    # 是否使用校准后的置信度
    if use_calibrated:
        if "hd_target" not in df.columns:
            raise ValueError("use_calibrated=True but 'hd_target' not in dataset.")
        y_conf = df["hd_target"].astype(float).values
    else:
        y_conf = y_pred

    # ===== Bin by hd_label =====
    bins = np.arange(0, 1.1, 0.1)   # 0.0, 0.1, ..., 1.0
    df["bin"] = np.digitize(y_pred, bins, right=True)

    # 给每个 bin 加一个区间说明
    df["bin_range"] = df["bin"].apply(
        lambda b: f"({bins[b-1]:.1f}, {bins[b]:.1f}]" if 0 < b < len(bins) else "out"
    )

    # ===== Bin statistics =====
    df["conf"] = y_conf
    bin_stats = df.groupby(["bin", "bin_range"]).agg(
        acc=("correctness", "mean"),
        conf=("conf", "mean"),
        count=("correctness", "size")
    ).reset_index()

    # ===== Compute ECE =====
    ece = 0.0
    n = len(df)
    for _, row in bin_stats.iterrows():
        if row["count"] > 0 and row["bin"] > 0 and row["bin"] < len(bins):
            ece += abs(row["acc"] - row["conf"]) * (row["count"] / n)

    # ===== Brier Score =====
    brier = np.mean((y_conf - y_true) ** 2)

    # ===== Plot Reliability Diagram =====
    plt.figure(figsize=(6, 6))
    plt.plot([0, 1], [0, 1], "--", color="gray", label="Perfect calibration")

    # 每个 bin 画一根柱子
    plt.bar(
        x=bin_stats["conf"],  # bin 内平均置信度
        height=bin_stats["acc"],  # bin 内 accuracy
        width=0.08, alpha=0.6, edgecolor="black", label="Bin accuracy"
    )
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.grid(True)

    save_path = os.path.join(save_dir, os.path.basename(pkl_path).replace(".pkl", "_reliability.png"))
    plt.savefig(save_path)
    plt.close()

    return {
        "ECE": ece,
        "Brier Score": brier,
        "Bin Stats": bin_stats,
        "Reliability Path": save_path
    }


# Example (won't run here since we don't have the dataset file)
results = evaluate_calibrated_dataset("/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief/llama2-7b/hd_data_llama.pkl")
print(results["ECE"], results["Brier Score"])
print(results["Bin Stats"])

# import pickle
# with open("/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief/llama2-7b/hd_data_em.pkl", "rb") as f:
#     dataset = pickle.load(f)

# if isinstance(dataset, list) and isinstance(dataset[0], dict):
#     print(dataset[0].keys())
# elif isinstance(dataset, dict):
#     print(dataset.keys())
# else:
#     print(type(dataset))

# import pickle
# import pandas as pd

# # 假设保存的文件路径
# save_path = "/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief/llama2-7b/hd_data_em.pkl"

# # 加载 pkl 文件
# with open(save_path, "rb") as f:
#     dataset = pickle.load(f)

# # 转换为 DataFrame
# df = pd.DataFrame(dataset)

# # 保存为 CSV，方便直接查看
# csv_path = save_path.replace(".pkl", ".csv")
# df.to_csv(csv_path, index=False)