'''
Evaluate 2-model cascades to obtain dataframe with various quality metrics.
'''

import json
import pickle

import numpy as np
import pandas as pd


# obtains Pareto fronts of cascades
def get_pareto(cascade):
    m = 0
    acc = []
    mac = []
    for i in range(len(cascade)):
        if cascade[i][0] > m:
            m = cascade[i][0]
            acc.append(cascade[i][0])
            mac.append(cascade[i][1])
    return np.array([acc, mac]).T


# evaluate cascades to obtain various quality metrics
def evaluate_cascades(cascades, baseline):
    ret = []
    for i in range(len(cascades[0])):
        l = len(cascades[2][i])-1
        drop01 = 1-(np.where(cascades[2][i][:,0] > cascades[2][i][-1,0]-0.001)[0][0]/l)
        drop5 = 1-(np.where(cascades[2][i][:,0] > 0.95*cascades[2][i][-1,0]+0.05*cascades[2][i][0,0])[0][0]/l)
        pareto = get_pareto(cascades[2][i])
        macB = np.linspace(cascades[1][i][0],cascades[1][i][1],int(round((cascades[2][i][-1,0]-cascades[2][i][0,0])*l+1)))
        l2 = len(macB)
        avgimp_l = sum(macB/pareto[:l2,1])/l2
        maximp_l = np.amax(macB/pareto[:l2,1])
        idx = np.argmax(macB/pareto[:l2,1])
        maximprate_l = 1-np.where(cascades[2][i][:,0] == round(cascades[2][i][0,0]+idx/l,6))[0][0]/l
        start = baseline[:,0].tolist().index(pareto[0,0])
        avgimp_p = sum(baseline[start:start+l2,1]/pareto[:l2,1])/l2
        maximp_p = np.amax(baseline[start:start+l2,1]/pareto[:l2,1])
        idx = np.argmax(baseline[start:start+l2,1]/pareto[:l2,1])
        maximprate_p = 1-np.where(cascades[2][i][:,0] == round(cascades[2][i][0,0]+idx/l,6))[0][0]/l
        ret.append([cascades[0][i][0], # first model
                    cascades[0][i][1], # second model
                    cascades[1][i][0], # first model cost
                    cascades[1][i][1], # second model cost
                    cascades[2][i][0,0], # first model accuracy
                    cascades[2][i][-1,0], # second model accuracy
                    max(cascades[2][i][:,0]), # max cascade accuracy
                    drop01, # largest early exit % before 0.1 absolute accuracy drop
                    drop5, # largest early exit % before 5% relative accuracy drop
                    avgimp_l, # average cost improvement cascade vs linear
                    maximp_l, # max cost improvement cascade vs linear
                    maximprate_l, # index of maximp_l as number of early exits
                    avgimp_p, # average cost improvement cascade vs Pareto
                    maximp_p, # max cost improvement cascade vs Pareto
                    maximprate_p, # index of maximp_p as number of early exits
                    ])
    return ret


def main():
    columns = ['model1', 'model2', 'mac1', 'mac2', 'accuracy1', 'accuracy2',
               'max_accuracy', 'accdrop01', 'reldrop5', 'avg_imp_l', 'max_imp_l',
               'max_imp_rate_l', 'avg_imp_p', 'max_imp_p','max_imp_rate_p']
    with open('data/mac/bi_cascades_softmax.pkl', 'rb') as f: cascades = pickle.load(f)
    with open('data/baseline_mac.txt', 'r') as f: baseline_mac = np.array(json.load(f)).T
    
    softmax_eval = evaluate_cascades(cascades, baseline_mac)
    df_softmax = pd.DataFrame(softmax_eval, columns=columns)
    df_softmax.to_pickle('data/mac/df_softmax.pkl')


if __name__ == '__main__':
    main()