'''
Plot GLUE benchmark Pareto fronts.
'''

import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


# makes Pareto front rectangular for plotting
def convertpareto(pareto):
    return np.array([[*np.repeat(pareto[:-1,0],2),pareto[-1,0]],
                     [pareto[0,1],*np.repeat(pareto[1:,1],2)]]).T


df_sst2 = pd.read_pickle('../data/df_sst2.pkl')
with open('../data/baseline_sst2.txt', 'r') as f: baseline_sst2 = json.load(f)
pareto_sst2 = np.load('../data/pareto_sst2.npy')

df_mrpc = pd.read_pickle('../data/df_mrpc.pkl')
with open('../data/baseline_mrpc.txt', 'r') as f: baseline_mrpc = json.load(f)
pareto_mrpc = np.load('../data/pareto_mrpc.npy')

df_qqp = pd.read_pickle('../data/df_qqp.pkl')
with open('../data/baseline_qqp.txt', 'r') as f: baseline_qqp = json.load(f)
pareto_qqp = np.load('../data/pareto_qqp.npy')

df_qnli = pd.read_pickle('../data/df_qnli.pkl')
with open('../data/baseline_qnli.txt', 'r') as f: baseline_qnli = json.load(f)
pareto_qnli = np.load('../data/pareto_qnli.npy')

with open('../data/baseline_mrpc_f1.txt', 'r') as f: baseline_mrpc_f1 = json.load(f)
pareto_mrpc_f1 = np.load('../data/pareto_mrpc_f1.npy')

with open('../data/baseline_qqp_f1.txt', 'r') as f: baseline_qqp_f1 = json.load(f)
pareto_qqp_f1 = np.load('../data/pareto_qqp_f1.npy')


pareto_sst2 = convertpareto(pareto_sst2)
pareto_mrpc = convertpareto(pareto_mrpc)
pareto_qqp = convertpareto(pareto_qqp)
pareto_qnli = convertpareto(pareto_qnli)
pareto_mrpc_f1 = convertpareto(pareto_mrpc_f1)
pareto_qqp_f1 = convertpareto(pareto_qqp_f1)


fig = plt.figure(figsize=(9, 4.5), dpi=200)
plt.rc('font', family=font)


plt.subplot(2,3,1)
l1, = plt.semilogx(df_sst2['mac'].tolist(), df_sst2['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
l2, = plt.plot(pareto_sst2[:,1], pareto_sst2[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
l3, = plt.plot(baseline_sst2[1],[i*100 for i in baseline_sst2[0]], '-', color='C2', lw=lw, ms=ms, label='Baseline Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
# plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('SST-2 Validation Accuracy %', fontsize=fsaxes)
# plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
plt.grid(lw=0.4)


plt.subplot(2,3,2)
plt.semilogx(df_mrpc['mac'].tolist(), df_mrpc['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(pareto_mrpc[:,1], pareto_mrpc[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
plt.plot(baseline_mrpc[1],[i*100 for i in baseline_mrpc[0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
# plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('MRPC Validation Accuracy %', fontsize=fsaxes)
plt.grid(lw=0.4)


plt.subplot(2,3,3)
plt.semilogx(df_mrpc['mac'].tolist(), df_mrpc['f1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(pareto_mrpc_f1[:,1], pareto_mrpc_f1[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
plt.plot(baseline_mrpc_f1[1],[i*100 for i in baseline_mrpc_f1[0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
# plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('MRPC Validation F1 %', fontsize=fsaxes)
plt.grid(lw=0.4)


plt.subplot(2,3,4)
plt.semilogx(df_qnli['mac'].tolist(), df_qnli['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(pareto_qnli[:,1], pareto_qnli[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
plt.plot(baseline_qnli[1],[i*100 for i in baseline_qnli[0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('QNLI Validation Accuracy %', fontsize=fsaxes)
plt.grid(lw=0.4)


plt.subplot(2,3,5)
plt.semilogx(df_qqp['mac'].tolist(), df_qqp['top1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(pareto_qqp[:,1], pareto_qqp[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
plt.plot(baseline_qqp[1],[i*100 for i in baseline_qqp[0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('QQP Validation Accuracy %', fontsize=fsaxes)
plt.grid(lw=0.4)


plt.subplot(2,3,6)
plt.semilogx(df_qqp['mac'].tolist(), df_qqp['f1'].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
plt.plot(pareto_qqp_f1[:,1], pareto_qqp_f1[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
plt.plot(baseline_qqp_f1[1],[i*100 for i in baseline_qqp_f1[0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')

plt.xticks(fontsize=fslabels)
plt.yticks(fontsize=fslabels)
plt.xlabel('Average Inference MAC', fontsize=fsaxes)
plt.ylabel('QQP Validation F1 %', fontsize=fsaxes)
plt.grid(lw=0.4)


fig.legend(handles = [l1,l2,l3], bbox_to_anchor=(0.112, 0.98, 0.802, 0.2), mode='expand', loc = 'lower left', ncol=3, borderaxespad=0, fontsize=fslabels, framealpha=0.)#, frameon=False)
plt.tight_layout()


plotname = 'figure_6_glue_paretos'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)