import pickle
import seaborn as sns
import matplotlib.pyplot as plt
from medmnist import INFO 
import medmnist
import torch
from collections import Counter
import numpy as np
import pandas as pd

data = ["breast","blood", "path", "organc"]
datadict = {}
labelall = {}
orderall = {}
for i, d in enumerate(data):
    datadict[d] = pickle.load(open("{}mnist_unc.pkl".format(d), mode="rb"))

    DataClass = getattr(medmnist, INFO[d+"mnist"]['python_class'])

    ds   = DataClass(split='train',download=True, root='/scratch/hy190/data/')
    ds_t = DataClass(split='val', download=True, root='/scratch/hy190/data/')

    y_data = torch.cat([torch.tensor(d[1]) for d in ds], dim=0)
    labelcount= Counter(y_data.tolist())
    labeldict = {}
    labelmeaning = INFO[d+"mnist"]["label"]
    print(labelcount)
    for l in np.argsort(list(labelcount.values())):
        labeldict["{}".format(list(labelcount.keys())[l])] = labelcount[list(labelcount.keys())[l]]
    labelall[d] = labeldict
    
    plt.style.use('seaborn')
plt.rcParams['figure.facecolor'] = 'white'
import matplotlib 
matplotlib.rc('xtick', labelsize=8) 
matplotlib.rc('ytick', labelsize=8) 
fig, axes = plt.subplots(figsize = (11,3), nrows=1, ncols=4, dpi=200)
for i,d in enumerate(data):
    ax = axes[i]
    norms = datadict[d]
    Ks = labelall[d]
    xticklabels = []
    label = []
    count = list(range(len(Ks.keys())))
    #for kidx,k in enumerate(Ks.keys()):
    #    xticklabels.append("{}".format(k))
    #    xticks.append(np.round(np.log(Ks[k]),3))
    #    tempk = int(k[-1])
    #    
    #    ax.errorbar(x=count[kidx], y=np.array(norms[tempk]).mean(), yerr=np.array(norms[tempk]).std(), fmt="o", capsize=3, color="tab:blue")
    normlist = []
    avg = []
    sd = []
    for kidx,k in enumerate(Ks.keys()):
        tempk = int(k[-1])
        xticklabels += ["{}".format(k)]*len(norms[tempk])
        normlist += norms[tempk]
        label.append(int(len(norms[tempk])/10))
        avg.append(np.mean(norms[tempk]))
        sd.append(np.std(norms[tempk]))
        #ax.errorbar(x=count[kidx], y=np.array(norms[tempk]).mean(), yerr=np.array(norms[tempk]).std(), fmt="o", capsize=3, color="tab:blue")
    df = pd.DataFrame({"":xticklabels, " ":normlist})
    pal = sns.color_palette("Blues", len(avg))
    rank = np.array(avg).argsort().argsort()
    sns.barplot(data=df, x="", y=" ", ax=ax, errorbar=('sd', 1), hue="", palette=list(np.array(pal[::-1])[rank]))
    ax.set_title(d.title())
    for rect, la,s in zip(ax.patches, label, sd):
        height = rect.get_height()
        ax.text(rect.get_x() + rect.get_width() / 2, height+0.03+s, la, ha="center", va="bottom", size=6)
    #if i == 0:
    #    ax.set_ylabel("Uncertainty Score")
    #else:
    #    ax.yaxis.set_ticklabels([])
    
    if i == 2:
        ax.set_xlabel("Class Index Sorted by Class Size")
        ax.xaxis.set_label_coords(-0.15, -0.25)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #ax.set_xticks(count, labels=xticklabels)
    
    ax.set_ylim([0,1])
plt.tight_layout()

plt.savefig("uncertaintyscore.pdf", bbox_inches="tight",facecolor=fig.get_facecolor())