'''
Plot ImageNet top5-MAC Pareto improvement for 2-model cascades.
'''

import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


df_timm = pd.read_pickle('../data/df_timm.pkl')
with open('../data/baseline_top5.txt', 'r') as f: baseline_top5 = json.load(f)
pareto_entropy = np.load('../data/top5/top5_pareto_entropy.npy')
pareto_softmax = np.load('../data/top5/top5_pareto_softmax.npy')
pareto_softmax_sum = np.load('../data/top5/top5_pareto_softmax_sum.npy')
pareto_sum_mean_logits = np.load('../data/top5/top5_pareto_softmax_sum_mean_logits.npy')

accB = np.array(baseline_top5[0])
macB = np.array(baseline_top5[1])
ylim = (80,100)
al = 0.75


plt.figure(figsize=(6, 3), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,6,(1,4))
plt.semilogx(df_timm['mac'].tolist(), df_timm['top5'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(macB,accB*100, '-', color='k', lw=lw, ms=ms, label='Baseline Pareto')
plt.plot(pareto_entropy[:,1], pareto_entropy[:,0]*100, '-C2', alpha=al, lw=lw, ms=ms, label='Entropy',zorder=2.5)
plt.plot(pareto_softmax[:,1], pareto_softmax[:,0]*100, '-C1', alpha=al, lw=lw, ms=ms, label='Max Softmax')
plt.plot(pareto_softmax_sum[:,1], pareto_softmax_sum[:,0]*100, '-C3', alpha=al, lw=lw, ms=ms, label='Softmax Sum',zorder=2.5)
plt.plot(pareto_sum_mean_logits[:,1], pareto_sum_mean_logits[:,0]*100, '-C9', alpha=al, lw=lw, ms=ms, label='Mean Logits',zorder=2.5)

plt.ylim(ylim)
plt.xticks([10**8,10**9,10**10,10**11,10**12],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-5 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


ax = plt.subplot(1,6,(5,6))
plt.plot(macB/pareto_entropy[:len(macB),1],accB*100, '-C2', alpha=al, lw=lw, ms=ms,zorder=2.5)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-C1', alpha=al, lw=lw, ms=ms)
plt.plot(macB/pareto_softmax_sum[:len(macB),1],accB*100, '-C3', alpha=al, lw=lw, ms=ms,zorder=2.5)
plt.plot(macB/pareto_sum_mean_logits[:len(macB),1],accB*100, '-C9', alpha=al, lw=lw, ms=ms,zorder=2.5)

plt.xlim(0.9,3.2)
plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3],fontsize=fslabels)
plt.yticks(fontsize=fslabels)


plotname = 'figure_9_appendix_top5'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)