'''
Creates Pareto fronts for 3 model cascades.
Requires baselines created by paretos.py.

Only one condition can be chosen at a time because execution can take days.
This can be parallelized by running the script multiple times at once.
Conditions are selected by integer according to the following table:
0   entropy
1   max softmax
2   softmax margin
3   logits margin
4   temperature scaled entropy
5   temperature scaled max softmax
6   temperature scaled softmax margin
7   temperature scaled logits margin
8   arithmetic mean logits with max softmax confidence
9   arithmetic mean softmax with max softmax confidence
10  arithmetic mean TS logits with max softmax confidence
11  arithmetic mean TS softmax with max softmax confidence
'''

import argparse
import datetime
from itertools import combinations
import json
from math import comb

import numpy as np
import pickle
import torch
import torch.nn.functional as F


# creates Pareto from multiple cascades
def get_pareto(cascades):
    arr = np.concatenate(cascades)
    arr = arr[arr[:, 1].argsort()]  
    m = arr[0][0] # current maximum accuracy
    acc = [arr[0][0]]
    cost = [arr[0][1]]
    # go through all points ordered by cost from small to large
    for i in range(1,len(arr)):
        # if accuracy is new best, add to Pareto
        if arr[i][0] > m:
            m = arr[i][0]
            # fill gaps in discrete accuracy points
            if m > acc[-1]+0.00003:
                k = 0
                while m > acc[-1]+0.00003:
                    k += 1
                    acc.append(round(acc[-1]+0.00002,6))
                    cost.append(arr[i][1])
                print('Filled Pareto gap of length',k,'prior to accuracy',m)
            acc.append(arr[i][0])
            cost.append(arr[i][1])
    return np.array([acc, cost]).T


# creates boolean correctness and condition values array for mean ensembles
def arithmetic_mean(models, condition, path_infer, path_logits, labels):
    arr = []
    logits1 = torch.load(path_logits+models[0]+'.pt')
    logits2 = torch.load(path_logits+models[1]+'.pt')
    logits3 = torch.load(path_logits+models[2]+'.pt')
    if condition in [10, 11]: # temperature scaling
        with open(path_infer+models[0]+'.txt', 'r') as f: ea1 = json.load(f)
        with open(path_infer+models[1]+'.txt', 'r') as f: ea2 = json.load(f)
        with open(path_infer+models[2]+'.txt', 'r') as f: ea3 = json.load(f)
        logits1 /= ea1[-1]
        logits2 /= ea2[-1]
        logits3 /= ea3[-1]
    arr.append((torch.argmax(logits1, 1) == labels).tolist())
    arr.append(torch.max(F.softmax(logits1, dim=1), 1)[0].tolist())
    if condition in [8, 10]: # averaged logits
        arr.append((torch.argmax(logits1+logits2, 1) == labels).tolist())
        arr.append(torch.max(F.softmax((logits1+logits2)/2, dim=1), 1)[0].tolist())
        arr.append((torch.argmax(logits1+logits2+logits3, 1) == labels).tolist())
    else: # averaged softmax
        arr.append((torch.argmax(F.softmax(logits1, dim=1)+F.softmax(logits2, dim=1), 1) == labels).tolist())
        arr.append(torch.max((F.softmax(logits1, dim=1)+F.softmax(logits2, dim=1))/2, 1)[0].tolist())
        arr.append((torch.argmax(F.softmax(logits1, dim=1)+F.softmax(logits2, dim=1)+F.softmax(logits3, dim=1), 1) == labels).tolist())
    return arr


# obtains Pareto front for 3-model cascade
def cascade_pareto_tri(models, cost, condition, path_infer, path_logits, labels):
    cas = []
    # load correctness and condition values array for models
    if condition in [8, 9, 10, 11]:
        arr = np.array(arithmetic_mean(models, condition, path_infer, path_logits, labels)).T
    else:
        idx = condition+2 # maps condition to inference data index
        with open(path_infer+models[0]+'.txt', 'r') as f: ea1 = json.load(f)
        with open(path_infer+models[1]+'.txt', 'r') as f: ea2 = json.load(f)
        with open(path_infer+models[2]+'.txt', 'r') as f: ea3 = json.load(f)
        arr = np.array([ea1[0], ea1[idx], ea2[0], ea2[idx], ea3[0], ea3[idx]]).T
    l = arr.shape[0]
    # sort by model 1 threshold condition
    if condition in [0, 4]: # entropy threshold is reversed, smaller is better
        arr = arr[arr[:, 1].argsort()]
    else:
        arr = arr[arr[:, 1].argsort()[::-1]]
    if condition in [4, 5, 6, 7]: # comparison ensemble
        if condition == 4: # entropy is reversed
            arr[:,2] = np.where(arr[:,1]<arr[:,3],arr[:,0],arr[:,2])
            arr[:,4] = np.where(np.logical_or(arr[:,1]<arr[:,5], arr[:,3]<arr[:,5]),arr[:,2],arr[:,4])
        else:
            arr[:,2] = np.where(arr[:,1]>arr[:,3],arr[:,0],arr[:,2])
            arr[:,4] = np.where(np.logical_or(arr[:,1]>arr[:,5], arr[:,3]>arr[:,5]),arr[:,2],arr[:,4])
    # add 2-model cascade as lower bound
    cas.append(np.concatenate(
        ((np.concatenate(([0], arr[:,0])).cumsum()[::-1]+np.concatenate(([0], arr[:,2][::-1])).cumsum())[...,None]/l,
         np.linspace(cost[0], cost[0]+cost[1], num=l+1)[...,None]), axis=1))
    # sweep first  threshold of 3-model cascade
    for i in range(0,l,10): # 10 is the step size, adjust this to influence execution speed
        if condition in [0, 4]: # entropy is reversed
            arr2 = arr[i:,2:][arr[i:,3].argsort()]
        else:
            arr2 = arr[i:,2:][arr[i:,3].argsort()[::-1]]
        cas.append(np.concatenate(((np.concatenate(
            ([0], arr2[:,0])).cumsum()[::-1]+np.concatenate(([0], arr2[:,2][::-1])).cumsum()+arr[:i,0].sum())[...,None]/l,
            np.linspace(cost[0]+cost[1]*(1-i/l), cost[0]+(cost[1]+cost[2])*(1-i/l), num=l+1-i)[...,None]), axis=1))
    pareto = get_pareto(cas)
    return pareto


# obtains 3-model cascade Pareto fronts via exhaustive search of all combinations
def tri_pareto(models, cost, condition, path_infer, path_logits, labels, verbose):
    total = comb(len(models),3)
    counter = 0
    cas, mod, mac = [], [], []
    for x in list(combinations([i for i in range(len(models))], 3)):
        cascade = cascade_pareto_tri([models[x[0]],models[x[1]],models[x[2]]], 
                                     [cost[x[0]],cost[x[1]],cost[x[2]]],
                                     condition, path_infer, path_logits, labels)
        cas.append(cascade)
        mod.append([models[x[0]],models[x[1]],models[x[2]]]) 
        mac.append([cost[x[0]],cost[x[1]],cost[x[2]]])
        counter += 1
        if verbose:
            print(datetime.datetime.now().strftime('%H:%M:%S.%f'),counter,'of',total,'combination ',mod[-1],'done!')
        # if counter %100 == 0:
        #     with open('data/tri_backup'+str(condition)+'.pkl', 'wb') as f: pickle.dump([mod, mac, cas], f)
        #     print('Saved backup at',counter,'of',total,'combinations!')
    pareto = get_pareto(cas)
    return pareto, [mod, mac, cas]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c','--condition', type=int, help='condition index, needs to be set')
    parser.add_argument('-m', '--verbose', action='store_false', help='use flag to mute verbosity')
    args = parser.parse_args()
    
    if args.condition not in [i for i in range(12)]:
        raise ValueError('Condition needs to be set via \'-c\' flag as integer in interval [0,11].'
                         +' Check file documentation for description of available conditions')
    
    descriptions = ['entropy',
                    'max softmax',
                    'softmax margin',
                    'logits margin',
                    'temperature scaled entropy',
                    'temperature scaled max softmax',
                    'temperature scaled softmax margin',
                    'temperature scaled logits margin',
                    'averaged logits',
                    'averaged softmax',
                    'averaged temperature scaled logits',
                    'averaged temperature scaled softmax'
                    ]
    filenames = ['entropy',
                 'softmax',
                 'softmax_margin',
                 'logits_margin',
                 'ts_entropy',
                 'ts_softmax',
                 'ts_softmax_margin',
                 'ts_logits_margin',
                 'mean_logits',
                 'mean_softmax',
                 'mean_ts_logits',
                 'mean_ts_softmax'
                 ]
    path_infer = 'data/infer/infer_'
    path_logits = 'data/logits/logits_'
    with open('data/labels_ImageNet_val.txt', 'r') as f: labels = torch.tensor(json.load(f))
    with open('data/pareto_mac.txt', 'r') as f: pareto_mac = json.load(f)
    models = [i[1] for i in pareto_mac]
    macs = [i[3] for i in pareto_mac]
    print('Computing', descriptions[args.condition], 'Pareto front.')
    pareto, cascades = tri_pareto(models, macs, args.condition, path_infer, path_logits, labels, args.verbose)
    np.save('data/mac/tri_pareto_'+filenames[args.condition]+'.npy', pareto, allow_pickle=True)
    with open('data/mac/tri_cascades_'+filenames[args.condition]+'.pkl', 'wb') as f: pickle.dump(cascades, f)


if __name__ == '__main__':
    main()