'''
Plot calibration of pretrained models on ImageNet.
'''

import json
import os

import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

from figuresettings import font, lw, ms, fstitle, fsaxes, fslabels


# plot model calibration
path = '../data/infer/'
names = sorted(f for f in os.listdir(path) if f.startswith('infer_'))

width = 1000 # width of sliding window / moving average


plt.figure(figsize=(20, 27), dpi=200)
plt.rc('font', family=font)

for idx in range(len(names)):
    # load inference data
    with open(path+names[idx], 'r') as f: infer = json.load(f)
    arr = np.array([infer[0], infer[3]]).T
    arr = arr[arr[:, 1].argsort()]

    # compute moving average
    avg_acc = [arr[:width,0].sum()]
    avg_conf = [arr[:width,1].sum()]
    for i in range(width,50000):
        avg_acc.append(avg_acc[-1] + arr[i,0] - arr[i-width,0])
        avg_conf.append(avg_conf[-1] + arr[i,1] - arr[i-width,1])
    avg_acc = np.array(avg_acc) / width
    avg_conf = np.array(avg_conf) / width

    # plot
    plt.subplot(9,5,idx+(2 if idx>=40 else 1))

    plt.plot(avg_conf, avg_acc, '-', color='C0', lw=lw, label=f'Moving Average of {width}')
    plt.plot([0,1],[0,1], '--', color='C3', lw=lw, label='Expected')

    plt.xlim((0,1))
    plt.ylim((0,1))
    plt.xticks(fontsize=fslabels)
    plt.yticks(fontsize=fslabels)
    plt.xlabel('Maximum Softmax',fontsize=fsaxes)
    plt.ylabel('Accuracy',fontsize=fsaxes)
    plt.title(names[idx][6:-4],fontsize=fstitle)
    plt.legend(loc='lower right', fontsize=fslabels, framealpha=0.5)
    plt.grid(lw=0.4)

plt.subplots_adjust(hspace=0.4)

plotname = 'figure_13_appendix_calibration_models'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)



# plot calibration analysis
path_infer = '../data/infer/infer_'
path_logits = '../data/logits/logits_'
models = ['mobilenetv3_large_100_miil', 'tf_efficientnet_b4_ns']
tests = [13519, 8031]
equiv_temp = [1.1594198942184448, 0.8725124597549438]
'''
equiv_temp is the temperature which scales the average maximum softmax
to be equal to the accuracy using the absolute loss:

loss = (F.softmax(logits/temperature, dim=1).max(1)[0].mean()-accuracy).abs()

Squared loss (replace '.abs()' with '.square()') was also tried but
minimally less accurate.

mobilenetv3_large_100_miil:
    accuracy: 0.77924
    mean max softmax: 0.8187013320521079
    equiv_temp[0] scaled mean max softmax: 0.7792400121688843

tf_efficientnet_b4_ns:
    accuracy: 0.85152
    mean max softmax: 0.7509132722109556
    equiv_temp[1] scaled mean max softmax: 0.8515222072601318
'''


n = 50000 # number of examples in validation set
d = 100 # delta (number of examples) before cumulative sums are plotted to remove unstable start
al = 0.8 # alpha


plt.figure(figsize=(7, 7), dpi=200)
plt.rc('font', family=font)

for idx in range(len(models)):
    plt.subplot(2,1,idx+1)
    with open(path_infer+models[idx]+'.txt', 'r') as f: infer = json.load(f)
    logits = torch.load(path_logits+models[idx]+'.pt')

    # uncalibrated
    arr = np.array([infer[0], infer[3]]).T
    arr = arr[arr[:, 1].argsort()]
    below = np.cumsum(arr[:, 0])
    above = np.cumsum(arr[:, 0][::-1])
    
    plt.plot(np.arange(1,n+1), arr[:, 1], '-C0', lw=lw, alpha=al)
    plt.plot(np.arange(1+d,n+1), (below/np.arange(1,n+1))[d:], '--C0', lw=lw, alpha=al, zorder=5)
    plt.plot(np.arange(n-d), (above/np.arange(1,n+1))[::-1][:-d], '-.C0', lw=lw, alpha=al, zorder=5)
    # test submission
    plt.plot(tests[idx],arr[tests[idx], 1], '+r', ms=ms, zorder=10)
    
    # temperature scaled
    arr = np.array([infer[0], F.softmax(logits/infer[-1], dim=1).max(1)[0].tolist()]).T
    arr = arr[arr[:, 1].argsort()]
    below = np.cumsum(arr[:, 0])
    above = np.cumsum(arr[:, 0][::-1])
    
    plt.plot(np.arange(1,n+1), arr[:, 1], '-C1', lw=lw, alpha=al)
    plt.plot(np.arange(1+d,n+1), (below/np.arange(1,n+1))[d:], '--C1', lw=lw, alpha=al, zorder=5)
    plt.plot(np.arange(n-d), (above/np.arange(1,n+1))[::-1][:-d], '-.C1', lw=lw, alpha=al, zorder=5)
    
    # assumed perfect calibration
    arr = np.sort(np.array([F.softmax(logits/equiv_temp[idx], dim=1).max(1)[0].tolist()]).T, 0)
    below = np.cumsum(arr[:, 0])
    above = np.cumsum(arr[:, 0][::-1])
    
    plt.plot(np.arange(1,n+1), arr, '-C2', lw=lw, alpha=al)
    plt.plot(np.arange(1+d,n+1), (below/np.arange(1,n+1))[d:], '--C2', lw=lw, alpha=al, zorder=5)
    plt.plot(np.arange(n-d), (above/np.arange(1,n+1))[::-1][:-d], '-.C2', lw=lw, alpha=al, zorder=5)
    
    # ideal
    n_true = sum(infer[0])
    n_false = n-n_true
    arr = np.array([[0.]*n_false + [1.]*n_true]).T
    below = np.cumsum(arr[:, 0])
    above = np.cumsum(arr[:, 0][::-1])
    
    plt.plot(np.arange(1,n+1), arr, '-k', lw=lw, zorder=6)
    plt.plot(np.arange(n_false+1,n+1), (below/np.arange(1,n+1))[n_false:], '--k', lw=lw, zorder=6)
    plt.plot(np.arange(n_false), (above/np.arange(1,n+1))[::-1][:n_false], '-.k', lw=lw, zorder=6)


    plt.xticks(fontsize=fslabels)
    plt.yticks(fontsize=fslabels)
    plt.xlabel('Validation Set Example',fontsize=fsaxes)
    plt.ylabel('Accuracy or Threshold',fontsize=fsaxes)
    plt.title(models[idx],fontsize=fstitle)
    plt.legend(handles=[
        mlines.Line2D([], [], color='gray', linestyle='-', linewidth=lw, label='Maximum Softmax Threshold'),
        mlines.Line2D([], [], color='gray', linestyle='--', linewidth=lw, label='Accuracy Below Threshold'),
        mlines.Line2D([], [], color='gray', linestyle='-.', linewidth=lw, label='Accuracy Above Threshold'),
        mlines.Line2D([], [], color='C0', alpha=al, linestyle='-', linewidth=lw, label='Uncalibrated'),
        mlines.Line2D([], [], color='C1', alpha=al, linestyle='-', linewidth=lw, label='Temperature Scaled'),
        mlines.Line2D([], [], color='C2', alpha=al, linestyle='-', linewidth=lw, label='Assumed Perfect Calibration'),
        mlines.Line2D([], [], color='k', linestyle='-', linewidth=lw, label='Ideal'),
        mlines.Line2D([], [], color='r', linestyle='None', marker='+', markersize=ms, label='ImageNet Test Submission'),
        ], loc='lower right', fontsize=fslabels, framealpha=0.5)
    plt.grid(lw=0.4)

plt.subplots_adjust(hspace=0.3)


plotname = 'figure_14_appendix_calibration_analysis'
plt.savefig(plotname+'.pdf', bbox_inches = 'tight', pad_inches = 0)