""" Just some utils to build calibration plots """


import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import torch.optim as optim
import numpy as np


@torch.no_grad()
def get_preds(model, test_loader, data_keys, label_key, device,
              calibration_method=None, cal_args=None, disable_tqdm=False):
    # Gets the prob-dist/one-hot for the entire testset and flatten
    preds = []
    labels_oneh = []
    correct,  total = 0, 0
    model = model.eval().to(device)

    for batch in tqdm(test_loader, disable=disable_tqdm):
        data = [batch[k].to(device) for k in data_keys]
        label = batch[label_key].to(device)


        pred = model(*data)
        num_classes = pred.shape[-1]

        if calibration_method:
            pred = calibration_method(pred, cal_args)

        sm = nn.Softmax(dim=1)
        pred = sm(pred)
        predicted_c1 = torch.max(pred, dim=1)[1]
        pred = pred.cpu().detach().numpy()

        label_oneh = F.one_hot(label, num_classes=num_classes).cpu().detach().numpy()

        preds.extend(pred)
        labels_oneh.extend(label_oneh)

        total += data[0].shape[0]
        correct += sum(predicted_c1 == label).item()

    preds = np.array(preds).flatten()
    labels_oneh = np.array(labels_oneh).flatten()

    total_acc = 100 * correct / total
    #print("Accuracy of network: %.03f" % total_acc)

    return preds, labels_oneh, total_acc


def calc_bins(preds, labels_oneh):
    # Assign each prediction to a bin
    num_bins = 10
    bins = np.linspace(0.1, 1, num_bins)
    binned = np.digitize(preds, bins)

    bin_accs = np.zeros(num_bins)
    bin_confs = np.zeros(num_bins)
    bin_sizes = np.zeros(num_bins)

    for b in range(num_bins):
        bin_sizes[b] = len(preds[binned==b])
        if bin_sizes[b] > 0:
            bin_accs[b] = labels_oneh[binned==b].sum() / bin_sizes[b]
            bin_confs[b] = preds[binned==b].sum() / bin_sizes[b]
    return bins, binned, bin_accs, bin_confs, bin_sizes


def get_metrics(preds, labels_oneh):
    ECE = 0
    MCE = 0
    bins, _, bin_accs, bin_confs, bin_sizes = calc_bins(preds, labels_oneh)

    for i in range(len(bins)):
        abs_conf_dif = abs(bin_accs[i] - bin_confs[i])
        ECE += (bin_sizes[i] / sum(bin_sizes)) * abs_conf_dif
        MCE = max(MCE, abs_conf_dif)
    return ECE, MCE


def draw_reliability_graph(preds, labels_oneh):
    ECE, MCE = get_metrics(preds, labels_oneh)
    bins, _, bin_accs, _, _ = calc_bins(preds, labels_oneh)

    fig = plt.figure(figsize=(8,8,))
    ax = fig.gca()

    ax.set_xlim(0, 1.05)
    ax.set_ylim(0, 1.)

    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')

    ax.set_axisbelow(True)
    ax.grid(color='gray', linestyle='dashed')

    plt.bar(bins, bins, width=0.1, alpha=0.3, edgecolor='black', color='r', hatch='\\')

    plt.bar(bins, bin_accs, width=0.1, alpha=1, edgecolor='black', color='b')
    plt.plot([0,1], [0,1], '--', color='gray', linewidth=2)

    plt.gca().set_aspect('equal', adjustable='box')

    ECE_patch = mpatches.Patch(color='green', label='ECE = {:.2f}%'.format(ECE*100))
    MCE_patch = mpatches.Patch(color='red', label='MCE = {:.2f}%'.format(MCE*100))
    plt.legend(handles=[ECE_patch, MCE_patch])

    return ax

# =================================================
# =           Temperature scaling stuff           =
# =================================================


def T_scaling(logits, args):
    temperature = args.get('temperature', 1.0)
    return torch.div(logits, temperature)



def learn_temperature(model, test_loader, data_keys, label_key, device, disable_tqdm=False):
    temperature = nn.Parameter(torch.ones(1).to(device))
    args = {'temperature': temperature}
    optimizer = optim.LBFGS([temperature], lr=0.001, max_iter=10_000, line_search_fn='strong_wolfe')

    logits_list = []
    labels_list = []
    temps = []
    losses = []

    model = model.eval().to(device)
    for i, batch in enumerate(tqdm(test_loader, disable=disable_tqdm)):
        data = [batch[k].to(device) for k in data_keys]
        with torch.no_grad():
            logits_list.append(model(*data))
            labels_list.append(batch[label_key].to(device))

    logits_list = torch.cat(logits_list).to(device)
    labels_list = torch.cat(labels_list).to(device)

    def _eval():
        loss = F.cross_entropy(T_scaling(logits_list, args), labels_list)
        loss.backward()
        temps.append(temperature.item())
        losses.append(loss)
        return loss

    optimizer.step(_eval)
    # print("FINAL TEMP: %.02f" % temperature.item())
    return args, temps, losses



if __name__ == '__main__':
    main()

