'''
Plot distribution shift experiment results for main text and appendix.
'''

import json
import pickle

import matplotlib.pyplot as plt
import numpy as np

from figuresettings import font, lw, ms, fsaxes, fslabels


with open('../data/pareto_mac.txt', 'r') as f: pareto = json.load(f)
with open('../data/baseline_mac.txt', 'r') as f: baseline_mac = json.load(f)
pareto_softmax = np.load('../data/mac/bi_pareto_softmax.npy')

with open('../data/v2_pareto_mac.txt', 'r') as f: v2_pareto = json.load(f)
with open('../data/v2_baseline_mac.txt', 'r') as f: v2_baseline_mac = json.load(f)
v2_pareto_softmax = np.load('../data/mac/v2_bi_pareto_softmax.npy')

with open('../data/mac/v2_bi_cascades_softmax_filtered.pkl', 'rb') as f: v2_cascades_filtered = pickle.load(f)
v2_pareto_softmax_filtered = np.load('../data/mac/v2_bi_pareto_softmax_filtered.npy')

accB = np.array(baseline_mac[0])
macB = np.array(baseline_mac[1])
v2_accB = np.array(v2_baseline_mac[0])
v2_macB = np.array(v2_baseline_mac[1])
ylim = (45,90) # fix ylim
ylim2 = (45,84)


# plot Figure 6 which compares original with distribution shifted Pareto front
plt.figure(figsize=(6, 3), dpi=400)
plt.rc('font', family=font)

plt.subplot(1,6,(1,4))
plt.semilogx([i[3] for i in pareto], [i[2] for i in pareto], '.', color='C0', lw=lw, ms=ms, label='Pretrained Models ImageNet')
plt.plot(baseline_mac[1],[i*100 for i in baseline_mac[0]], '-', color='k', lw=lw, ms=ms, label='Baseline Pareto ImageNet')
plt.plot(pareto_softmax[:,1], pareto_softmax[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Softmax Cascade Pareto')

plt.plot([i[3] for i in v2_pareto], [i[2] for i in v2_pareto], '+', color='C0', lw=lw, ms=ms, label='Pretrained Models ImageNetV2')
plt.plot(v2_baseline_mac[1],[i*100 for i in v2_baseline_mac[0]], '-', color='C7', lw=lw, ms=ms, label='Baseline Pareto ImageNetV2')
plt.plot(v2_pareto_softmax_filtered[:,1], v2_pareto_softmax_filtered[:,0]*100, '-', color='C2', lw=lw, ms=ms, label='Distribution Shifted Cascade Pareto')

plt.ylim(ylim)
plt.xticks([10**8,10**9,10**10,10**11,10**12],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)

ax = plt.subplot(1,6,(5,6))
plt.plot(macB/pareto_softmax[:len(macB),1],accB*100, '-', color='C1', lw=lw, ms=ms)
plt.plot(v2_macB/v2_pareto_softmax_filtered[:len(v2_macB),1],v2_accB*100, '-', color='C2', lw=lw, ms=ms)

plt.ylim(ylim)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3,4],fontsize=fslabels)
plt.yticks(fontsize=fslabels)

plt.savefig('figure_6_robustness_comparison.pdf', bbox_inches = 'tight', pad_inches = 0)


# plot Figure 12 which compares distribution shifted with optimal Pareto front
plt.figure(figsize=(6, 3), dpi=400)
plt.rc('font', family=font)

plt.subplot(1,6,(1,4))
plt.semilogx([i[3] for i in v2_pareto], [i[2] for i in v2_pareto], '+', color='C0', lw=lw, ms=ms, label='Pretrained Models ImageNetV2')
plt.plot(v2_baseline_mac[1],[i*100 for i in v2_baseline_mac[0]], '-', color='C7', lw=lw, ms=ms, label='Baseline Pareto ImageNetV2')
plt.plot(v2_pareto_softmax[:,1], v2_pareto_softmax[:,0]*100, '-', color='C2', lw=lw, ms=ms, label='Softmax Cascade Pareto ImageNetV2')
for c in v2_cascades_filtered:
    plt.plot(c[:,1], c[:,0]*100, '-', color='C1', lw=lw, ms=ms)

plt.plot([], [], '-', color='C1', lw=lw, ms=ms, label='Distribution Shifted Cascade Pareto')

plt.ylim(ylim2)
plt.xticks([10**8,10**9,10**10,10**11,10**12],fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('ImageNetV2 Test Top-1 %', fontsize=fsaxes)
plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)

ax = plt.subplot(1,6,(5,6))
plt.plot(v2_macB/v2_pareto_softmax[:len(v2_macB),1],v2_accB*100, '-', color='C2', lw=lw, ms=ms)
plt.plot(v2_macB/v2_pareto_softmax_filtered[:len(v2_macB),1],v2_accB*100, '-', color='C1', lw=lw, ms=ms)

plt.ylim(ylim2)
plt.xlabel('Improvement Factor', fontsize=fsaxes)
plt.grid(lw=0.4)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.xticks([1,2,3,4],fontsize=fslabels)
plt.yticks(fontsize=fslabels)

plt.savefig('figure_12_appendix_distribution_shift.pdf', bbox_inches = 'tight', pad_inches = 0)