'''
Plot ImageNet accuracy-MAC Pareto improvement for max softmax cascades.
'''

import json
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


df_timm = pd.read_pickle('../data/df_timm.pkl')
with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
pareto_softmax = np.load('../data/mac/bi_pareto_softmax.npy')
with open('../data/mac/bi_cascades_softmax.pkl', 'rb') as f: cascades = pickle.load(f)

accB = np.array(baseline_mac[0])
macB = np.array(baseline_mac[1])
ylim = (56,90) # fix ylim


plt.figure(figsize=(6, 3), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,6,(1,4))
plt.semilogx(df_timm['mac'].tolist(), df_timm['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(baseline_mac[1],[i*100 for i in baseline_mac[0]], '-', color='C2', lw=lw, ms=ms, label='Baseline Pareto')
plt.plot(pareto_softmax[:,1], pareto_softmax[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Softmax Cascade Pareto')

plt.plot(cascades[2][91][:,1], cascades[2][91][:,0]*100, '-k', lw=lw*0.5, ms=ms,label='Tested Cascades')
plt.plot(cascades[2][239][:,1], cascades[2][239][:,0]*100, '-k', lw=lw*0.5, ms=ms)
plt.plot(cascades[2][316][:,1], cascades[2][316][:,0]*100, '-k', lw=lw*0.5, ms=ms)
plt.plot([135877545.78176, 735225919.7856001, 14379633777.443842],
          [75.908, 83.126, 87.207],
          '+', color='r', lw=lw, ms=ms, label='ImageNet Tests')

plt.ylim(ylim)
plt.xticks([10**8,10**9,10**10,10**11,10**12],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


ax = plt.subplot(1,6,(5,6))
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='C1', lw=lw, ms=ms)

plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3,4],fontsize=fslabels)
plt.yticks(fontsize=fslabels)


plotname = 'figure_1_pareto_softmax'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)