'''
Creates top-5 accuracy vs MAC 2-model Pareto fronts for appendix.
'''

import json
import os
import pickle

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

from inference import infer
from paretos import pareto_filter, paretoRtoL, paretoLfull, get_pareto
from prepare_dataframe import pareto_models


# obtain nested list with model inference data modified for top-5 experiment
def inference_data(model, path_logits, path_infer):
    logits = torch.load(path_logits+model+'.pt')
    with open('data/labels_ImageNet_val.txt', 'r') as f: labels = torch.tensor(json.load(f))
    predicted = logits.topk(5, 1)[1]
    ea = [[labels[i] in predicted[i] for i in range(len(labels))], # correctness
          (-(F.softmax(logits, dim=1)*F.log_softmax(logits, dim=1)).sum(dim=1)).tolist(), # entropy
          torch.max(F.softmax(logits, dim=1), 1)[0].tolist(), # top-1 softmax
          F.softmax(logits, dim=1).topk(5, 1)[0].sum(dim=1).tolist()] # sum top-5 softmax
    with open(path_infer+model+'.txt', 'w') as f: json.dump(ea,f,indent=2)
    print(f'{model} validation top-5: {100*sum(ea[0])/len(labels)}')


# creates correctness and condition values array for top-5 mean ensemble
def arithmetic_mean(models, condition, path_infer, path_logits, labels):
    logits1 = torch.load(path_logits+models[0]+'.pt')
    logits2 = torch.load(path_logits+models[1]+'.pt')
    with open(path_infer+models[0]+'.txt', 'r') as f: ea1 = json.load(f)
    arr = [ea1[0], ea1[3]]
    if condition in [4]: # mean logits
        predicted = (logits1+logits2).topk(5, 1)[1]
        arr.append([labels[i] in predicted[i] for i in range(len(labels))])
    else: # mean softmax
        predicted = (F.softmax(logits1, dim=1)+F.softmax(logits2, dim=1)).topk(5, 1)[1]
        arr.append([labels[i] in predicted[i] for i in range(len(labels))])
    return arr


# obtains top-5 accuracy Pareto front for all possible 2 model cascades
def cascade_pareto_bi(models, cost, condition, path_infer, path_logits, labels):
    '''
    conditions are selected by their index in the inference data list

    (0 is prediction correctness)
    1   entropy
    2   max softmax
    3   top-5 softmax sum
    4   mean logits with top-5 softmax sum
    5   mean softmax with top-5 softmax sum
    '''
    cas, mod, cst = [], [], []
    n = len(models)
    for i in range(1,n):
        for j in range(n-i):
            # load boolean correctness and condition values for models
            if condition in [4, 5]:
                arr = np.array(arithmetic_mean([models[j],models[-i]], condition, path_infer, path_logits, labels)).T
            else:
                with open(path_infer+models[j]+'.txt', 'r') as f: ea1 = json.load(f)
                with open(path_infer+models[-i]+'.txt', 'r') as f: ea2 = json.load(f)
                arr = np.array([ea1[0], ea1[condition], ea2[0]]).T
            l = arr.shape[0]
            # sort by model 1 threshold condition
            if condition in [1]: # entropy threshold is reversed
                arr = arr[arr[:, 1].argsort()]
            else:
                arr = arr[arr[:, 1].argsort()[::-1]]
            cas.append(np.concatenate(
                ((np.concatenate(([0], arr[:,0])).cumsum()[::-1]+np.concatenate(([0], arr[:,2][::-1])).cumsum())[...,None]/l,
                 np.linspace(cost[j], cost[j]+cost[-i], num=l+1)[...,None]), axis=1))
            mod.append([models[j],models[-i]]) 
            cst.append([cost[j],cost[-i]])
    pareto = get_pareto(cas)
    return pareto, [mod, cst, cas]


def main():
    if not os.path.exists('data/top5'):
        os.makedirs('data/top5')
    
    df_timm = pd.read_pickle('data/df_timm.pkl')
    with open('data/models_all.txt', 'r') as f: models_all = json.load(f)
    
    # obtain Pareto optimal models
    prepareto_top5 = pareto_models(df_timm, cost = 'mac', metric = 'top5')
    pareto_top5 = pareto_filter(prepareto_top5)
    if len(prepareto_top5)-len(pareto_top5):
        print('Filtered',len(prepareto_top5)-len(pareto_top5),'models:\n',[i for i in prepareto_top5 if i not in pareto_top5])
    
    # obtain inference data 
    models = [i[1] for i in pareto_top5]
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    for i in models:
        if not os.path.exists('data/logits/logits_'+i+'.pt'):
            infer(i, 'data/logits/logits_', device, process = False)
        inference_data(i, 'data/logits/logits_', 'data/top5/infer_top5_')
    
    # update top-5 accuracies with true values
    for i in range(len(pareto_top5)):
        with open('data/top5/infer_top5_'+pareto_top5[i][1]+'.txt', 'r') as f: ea = json.load(f)
        pareto_top5[i][2] = round(sum(ea[0])/len(ea[0])*100,5)
    with open('data/pareto_top5.txt', 'w') as f: json.dump(pareto_top5,f,indent=2)
    
    # create baseline Pareto
    pareto_top5L = paretoRtoL(pareto_top5)
    pareto_top5Lf = paretoLfull(pareto_top5L)
    with open('data/baseline_top5.txt', 'w') as f: json.dump(pareto_top5Lf,f,indent=2)
    
    path_infer = 'data/top5/infer_top5_'
    path_logits = 'data/logits/logits_'
    with open('data/labels_ImageNet_val.txt', 'r') as f: labels = torch.tensor(json.load(f))
    models = [i[1] for i in pareto_top5]
    macs = [i[3] for i in pareto_top5]
    
    print('Computing entropy Pareto.')
    pareto, cascades = cascade_pareto_bi(models, macs,1, path_infer, path_logits, labels)
    np.save('data/top5/top5_pareto_entropy.npy', pareto, allow_pickle=True)
    with open('data/top5/top5_cascades_entropy.pkl', 'wb') as f: pickle.dump(cascades, f)
    
    print('Computing max softmax Pareto.')
    pareto, cascades = cascade_pareto_bi(models, macs,2, path_infer, path_logits, labels)
    np.save('data/top5/top5_pareto_softmax.npy', pareto, allow_pickle=True)
    with open('data/top5/top5_cascades_softmax.pkl', 'wb') as f: pickle.dump(cascades, f)
    
    print('Computing top-5 softmax sum Pareto.')
    pareto, cascades = cascade_pareto_bi(models, macs,3, path_infer, path_logits, labels)
    np.save('data/top5/top5_pareto_softmax_sum.npy', pareto, allow_pickle=True)
    with open('data/top5/top5_cascades_softmax_sum.pkl', 'wb') as f: pickle.dump(cascades, f)
    
    print('Computing mean logits top-5 softmax sum Pareto.')
    pareto, cascades = cascade_pareto_bi(models, macs,4, path_infer, path_logits, labels)
    np.save('data/top5/top5_pareto_softmax_sum_mean_logits.npy', pareto, allow_pickle=True)
    with open('data/top5/top5_cascades_softmax_sum_mean_logits.pkl', 'wb') as f: pickle.dump(cascades, f)


if __name__ == '__main__':
    main()