'''
Plots an example cascade, the relationship between model accuracy difference
and accuracy breakoff point, and individual cascade performances.
'''

import json
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


with open('../data/mac/bi_cascades_softmax.pkl', 'rb') as f: cascades = pickle.load(f)
with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
df_softmax = pd.read_pickle('../data/mac/df_softmax.pkl')

idx = cascades[0].index(['tf_efficientnet_b4_ns', 'beit_large_patch16_224'])
thresh = np.array(df_softmax['reldrop5'])
acc1 = np.array(df_softmax['accuracy1'])
acc2 = np.array(df_softmax['accuracy2'])
mac1 = np.array(df_softmax['mac1'])
mac2 = np.array(df_softmax['mac2'])
c1 = np.array(df_softmax['avg_imp_p'])
m1 = np.array([4]*len(c1))


plt.figure(figsize=(10, 3), dpi=200)
plt.rc('font', family=font)

plt.subplot(1,10,(1,4))
plt.plot(cascades[1][idx][0], cascades[2][idx][0,0]*100, 'oC0', ms=ms, label='EfficientNet B4 NS', zorder = 2.5)
plt.plot(cascades[1][idx][1], cascades[2][idx][-1,0]*100, '*C0', ms=ms*1.5, label='BEiT Large 224', zorder = 2.5)
# plt.plot(cascades[1][idx], [cascades[2][idx][0,0]*100, cascades[2][idx][-1,0]*100], '--C0', lw=lw, label='Linear Combination')
plt.plot(cascades[2][idx][:,1], cascades[2][idx][:,0]*100, '-', color='C1', lw=lw, label='Max Softmax Cascade')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)

plt.autoscale(False)
plt.plot(baseline_mac[1],[i*100 for i in baseline_mac[0]], '-', color='C2', lw=lw, ms=ms, label='Baseline Pareto')
plt.plot([15070204892,15070204892,15070204892],[85.16,87.3621,87.478],'_-k',lw=lw, ms=ms, label='5% Relative Accuracy Drop')

plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,10,(5,7))
plt.plot(100*thresh, acc2-acc1, '.', color='C0', lw=lw, ms=ms, label='5% Relative Accuracy Drop')

plt.ylim(plt.ylim()[0],0.3)
plt.xticks([i for i in range(40,101,10)],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Early Exit %', fontsize=fsaxes)
plt.ylabel('Models Top-1 Difference', fontsize=fsaxes)
plt.legend(loc='upper right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,10,(8,10))
plt.scatter(mac2/mac1, acc2-acc1, s=m1, c=c1, zorder = 2)
plt.xscale('log')

plt.ylim(plt.ylim()[0],0.3)
plt.xticks([10**0,10**1,10**2,10**3,10**4],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Models MAC Difference Factor', fontsize=fsaxes)
plt.ylabel('Models Top-1 Difference', fontsize=fsaxes)
plt.grid(lw=0.4)
cb1 = plt.colorbar(ticks = [0,0.5,1,1.5,2,2.5])
cb1.set_label('Average Improvement Over Pareto', fontsize=fsaxes, rotation=270, va='bottom')
cb1.ax.tick_params(labelsize=fslabels) 


plt.tight_layout()
plotname = 'figure_2_breakoff_correlation'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)