import numpy as np
from scipy.optimize import minimize
import scipy.stats
import pandas as pd
import time
from sklearn.metrics import log_loss, brier_score_loss
from os.path import join
import sklearn.metrics as metrics
import sys
from os import path
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression

import torch
import torch.nn.functional as F
import torch.nn as nn
import pickle
from tools import ECE, MCE, ECE_balanced, MCE_balanced
import spline

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms

BATCH_SIZE = 256
#from tools import unpickle_probs
# from utils import progress_bar
def unpickle_probs(file, verbose = 0):
    with open(file, 'rb') as f:  # Python 3: open(..., 'rb')
        (y_probs_val, y_val), (y_probs_test, y_test) = pickle.load(f)  # unpickle the content
        
    if verbose:    
        print("y_probs_val:", y_probs_val.shape)  # (5000, 10); Validation set probabilities of predictions
        print("y_true_val:", y_val.shape)  # (5000, 1); Validation set true labels
        print("y_probs_test:", y_probs_test.shape)  # (10000, 10); Test set probabilities
        print("y_true_test:", y_test.shape)  # (10000, 1); Test set true labels
        
    return ((y_probs_val, y_val), (y_probs_test, y_test))

def softmax(x):
    """
    Compute softmax values for each sets of scores in x.
    
    Parameters:
        x (numpy.ndarray): array containing m samples with n-dimensions (m,n)
    Returns:
        x_softmax (numpy.ndarray) softmaxed values for initial (m,n) array
    """
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=1, keepdims=1)

# def get_pred_conf(y_probs, normalize = False):
#
#     y_preds = np.argmax(y_probs, axis=1)  # Take maximum confidence as prediction
#
#     if normalize:
#         y_confs = np.max(y_probs, axis=1)/np.sum(y_probs, axis=1)
#     else:
#         y_confs = np.max(y_probs, axis=1)  # Take only maximum confidence
#
#     return y_preds, y_confs

def confidence_ellipse(x, y, ax, n_std=3.0, facecolor='none', **kwargs):
    if x.size != y.size:
        raise ValueError("x and y must be the same size")

    cov = np.cov(x, y)
    pearson = cov[0, 1]/(np.sqrt(cov[0, 0] * cov[1, 1]) + 1e-10)
    # Using a special case to obtain the eigenvalues of this
    # two-dimensionl dataset.
    ell_radius_x = np.sqrt(1 + pearson)
    ell_radius_y = np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2,
                      facecolor=facecolor, **kwargs)

    # Calculating the stdandard deviation of x from
    # the squareroot of the variance and multiplying
    # with the given number of standard deviations.
    scale_x = np.sqrt(cov[0, 0]) * n_std
    mean_x = np.mean(x)

    # calculating the stdandard deviation of y ...
    scale_y = np.sqrt(cov[1, 1]) * n_std
    mean_y = np.mean(y)

    transf = transforms.Affine2D() \
        .rotate_deg(45) \
        .scale(scale_x, scale_y) \
        .translate(mean_x, mean_y)

    ellipse.set_transform(transf + ax.transData)
    ax.add_patch(ellipse)


def confidence_ball(y_true, y_prob, n, name, ax=None, n_bins=3):
    logging = [[], []]
    np.random.seed(2021)
    quantiles = np.linspace(0, 1, n_bins + 1)
    bins = np.percentile(y_prob, quantiles * 100)
    bins[-1] = bins[-1] + 1e-8
    binids = np.digitize(y_prob, bins) - 1
    bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
    bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
    bin_total = np.bincount(binids, minlength=len(bins))
    nonzero = bin_total != 0
    for i, check in enumerate(nonzero):
        if check:
            bin_idx = np.where(binids==i)[0]
            bin_empirical_list = []
            bin_expected_list = []
            for _ in range(n):
                shuffle = np.random.choice(len(bin_idx),size=len(bin_idx), replace=True)
                bin_idx_shuffled = bin_idx[shuffle]
                y_true_bin, y_prob_bin = y_true[bin_idx_shuffled], y_prob[bin_idx_shuffled]
                bin_empirical_list.append(y_true_bin.mean())
                bin_expected_list.append(y_prob_bin.mean())
            bin_empirical_list = np.array(bin_empirical_list)
            bin_expected_list = np.array(bin_expected_list)
#             logging[0].append([bin_empirical_list.mean(), bin_empirical_list.std()])
#             logging[1].append([bin_expected_list.mean(), bin_expected_list.std()])
            logging[0].append(bin_empirical_list)
            logging[1].append(bin_expected_list)
#     logging = np.array(logging)
    if ax:
        for i in range(len(logging[0])):
            confidence_ellipse(logging[1][i], logging[0][i], ax, facecolor='blue', n_std=1.96, alpha=0.5)
        ax.plot([logging[1][i].mean() for i in range(len(logging[1]))],
                [logging[0][i].mean() for i in range(len(logging[0]))], '-s')
        ax.plot([0, 1], [0, 1], "k:")
#         ax.set_title('{} \n Brier - {:.3f} ({:.3f}) \n ECE - {:.3f} ({:.3f})'.format( \
#         name, brier_mean, brier_std, ece_mean, ece_std))
        ax.set_title(name)
    return logging


def plot_calibration_curve(y, output_prob, ece, mce, ax=None, n_bins=10, confidence=True):
    idx_sorted = np.argsort(output_prob[:,1])
    sorted_prob = output_prob[idx_sorted,1]
    sorted_labels = y[idx_sorted]
    plt.figure(figsize = (10,10))
    plt.rcParams.update({'font.size': 28})
    n = sorted_prob.shape[0]
    all_probs = []
    if ax is None:
        fig, ax = plt.subplots(1,1, figsize=(7, 7))
    if confidence:
        logging = confidence_ball(sorted_labels, sorted_prob, 1000, '', ax, n_bins)
    else:
        for i in range(n_bins):
            avg_pred_prob = np.mean(sorted_prob[int(i*(n//n_bins)):int((n//n_bins)*(i+1))])
            avg_true_prob = np.mean(sorted_labels[int(i*(n//n_bins)):int((n//n_bins)*(i+1))])
            all_probs.append(avg_pred_prob)
            all_probs.append(avg_true_prob)
            ax.scatter(avg_pred_prob, avg_true_prob, color = 'r')
        max_prob = np.max(all_probs)
        min_prob = np.min(all_probs)
        ax.set_xlim([min_prob * 0.95, max_prob * 1.05])
        ax.set_ylim([min_prob * 0.95, max_prob * 1.05])
        ax.plot(np.linspace(min_prob*0.95,max_prob*1.05),np.linspace(min_prob*0.95,max_prob*1.05), '--k')
    ax.set_xlabel(r"Predicted Probability")
    ax.set_ylabel(r"Empirical Probability")
    ax.axis('equal')
    ax.set_title('ECE = {:.4f}, MCE = {:.4f}'.format(ece,mce))


def plot_histogram(prob, label, gt, ax=None):
    if not ax:
        fig, ax = plt.subplots(1,1,figsize=(7,7))
    ax.hist(prob[:, 1], range=(0,1), histtype="stepfilled", bins=10, alpha=0.6)
    ax.hist(label, range=(0,1), histtype="stepfilled", bins=10, alpha=0.6)
    ax.hist(gt, range=(0,1), histtype="stepfilled", bins=10, alpha=0.6)

def plot_scatter(prob, label, gt, ax=None):
    if not ax:
        fig, ax = plt.subplots(1,1,figsize=(7,7))
    ax.plot(gt, prob[:, 1], '.', alpha=0.6)
    ax.plot(np.linspace(0,1,100), np.linspace(0,1,100), 'r')
    ax.set_xlabel('Ground Truth')
    ax.set_ylabel('Estimated probability')


def plot_empirical_distribution(y_probs, y_label, y_gt, ax=None, showplots=False):
    if not ax:
        fig, ax = plt.subplots(1,1,figsize=(7,7))
    prob = y_probs[:, 1]
    order = prob.argsort()
    prob = prob[order]
    label = y_label[order]
    gt = y_gt[order]

    # Accumulate and normalize by dividing by num samples
    nsamples = len(prob)
    integrated_scores = np.cumsum(prob) / nsamples
    integrated_accuracy = np.cumsum(label) / nsamples
    integrated_gts = np.cumsum(gt) / nsamples
    percentile = np.linspace(0.0, 1.0, nsamples)
    spline_method = 'natural'
    splines = 6
    fitted_accuracy, fitted_error = spline.compute_accuracy(prob, label, spline_method, splines, showplots=showplots, ax=ax)

    # Work out the Kolmogorov-Smirnov error
    KS_error_max = np.amax(np.absolute(integrated_scores - integrated_accuracy))
    if showplots:
        # Set up the graphs
        f, ax = plt.subplots(1, 4, figsize=(20, 5))
        size = 0.2
        f.suptitle(f"\nKS-error = {spline.str(float(KS_error_max) * 100.0)}%, "
                           f"Probability={spline.str(float(integrated_accuracy[-1]) * 100.0)}%"
                   , fontsize=18, fontweight="bold")

        # First graph, (accumualated) integrated_scores and integrated_accuracy vs sample number
        ax[0].plot(100.0 * percentile, integrated_scores, linewidth=3, label='Cumulative Score')
        ax[0].plot(100.0 * percentile, integrated_accuracy, linewidth=3, label='Cumulative Probability')
        ax[0].set_xlabel("Percentile", fontsize=16, fontweight="bold")
        ax[0].set_ylabel("Cumulative Score / Probability", fontsize=16, fontweight="bold")
        ax[0].legend(fontsize=13)
        ax[0].set_title('(a)', y=-size, fontweight="bold", fontsize=16)  # increase or decrease y as needed
        ax[0].grid()

        # Second graph, (accumualated) integrated_scores and integrated_accuracy versus
        # integrated_scores
        ax[1].plot(integrated_scores, integrated_scores, linewidth=3, label='Cumulative Score')
        ax[1].plot(integrated_scores, integrated_accuracy, linewidth=3,
                   label="Cumulative Probability")
        ax[1].set_xlabel("Cumulative Score", fontsize=16, fontweight="bold")
        # ax[1].set_ylabel("Cumulative Score / Probability", fontsize=12)
        ax[1].legend(fontsize=13)
        ax[1].set_title('(b)', y=-size, fontweight="bold", fontsize=16)  # increase or decrease y as needed
        ax[1].grid()

        # Third graph, scores and accuracy vs percentile
        ax[2].plot(100.0 * percentile, prob, linewidth=3, label='Score')
        ax[2].plot(100.0 * percentile, fitted_accuracy, linewidth=3, label=f"Probability")
        ax[2].set_xlabel("Percentile", fontsize=16, fontweight="bold")
        ax[2].set_ylabel("Score / Probability", fontsize=16, fontweight="bold")
        ax[2].legend(fontsize=13)
        ax[2].set_title('(c)', y=-size, fontweight="bold", fontsize=16)  # increase or decrease y as needed
        ax[2].grid()

        # Fourth graph,
        # integrated_scores
        ax[3].plot(prob, prob, linewidth=3, label=f"Score")
        ax[3].plot(prob, fitted_accuracy, linewidth=3, label='Probability')
        ax[3].set_xlabel("Score", fontsize=16, fontweight="bold")
        # ax[3].set_ylabel("Score / Probability", fontsize=12)
        ax[3].legend(fontsize=13)
        ax[3].set_title('(d)', y=-size, fontweight="bold", fontsize=16)  # increase or decrease y as needed
        ax[3].grid()
    return KS_error_max


class HistogramBinning():
    """
    Histogram Binning as a calibration method. The bins are divided into equal lengths.
    
    The class contains two methods:
        - fit(probs, true), that should be used with validation data to train the calibration model.
        - predict(probs), this method is used to calibrate the confidences.
    """
    
    def __init__(self, n_bins=15):
        """
        M (int): the number of equal-length bins used
        """
        self.bin_size = 1./n_bins  # Calculate bin size
        self.conf = []  # Initiate confidence list
        self.upper_bounds = np.arange(self.bin_size, 1+self.bin_size, self.bin_size)  # Set bin bounds for intervals

    
    def _get_conf(self, conf_thresh_lower, conf_thresh_upper, probs, true):
        """
        Inner method to calculate optimal confidence for certain probability range
        
        Params:
            - conf_thresh_lower (float): start of the interval (not included)
            - conf_thresh_upper (float): end of the interval (included)
            - probs : list of probabilities.
            - true : list with true labels, where 1 is positive class and 0 is negative).
        """

        # Filter labels within probability range
        filtered = [x[0] for x in zip(true, probs) if x[1] > conf_thresh_lower and x[1] <= conf_thresh_upper]
        nr_elems = len(filtered)  # Number of elements in the list.

        if nr_elems < 1:
            return 0
        else:
            # In essence the confidence equals to the average accuracy of a bin
            conf = sum(filtered)/nr_elems  # Sums positive classes
            return conf


    def fit(self, probs, true):
        """
        Fit the calibration model, finding optimal confidences for all the bins.
        
        Params:
            probs: probabilities of data
            true: true labels of data
        """

        conf = []

        # Got through intervals and add confidence to list
        for conf_thresh in self.upper_bounds:
            temp_conf = self._get_conf((conf_thresh - self.bin_size), conf_thresh, probs = probs, true = true)
            conf.append(temp_conf)

        self.conf = np.array(conf)


    # Fit based on predicted confidence
    def predict(self, probs):
        """
        Calibrate the confidences
        
        Param:
            probs: the output from neural network for each class (shape [samples, classes])
            
        Returns:
            Calibrated probabilities (shape [samples, classes])
        """

        # Go through all the probs and check what confidence is suitable for it.
        
        for i, prob in enumerate(probs):
            idx = np.searchsorted(self.upper_bounds, prob)
            probs[i] = self.conf[idx]

        return probs



class TemperatureScaling():
    
    def __init__(self, temp = 1, maxiter = 50, solver = "BFGS"):
        """
        Initialize class
        
        Params:
            temp (float): starting temperature, default 1
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver


    
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_proba(probs, x)    
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    
    # Find the temperature
    def fit(self, logtis, true):
        """
        Trains the model and finds optimal temperature
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: one-hot-encoding of true labels.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """
        
        #true = true.flatten() # Flatten y_val
        true =  np.eye(2)[true.astype(int)]
        opt = minimize(self._loss_fun, x0 = 1, args=(logtis, true), options={'maxiter':self.maxiter}, method = self.solver)
        self.temp = opt.x[0]
        
        return opt
        
    def predict_proba(self, logits, temp = None):
        """
        Scales logits based on the temperature and returns calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        if not temp:
            return softmax(logits/self.temp)
        else:
            return softmax(logits/temp)


class Our_method():
    
    def __init__(self, net, optimizer, train_dataset, val_dataset, finetune_type='bin', num_epoch=10, n_bins=10,
                 calpertrain=5, sigma=0.1, window=500, save_dir="best_checkpoint.pth"):
        """
        Initialize class
        
        Params:
            net : early stopped neural network   
            train_dataset : pytorch training dataset
            val_dataset : pytorch validation dataset
            num_epoch : epochs of finetuning the model
            n_bins: number of bins to calculate probability targets. 

        """

        self.net = net
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.n_bins = n_bins
        self.optimizer = optimizer
        self.num_epoch = num_epoch
        self.calpertrain = calpertrain
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.save_dir = save_dir
        self.sigma = sigma
        self.window = window
        self.finetune_type = finetune_type

    
    def _inference(self, net, dataloader, return_gt=False):
        targets_probs = np.zeros(len(dataloader.dataset))
        labels = np.zeros(len(dataloader.dataset))
        indices = np.zeros(len(dataloader.dataset))
        gt_labels = np.zeros(len(dataloader.dataset))
        net.eval()
        with torch.no_grad():
            for batch_idx, (inputs, label, _, idx, gt_label) in enumerate(dataloader):
                inputs = inputs.to(self.device)
                outputs = net(inputs)
                out_prob = F.softmax(outputs,dim=1)
                targets_probs[idx] = out_prob[:,1].cpu().numpy()
                labels[idx] = label
                indices[idx] = idx
                gt_labels[idx] = gt_label
        if return_gt:
            return targets_probs, labels, indices, gt_labels
        else:
            return targets_probs, labels, indices

    def get_new_prob(self, mean_value, prob_array, true_array, scale):
        weight = scipy.stats.norm.pdf(prob_array, loc=mean_value, scale=scale)
        return np.sum(weight * true_array) / np.sum(weight)

    def _sort_update(self, targets_probs, labels, indices, n_bins = 10):
        sorted_idx = np.argsort(targets_probs)
        targets_probs = targets_probs[sorted_idx]
        labels = labels[sorted_idx]
        indices = indices[sorted_idx]
        num_sample = len(labels)
        proposed_probs = np.zeros(num_sample)
        new_labels = np.zeros(num_sample)
        if self.finetune_type == 'bin':
            for i in range(n_bins):
                left = int(i*num_sample/n_bins)
                right = int((i+1)*num_sample/n_bins)
                new_labels[left:right] = np.mean((labels[left:right]))
        elif self.finetune_type == 'kde':
            for i in range(num_sample):
                left = np.maximum(0, i - self.window)
                right = np.minimum(i + self.window, num_sample)
                new_labels[i] = self.get_new_prob(targets_probs[i],
                                             targets_probs[left:right], labels[left:right], scale=self.sigma)
        else:
            raise NotImplementedError

        for i in range(num_sample):
            proposed_probs[int(indices[i])] = new_labels[i]
        self.train_dataset.proposed_probs = proposed_probs
        print('%'*40)
        print('\nUpdated the Probs Labels\n')

    def _calibrate(self,  trainloader, use_prob=True):
        self.net.train()
        train_loss = 0

        if use_prob:
            print('Using Binned Probs')
        else:
            print('Using GT')

        for batch_idx, (inputs, labels, targets_probs, idx, _) in enumerate(trainloader):
            self.optimizer.zero_grad()
            inputs, labels = inputs.to(self.device), labels.to(self.device)
            outputs = self.net(inputs)
            if use_prob:
                outputs = outputs.log_softmax(dim=1)
                targets = torch.stack([1-targets_probs, targets_probs]).T.to(self.device)
                loss = torch.mean(torch.sum(-targets * outputs, dim=1))
            else:
                criterion = nn.CrossEntropyLoss()
                loss = criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            # print(batch_idx, len(trainloader), 'Loss: %.3f'
            #          % (train_loss/(batch_idx+1)))
        return train_loss/(batch_idx+1)

    @staticmethod
    def _sigmoid_rampup(current, rampup_length):
        """Exponential rampup from  2"""
        if rampup_length == 0:
            return 1.0
        else:
            current = np.clip(current, 0.0, rampup_length)
            phase = 1.0 - current / rampup_length
            return float(np.exp(-5.0 * phase * phase))
            
        

    # Find the temperature
    def fit(self):
        """
        Trains the model and finds optimal temperature
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: one-hot-encoding of true labels.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        trainloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
        min_val_loss = 1e4

        for epoch in range(0, self.num_epoch):
            targets_probs, labels, indices = self._inference(self.net, trainloader)
            # n_bins = _sigmoid_rampup(epoch, rampup_length)*30
            self._sort_update(targets_probs, labels, indices, n_bins = self.n_bins)
            trainloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
            valloader = torch.utils.data.DataLoader(self.val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

            use_prob = (epoch % (self.calpertrain + 1) == 0)
            loss = self._calibrate(trainloader, use_prob=use_prob)

            # evaluate
            val_targets_probs, labels, _ = self._inference(self.net, valloader)

            val_loss = log_loss(y_true=labels, y_pred=val_targets_probs)

            if min_val_loss > val_loss:
                torch.save({
                    'net': self.net.state_dict(),
                    'val_loss': val_loss,
                    'epoch': epoch, }, self.save_dir)
                min_val_loss = val_loss
            print('Epoch: ', epoch, ' Loss: %.3f' % (loss), 'Val Loss: %.3f' % (val_loss))

            

        
    def predict(self, test_dataset, file=None):
        """
        Scales logits based on the temperature and returns calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """

        if file is not None:
            self.net.load_state_dict(torch.load(file)['net'])
        testloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
        test_targets_probs, labels, _, gt_labels = self._inference(self.net, testloader, return_gt=True)

        test_targets_probs = np.stack([1-test_targets_probs,test_targets_probs]).T


        return test_targets_probs, labels, gt_labels




def evaluate(probs, y_true, y_true_gt, verbose = False, normalize = False, n_bins = 15):
    """
    Evaluate model using various scoring measures: Error Rate, ECE, MCE, NLL, Brier Score
    
    Params:
        probs: a list containing probabilities for all the classes with a shape of (samples, classes)
        y_true: a list containing the actual class labels
        verbose: (bool) are the scores printed out. (default = False)
        normalize: (bool) in case of 1-vs-K calibration, the probabilities need to be normalized.
        bins: (int) - into how many bins are probabilities divided (default = 15)
        
    Returns:
        (error, ece, mce, loss, brier), returns various scoring measures
    """
    
    preds = np.argmax(probs, axis=1)  # Take maximum confidence as prediction
    
    if normalize:
        # confs = np.max(probs, axis=1)/np.sum(probs, axis=1)
        # Check if everything below or equal to 1?
        confs = probs[:, 1] / np.sum(probs, axis=1)
    else:
        # confs = np.max(probs, axis=1)  # Take only maximum confidence
        confs = probs[:, 1]

    accuracy = metrics.accuracy_score(y_true, preds) * 100
    error = 100 - accuracy
    
        # Calculate ECE
    ece = ECE_balanced(confs, y_true, n_bins=n_bins)
    # Calculate MCE
    mce = MCE_balanced(confs, y_true, n_bins=n_bins)
    
    loss = log_loss(y_true=y_true, y_pred=probs)
    eps = 1e-10
    gt_cel_loss = -np.mean(y_true_gt*np.log(probs[:,1]) + (1-y_true_gt)*np.log(probs[:,0]))
    kl_gt_prob = -np.mean(y_true_gt*np.log((probs[:,1] + eps) / (y_true_gt + eps)) \
                          + (1-y_true_gt)*np.log((probs[:,0] + eps) / (1-y_true_gt + eps)))
    kl_prob_gt = -np.mean(probs[:, 1] * np.log((y_true_gt + eps) / (probs[:, 1] + eps)) \
                          + probs[:, 0] * np.log((1 - y_true_gt + eps) / probs[:, 0] + eps))
    #print(y_true)
    
    y_prob_true = np.array([probs[i, int(idx)] for i, idx in enumerate(y_true)])  # Probability of positive class
    brier = brier_score_loss(y_true=y_true, y_prob=probs[:,1])  # Brier Score (MSE)
    gt_brier = np.mean((y_true_gt-probs[:,1])**2) # Brier Score for gt probs (MSE)
    ks_error = plot_empirical_distribution(probs, y_true, y_true_gt, ax=None, showplots=False)


    if verbose:
        print("Accuracy:", accuracy)
        print("Error:", error)
        print("ECE:", ece)
        print("MCE:", mce)
        print("Loss:", loss)
        print("brier:", brier)
        print("brier with gt:", gt_brier)
        print("Loss with gt:", gt_cel_loss)
        print("KL(gt || prob)", kl_gt_prob)
        print("KL(prob || gt )", kl_prob_gt)
        print("KS error", ks_error)
    
    return error, ece, mce, loss, brier, gt_brier, gt_cel_loss, kl_gt_prob, kl_prob_gt, ks_error


def calibrate_model(method, name='ours', m_kwargs={}, net=None, train_dataset=None, val_dataset=None, test_dataset=None,
                    approach='single', n_bins_4_plot=10, finetune=True, finetuned_model_path=None, plot_figure=True):
    """
    Params:
        method (class): class of the calibration method used. It must contain methods "fit" and "predict", 
                    where first fits the models and second outputs calibrated probabilities.
        path (string): path to the folder with logits files
        files (list of strings): pickled logits files ((logits_val, y_val), (logits_test, y_test))
        
    Returns:
        df (pandas.DataFrame): dataframe with calibrated and uncalibrated results for all the input files.
    """

    df = pd.DataFrame(columns=["Name", "Error", "ECE", "MCE", "Loss", "Brier", "Brier_w_gt", "Loss_w_gt",
                               "KL_gt_prob", "KL_prob_gt", "ks_error"])
    if 'ours' not in name:

        # for i, f in enumerate(files):
        #     name = "_".join(f.split("_")[1:-1])
        #     FILE_PATH = join(path, f)
        #     (logits_val, y_val), (logits_test, y_test) = unpickle_probs(FILE_PATH)

            # Train calibration model



        valloader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
        testloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

        val_logits, y_val, y_val_gt = inference(net, valloader)

        test_logits, y_test, y_test_gt = inference(net, testloader) 

        probs_val = softmax(val_logits)

        probs_test = softmax(test_logits)
        
        error, ece, mce, loss, brier, brier_gt, gt_cel_loss, kl_gt_prob, kl_prob_gt, ks_error = \
            evaluate(probs_test, y_test, y_test_gt, verbose=True)  # Test before scaling
        if plot_figure:
            fig, ax = plt.subplots(1, 4, figsize=(40, 10))
            plot_calibration_curve(y_test,probs_test,ece,mce,n_bins=n_bins_4_plot, ax=ax[0])
            plot_histogram(probs_test, y_test, y_test_gt, ax=ax[1])
            plot_scatter(probs_test, y_test, y_test_gt, ax=ax[2])
            plot_empirical_distribution(probs_test, y_test, y_test_gt, ax=ax[3], showplots=True)
        if approach == 'single': 
            for k in range(probs_test.shape[1]):
                #print(np.array(y_val == k, dtype="int"))
                y_cal = np.array(y_val == k, dtype="int")

                # Train model
                model = method(**m_kwargs)
                model.fit(val_logits[:, k], y_cal) # Get only one column with probs for given class "k"

                if "histogram" in name:
                    probs_val[:, k] = model.predict(probs_val[:, k])  # Predict new values based on the fittting
                    probs_test[:, k] = model.predict(probs_test[:, k])
                else:
                    probs_val[:, k] = model.predict_proba(val_logits[:, k])  # Predict new values based on the fittting
                    probs_test[:, k] = model.predict_proba(test_logits[:, k])

                # Replace NaN with 0, as it should be close to zero  # TODO is it needed?
                idx_nan = np.where(np.isnan(probs_test))
                probs_test[idx_nan] = 0

                idx_nan = np.where(np.isnan(probs_val))
                probs_val[idx_nan] = 0
                
            # _, probs_val = get_pred_conf(probs_val, normalize = True)
            # _, probs_test = get_pred_conf(probs_test, normalize = True)
            probs_val = probs_val/probs_val.sum(axis=1, keepdims=True)
            probs_test = probs_test/probs_test.sum(axis=1, keepdims=True)
        else:
            model = method(**m_kwargs)
            model.fit(val_logits, y_val)

            probs_test = model.predict_proba(test_logits)
            #print(probs_test)

        error2, ece2, mce2, loss2, brier2, brier_gt2, gt_cel_loss2,  kl_gt_prob2, kl_prob_gt2, ks_error2 \
            = evaluate(probs_test, y_test,y_test_gt,verbose=False)
        if plot_figure:
            fig, ax = plt.subplots(1, 4, figsize=(40, 10))
            plot_calibration_curve(y_test, probs_test, ece, mce, n_bins=n_bins_4_plot, ax=ax[0])
            plot_histogram(probs_test, y_test, y_test_gt, ax=ax[1])
            plot_scatter(probs_test, y_test, y_test_gt, ax=ax[2])
            plot_empirical_distribution(probs_test, y_test, y_test_gt, ax=ax[3], showplots=True)

        df.loc[0] = [name, error, ece, mce, loss, brier,brier_gt, gt_cel_loss, kl_gt_prob, kl_prob_gt, ks_error]
        df.loc[1] = [(name + "_calibrated"), error2, ece2, mce2, loss2, brier2,brier_gt2, gt_cel_loss2, kl_gt_prob2, kl_prob_gt2, ks_error2]

    else:
        model = method(**m_kwargs)
        probs_test, y_test, y_test_gt = model.predict(test_dataset, file = None) 
        
        error, ece, mce, loss, brier, brier_gt, gt_cel_loss,  kl_gt_prob, kl_prob_gt, ks_error = \
            evaluate(probs_test, y_test,y_test_gt, verbose=True)  # Test before recalibration
        df.loc[0] = ['Ours', error, ece, mce, loss, brier,brier_gt, gt_cel_loss, kl_gt_prob, kl_prob_gt, ks_error]
        if finetune:
            model.fit()
            probs_test, y_test, y_test_gt = model.predict(test_dataset) 
            error2, ece2, mce2, loss2, brier2, brier_gt2, gt_cel_loss2, kl_gt_prob2, kl_prob_gt2, ks_error2 = \
                evaluate(probs_test, y_test,y_test_gt,verbose=False)
            if plot_figure:
                fig, ax = plt.subplots(1, 4, figsize=(40, 10))
                plot_calibration_curve(y_test, probs_test, ece, mce, n_bins=n_bins_4_plot, ax=ax[0])
                plot_histogram(probs_test, y_test, y_test_gt, ax=ax[1])
                plot_scatter(probs_test, y_test, y_test_gt, ax=ax[2])
                plot_empirical_distribution(probs_test, y_test, y_test_gt, ax=ax[3], showplots=True)
            df.loc[1] = ['Ours'+ "_calibrated", error2, ece2, mce2, loss2, brier2,brier_gt2, gt_cel_loss2, kl_gt_prob2, kl_prob_gt2, ks_error2]
        else:
            probs_test, y_test, y_test_gt = model.predict(test_dataset, file=finetuned_model_path)
            error2, ece2, mce2, loss2, brier2, brier_gt2, gt_cel_loss2, kl_gt_prob2, kl_prob_gt2, ks_error2 = \
                evaluate(probs_test, y_test,y_test_gt,verbose=False)
            if plot_figure:
                fig, ax = plt.subplots(1, 4, figsize=(40, 10))
                plot_calibration_curve(y_test, probs_test, ece, mce, n_bins=n_bins_4_plot, ax=ax[0])
                plot_histogram(probs_test, y_test, y_test_gt, ax=ax[1])
                plot_scatter(probs_test, y_test, y_test_gt, ax=ax[2])
                plot_empirical_distribution(probs_test, y_test, y_test_gt, ax=ax[3], showplots=True)
            df.loc[1] = ['Ours'+ "_calibrated", error2, ece2, mce2, loss2, brier2,brier_gt2, gt_cel_loss2, kl_gt_prob2, kl_prob_gt2, ks_error2]

    return df



def inference(net, dataloader):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    targets_logits = np.zeros((len(dataloader.dataset),2))
    labels = np.zeros(len(dataloader.dataset))
    gt_labels = np.zeros(len(dataloader.dataset))
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, label, _, idx, gt_label) in enumerate(dataloader):

            inputs = inputs.to(device)
            outputs = net(inputs)
            out_logits = outputs
            targets_logits[idx] = out_logits.cpu().numpy()
            labels[idx] = label
            gt_labels[idx] = gt_label
    return targets_logits, labels, gt_labels
