import pickle
import numpy as np
import matplotlib.pyplot as plt

plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

plt.figure(constrained_layout=True)
plt.rcParams.update({'font.size': 15})

def get_results(file_name):
    with open(file_name, 'rb') as handle:
        list_acc_rd = pickle.load(handle)
        list_acc_erm = pickle.load(handle)
        list_acc_margin = pickle.load(handle)
        list_acc_DBAL = pickle.load(handle)
        list_acc_coreset = pickle.load(handle)
        list_acc_BADGE = pickle.load(handle)
        list_acc_ours = pickle.load(handle)
        list_ece_rd = pickle.load(handle)
        list_ece_erm = pickle.load(handle)
        list_ece_margin = pickle.load(handle)
        list_ece_DBAL = pickle.load(handle)
        list_ece_coreset = pickle.load(handle)
        list_ece_BADGE = pickle.load(handle)
        list_ece_ours = pickle.load(handle)

    return list_acc_rd, list_acc_erm, list_acc_margin, list_acc_DBAL, list_acc_coreset, list_acc_BADGE, list_acc_ours, list_ece_rd, list_ece_erm, list_ece_margin, list_ece_DBAL, list_ece_coreset, list_ece_BADGE, list_ece_ours

def plot_by_normal(ax, value, label, color = None):
	mean = np.mean(np.array(value), axis = 0)
	std = np.std(np.array(value), axis = 0)
	if color is not None:
		ax.fill_between(np.arange(len(mean)), mean - std, mean + std, alpha=0.3, color = color)
		ax.plot(mean, marker='.', label = label, color = color)
	else:
		ax.fill_between(np.arange(len(mean)), mean - std, mean + std, alpha=0.3)
		ax.plot(mean, marker='.', label = label)

mnist_acc_rd, mnist_acc_erm, mnist_acc_margin, mnist_acc_DBAL, mnist_acc_coreset, mnist_acc_BADGE, mnist_acc_ours, mnist_ece_rd, mnist_ece_erm, mnist_ece_margin, mnist_ece_DBAL, mnist_ece_coreset, mnist_ece_BADGE, mnist_ece_ours = get_results("out/new/mnist_plt")
svhn_acc_rd, svhn_acc_erm, svhn_acc_margin, svhn_acc_DBAL, svhn_acc_coreset, svhn_acc_BADGE, svhn_acc_ours, svhn_ece_rd, svhn_ece_erm, svhn_ece_margin, svhn_ece_DBAL, svhn_ece_coreset, svhn_ece_BADGE, svhn_ece_ours = get_results("out/new/svhn_plt")
f_mnist_acc_rd, f_mnist_acc_erm, f_mnist_acc_margin, f_mnist_acc_DBAL, f_mnist_acc_coreset, f_mnist_acc_BADGE, f_mnist_acc_ours, f_mnist_ece_rd, f_mnist_ece_erm, f_mnist_ece_margin, f_mnist_ece_DBAL, f_mnist_ece_coreset, f_mnist_ece_BADGE, f_mnist_ece_ours = get_results("out/new/f_mnist_plt")
cifar10_acc_rd, cifar10_acc_erm, cifar10_acc_margin, cifar10_acc_DBAL, cifar10_acc_coreset, cifar10_acc_BADGE, cifar10_acc_ours, cifar10_ece_rd, cifar10_ece_erm, cifar10_ece_margin, cifar10_ece_DBAL, cifar10_ece_coreset, cifar10_ece_BADGE, cifar10_ece_ours = get_results("out/new/cifar10_plt")

T = 100
rounds = np.arange(T)
fig, axs = plt.subplots(2, 4, figsize=(20, 10))


plot_by_normal(axs[0,0], mnist_ece_rd, label="Random")
plot_by_normal(axs[0,0], mnist_ece_erm, label="Least-conf")
plot_by_normal(axs[0,0], mnist_ece_margin, label="Margin")
plot_by_normal(axs[0,0], mnist_ece_DBAL, label="DBAL")
plot_by_normal(axs[0,0], mnist_ece_coreset, label="Coreset")
plot_by_normal(axs[0,0], mnist_ece_BADGE, label="BADGE")
plot_by_normal(axs[0,0], mnist_ece_ours, label="Ours", color = "blue")
axs[0,0].set_xlim([0, 40])
axs[0,0].set_ylim([0.02, 0.25])
axs[0,0].set_title('MNIST')
axs[0,0].set_ylabel('Expected Calibration Error', fontsize = 18)
axs[0,0].grid(True)


plot_by_normal(axs[1,0], mnist_acc_rd, label="Random")
plot_by_normal(axs[1,0], mnist_acc_erm, label="Least-conf")
plot_by_normal(axs[1,0], mnist_acc_margin, label="Margin")
plot_by_normal(axs[1,0], mnist_acc_DBAL, label="DBAL")
plot_by_normal(axs[1,0], mnist_acc_coreset, label="Coreset")
plot_by_normal(axs[1,0], mnist_acc_BADGE, label="BADGE")
plot_by_normal(axs[1,0], mnist_acc_ours, label="Ours", color = "blue")
axs[1,0].set_xlim([0, 40])
axs[1,0].set_ylim([70, 97])
axs[1,0].set_title('MNIST')
axs[1,0].set_ylabel('Accuracy', fontsize = 18)
axs[1,0].grid(True)

plot_by_normal(axs[0,1], svhn_ece_rd, label="Random")
plot_by_normal(axs[0,1], svhn_ece_erm, label="Least-conf")
plot_by_normal(axs[0,1], svhn_ece_margin, label="Margin")
plot_by_normal(axs[0,1], svhn_ece_DBAL, label="DBAL")
plot_by_normal(axs[0,1], svhn_ece_coreset, label="Coreset")
plot_by_normal(axs[0,1], svhn_ece_BADGE, label="BADGE")
plot_by_normal(axs[0,1], svhn_ece_ours, label="Ours", color = "blue")
axs[0,1].set_ylim([0.078, 0.16])
axs[0,1].set_title('SVHN')
axs[0,1].grid(True)

plot_by_normal(axs[1,1], svhn_acc_rd, label="Random")
plot_by_normal(axs[1,1], svhn_acc_erm, label="Least-conf")
plot_by_normal(axs[1,1], svhn_acc_margin, label="Margin")
plot_by_normal(axs[1,1], svhn_acc_DBAL, label="DBAL")
plot_by_normal(axs[1,1], svhn_acc_coreset, label="Coreset")
plot_by_normal(axs[1,1], svhn_acc_BADGE, label="BADGE")
plot_by_normal(axs[1,1], svhn_acc_ours, label="Ours", color = "blue")
axs[1,1].set_ylim([76, 90.5])
axs[1,1].set_title('SVHN')
axs[1,1].grid(True)

plot_by_normal(axs[0,2], f_mnist_ece_rd, label="Random")
plot_by_normal(axs[0,2], f_mnist_ece_erm, label="Least-conf")
plot_by_normal(axs[0,2], f_mnist_ece_margin, label="Margin")
plot_by_normal(axs[0,2], f_mnist_ece_DBAL, label="DBAL")
plot_by_normal(axs[0,2], f_mnist_ece_coreset, label="Coreset")
plot_by_normal(axs[0,2], f_mnist_ece_BADGE, label="BADGE")
plot_by_normal(axs[0,2], f_mnist_ece_ours, label="Ours", color = "blue")
axs[0,2].set_ylim([0.16, 0.30])
axs[0,2].set_title('Fashion MNIST')
axs[0,2].grid(True)

plot_by_normal(axs[1,2], f_mnist_acc_rd, label="Random")
plot_by_normal(axs[1,2], f_mnist_acc_erm, label="Least-conf")
plot_by_normal(axs[1,2], f_mnist_acc_margin, label="Margin")
plot_by_normal(axs[1,2], f_mnist_acc_DBAL, label="DBAL")
plot_by_normal(axs[1,2], f_mnist_acc_coreset, label="Coreset")
plot_by_normal(axs[1,2], f_mnist_acc_BADGE, label="BADGE")
plot_by_normal(axs[1,2], f_mnist_acc_ours, label="Ours", color = "blue")
axs[1,2].set_ylim([64, 83])
axs[1,2].set_title('Fashion MNIST')
axs[1,2].grid(True)

plot_by_normal(axs[0,3], cifar10_ece_rd, label="Random")
plot_by_normal(axs[0,3], cifar10_ece_erm, label="Least-conf")
plot_by_normal(axs[0,3], cifar10_ece_margin, label="Margin")
plot_by_normal(axs[0,3], cifar10_ece_DBAL, label="DBAL")
plot_by_normal(axs[0,3], cifar10_ece_coreset, label="Coreset")
plot_by_normal(axs[0,3], cifar10_ece_BADGE, label="BADGE")
plot_by_normal(axs[0,3], cifar10_ece_ours, label="Ours", color = "blue")
axs[0,3].set_ylim([0.25, 0.36])
axs[0,3].set_title('CIFAR-10')
axs[0,3].grid(True)

plot_by_normal(axs[1,3], cifar10_acc_rd, label="Random")
plot_by_normal(axs[1,3], cifar10_acc_erm, label="Least-conf")
plot_by_normal(axs[1,3], cifar10_acc_margin, label="Margin")
plot_by_normal(axs[1,3], cifar10_acc_DBAL, label="DBAL")
plot_by_normal(axs[1,3], cifar10_acc_coreset, label="Coreset")
plot_by_normal(axs[1,3], cifar10_acc_BADGE, label="BADGE")
plot_by_normal(axs[1,3], cifar10_acc_ours, label="Ours", color = "blue")
axs[1,3].set_ylim([48, 68])
axs[1,3].set_title('CIFAR-10')
axs[1,3].grid(True)

axs[1,1].legend(loc="upper center", bbox_to_anchor=(1.1, -0.1), fancybox=True, shadow=True, ncol=7, fontsize = 20)
plt.subplots_adjust(left=0.05, bottom=0.11, right=0.98, top=0.95, wspace=0.25, hspace=0.2)

plt.savefig("main_fig.pdf")