'''
Plot ImageNet accuracy-time Pareto improvement for 2-model cascades.
'''

import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


df_timm = pd.read_pickle('../data/df_timm.pkl')
with open('../data/baseline_time.txt', 'r') as f: baseline_mac = json.load(f)
pareto_softmax = np.load('../data/time/time_bi_pareto_softmax.npy')
pareto_softmax_margin = np.load('../data/time/time_bi_pareto_softmax_margin.npy')
pareto_mean_logits = np.load('../data/time/time_bi_pareto_mean_logits.npy')

accB = np.array(baseline_mac[0])
timeB = np.array(baseline_mac[1])
ylim = (56,90) # fix ylim


plt.figure(figsize=(6, 3), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,6,(1,4))
plt.semilogx(df_timm['secperinf'].tolist(), df_timm['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(baseline_mac[1],[i*100 for i in baseline_mac[0]], '-', color='k', lw=lw, ms=ms, label='Baseline Pareto')
plt.plot(pareto_softmax[:,1], pareto_softmax[:,0]*100, '-C1', lw=lw, ms=ms, label='Max Softmax', zorder=2.5)
plt.plot(pareto_softmax_margin[:,1], pareto_softmax_margin[:,0]*100, '-C2', lw=lw, ms=ms, label='Softmax Margin')
plt.plot(pareto_mean_logits[:,1], pareto_mean_logits[:,0]*100, '-C3', lw=lw, ms=ms, label='Mean Logits')

plt.ylim(ylim)
plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Seconds/Inference', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


ax = plt.subplot(1,6,(5,6))
plt.plot(timeB/pareto_softmax[:len(timeB),1],accB*100, '-C1', lw=lw, ms=ms, label='Max Softmax', zorder=2.5)
plt.plot(timeB/pareto_softmax_margin[:len(timeB),1],accB*100, '-C2', lw=lw, ms=ms, label='Softmax Margin')
plt.plot(timeB/pareto_mean_logits[:len(timeB),1],accB*100, '-C3', lw=lw, ms=ms, label='Mean Logits')

plt.xlim(0.85,3)
plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3],fontsize=fslabels)
plt.yticks(fontsize=fslabels)


plotname = 'figure_5_pareto_time'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)