'''
Plot ImageNet accuracy-MAC Pareto improvement for 3-model cascades.
'''

import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


df_timm = pd.read_pickle('../data/df_timm.pkl')
with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
bi_softmax = np.load('../data/mac/bi_pareto_softmax.npy')
pareto_softmax = np.load('../data/mac/tri_pareto_softmax.npy')
pareto_mean_logits = np.load('../data/mac/tri_pareto_mean_logits.npy')

accB = np.array(baseline_mac[0])
macB = np.array(baseline_mac[1])
ylim = (56,90) # fix ylim


plt.figure(figsize=(6, 3), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,6,(1,4))
plt.semilogx(df_timm['mac'].tolist(), df_timm['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(baseline_mac[1],[i*100 for i in baseline_mac[0]], '-', color='k', lw=lw, ms=ms, label='Baseline Pareto')
plt.plot(bi_softmax[:,1], bi_softmax[:,0]*100, '-C1', lw=lw, ms=ms, label='2-model Max Softmax')
plt.plot(pareto_softmax[:,1], pareto_softmax[:,0]*100, '-C2', lw=lw, ms=ms, label='3-model Max Softmax')
plt.plot(pareto_mean_logits[:,1], pareto_mean_logits[:,0]*100, '-C3', lw=lw, ms=ms, label='3-model Mean Logits')

plt.ylim(ylim)
plt.xticks([10**8,10**9,10**10,10**11,10**12],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


ax = plt.subplot(1,6,(5,6))
plt.plot(macB/bi_softmax[:len(macB),1],accB*100, '-C1', lw=lw, ms=ms)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-C2', lw=lw, ms=ms)
plt.plot(macB/pareto_mean_logits[:len(macB),1],accB*100, '-C3', lw=lw, ms=ms)

plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3,4,5],fontsize=fslabels)
plt.yticks(fontsize=fslabels)


plotname = 'figure_4_pareto_tri'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)