'''
Plot GLUE benchmark Pareto fronts.
'''

import json

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from figuresettings import font, lw, ms, fsaxes, fslabels


# makes Pareto front rectangular for figure
def convertpareto(pareto):
    return np.array([[*np.repeat(pareto[:-1,0],2),pareto[-1,0]],
                     [pareto[0,1],*np.repeat(pareto[1:,1],2)]]).T


# load model dataframes, cascade Paretos, and baseline Paretos
df_sst2 = pd.read_pickle('../data/df_sst2.pkl')
with open('../data/baseline_sst2.txt', 'r') as f: baseline_sst2 = json.load(f)
pareto_sst2 = np.load('../data/pareto_sst2.npy')

df_mrpc = pd.read_pickle('../data/df_mrpc.pkl')
with open('../data/baseline_mrpc.txt', 'r') as f: baseline_mrpc = json.load(f)
pareto_mrpc = np.load('../data/pareto_mrpc.npy')

df_qqp = pd.read_pickle('../data/df_qqp.pkl')
with open('../data/baseline_qqp.txt', 'r') as f: baseline_qqp = json.load(f)
pareto_qqp = np.load('../data/pareto_qqp.npy')

df_qnli = pd.read_pickle('../data/df_qnli.pkl')
with open('../data/baseline_qnli.txt', 'r') as f: baseline_qnli = json.load(f)
pareto_qnli = np.load('../data/pareto_qnli.npy')

with open('../data/baseline_mrpc_f1.txt', 'r') as f: baseline_mrpc_f1 = json.load(f)
pareto_mrpc_f1 = np.load('../data/pareto_mrpc_f1.npy')

with open('../data/baseline_qqp_f1.txt', 'r') as f: baseline_qqp_f1 = json.load(f)
pareto_qqp_f1 = np.load('../data/pareto_qqp_f1.npy')


# nested list to enable plotting with loop
# [dataset name, quality metric, models dataframe, cascade Pareto, baseline Pareto]
data = [['SST-2', 'top1', df_sst2, pareto_sst2, baseline_sst2],
        ['MRPC', 'top1', df_mrpc, pareto_mrpc, baseline_mrpc],
        ['MRPC', 'f1', df_mrpc, pareto_mrpc_f1, baseline_mrpc_f1],
        ['QNLI', 'top1', df_qnli, pareto_qnli, baseline_qnli],
        ['QQP', 'top1', df_qqp, pareto_qqp, baseline_qqp],
        ['QQP', 'f1', df_qqp, pareto_qqp_f1, baseline_qqp_f1]]


fig = plt.figure(figsize=(12, 5), dpi=200)
plt.rc('font', family=font)
outer = fig.add_gridspec(2, 3)

# plot 6 subfigures
for i in range(6):
    inner = outer[i].subgridspec(1, 6, wspace=0.1)
    
    # plot Pareto front
    pareto = convertpareto(data[i][3])
    ax1 = plt.subplot(inner[0:4])
    plt.semilogx(data[i][2]['mac'].tolist(), data[i][2][data[i][1]].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
    plt.plot(pareto[:,1], pareto[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
    plt.plot(data[i][4][1],[i*100 for i in data[i][4][0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')
    
    plt.xticks(fontsize=fslabels)
    plt.yticks(fontsize=fslabels)
    if i > 2: # plot xlabel only for lower half
        plt.xlabel('Average Inference MAC', fontsize=fsaxes)
    plt.ylabel(f'{data[i][0]} Validation {"F1" if data[i][1] == "f1" else "Accuracy"} %', fontsize=fsaxes)
    plt.grid(lw=0.4)
    
    # plot Pareto improvement
    ax2 = plt.subplot(inner[4:6], sharey=ax1)
    baseline = np.array(data[i][4]).T
    if (i+1) % 3: # accuracy improvement
        plt.plot(baseline[:,1]/data[i][3][:baseline.shape[0],1],baseline[:,0]*100, '-', color='C1', lw=lw, ms=ms)
    else: # F1 improvement, requires interpolation
        interpolated = np.interp(data[i][3][:,0][data[i][3][:,0] <= np.max(baseline[:,0])], baseline[:,0], baseline[:,1])
        n = interpolated.shape[0]
        plt.plot(interpolated/data[i][3][:n,1],data[i][3][:n,0]*100, '-', color='C1', lw=lw, ms=ms)

    if i > 2: # plot xlabel only for lower half
        plt.xlabel('Improvement Factor', fontsize=fsaxes)
    plt.grid(lw=0.4)
    plt.xticks(fontsize=fslabels)
    # hide y ticks
    for tick in ax2.yaxis.get_major_ticks():
        tick.label.set_visible(False)
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)

l1, = plt.plot([], [], '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
l2, = plt.plot([], [], '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
l3, = plt.plot([], [], '-', color='C2', lw=lw, ms=ms, label='Baseline Pareto')
fig.legend(handles = [l1,l2,l3], bbox_to_anchor=(0.13, 0.98, 0.761, 0.2), mode='expand', loc = 'lower left', ncol=3, borderaxespad=0, fontsize=fslabels, framealpha=0.)#, frameon=False)
fig.tight_layout()


plotname = 'figure_7_glue_paretos'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)


'''
# figure without improvement

fig = plt.figure(figsize=(9, 4.5), dpi=200)
plt.rc('font', family=font)

for i in range(6):
    pareto = convertpareto(data[i][3])
    plt.subplot(2,3,i+1)
    plt.semilogx(data[i][2]['mac'].tolist(), data[i][2][data[i][1]].tolist(), '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
    plt.plot(pareto[:,1], pareto[:,0]*100, '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
    plt.plot(data[i][4][1],[i*100 for i in data[i][4][0]], '-', color='C2', lw=lw, ms=ms, label='Base Pareto')
    
    plt.xticks(fontsize=fslabels)
    plt.yticks(fontsize=fslabels)
    if i > 2: # plot xlabel only for lower half
        plt.xlabel('Average Inference MAC', fontsize=fsaxes)
    plt.ylabel(f'{data[i][0]} Validation {"F1" if data[i][1] == "f1" else "Accuracy"} %', fontsize=fsaxes)
    plt.grid(lw=0.4)

l1, = plt.plot([], [], '.', color='C0', lw=lw, ms=ms, label='Pretrained Models')
l2, = plt.plot([], [], '-', color='C1', lw=lw, ms=ms, label='Cascade Pareto')
l3, = plt.plot([], [], '-', color='C2', lw=lw, ms=ms, label='Baseline Pareto')
fig.legend(handles = [l1,l2,l3], bbox_to_anchor=(0.112, 0.98, 0.802, 0.2), mode='expand', loc = 'lower left', ncol=3, borderaxespad=0, fontsize=fslabels, framealpha=0.)#, frameon=False)
plt.tight_layout()

plotname = 'figure_6_glue_paretos_small'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)
'''