'''
Distribution shift robustness test on ImaveNetV2.

Obtains the distribution shifted softmax cascade Pareto front.
'''

import argparse
import json
import math
import os
import pickle

import numpy as np
import pandas as pd
import timm
import torch
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms

from config import path_imagenet_v2, batch_size, workers
from paretos import cascade_pareto_bi, get_pareto


# inference function which saves logits
@torch.no_grad()
def infer(model, path_logits, device = 0):
    # load model and create transform
    net = timm.create_model(model, pretrained=True).to(device).eval()
    config = net.default_cfg
    if 'test_input_size' in config:
        input_size = config['test_input_size']
        print('Using test input size',input_size)
    else: input_size = config['input_size']
    if input_size[-1] != input_size[-2]: print('INSIZE ALERT:',input_size)
    if config['interpolation'] == 'bicubic':
        interpolation = transforms.InterpolationMode.BICUBIC
    else: interpolation = transforms.InterpolationMode.BILINEAR
    tf = transforms.Compose(
        [transforms.Resize(int(math.floor(input_size[-1] / config['crop_pct'])), interpolation=interpolation),
         transforms.CenterCrop(input_size[-1]),
         transforms.ToTensor(),
         transforms.Normalize(config['mean'], config['std'])
         ])
    print('Starting model',model,'with transform:\n',tf)
    # create dataloader
    dataset = datasets.ImageFolder(root=path_imagenet_v2, transform=tf)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    logits = []
    # obtain and save logits
    for data in dataloader:
        images = data[0].to(device)
        logits.append(net(images).to('cpu'))
    logits = torch.cat(logits,0)
    torch.save(logits, path_logits+model+'.pt')
    # print accuracy
    labels = torch.tensor(dataset.targets)
    predicted = torch.argmax(logits, 1)
    correct = (predicted == labels).sum().item()
    l = len(labels)
    print(f'{model} validation accuracy: {100*correct/l} with {correct} of {l}')


# obtain nested list with model inference data
def inference_data(model, path_logits, path_infer):
    logits = torch.load(path_logits+model+'.pt')
    with open('data/labels_ImageNetV2.txt', 'r') as f: labels = torch.tensor(json.load(f))#.to(0)
    predicted = torch.argmax(logits, 1)
    marginS = F.softmax(logits, dim=1).topk(2,1)[0]
    marginL = logits.topk(2,1)[0]
    ea = [(predicted == labels).tolist(),
          predicted.tolist(),
          (-(F.softmax(logits, dim=1)*F.log_softmax(logits, dim=1)).sum(dim=1)).tolist(),
          torch.max(F.softmax(logits, dim=1), 1)[0].tolist(),
          (marginS[:,0]-marginS[:,1]).tolist(),
          (marginL[:,0]-marginL[:,1]).tolist()]
    with open(path_infer+model+'.txt', 'w') as f: json.dump(ea,f,indent=2)
    print(f'{model} validation accuracy: {100*sum(ea[0])/len(labels)}')


# converts rectangular to linear Pareto front
def paretoRtoL(pareto):
    l = len(pareto)
    m = [[(pareto[j][2]-pareto[i][2])/(pareto[j][3]-pareto[i][3]) for j in range(i+1,l)] for i in range(l-1)]
    imax = [i.index(max(i))+len(m)+1-len(i) for i in m]
    i = 0
    idx = [0]
    while i < len(imax):
        i = imax[i]
        idx.append(i)
    return [pareto[i] for i in idx]


# linear interpolation to create the baseline Pareto
def paretoLfull(pareto):
    acc = [round(i[2]/100,6) for i in pareto]
    mac = [i[3] for i in pareto]
    l = len(acc)
    ret = [[round(i,6) for i in np.linspace(acc[0],acc[-1],int(round((acc[-1]-acc[0])/0.0001+1))).tolist()],[]]
    for i in range(l-1):
        ret[1].extend(np.linspace(mac[i],mac[i+1],int(round((acc[i+1]-acc[i])/0.0001+1))).tolist()[:-1])
    ret[1].extend([mac[-1]])
    return ret


# obtain the indices of cascades at the Pareto front
def get_pareto_cascades(cascades):
    shape = (cascades[0].shape[0],1)
    arr = np.concatenate([np.hstack([cascades[i],np.full(shape,i)]) for i in range(len(cascades))])
    arr = arr[arr[:, 1].argsort()]  
    m = arr[0][0] # current maximum accuracy
    idx = []
    # go through all points ordered by cost from small to large
    for i in range(1,len(arr)):
        # if accuracy is new best, add to Pareto
        if arr[i][0] > m:
            m = arr[i][0]
            idx.append(int(arr[i][2]))
    return idx


# obtains a specific max softmax cascade with threshold values
def get_cascade(models, cost, path_infer):
    # load boolean correctness and condition values for models
    with open(path_infer+models[0]+'.txt', 'r') as f: ea1 = json.load(f)
    with open(path_infer+models[1]+'.txt', 'r') as f: ea2 = json.load(f)
    arr = np.array([ea1[0], ea1[3], ea2[0], ea2[3]]).T
    l = arr.shape[0]
    arr = arr[arr[:, 1].argsort()[::-1]] # sort based on threshold
    limit = 1 # upper limit for max softmax threshold
    return np.concatenate( # cumulative sums for accuracy, linspace for cost
        ((np.concatenate(([0], arr[:,0])).cumsum()[::-1]+np.concatenate(([0], arr[:,2][::-1])).cumsum())[...,None]/l,
         np.linspace(cost[0], cost[0]+cost[1], num=l+1)[...,None],
         np.concatenate(([limit], arr[:,1]))[::-1,None]), axis=1)


# get pareto front with thresholds
def get_pareto_thresh(cascades, indices, size):
    print(f'Creating Pareto for size {size} with {len(cascades)} cascades and {len(indices)} indices.')
    shape = (cascades[0].shape[0],1)
    arr = np.concatenate([np.hstack([cascades[i],np.full(shape,indices[i])]) for i in range(len(cascades))])
    arr = arr[arr[:, 1].argsort()]  
    m = arr[0][0] # current maximum accuracy
    acc = [arr[0][0]]
    cost = [arr[0][1]]
    thresh = [0]
    idx = [-1]
    # go through all points ordered by cost from small to large
    for i in range(1,len(arr)):
        # if accuracy is new best, add to Pareto
        if arr[i][0] > m:
            m = arr[i][0]
            # handle gaps
            if m > acc[-1]+1/size*1.5:
                acc.append(arr[i][0])
                cost.append(arr[i][1])
                thresh.append(0)
                idx.append(-1)           
            else:
                acc.append(arr[i][0])
                cost.append(arr[i][1])
                thresh.append(arr[i][2])
                idx.append(int(arr[i][3]))
    return np.array([acc, cost, thresh, idx]).T


# group to obtain a nested list of threshold ranges
def thresh_ranges(pareto,cascades,indices):
    ret = []
    idx = pareto[1,3]
    t1 = 0
    for i in range(2,len(pareto)):
        # this indicates a gap in the Pareto front
        # should be at most one -1 in a row for code to work
        if pareto[i,3] == -1:
            print('Skipping',i,pareto[i])
            t2 = pareto[i-1,2]
            ret.append([int(idx),t1,t2])
            idx = pareto[i+1,3]
            t1 = 0
        elif pareto[i,3] != idx:
            t2 = pareto[i-1,2]
            ret.append([int(idx),t1,t2])
            idx = pareto[i,3]
            a = indices.index(int(idx))
            b = np.nonzero(cascades[a][:,2]==pareto[i,2])[0][0]
            if b == 0:
                t1 = 0
            else:
                t1 = cascades[a][b-1,2]
    t2 = pareto[-1,2]
    ret.append([int(idx),t1,t2])
    return ret


# filter cascades based on threshold ranges in groups
def filter_cascades(cascades, groups, indices):
    ret = []
    # iterate through groups and append filtered cascades
    for idx, t1, t2 in groups:
        a = indices.index(idx)
        b = np.nonzero((t1<cascades[a][:,2]) & (cascades[a][:,2]<=t2))[0]
        ret.append(cascades[a][b])
    return ret


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--force', action='store_true', help='force when file is already found')
    parser.add_argument('-l', '--labels', action='store_true', help='save dataset labels even when found')
    parser.add_argument('-s', '--skip', action='store_true', help='skip inference, use when logits already exist')
    args = parser.parse_args()
    
    # get models and costs
    with open('data/pareto_mac.txt', 'r') as f: pareto_mac = json.load(f)
    models = [i[1] for i in pareto_mac]
    costs = [i[3] for i in pareto_mac]
    
    # save ImageNetV2 labels for later usage
    if not os.path.exists('data/labels_ImageNetV2.txt') or args.labels:
        dataset = datasets.ImageFolder(root=path_imagenet_v2)
        with open('data/labels_ImageNetV2.txt', 'w') as f: json.dump(dataset.targets,f,indent=2)
    
    # infer models on ImageNetV2
    if not args.skip:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for i in models:
            if not args.force and os.path.exists('data/logits/v2_logits_'+i+'.pt'):
                print('Skipping model', i, 'because logits already exist.')
                continue
            infer(i, 'data/logits/v2_logits_', device)
    
    # compute inference data
    for i in models:
        if not args.force and os.path.exists('data/infer/v2_infer_'+i+'.txt'):
            print('Skipping model', i, 'because infer data already exist.')
            continue
        inference_data(i, 'data/logits/v2_logits_', 'data/infer/v2_infer_')
    
    # update Pareto models with ImageNetV2 accuracy
    with open('data/pareto_mac.txt', 'r') as f: v2_pareto_mac = json.load(f)
    
    for i in range(len(v2_pareto_mac)):
        with open('data/infer/v2_infer_'+v2_pareto_mac[i][1]+'.txt', 'r') as f: ea = json.load(f)
        v2_pareto_mac[i][2] = round(sum(ea[0])/len(ea[0])*100,5)
    
    with open('data/v2_pareto_mac.txt', 'w') as f: json.dump(v2_pareto_mac,f,indent=2)
    
    # create baseline Pareto front
    v2_paretoL = paretoRtoL(v2_pareto_mac)[:-1]
    v2_baseline_mac = paretoLfull(v2_paretoL)
    with open('data/v2_baseline_mac.txt', 'w') as f: json.dump(v2_baseline_mac,f,indent=2)
    
    # create cascade Pareto front
    path_infer = 'data/infer/v2_infer_'
    path_logits = 'data/logits/v2_logits_'
    with open('data/labels_ImageNetV2.txt', 'r') as f: labels = torch.tensor(json.load(f))
    
    v2_pareto, v2_cascades = cascade_pareto_bi(models, costs, 1, path_infer, path_logits, labels)
    np.save('data/mac/v2_bi_pareto_softmax.npy', v2_pareto, allow_pickle=True)
    with open('data/mac/v2_bi_cascades_softmax.pkl', 'wb') as f: pickle.dump(v2_cascades, f)
    
    # obtain cascades that are Pareto optimal for max softmax on ImageNet ...
    with open('data/mac/bi_cascades_softmax.pkl', 'rb') as f: cascades = pickle.load(f)
    pareto_cascade_indices = sorted(set(get_pareto_cascades(cascades[2])))
    
    # ... for ImageNet
    pareto_cascades = []
    for i in pareto_cascade_indices:
        pareto_cascades.append(get_cascade(cascades[0][i], cascades[1][i], 'data/infer/infer_'))
    
    # ... for ImageNetV2
    pareto_cascades_v2 = []
    for i in pareto_cascade_indices:
        pareto_cascades_v2.append(get_cascade(v2_cascades[0][i], v2_cascades[1][i], 'data/infer/v2_infer_'))
    
    # construct Pareto with thresholds
    pareto_thresh = get_pareto_thresh(pareto_cascades, pareto_cascade_indices, 50000)
    
    # obtain the Pareto optimal cascade threshold ranges for ImageNet
    groups = thresh_ranges(pareto_thresh, pareto_cascades, pareto_cascade_indices)
    
    # filter ImageNetV2 cascades based on optimal ImageNet threshold ranges
    cascades_filtered_v2 = filter_cascades(pareto_cascades_v2, groups, pareto_cascade_indices)
    with open('data/mac/v2_bi_cascades_softmax_filtered.pkl', 'wb') as f: pickle.dump(cascades_filtered_v2, f)
    
    # obtain ImageNetV2 Pareto front of filtered cascades
    v2_pareto_filtered = get_pareto(cascades_filtered_v2, 10000)
    np.save('data/mac/v2_bi_pareto_softmax_filtered.npy', v2_pareto_filtered, allow_pickle=True)


if __name__ == '__main__':
    main()