import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import numpy as np
import pandas as pd

matplotlib.rc('font', family='sans-serif')
matplotlib.rc('font', serif='Arial')
matplotlib.rc('text', usetex='false')

# Data for weight sweep
# depths = [0, 1, 2, 3, 4]
depths = [1, 2, 3]
methods = ["eagle", "topk1", "topk4"]


# valid x
# sample_brier_score_eagle = [0.67578125, 0.77294921875, 0.79296875, 0.81005859375, 0.81982421875]
# sample_brier_score_topk1 = [0.56591796875, 0.7021484375, 0.7373046875, 0.763671875, 0.7763671875]
# sample_brier_score_topk4 = [0.54248046875, 0.61962890625, 0.654296875, 0.66943359375, 0.68603515625]

sample_ece_score_eagle = [0.03959261713300306, 0.06598885184311702, 0.07212399043571328, 0.08030220409310468, 0.08349327025909047]
sample_ece_score_topk1 = [0.01856310699733183, 0.0659048072288394, 0.07888780881785332, 0.09153476412069567, 0.0957433310845468]
sample_ece_score_topk4 = [0.006798583931409069, 0.02327595046043856, 0.03621047114459126, 0.039894723881773085, 0.04772736349436951]

sample_acc_score_eagle = [0.5503495931625366, 0.47467708587646484, 0.45690247416496277, 0.4445194900035858, 0.4366038739681244]
sample_acc_score_topk1 = [0.6156535148620605, 0.5184618830680847, 0.4922502636909485, 0.4730062782764435, 0.4649839997291565]
sample_acc_score_topk4 = [0.6293755173683167, 0.5694987773895264, 0.5444839596748352, 0.5328238010406494, 0.5205237865447998]

# # valid o
# sample_brier_score_eagle = [0.62158203125, 0.71875, 0.73876953125, 0.7724609375, 0.7939453125]
# sample_brier_score_topk1 = [0.49462890625, 0.611328125, 0.6484375, 0.67333984375, 0.67724609375]
# sample_brier_score_topk4 = [0.4853515625, 0.5634765625, 0.6162109375, 0.6259765625, 0.6171875]

# sample_ece_score_eagle = [0.03397627745299543, 0.037690370130954845, 0.029958654371342264, 0.042696076784182024, 0.04223600975319745]
# sample_ece_score_topk1 = [0.016292740010041593, 0.03195718854681316, 0.03228001476418495, 0.04709291425047589, 0.03028493209918352]
# sample_ece_score_topk4 = [0.018924747757867958, 0.013729464522705136, 0.021220823002713975, 0.02110003870579564, 0.01242811708328916]

# sample_acc_score_eagle = [0.5936505198478699, 0.52935791015625, 0.5210105180740356, 0.5022789239883423, 0.4979367256164551]
# sample_acc_score_topk1 = [0.6679000854492188, 0.5907415151596069, 0.5712848901748657, 0.5493798851966858, 0.5578181743621826]
# sample_acc_score_topk4 = [0.6729102730751038, 0.6155007481575012, 0.585197925567627, 0.5793324112892151, 0.592383623123169]

cnts = [112372, 35680, 16950, 10215, 7118, 5385, 4301, 3498, 2872, 2386]

plt.style.use('seaborn-v0_8-whitegrid')   # ggplot

def itonth(i):
    if i == 0:
        return "1st"
    elif i == 1:
        return "2nd"
    elif i == 2:
        return "3rd"
    else:
        return f"{i+1}th"

data = []
for i in range(len(sample_ece_score_eagle)):
    nth = itonth(i)
    # brier = sample_brier_score_eagle[i]
    ece = sample_ece_score_eagle[i]
    acc = sample_acc_score_eagle[i]
    # data.append({"nd_sample": nth, "method": "EAGLE", "depth": i+1, "score": "Brier score", "value": brier})
    data.append({"nd_sample": nth, "method": "EAGLE", "depth": i+1, "score": "ECE", "value": ece})
    data.append({"nd_sample": nth, "method": "EAGLE", "depth": i+1, "score": "Accuracy", "value": acc})

    # brier = sample_brier_score_topk1[i]
    ece = sample_ece_score_topk1[i]
    acc = sample_acc_score_topk1[i]
    # data.append({"nd_sample": nth, "method": "HASS", "depth": i+1, "score": "Brier score", "value": brier})
    data.append({"nd_sample": nth, "method": "HASS", "depth": i+1, "score": "ECE", "value": ece})
    data.append({"nd_sample": nth, "method": "HASS", "depth": i+1, "score": "Accuracy", "value": acc})

    # brier = sample_brier_score_topk4[i]
    ece = sample_ece_score_topk4[i]
    acc = sample_acc_score_topk4[i]
    # data.append({"nd_sample": nth, "method": "TALS", "depth": i+1, "score": "Brier score", "value": brier})
    data.append({"nd_sample": nth, "method": "TALS", "depth": i+1, "score": "ECE", "value": ece})
    data.append({"nd_sample": nth, "method": "TALS", "depth": i+1, "score": "Accuracy", "value": acc})

df = pd.DataFrame(data)

# data = []
# for i, depth in enumerate(depths):
#     brier = first_sample_brier_score_eagle[i]
#     ece = first_sample_ece_score_eagle[i]
#     acc = first_sample_acc_score_eagle[i]
#     data.append({"nd_sample": "1st", "method": "EAGLE", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "1st", "method": "EAGLE", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "1st", "method": "EAGLE", "depth": i, "score": "Accuracy", "value": acc})

#     brier = fifth_sample_brier_score_eagle[i]
#     ece = fifth_sample_ece_score_eagle[i]
#     acc = fifth_sample_acc_score_eagle[i]
#     data.append({"nd_sample": "5th", "method": "EAGLE", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "5th", "method": "EAGLE", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "5th", "method": "EAGLE", "depth": i, "score": "Accuracy", "value": acc})

#     brier = first_sample_brier_score_topk1[i]
#     ece = first_sample_ece_score_topk1[i]
#     acc = first_sample_acc_score_topk1[i]
#     data.append({"nd_sample": "1st", "method": "HASS", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "1st", "method": "HASS", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "1st", "method": "HASS", "depth": i, "score": "Accuracy", "value": acc})

#     brier = fifth_sample_brier_score_topk1[i]
#     ece = fifth_sample_ece_score_topk1[i]
#     acc = fifth_sample_acc_score_topk1[i]
#     data.append({"nd_sample": "5th", "method": "HASS", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "5th", "method": "HASS", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "5th", "method": "HASS", "depth": i, "score": "Accuracy", "value": acc})

#     brier = first_sample_brier_score_topk4[i]
#     ece = first_sample_ece_score_topk4[i]
#     acc = first_sample_acc_score_topk4[i]
#     data.append({"nd_sample": "1st", "method": "TALS", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "1st", "method": "TALS", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "1st", "method": "TALS", "depth": i, "score": "Accuracy", "value": acc})

#     brier = fifth_sample_brier_score_topk4[i]
#     ece = fifth_sample_ece_score_topk4[i]
#     acc = fifth_sample_acc_score_topk4[i]
#     data.append({"nd_sample": "5th", "method": "TALS", "depth": i, "score": "Brier score", "value": brier})
#     data.append({"nd_sample": "5th", "method": "TALS", "depth": i, "score": "ECE", "value": ece})
#     data.append({"nd_sample": "5th", "method": "TALS", "depth": i, "score": "Accuracy", "value": acc})

df = pd.DataFrame(data)
# df = df[df["depth"] == 2]

sns.set_theme(style="whitegrid")
sns.set_context("poster", rc={
    "axes.titlesize": 40,
    "axes.labelsize": 40,
    "xtick.labelsize": 32,
    "ytick.labelsize": 32,
    "legend.fontsize": 32,
    "legend.title_fontsize": 32,
})

fig, axs = plt.subplots(1, 4, figsize=(28, 8), gridspec_kw={'width_ratios': [1, 0.01, 1, 1]})
plt.subplots_adjust(wspace=0.4)

# Brier score
# with sns.plotting_context("poster"):
g = sns.barplot(
    data=df[df["score"] == "Accuracy"], #kind="bar",
    x="nd_sample", y="value", hue="method",
    errorbar="sd", alpha=1.0, #height=6,
    # col="score",
    # sharey=False,
    ax=axs[2],
    hue_order=["EAGLE", "HASS", "TALS"],
    # col_order=["Accuracy", "ECE"],
    palette=['#A5DAF6', '#DBA1D9', '#F05626'],
    order=["1st", "2nd", "3rd", "4th", "5th"],
    legend=False,
)
# g.legend.set_title("Method")
# g.despine(left=True)
# g.set_axis_labels("", "", size=20)
g.set_ylim(0.4, 0.65)
g.set_xlabel("", size=32)
g.set_ylabel("", size=32)
g.set_title("Accuracy ↑", size=40)

g = sns.barplot(
    data=df[df["score"] == "ECE"], #kind="bar",
    x="nd_sample", y="value", hue="method",
    errorbar="sd", alpha=1.0, #height=6,
    # col="score",
    # sharey=False,
    ax=axs[3],
    hue_order=["EAGLE", "HASS", "TALS"],
    # col_order=["Accuracy", "ECE"],
    palette=['#A5DAF6', '#DBA1D9', '#F05626'],
    order=["1st", "2nd", "3rd", "4th", "5th"],
    legend=True,
)
# g.legend.set_title("Method")
# g.despine(left=True)
# g.set_axis_labels("", "", size=20)
g.set_xlabel("", size=32)
g.set_ylabel("", size=32)
g.set_title("ECE ↓", size=40)
# move legend to the right
handles, labels = g.get_legend_handles_labels()
fig.legend(handles, labels, loc='center right', title="Method", frameon=False, bbox_to_anchor=(1.02, 0.5), ncol=1, fontsize=32, title_fontsize=32)
axs[3].legend_.remove()

axs[1].plot([1.5, 1.5], [0, 1], color='black', lw=2, clip_on=False)
axs[1].axis('off')



# ECE score

# plt.savefig("brier_score_5th_sample.svg", dpi=300, bbox_inches='tight')

# plt.figure()
cnts[4] = sum(cnts[4:])
cnts = cnts[:5]

colors = ['#A5DAF6', '#2D5F7F', '#2D5F7F', '#2D5F7F', '#DBA1D9']

g = sns.barplot(
    data=[cnts[i] / sum(cnts) for i in range(len(cnts))],
    palette=colors,
    ax=axs[0],
)
# g.set_ylabel("Proportion")
g.set_ylabel("")
g.set_title("Proportion of\neach rank", size=40)
g.set_xticklabels(["1st", "2nd", "3rd", "4th", "≥5th"])

axs[0].text(0.5, -0.2, f"(a)", transform=axs[0].transAxes, fontsize=40)

axs[2].text(1.0, -0.2, f"(b)", transform=axs[2].transAxes, fontsize=40)

plt.savefig("counts.pdf", dpi=300, bbox_inches='tight')
