'''
Plot ImageNet 3-model accuracy-MAC and 2-model accuracy-time Pareto improvement
for various cascading methods.
'''

import json

import matplotlib.pyplot as plt
import numpy as np

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
pareto_entropy = np.load('../data/mac/tri_pareto_entropy.npy')
pareto_softmax = np.load('../data/mac/tri_pareto_softmax.npy')
pareto_softmax_margin = np.load('../data/mac/tri_pareto_softmax_margin.npy')
pareto_logits_margin = np.load('../data/mac/tri_pareto_logits_margin.npy')
pareto_ts_entropy = np.load('../data/mac/tri_pareto_ts_entropy.npy')
pareto_ts_softmax = np.load('../data/mac/tri_pareto_ts_softmax.npy')
pareto_ts_softmax_margin = np.load('../data/mac/tri_pareto_ts_softmax_margin.npy')
pareto_ts_logits_margin = np.load('../data/mac/tri_pareto_ts_logits_margin.npy')
pareto_mean_logits = np.load('../data/mac/tri_pareto_mean_logits.npy')
pareto_mean_softmax = np.load('../data/mac/tri_pareto_mean_softmax.npy')
pareto_mean_ts_logits = np.load('../data/mac/tri_pareto_mean_ts_logits.npy')
pareto_mean_ts_softmax = np.load('../data/mac/tri_pareto_mean_ts_softmax.npy')

accB = np.array(baseline_mac[0])
macB = np.array(baseline_mac[1])
al = 0.8 # alpha
xlim = (0.8,5.2)
ylim = (75,89)


# plot 3-model cascades Pareto improvement
plt.figure(figsize=(8, 4), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,3,1)
plt.plot(macB/pareto_entropy[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='Entropy')
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='Max Softmax', zorder=2.5)
plt.plot(macB/pareto_softmax_margin[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='Softmax Margin')
plt.plot(macB/pareto_logits_margin[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='Logits Margin')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3,4,5],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('No Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,3,2)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='k', lw=lw, ms=ms, label='Max Softmax')#, zorder=10)
plt.plot(macB/pareto_ts_entropy[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='TS Entropy')
plt.plot(macB/pareto_ts_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='TS Max Softmax')
plt.plot(macB/pareto_ts_softmax_margin[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='TS Softmax Margin')
plt.plot(macB/pareto_ts_logits_margin[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='TS Logits Margin')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3,4,5],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
# plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('Comparison Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,3,3)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='k', lw=lw, ms=ms, label='Max Softmax')#, zorder=10)
plt.plot(macB/pareto_mean_logits[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='Mean Logits', zorder=2.5)
plt.plot(macB/pareto_mean_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='Mean Softmax')
plt.plot(macB/pareto_mean_ts_logits[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='Mean TS Logits')
plt.plot(macB/pareto_mean_ts_softmax[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='Mean TS Softmax')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3,4,5],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
# plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('Mean Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plotname = 'figure_7_appendix_improvement_tri'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)


with open('../data/baseline_time.txt', 'r') as f: baseline_time = json.load(f)
pareto_entropy = np.load('../data/time/time_bi_pareto_entropy.npy')
pareto_softmax = np.load('../data/time/time_bi_pareto_softmax.npy')
pareto_softmax_margin = np.load('../data/time/time_bi_pareto_softmax_margin.npy')
pareto_logits_margin = np.load('../data/time/time_bi_pareto_logits_margin.npy')
pareto_ts_entropy = np.load('../data/time/time_bi_pareto_ts_entropy.npy')
pareto_ts_softmax = np.load('../data/time/time_bi_pareto_ts_softmax.npy')
pareto_ts_softmax_margin = np.load('../data/time/time_bi_pareto_ts_softmax_margin.npy')
pareto_ts_logits_margin = np.load('../data/time/time_bi_pareto_ts_logits_margin.npy')
pareto_mean_logits = np.load('../data/time/time_bi_pareto_mean_logits.npy')
pareto_mean_softmax = np.load('../data/time/time_bi_pareto_mean_softmax.npy')
pareto_mean_ts_logits = np.load('../data/time/time_bi_pareto_mean_ts_logits.npy')
pareto_mean_ts_softmax = np.load('../data/time/time_bi_pareto_mean_ts_softmax.npy')

accB = np.array(baseline_time[0])
macB = np.array(baseline_time[1])
al = 0.8 # alpha
xlim = (0.8,3.2)
ylim = (75,89)


# plot Pareto improvement
plt.figure(figsize=(8, 4), dpi=200)
plt.rc('font', family=font)


plt.subplot(1,3,1)
plt.plot(macB/pareto_entropy[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='Entropy')
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='Max Softmax', zorder=2.5)
plt.plot(macB/pareto_softmax_margin[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='Softmax Margin')
plt.plot(macB/pareto_logits_margin[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='Logits Margin')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('No Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,3,2)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='k', lw=lw, ms=ms, label='Max Softmax')#, zorder=10)
plt.plot(macB/pareto_ts_entropy[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='TS Entropy')
plt.plot(macB/pareto_ts_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='TS Max Softmax')
plt.plot(macB/pareto_ts_softmax_margin[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='TS Softmax Margin')
plt.plot(macB/pareto_ts_logits_margin[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='TS Logits Margin')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
# plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('Comparison Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(1,3,3)
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='k', lw=lw, ms=ms, label='Max Softmax')#, zorder=10)
plt.plot(macB/pareto_mean_logits[:len(macB),1],accB*100, '-', color='C0', alpha=al, lw=lw, ms=ms, label='Mean Logits', zorder=2.5)
plt.plot(macB/pareto_mean_softmax[:len(macB),1],accB*100, '-', color='C1', alpha=al, lw=lw, ms=ms, label='Mean Softmax')
plt.plot(macB/pareto_mean_ts_logits[:len(macB),1],accB*100, '-', color='C2', alpha=al, lw=lw, ms=ms, label='Mean TS Logits')
plt.plot(macB/pareto_mean_ts_softmax[:len(macB),1],accB*100, '-', color='C3', alpha=al, lw=lw, ms=ms, label='Mean TS Softmax')

plt.xlim(xlim)
plt.ylim(ylim)
plt.xticks([1,2,3],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
# plt.ylabel('ImageNet Validation Top-1 %', fontsize=fsaxes)
plt.title('Mean Ensemble', fontsize=fstitle)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plotname = 'figure_8_appendix_improvement_time'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)