'''
Infer validation set on Pareto optimal models to obtain logits and 
inference data, from which cascades can be constructed.

The inference data is a nested list containing:
[prediction correctness for every input,
 predicted label,
 entropy,
 max softmax,
 softmax margin,
 logits margin,
 temperature scaled entropy,
 temperature scaled max softmax,
 temperature scaled softmax margin,
 temperature scaled logits margin,
 temperature]

Temperature scaling is done on CPU to be more accurate.
'''

import argparse
import json
import math
import os

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from config import path_imagenet, batch_size, workers


# inference function which saves logits
@torch.no_grad()
def infer(model, path_logits, device = 0, process = True):
    # load model and create transform
    net = timm.create_model(model, pretrained=True).to(device).eval()
    config = net.default_cfg
    if 'test_input_size' in config:
        input_size = config['test_input_size']
        print('Using test input size',input_size)
    else: input_size = config['input_size']
    if config['interpolation'] == 'bicubic':
        interpolation = transforms.InterpolationMode.BICUBIC
    else: interpolation = transforms.InterpolationMode.BILINEAR
    tf = transforms.Compose(
        [transforms.Resize(int(math.floor(input_size[-1] / config['crop_pct'])), interpolation=interpolation),
         transforms.CenterCrop(input_size[-1]),
         transforms.ToTensor(),
         transforms.Normalize(config['mean'], config['std'])
         ])
    print('Starting model',model,'with transform:\n',tf)
    # create dataloader
    dataset = datasets.ImageNet(
        root=path_imagenet, split='val', transform=tf)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
    # obtain and save logits
    logits = []
    for data in dataloader:
        images = data[0].to(device)
        logits.append(net(images).to('cpu'))
    logits = torch.cat(logits,0)
    torch.save(logits, path_logits+model+'.pt')
    if process:
        inference_data(model, logits, 'data/infer/infer_')
    # labels = torch.tensor(dataset.targets)
    # correct = (torch.argmax(logits, 1) == torch.tensor(dataset.targets)).sum().item()
    # print(f'{model} validation accuracy: {100*correct/l} with {correct} of {len(logits)}')


# obtain nested list with model inference data
def inference_data(model, logits, path_infer):
    # loads logits if path is given
    if type(logits) == str:
        logits = torch.load(logits+model+'.pt')
    with open('data/labels_ImageNet_val.txt', 'r') as f: labels = torch.tensor(json.load(f))
    # temperature scaling, end early if converged
    temperature = nn.Parameter(torch.ones(1))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.LBFGS([temperature], lr=0.001, max_iter=10000, line_search_fn='strong_wolfe')
    last = 0.
    def closure():
      loss = criterion(torch.div(logits, temperature), labels)
      loss.backward()
      return loss
    for _ in range(50):
        optimizer.zero_grad()
        optimizer.step(closure)
        if temperature.item() == last:
            print(model, 'ending temperature optimization early at iteration', _+1)
            break
        else:
            last = temperature.item()     
    # obtain inference data
    logitsT = torch.div(logits, temperature.item())
    predicted = torch.argmax(logits, 1)
    marginS = F.softmax(logits, dim=1).topk(2,1)[0]
    marginL = logits.topk(2,1)[0]
    marginST = F.softmax(logitsT, dim=1).topk(2,1)[0]
    marginLT = logitsT.topk(2,1)[0]
    ea = [(predicted == labels).tolist(),
          predicted.tolist(),
          (-(F.softmax(logits, dim=1)*F.log_softmax(logits, dim=1)).sum(dim=1)).tolist(),
          torch.max(F.softmax(logits, dim=1), 1)[0].tolist(),
          (marginS[:,0]-marginS[:,1]).tolist(),
          (marginL[:,0]-marginL[:,1]).tolist(), 
          (-(F.softmax(logitsT, dim=1)*F.log_softmax(logitsT, dim=1)).sum(dim=1)).tolist(),
          torch.max(F.softmax(logitsT, dim=1), 1)[0].tolist(),
          (marginST[:,0]-marginST[:,1]).tolist(),
          (marginLT[:,0]-marginLT[:,1]).tolist(),
          temperature.item()]
    correct = (predicted == labels).sum().item()
    with open(path_infer+model+'.txt', 'w') as f: json.dump(ea,f,indent=2)
    print(f'{model} validation accuracy: {100*correct/len(labels)} Temperature: {ea[-1]}')


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--force', action='store_true', help='force when file is already found')
    parser.add_argument('-l', '--labels', action='store_true', help='save dataset labels even when found')
    parser.add_argument('-s', '--skip', action='store_true', help='skip inference, use when logits already exist')
    parser.add_argument('-n', '--process', action='store_false', help='infer without computing inference data')
    parser.add_argument('-m', '--models', default='all', type=str, help='models for inference, can be \'mac\' or \'time\', default: \'all\'')
    args = parser.parse_args()
    
    # save validation set labels for later usage
    if not os.path.exists('data/labels_ImageNet_val.txt') or args.labels:
        dataset = datasets.ImageNet(root='/scratch/datasets/ilsvrc12/', split='val')
        with open('data/labels_ImageNet_val.txt', 'w') as f: json.dump(dataset.targets,f,indent=2)
    
    if not os.path.exists('data/logits'):
        os.makedirs('data/logits')
    if not os.path.exists('data/infer'):
        os.makedirs('data/infer')
    
    # load list of Pareto optimal model names
    with open('data/models_'+args.models+'.txt', 'r') as f: models = json.load(f)
    
    # infer
    if not args.skip:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        for i in models:
            if not args.force:
                if os.path.exists('data/logits/logits_'+i+'.pt'):
                    print('Skipping model', i, 'because logits already exist.')
                    continue
            infer(i, 'data/logits/logits_', device, args.process)
    # compute inference data without inference
    else:
        for i in models:
            if not args.force:
                if os.path.exists('data/logits/logits_'+i+'.pt'):
                    print('Skipping model', i, 'because infer data already exist.')
                    continue
            inference_data(i, 'data/logits/logits_', 'data/infer/infer_')


if __name__ == '__main__':
    main()