'''
Plot comparison of confidence metrics for mean ensembles.
'''

import json

import matplotlib.pyplot as plt
import numpy as np

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
pareto_S_logits = np.load('../data/mac/bi_pareto_mean_logits.npy')
pareto_S_softmax = np.load('../data/mac/bi_pareto_mean_softmax.npy')
pareto_M_logits = np.load('../data/mac/bi_pareto_mean_m_logits.npy')
pareto_M_softmax = np.load('../data/mac/bi_pareto_mean_m_softmax.npy')

accB = np.array(baseline_mac[0])
macB = np.array(baseline_mac[1])
al = 0.8 # alpha
ylim = (82,89)


plt.figure(figsize=(3, 3), dpi=200)
plt.rc('font', family=font)

plt.plot(macB/pareto_S_logits[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='Max Softmax Mean Logits', zorder=2.5)
plt.plot(macB/pareto_S_softmax[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='Max Softmax Mean Softmax', zorder=2.5)
plt.plot(macB/pareto_M_logits[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='Softmax Margin Mean Logits')
plt.plot(macB/pareto_M_softmax[:len(macB),1],accB*100, '-', color='C9', alpha=al, lw=lw, ms=ms, label='Softmax Margin Mean Softmax')

plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.grid(lw=0.4)
plt.xticks([1,2,3,4],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)


plotname = 'figure_11_appendix_mean_ensemble_confidence'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)