'''
Creates baseline and 2 model cascade Pareto fronts.
'''

import json

import numpy as np
import pandas as pd


# returns list of Pareto optimal models as [index, name, metric, cost]
def pareto_models(df, cost = 'mac', metric = 'top1'):
    dfcost = df[cost].tolist()
    dfmetric = df[metric].tolist()
    dfname = df['model'].tolist()
    val = dfcost[0]
    pareto = [[0, dfname[0], dfmetric[0], val]]
    for i in range(1, len(df)):
        if dfcost[i]<val:
            val = dfcost[i]
            pareto.append([i, dfname[i], dfmetric[i], val])
    return pareto[::-1]


# filters model duplicates and bad outliers 
def pareto_filter(pareto, factor = 2):
    return [pareto[i][:] for i in range(len(pareto)-1) if 
            (pareto[i+1][2]/pareto[i][2]-1) < factor*(pareto[i+1][3]/pareto[i][3]-1)]+[pareto[-1]]


# converts rectangular to linear Pareto front
def paretoRtoL(pareto):
    l = len(pareto)
    m = [[(pareto[j][2]-pareto[i][2])/(pareto[j][3]-pareto[i][3]) for j in range(i+1,l)] for i in range(l-1)]
    imax = [i.index(max(i))+len(m)+1-len(i) for i in m]
    i = 0
    idx = [0]
    while i < len(imax):
        i = imax[i]
        idx.append(i)
    return [pareto[i] for i in idx]


# linear interpolation to create exact baseline Pareto
def paretoLfull(pareto, size):
    acc = [i[2]/100 for i in pareto]
    mac = [i[3] for i in pareto]
    correct = [round(i*size) for i in acc]
    l = len(acc)
    ret = [[i/size for i in range(correct[0],correct[-1]+1)],[]]
    for i in range(l-1):
        delta = correct[i+1]-correct[i]
        ret[1].extend([mac[i]*(delta-k)/delta + mac[i+1]*k/delta for k in range(delta)])
    ret[1].extend([mac[-1]])
    return ret


# linear interpolation adjusted for f1 score
def f1paretoLfull(pareto, fill):
    acc = [i[2]/100 for i in pareto]
    mac = [i[3] for i in pareto]
    l = len(acc)
    ret = [[],[]]
    for i in range(l-1):
        ret[0].extend([acc[i]*(fill-k)/fill + acc[i+1]*k/fill for k in range(fill)])
        ret[1].extend([mac[i]*(fill-k)/fill + mac[i+1]*k/fill for k in range(fill)])
    ret[0].extend([acc[-1]])
    ret[1].extend([mac[-1]])
    return ret


# creates Pareto from multiple cascades
def get_pareto(cascades):
    arr = np.concatenate(cascades)
    arr = arr[arr[:, 1].argsort()]  
    m = arr[0][0] # current maximum accuracy
    acc = [arr[0][0]]
    cost = [arr[0][1]]
    # go through all points ordered by MAC from small to large
    for i in range(1,len(arr)):
        # if accuracy is new best, add to Pareto
        if arr[i][0] > m:
            m = arr[i][0]
            acc.append(arr[i][0])
            cost.append(arr[i][1])
    return np.array([acc, cost]).T


# obtains Pareto front for all possible 2 model cascades
def cascade_pareto_bi(models, cost, path_infer, get_f1 = False):
    # store cascades, the used models and their cost (MAC, time)
    cas = []
    # cas, mod, cst = [], [], []
    n = len(models)
    for i in range(1,n):
        for j in range(n-i):
            # load boolean correctness and condition values for models
            with open(path_infer+models[j]+'.txt', 'r') as f: ea1 = json.load(f)
            with open(path_infer+models[-i]+'.txt', 'r') as f: ea2 = json.load(f)
            l = len(ea1[0])
            # f1 score
            if get_f1:
                arr = np.array([ea1[0], ea1[2], ea2[0], ea2[2], ea1[1], ea2[1]]).T
                arr = arr[arr[:, 1].argsort()[::-1]]
                tp1 = np.logical_and(arr[:,0]==1,arr[:,4]==1).tolist()
                tp2 = np.logical_and(arr[:,2]==1,arr[:,5]==1).tolist()
                tp = np.cumsum([0]+tp1)[::-1]+np.cumsum([0]+tp2[::-1])
                fp1 = np.logical_and(arr[:,0]==0,arr[:,4]==1).tolist()
                fp2 = np.logical_and(arr[:,2]==0,arr[:,5]==1).tolist()
                fp = np.cumsum([0]+fp1)[::-1]+np.cumsum([0]+fp2[::-1])
                fn1 = np.logical_and(arr[:,0]==0,arr[:,4]==0).tolist()
                fn2 = np.logical_and(arr[:,2]==0,arr[:,5]==0).tolist()
                fn = np.cumsum([0]+fn1)[::-1]+np.cumsum([0]+fn2[::-1])
                cas.append(np.concatenate(
                    ((tp / (tp + (fp+fn)/2))[...,None],
                     np.linspace(cost[j], cost[j]+cost[-i], num=l+1)[...,None]), axis=1))
            # top-1 accuracy
            else:
                arr = np.array([ea1[0], ea1[2], ea2[0], ea2[2]]).T
                arr = arr[arr[:, 1].argsort()[::-1]]
                cas.append(np.concatenate(
                    ((np.concatenate(([0], arr[:,0])).cumsum()[::-1]+np.concatenate(([0], arr[:,2][::-1])).cumsum())[...,None]/l,
                     np.linspace(cost[j], cost[j]+cost[-i], num=l+1)[...,None]), axis=1))
            # mod.append([models[j],models[-i]]) 
            # cst.append([cost[j],cost[-i]])
    pareto = get_pareto(cas)
    return pareto #, [mod, cst, cas]


def main():
    # obtain baseline Pareto fronts
    # top1
    for ds, k in [['sst2',872],['mrpc',408],['qqp',40430],['qnli',5463]]:
        df = pd.read_pickle('data/df_'+ds+'.pkl')
        preparetomacR = pareto_models(df, cost = 'mac', metric = 'top1')
        paretomacR = pareto_filter(preparetomacR)
        if len(preparetomacR)-len(paretomacR):
            print('Filtered',len(preparetomacR)-len(paretomacR),'models:\n',[i for i in preparetomacR if i not in paretomacR])
        paretomacL = paretoRtoL(paretomacR)
        paretomacLfull = paretoLfull(paretomacL, k)
        with open('data/pareto_models_'+ds+'.txt', 'w') as f: json.dump(paretomacR,f,indent=2)
        with open('data/baseline_'+ds+'.txt', 'w') as f: json.dump(paretomacLfull,f,indent=2)
    # f1
    for ds, k in [['mrpc',408],['qqp',40430]]:
        df = pd.read_pickle('data/df_'+ds+'.pkl')
        preparetomacR = pareto_models(df, cost = 'mac', metric = 'f1')
        paretomacR = pareto_filter(preparetomacR)
        if len(preparetomacR)-len(paretomacR):
            print('Filtered',len(preparetomacR)-len(paretomacR),'models:\n',[i for i in preparetomacR if i not in paretomacR])
        paretomacL = paretoRtoL(paretomacR)
        paretomacLfull = f1paretoLfull(paretomacL, 200)
        with open('data/pareto_models_'+ds+'_f1.txt', 'w') as f: json.dump(paretomacR,f,indent=2)
        with open('data/baseline_'+ds+'_f1.txt', 'w') as f: json.dump(paretomacLfull,f,indent=2)
    
    # obtain MAC Pareto fronts for 2-model cascades
    # sst2
    with open('data/pareto_models_sst2.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing sst2 Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_')
    np.save('data/pareto_sst2.npy', pareto, allow_pickle=True)
    
    # mrpc
    with open('data/pareto_models_mrpc.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing mrpc Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_')
    np.save('data/pareto_mrpc.npy', pareto, allow_pickle=True)
    
    # qqp
    with open('data/pareto_models_qqp.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing qqp Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_')
    np.save('data/pareto_qqp.npy', pareto, allow_pickle=True)
    
    # qnli
    with open('data/pareto_models_qnli.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing qnli Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_')
    np.save('data/pareto_qnli.npy', pareto, allow_pickle=True)
    
    # mrpc f1
    with open('data/pareto_models_mrpc_f1.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing mrpc f1 Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_', get_f1 = True)
    np.save('data/pareto_mrpc_f1.npy', pareto, allow_pickle=True)
    
    # qqp f1
    with open('data/pareto_models_qqp_f1.txt', 'r') as f: pareto_base = json.load(f)
    models = [i[1] for i in pareto_base]
    macs = [i[3] for i in pareto_base]
    print('Computing qqp f1 Pareto.')
    pareto = cascade_pareto_bi(models, macs, 'data/infer/infer_', get_f1 = True)
    np.save('data/pareto_qqp_f1.npy', pareto, allow_pickle=True)


if __name__ == '__main__':
    main()