'''
Prepares dataframe with model information.
'''

import json
import os

from fvcore.nn import FlopCountAnalysis, parameter_count
import pandas as pd
import timm
import torch


# url to current results-imagenet.csv from timm
# url = 'https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv'
# permament url to version used in paper
url = 'https://raw.githubusercontent.com/rwightman/pytorch-image-models/c45c6ee8e406861898acccb1818f91d0c9240e48/results/results-imagenet.csv'

# url to inference speed benchmark from timm
# url_infer = 'https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/benchmark-infer-amp-nhwc-pt112-cu113-rtx3090.csv'
# permanent url to version used in paper
url_infer = 'https://raw.githubusercontent.com/rwightman/pytorch-image-models/c45c6ee8e406861898acccb1818f91d0c9240e48/results/benchmark-infer-amp-nhwc-pt112-cu113-rtx3090.csv'


# returns list of Pareto optimal models as [index, name, metric, cost]
def pareto_models(df, cost = 'mac', metric = 'top1'):
    dfcost = df[cost].tolist()
    dfmetric = df[metric].tolist()
    dfname = df['model'].tolist()
    val = dfcost[0]
    pareto = [[0, dfname[0], dfmetric[0], val]]
    for i in range(1, len(df)):
        if dfcost[i]<val:
            val = dfcost[i]
            pareto.append([i, dfname[i], dfmetric[i], val])
    return pareto[::-1]


# filters possible model duplicates and bad outliers from Pareto model list
def pareto_filter(pareto, factor = 2):
    return [pareto[i][:] for i in range(len(pareto)-1) if 
            (pareto[i+1][2]/pareto[i][2]-1) < factor*(pareto[i+1][3]/pareto[i][3]-1)]+[pareto[-1]]


def main():
    # create data folder if it does not exist
    if not os.path.exists('data'):
        os.makedirs('data')
    
	# load and merge dataframes from timm repository
    df_timm_csv = pd.read_csv(url)
    df_timm_infer = pd.read_csv(url_infer)
    df_timm_csv = df_timm_csv.merge(df_timm_infer, on='model', how='left')
    
    # create new dataframe
    df_timm = df_timm_csv[['model','top1', 'top5']].copy()
    df_timm[['mac','param']] = 0
    df_timm[['infpersec']] = df_timm_csv[['infer_samples_per_sec']].copy()
    df_timm[['secperinf']] = df_timm_csv[['infer_samples_per_sec']].copy().applymap(lambda x: 1/x)
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # obtains parameters and MAC for every model in dataframe
    for n in df_timm['model']:
        model = timm.create_model(n).to(device).eval()
        df_timm.loc[df_timm['model'] == n,'param'] = parameter_count(model)['']
        config = model.default_cfg
        # print(config) # for printing the model configuration
        if 'test_input_size' in config:
            input_size = config['test_input_size']
            print('Using test input size',input_size)
        else: input_size = config['input_size']
        flops = FlopCountAnalysis(model, torch.rand(input_size)[None].to(device))
        try:
            df_timm.loc[df_timm['model'] == n,'mac'] = flops.total()
        except:
            if device.type == 'cpu':
                print('Error for model',n)
            else:
                try:
                    print('Error: Possible GPU memory issue, retry on CPU.')
                    model = timm.create_model(n).eval()
                    flops = FlopCountAnalysis(model, torch.rand(input_size)[None])
                    df_timm.loc[df_timm['model'] == n,'mac'] = flops.total()
                except:
                    print('Error for model',n)
        print('\033[38;2;255;0;0m',df_timm.loc[df_timm['model'] == n],'\033[0m\n')
        
    # removes models without MAC value if present
    if len(df_timm.loc[df_timm['mac'] == 0]) > 1:
        print('Failed to obtain MAC for the following models:\n',
              df_timm.loc[df_timm['mac'] == 0],sep='')
        df_timm = df_timm.loc[df_timm['mac'] > 0].reset_index(drop=True)
    
    # saves dataframe
    df_timm.to_pickle('data/df_timm.pkl')
    
    # obtain and save preliminary Pareto model lists for mac and inference time
    prepareto_mac = pareto_models(df_timm, 'mac')
    prepareto_time = pareto_models(df_timm, 'secperinf')
    
    # save as json for human readability
    with open('data/prepareto_mac.txt', 'w') as f: json.dump(prepareto_mac,f,indent=2)
    with open('data/prepareto_time.txt', 'w') as f: json.dump(prepareto_time,f,indent=2)
    
    # obtain lists of models used for cascades which need validation set inference
    models_mac = [i[1] for i in pareto_filter(prepareto_mac)]
    models_time = [i[1] for i in pareto_filter(prepareto_time)]
    models_all = models_mac + list(set(models_time) - set(models_mac))
    
    with open('data/models_mac.txt', 'w') as f: json.dump(models_mac,f,indent=2)
    with open('data/models_time.txt', 'w') as f: json.dump(models_time,f,indent=2)
    with open('data/models_all.txt', 'w') as f: json.dump(models_all,f,indent=2)


if __name__ == '__main__':
    main()