import numpy as np
from scipy.optimize import minimize
from scipy.special import softmax
from sklearn.metrics import log_loss
from sklearn.preprocessing import label_binarize
from errors import BS, NLL, accuracy

# https://github.com/MLO-lab/better_uncertainty_calibration/blob/main/temperature_scaling.py
class TemperatureScaling():

    def __init__(self, temp=1, maxiter=50, solver="BFGS", loss='NLL'):
        """
        Initialize class

        Params:
            temp (float): starting temperature, default 1
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver
        self.loss = loss

    def _loss_fun(self, x, logits, true):
        scaled_l = self.predict(logits, x)
        if self.loss == 'BS':
            loss = BS(scaled_l, true)
        elif self.loss == 'NLL':
            loss = NLL(scaled_l, true)
        return loss

    # Find the temperature
    def fit(self, logits, true, verbose=False):
        """
        Trains the model and finds optimal temperature

        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: one-hot-encoding of true labels.

        Returns:
            the results of optimizer after minimizing is finished.
        """

        true = true.flatten() # Flatten y_val
        opt = minimize(self._loss_fun, x0 = 1, args=(logits, true), options={'maxiter':self.maxiter}, method = self.solver)
        self.temp = opt.x[0]

        if verbose:
            print("Temperature:", 1/self.temp)

        return opt

    def predict(self, logits, temp = None):
        """
        Scales logits based on the temperature and returns calibrated probabilities

        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.

        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """

        if not temp:
            return logits/self.temp
        else:
            return logits/temp