'''
Updates Pareto models with correct accuracy from inference and
creates baseline and 2 model cascade Pareto fronts.

Conditions are selected by integer according to the following table:
0   entropy
1   max softmax
2   softmax margin
3   logits margin
4   temperature scaled entropy
5   temperature scaled max softmax
6   temperature scaled softmax margin
7   temperature scaled logits margin
8   arithmetic mean logits with max softmax confidence
9   arithmetic mean softmax with max softmax confidence
10  arithmetic mean TS logits with max softmax confidence
11  arithmetic mean TS softmax with max softmax confidence
12  arithmetic mean logits with softmax margin confidence
13  arithmetic mean softmax with softmax margin confidence
'''

import argparse
import json
import os
import pickle

import numpy as np
import torch
import torch.nn.functional as F


# filters model duplicates and bad outliers 
def pareto_filter(pareto, factor = 2):
    return [pareto[i][:] for i in range(len(pareto)-1) if 
            (pareto[i+1][2]/pareto[i][2]-1) < factor*(pareto[i+1][3]/pareto[i][3]-1)]+[pareto[-1]]


# converts rectangular to linear Pareto front
def paretoRtoL(pareto):
    l = len(pareto)
    m = [[(pareto[j][2]-pareto[i][2])/(pareto[j][3]-pareto[i][3]) for j in range(i+1,l)] for i in range(l-1)]
    imax = [i.index(max(i))+len(m)+1-len(i) for i in m]
    i = 0
    idx = [0]
    while i < len(imax):
        i = imax[i]
        idx.append(i)
    return [pareto[i] for i in idx]


# linear interpolation to create the baseline Pareto
def paretoLfull(pareto):
    acc = [round(i[2]/100,6) for i in pareto]
    mac = [i[3] for i in pareto]
    l = len(acc)
    ret = [[round(i,6) for i in np.linspace(acc[0],acc[-1],int(round((acc[-1]-acc[0])/0.00002+1))).tolist()],[]]
    for i in range(l-1):
        ret[1].extend(np.linspace(mac[i],mac[i+1],int(round((acc[i+1]-acc[i])/0.00002+1))).tolist()[:-1])
    ret[1].extend([mac[-1]])
    return ret


# obtain baseline Pareto fronts for each cost metric
def baseline(costs):
    for cost in costs:
        with open('data/prepareto_'+cost+'.txt', 'r') as f: prepareto = json.load(f)
        paretoR = pareto_filter(prepareto)
        if len(prepareto)-len(paretoR):
            print('Filtered',len(prepareto)-len(paretoR),'models:\n',[i for i in prepareto if i not in paretoR])
        for i in range(len(paretoR)):
            with open('data/infer/infer_'+paretoR[i][1]+'.txt', 'r') as f: ea = json.load(f)
            paretoR[i][2] = round(sum(ea[0])/len(ea[0])*100,5)
        paretoL = paretoRtoL(paretoR)
        paretoLf = paretoLfull(paretoL)
        with open('data/pareto_'+cost+'.txt', 'w') as f: json.dump(paretoR,f,indent=2)
        with open('data/baseline_'+cost+'.txt', 'w') as f: json.dump(paretoLf,f,indent=2)


# creates Pareto from multiple cascades
def get_pareto(cascades):
    arr = np.concatenate(cascades)
    arr = arr[arr[:, 1].argsort()]  
    m = arr[0][0] # current maximum accuracy
    acc = [arr[0][0]]
    cost = [arr[0][1]]
    # go through all points ordered by cost from small to large
    for i in range(1,len(arr)):
        # if accuracy is new best, add to Pareto
        if arr[i][0] > m:
            m = arr[i][0]
            # fill gaps in discrete accuracy points
            if m > acc[-1]+0.00003:
                k = 0
                while m > acc[-1]+0.00003:
                    k += 1
                    acc.append(round(acc[-1]+0.00002,6))
                    cost.append(arr[i][1])
                print('Filled Pareto gap of length',k,'prior to accuracy',m)
            acc.append(arr[i][0])
            cost.append(arr[i][1])
    return np.array([acc, cost]).T


# creates correctness and condition values array for arithmetic mean ensembles
def arithmetic_mean(models, condition, path_infer, path_logits, labels):
    arr = []
    logits1 = torch.load(path_logits+models[0]+'.pt')
    logits2 = torch.load(path_logits+models[1]+'.pt')
    if condition in [10, 11]: # temperature scaling
        with open(path_infer+models[0]+'.txt', 'r') as f: ea1 = json.load(f)
        with open(path_infer+models[1]+'.txt', 'r') as f: ea2 = json.load(f)
        logits1 /= ea1[-1]
        logits2 /= ea2[-1]
    arr.append((torch.argmax(logits1, 1) == labels).tolist())
    if condition in [12,13]: # softmax margin confidence
        marginS = F.softmax(logits1, dim=1).topk(2,1)[0]
        arr.append((marginS[:,0]-marginS[:,1]).tolist())
    else: # maximum softmax confidence
        arr.append(torch.max(F.softmax(logits1, dim=1), 1)[0].tolist())
    if condition in [8, 10, 12]: # averaged logits
        arr.append((torch.argmax(logits1+logits2, 1) == labels).tolist())
    else: # averaged softmax
        arr.append((torch.argmax(F.softmax(logits1, dim=1)+F.softmax(logits2, dim=1), 1) == labels).tolist())
    return arr


# obtains Pareto front for all possible 2 model cascades
def cascade_pareto_bi(models, cost, condition, path_infer, path_logits, labels):
    # store cascades, the used models and their cost (MAC, time)
    cas, mod, cst = [], [], []
    n = len(models)
    for i in range(1,n):
        for j in range(n-i):
            # load correctness and condition values array for models
            if condition in [8, 9, 10, 11, 12, 13]:
                arr = np.array(arithmetic_mean([models[j],models[-i]], condition, path_infer, path_logits, labels)).T
            else:
                idx = condition+2 # maps condition to inference data index
                with open(path_infer+models[j]+'.txt', 'r') as f: ea1 = json.load(f)
                with open(path_infer+models[-i]+'.txt', 'r') as f: ea2 = json.load(f)
                arr = np.array([ea1[0], ea1[idx], ea2[0], ea2[idx]]).T
            l = arr.shape[0]
            # sort by model 1 threshold condition
            if condition in [0, 4]: # entropy threshold is reversed, smaller is better
                arr = arr[arr[:, 1].argsort()]
            else:
                arr = arr[arr[:, 1].argsort()[::-1]]
            if condition in [4, 5, 6, 7]: # comparison ensemble
                if condition == 4: # entropy is reversed
                    arr[:,2] = np.where(arr[:,1]<arr[:,3],arr[:,0],arr[:,2])
                else:
                    arr[:,2] = np.where(arr[:,1]>arr[:,3],arr[:,0],arr[:,2])
            # create cascade with accuracy and cost columns
            cas.append(np.concatenate(
                ((np.concatenate(([0], arr[:,0])).cumsum()[::-1]+np.concatenate(([0], arr[:,2][::-1])).cumsum())[...,None]/l,
                 np.linspace(cost[j], cost[j]+cost[-i], num=l+1)[...,None]), axis=1))
            mod.append([models[j],models[-i]]) 
            cst.append([cost[j],cost[-i]])
    pareto = get_pareto(cas)
    return pareto, [mod, cst, cas]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--skip', action='store_true', help='skip baseline creation, use when already done')
    parser.add_argument('-c','--conditions', type=int, nargs='+', help='condition indices, must be integers in interval [0,13] default: all')
    parser.add_argument('--cost', default='all', type=str, help='cost metric, can be \'mac\' or \'time\', default: \'all\'')
    args = parser.parse_args()
    
    # create baseline Pareto fronts
    if args.cost == 'all':
        cost = ['mac', 'time']
    else:
        cost = [args.cost]
    
    if not args.skip:
        baseline(cost)
    
    descriptions = ['entropy',
                    'max softmax',
                    'softmax margin',
                    'logits margin',
                    'temperature scaled entropy',
                    'temperature scaled max softmax',
                    'temperature scaled softmax margin',
                    'temperature scaled logits margin',
                    'averaged logits',
                    'averaged softmax',
                    'averaged temperature scaled logits',
                    'averaged temperature scaled softmax',
                    'softmax margin averaged logits',
                    'softmax margin averaged softmax',
                    ]
    filenames = ['entropy',
                 'softmax',
                 'softmax_margin',
                 'logits_margin',
                 'ts_entropy',
                 'ts_softmax',
                 'ts_softmax_margin',
                 'ts_logits_margin',
                 'mean_logits',
                 'mean_softmax',
                 'mean_ts_logits',
                 'mean_ts_softmax',
                 'mean_m_logits',
                 'mean_m_softmax',
                 ]
    path_infer = 'data/infer/infer_'
    path_logits = 'data/logits/logits_'
    with open('data/labels_ImageNet_val.txt', 'r') as f: labels = torch.tensor(json.load(f))
    
    # compute Pareto fronts
    for x in cost:
        print(f'Computing accuracy-{x} Pareto fronts.')
        if not os.path.exists('data/'+x):
            os.makedirs('data/'+x)
        with open('data/pareto_'+x+'.txt', 'r') as f: pareto = json.load(f)
        models = [i[1] for i in pareto]
        costs = [i[3] for i in pareto]
        if args.conditions:
            conditions = args.conditions
        elif x == 'mac':
            conditions = [i for i in range(14)]
        else:
            conditions = [i for i in range(12)]
        if x == 'mac':
            path = 'mac/'
        elif x == 'time':
            path = 'time/time_'
        for i in conditions:
            print('Computing', descriptions[i], 'Pareto front.')
            pareto, cascades = cascade_pareto_bi(models, costs, i, path_infer, path_logits, labels)
            np.save('data/'+path+'bi_pareto_'+filenames[i]+'.npy', pareto, allow_pickle=True)
            with open('data/'+path+'bi_cascades_'+filenames[i]+'.pkl', 'wb') as f: pickle.dump(cascades, f)


if __name__ == '__main__':
    main()