import torch 
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

def plot_model_calibration(att_mod, com_mod,dataloader, plot_name=None):
    def make_plot_for_model(mod,ax):
        mod.eval()
        all_outputs = []
        all_labels = []
        for images,labels,tasks in dataloader:
            images, tasks,labels = images.cuda(), tasks.cuda(),labels.cuda()
            outs,_,_ = mod(images,tasks)
            if not (mod.is_attention or mod.is_comodulation):
                targets = tasks
            else:
                targets = labels
            all_outputs.append(outs.detach().cpu())
            all_labels.append(targets.detach().cpu())

        all_outputs = torch.cat(all_outputs)
        all_labels = torch.cat(all_labels)
        return make_model_diagrams(all_outputs,all_labels,ax=ax,put_labels=mod.is_attention)
    
    fig, (ax1,ax2) = plt.subplots(1,2)
    att_ece = make_plot_for_model(att_mod,ax1,)
    com_ece = make_plot_for_model(com_mod,ax2)
    
    if plot_name is not None:
        plt.savefig(plot_name,format="pdf", bbox_inches="tight")
    return {'att_ece':att_ece,'com_ece':com_ece}

def calculate_ece(logits, labels, n_bins=10):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """

    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]

    softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    accuracies = predictions.eq(labels)

    ece = torch.zeros(1, device=logits.device)
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
    return ece.item()

def make_model_diagrams(outputs, labels,ax=None, n_bins=10,put_labels=True):
    """
    outputs - a torch tensor (size n x num_classes) with the outputs from the final linear layer
    - NOT the softmaxes
    labels - a torch tensor (size n) with the labels
    """
    softmaxes = torch.nn.functional.softmax(outputs, 1)
    confidences, predictions = softmaxes.max(1)
    accuracies = torch.eq(predictions, labels)
    overall_accuracy = (predictions==labels).sum().item()/len(labels)
    
    # Reliability diagram
    bins = torch.linspace(0, 1, n_bins + 1)
    width = 1.0 / n_bins
    bin_centers = np.linspace(0, 1.0 - width, n_bins) + width / 2
    bin_indices = [confidences.ge(bin_lower) * confidences.lt(bin_upper) for bin_lower, bin_upper in zip(bins[:-1], bins[1:])]
    
    bin_corrects = np.array([ torch.mean(accuracies[bin_index].float()) for bin_index in bin_indices])
    bin_scores = np.array([ torch.mean(confidences[bin_index].float()) for bin_index in bin_indices])
     
    #plt.figure(0, figsize=(8, 8))
    gap = (bin_scores - bin_corrects)
    confs = ax.bar(bin_centers, bin_corrects, width=width, alpha=0.1, ec='black')
    gaps = ax.bar(bin_centers, (bin_scores - bin_corrects), bottom=bin_corrects, color=[1, 0.7, 0.7], alpha=0.5, width=width, hatch='//', edgecolor='r')
    ax.plot([0, 1], [0, 1], '--', color='gray')
    
    ece = calculate_ece(outputs, labels)
    # Clean up
    bbox_props = dict(boxstyle="round", fc="lightgrey", ec="brown", lw=2)
    ax.text(0.2, 0.85, "ECE: {:.3f}".format(ece), ha="center", va="center", size=10, weight = 'bold', bbox=bbox_props)
    
    if put_labels:
        ax.set_ylabel("Accuracy (P[y]",  size=18)
        ax.set_xlabel("Confidence",  size=18)
        ax.legend([confs, gaps], ['Outputs', 'Gap'], loc='best', fontsize='small')
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    return ece
