import os
import pandas as pd
import numpy as np
import pickle
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, brier_score_loss
import argparse
import pandas as pd

def ece(y_true: np.array, y_pred: np.array, n_bins: int = 10) -> float:
    """
    Calculate the Expected Calibration Error: for each bin, the absolute difference between
    the mean fraction of positives and the average predicted probability is taken. The ECE is
    the weighed mean of these differences.

    Parameters
    ----------
    y_true: np.ndarray
        The true labels.
    y_pred: np.ndarray
        The predicted probabilities
    n_bins: int
        The number of bins to use.
    Returns
    -------
    ece: float
        The expected calibration error.
    """
    n = len(y_pred)
    bins = np.arange(0.0, 1.0, 1.0 / n_bins)
    bins_per_prediction = np.digitize(y_pred, bins)

    df = pd.DataFrame({"y_pred": y_pred, "y": y_true, "pred_bins": bins_per_prediction})

    grouped_by_bins = df.groupby("pred_bins")
    # calculate the mean y and predicted probabilities per bin
    binned = grouped_by_bins.mean()

    # calculate the number of items per bin
    binned_counts = grouped_by_bins["y"].count()

    # calculate the proportion of data per bin
    binned["weight"] = binned_counts / n

    weighed_diff = abs(binned["y_pred"] - binned["y"]) * binned["weight"]
    return weighed_diff.sum()

def plot_reliability_diagram(
    all_confidences,
    all_correctness,
    num_bins,
    save_path,
    ece,
):
    bins = np.arange(0.0, 1.0, 1.0 / num_bins)
    bins_per_prediction = np.digitize(all_confidences, bins)
    df = pd.DataFrame(
        {
            "y_pred": all_confidences,
            "y": all_correctness,
            "pred_bins": bins_per_prediction,
        }
    )

    grouped_by_bins = df.groupby("pred_bins")
    bin_size = grouped_by_bins.size()
    grouped_bins = grouped_by_bins.mean()

    # calculate the mean y and predicted probabilities per bin
    grouped_bins = grouped_bins["y"].reindex(range(1, num_bins + 1), fill_value=0)
    bin_values = grouped_bins.values

    # calculate the number of items per bin
    bin_sizes = grouped_by_bins["y"].count()
    bin_sizes = bin_sizes.reindex(range(1, num_bins + 1), fill_value=0)

    plt.figure(figsize=(4, 4), dpi=200)
    ax = plt.gca()
    ax.grid(visible=True, axis="both", which="major", linestyle=":", color="grey")
    step_size = 1.0 / num_bins

    # Get bar colors
    bar_colors = []

    # Display the amount of points that fall into each bin via different shading
    total = sum(bin_sizes.values)

    for i, (bin, bin_size) in enumerate(zip(bins, bin_sizes.values)):
        bin_percentage = bin_size / total
        cmap = matplotlib.cm.get_cmap("Blues")
        bar_colors.append(cmap(min(0.9999, bin_percentage + 0.2)))
    
    plt.bar(
        bins + step_size / 2,
        bin_values,
        width=0.09,
        alpha=0.8,
        color=bar_colors,  # "royalblue",
        edgecolor="black",
    )
    plt.plot(
        np.arange(0, 1 + 0.05, 0.05),
        np.arange(0, 1 + 0.05, 0.05),
        color="black",
        alpha=0.4,
        linestyle="--",
    )

    # Now add the percentage value of points per bin as text
    # if display_percentages:
    total = sum(bin_sizes.values)
    eps = 0.01

    for i, (bin, bin_size) in enumerate(zip(bins, bin_sizes.values)):
        bin_percentage = round(bin_size / total * 100, 2)

        # Omit labelling for very small bars
        # if bin_size == 0:
        #     continue

        plt.annotate(
            f"{bin_percentage} %",
            xy=(bin + step_size / 2, 0.5),
            ha="center",
            va="top",
            rotation=90,
            color="black",
            # color="white" if bin_percentage > 40 else "black",
            alpha=0.7 if bin_percentage > 40 else 0.8,
            fontsize=10,
        )
    auc = roc_auc_score(all_correctness, all_confidences)
    brier = brier_score_loss(all_correctness, all_confidences)
    print(f"auc: {auc}, brier: {brier}")

    plt.text(0.4, 0.8, f'ece:{ece:.4f} \nauc:{auc:.4f} \nbrier:{brier:.4f}')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.xlabel("Confidence", fontsize=14, alpha=0.8)
    plt.ylabel("Accuracy", fontsize=14, alpha=0.8)
    plt.tight_layout()

    # save_path = os.path.join(os.path.dirname(save_path), "platt_scaling.png")
    save_path = save_path.replace(".csv", ".png")
    plt.savefig(save_path)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--path", 
        type=str, 
        default="/data/DERI-Gong/jh015/grace_codes/data/triviaqa_brief_prob_noinit_binmean_sft_100.csv",
        required=True, 
        help="Path to the CSV file"
    )
    parser.add_argument(
        "--num_bins", 
        type=int, 
        default=10, 
        help="Number of bins for ECE calculation"
    )
    args = parser.parse_args()

    results = pd.read_csv(args.path)
    all_confidences = results['all_conf']
    all_correctness = results['all_correctnesses']
    
    ece_score = ece(all_correctness, all_confidences, args.num_bins)
    plot_reliability_diagram(all_confidences, all_correctness, args.num_bins, args.path, ece_score)
    plot_reliability_diagram_1(all_confidences, all_correctness, args.num_bins, args.path, ece_score)
    print(f'ece: {ece_score}')

def plot_reliability_diagram_1(
    all_confidences,
    all_correctness,
    num_bins,
    save_path,
    ece,
):
    bins = np.arange(0.0, 1.0, 1.0 / num_bins)
    bins_per_prediction = np.digitize(all_confidences, bins)
    df = pd.DataFrame(
        {
            "y_pred": all_confidences,
            "y": all_correctness,
            "pred_bins": bins_per_prediction,
        }
    )

    grouped_by_bins = df.groupby("pred_bins")
    grouped_bins = grouped_by_bins.mean()
    grouped_bins = grouped_bins["y"].reindex(range(1, num_bins + 1), fill_value=0)
    bin_values = grouped_bins.values

    bin_sizes = grouped_by_bins["y"].count()
    bin_sizes = bin_sizes.reindex(range(1, num_bins + 1), fill_value=0)

    plt.figure(figsize=(4, 4), dpi=200)
    ax = plt.gca()
    ax.grid(visible=True, axis="both", which="major", linestyle=":", color="grey")
    step_size = 1.0 / num_bins

    # Better colormap
    # cmap = matplotlib.cm.get_cmap("viridis")
    # bar_colors = [cmap(i / num_bins) for i in range(num_bins)]
    cmap = matplotlib.cm.get_cmap("Blues")  # 也可以换成 Reds/Greens
    bar_colors = [cmap(i / num_bins) for i in range(1, num_bins + 1)]

    # Draw bars
    plt.bar(
        bins + step_size / 2,
        bin_values,
        width=0.09,
        alpha=0.9,
        color=bar_colors,
        edgecolor="black",
    )

    # Diagonal reference line
    plt.plot(
        np.arange(0, 1 + 0.05, 0.05),
        np.arange(0, 1 + 0.05, 0.05),
        color="black",
        alpha=0.4,
        linestyle="--",
    )

    # auc/brier
    auc = roc_auc_score(all_correctness, all_confidences)
    brier = brier_score_loss(all_correctness, all_confidences)
    print(f"auc: {auc}, brier: {brier}")

    # plt.text(0.4, 0.8, f'ece:{ece:.4f} \nauc:{auc:.4f} \nbrier:{brier:.4f}')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.xlabel("Confidence", fontsize=14, alpha=0.8)
    plt.ylabel("Accuracy", fontsize=14, alpha=0.8)
    plt.tight_layout()

    save_path = save_path.replace(".csv", ".png")
    base, ext = os.path.splitext(save_path)
    save_path = base + "_new.png"
    plt.savefig(save_path)


if __name__ == "__main__":
    main()