# another implementation of temp scaling that only requires the logits
import numpy as np
from scipy.optimize import minimize 

import pandas as pd
import time
from sklearn.metrics import log_loss, mean_squared_error
from os.path import join
import sklearn.metrics as metrics

# there is a bug in this 
# def softmax(x):
#     """
#     Compute softmax values for each sets of scores in x.
    
#     Parameters:
#         x (numpy.ndarray): array containing m samples with n-dimensions (m,n)
#     Returns:
#         x_softmax (numpy.ndarray) softmaxed values for initial (m,n) array
#     """
#     e_x = np.exp(x - np.max(x))  # Subtract max so biggest is 0 to avoid numerical instability
    
#     # Axis 0 if only one dimensional array
#     axis = 0 if len(e_x.shape) == 1 else 1
    
#     return e_x / e_x.sum(axis=axis, keepdims=1)
    

def softmax(x):
    """
    Compute softmax values for each sets of scores in x.
    
    Parameters:
        x (numpy.ndarray): array containing m samples with n-dimensions (m,n)
    Returns:
        x_softmax (numpy.ndarray) softmaxed values for initial (m,n) array
    """
    e_x = np.exp(x - np.max(x,axis = 1).reshape(x.shape[0],1))  # Subtract max so biggest is 0 to avoid numerical instability
    
    # Axis 0 if only one dimensional array
    axis = 0 if len(e_x.shape) == 1 else 1
    
    return e_x / e_x.sum(axis=axis, keepdims=1)

class TemperatureScaling():
    
    def __init__(self, temp = 1, maxiter = 50, solver = "BFGS"):
        """
        Initialize class
        
        Params:
            temp (float): starting temperature, default 1
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict(probs, x)    
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    
    # Find the temperature
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """
        
        # true = true.flatten() # Flatten y_val
        opt = minimize(self._loss_fun, x0 = 1, args=(logits, true), options={'maxiter':self.maxiter}, method = self.solver)
        self.temp = opt.x[0]
        
        return opt
        
    def predict(self, logits, temp = None):
        """
        Scales logits based on the temperature and returns calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        if not temp:
            # print("Cal. temperature is", self.temp)
            return softmax(logits/self.temp)
        else:
            return softmax(logits/temp)




class TemperatureScaling_no_softmax():
    
    def __init__(self, temp = 1, maxiter = 50, solver = "BFGS"):
        """
        Initialize class
        
        Params:
            temp (float): starting temperature, default 1
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.maxiter = maxiter
        self.solver = solver
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict(probs, x)    
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    
    # Find the temperature
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """
        
        true = true.flatten() # Flatten y_val
        opt = minimize(self._loss_fun, x0 = 1, args=(logits, true), options={'maxiter':self.maxiter}, method = self.solver)
        self.temp = opt.x[0]
        
        return opt
        
    def predict(self, logits, temp = None):
        """
        Scales logits based on the temperature and returns calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set use temperatures find by model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        if not temp:
            print("Cal. temperature is", self.temp)
            return logits/self.temp
        else:
            return logits/temp
        
# on sorted logits vector scaling, make the weight scrictly increasing also bias terms scritcly increasing.


class TemperatureScalingWithBias():
    
    def __init__(self, temp=1, bias=0, maxiter=100, solver="BFGS"):
        """
        Initialize class
        
        Params:
            temp (float): starting temperature, default 1
            bias (float): bias term added to all logits after scaling with temperature, default 0
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.bias = bias
        self.maxiter = maxiter
        self.solver = solver
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict(probs, temp=x[0], bias=x[1])    
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """
        
        # Initial guess for [temperature, bias]
        initial_guess = [1, 0]
        opt = minimize(self._loss_fun, x0=initial_guess, args=(logits, true), 
                       options={'maxiter':self.maxiter}, method=self.solver)
        
        self.temp = opt.x[0]  # Update temperature
        self.bias = opt.x[1]  # Update bias
        
        return opt
        
    def predict(self, logits, temp=None, bias=None):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set, uses temperature found by the model or previously set.
            bias: if not set, uses bias found by the model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        # if in inference time use the trained temp and bias 
        if temp is None:
            temp = self.temp
        if bias is None:
            bias = self.bias
        
        # Scale logits by temperature and add bias
        scaled_logits = (logits / temp) + bias
        return softmax(scaled_logits)
    

# monotonic vector scaling with class
# although the paramewter is named temp but it is just a weight that we multiply to the logits
class VectorScaling_mono():
    
    def __init__(self, temp=1, bias=0, maxiter=50, solver="SLSQP"):
        """
        Initialize class
        
        Params:
            temp (float): starting temperature, default 1
            bias (float): bias term added to all logits after scaling with temperature, default 0
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
        """
        self.temp = temp
        self.bias = bias
        self.maxiter = maxiter
        self.solver = solver
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_train(probs, temp=x[:self.nclass], bias=x[self.nclass:])    
        loss = log_loss(y_true=true, y_pred=scaled_probs)
        return loss
    
    def ascending_constraints(self, params, nclass):
        # Extract temperature and bias from params
        t = params[:nclass]
        b = params[nclass:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for temperature
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        temp_init = np.linspace(0, 1, num=self.nclass)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.nclass)
        
        initial_guess = np.concatenate((temp_init, bias_init), axis=0)

        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.nclass,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter}, method=self.solver, constraints=constraints)
        
        self.temp = opt.x[:self.nclass]  # Update temperature
        self.bias = opt.x[self.nclass:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set, uses temperature found by the model or previously set.
            bias: if not set, uses bias found by the model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        # Scale logits by temperature and add bias
        scaled_logits = (logits * temp) + bias
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index
        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)        
        scaled_logits = (logits_sorted * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][i]

        return softmax(scaled_logits_original)
    

class VectorScaling_mono_topk_2():
    
    def __init__(self, maxiter=50, solver="SLSQP", topk=10, loss = "ce"):
        """
        Initialize class
        
        Params:
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
            solver (str): optimization algorithm used by scipy.optimize.minimize
            topk (int): number of classes to be considered for monotonic vector scaling
        """
        self.maxiter = maxiter
        self.solver = solver
        self.topk = topk
        self.loss = loss
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_train(probs, temp=x[:self.topk], bias=x[self.topk:])
        if self.loss == "ce":
            #loss = log_loss(y_true=true, y_pred=scaled_probs)
            loss = -np.sum(true*np.log(scaled_probs))/probs.shape[0]
        elif self.loss == "mse":
            loss = mean_squared_error(y_true=true, y_pred=scaled_probs)    

        return loss
    
    def ascending_constraints(self, params, topk):
        # Extract temperature and bias from params
        t = params[:topk]
        b = params[topk:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for temperature
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        # maybe consider normalizing the logits before applying the scaling?
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        # only consider topk classes check if the labels are in the tok k classes
        # do we need to take them out? The ones that the labels are not in topk classes? Should not make any difference should it?
        # for now we do not take them out justkeep them might be a problem for imagenet
        sorted_logits = sorted_logits[:, -self.topk:]
        true_sorted = true_sorted[:, -self.topk:]


        temp_init = np.linspace(0, 1, num=self.topk)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.topk)
        # bias_init = np.linspace(0, 1, num=self.topk)

        bounds_weight = [(0, 1) for i in range(self.topk)]
        bounds_bias = [(-1, 1) for i in range(self.topk)]
        bounds = bounds_weight + bounds_bias

        initial_guess = np.concatenate((temp_init, bias_init), axis=0)

        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.topk,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter, 'disp':True}, method=self.solver, constraints=constraints)
        
        self.temp = opt.x[:self.topk]  # Update temperature
        self.bias = opt.x[self.topk:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set, uses temperature found by the model or previously set.
            bias: if not set, uses bias found by the model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        # Scale logits by temperature and add bias
        scaled_logits = (logits * temp) + bias
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index

        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)

        logits_sorted_k = logits_sorted[:, -self.topk:]
        # logit_sort_index_k = logit_sort_index[:, -self.topk:]

        scaled_logits = (logits_sorted_k * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        min_temp = self.temp[0]
        min_bias = self.bias[0]
        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                # check if the index is in the topk classes
                # keep putting zeros untile reach topk
                if i < self.nclass - self.topk:
                    # keep zero or put the original value?
                    scaled_logits_original[a][logit_sort_index[a][i]] = logits_sorted[a][i] * min_temp + min_bias
                else:
                    scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][self.topk - (self.nclass - i)]

        return softmax(scaled_logits_original)
    

class VectorScaling_mono_topk_2_filter():
    
    def __init__(self, maxiter=50, solver="SLSQP", topk=10, loss = "ce"):
        """
        Initialize class
        
        Params:
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
            solver (str): optimization algorithm used by scipy.optimize.minimize
            topk (int): number of classes to be considered for monotonic vector scaling
        """
        self.maxiter = maxiter
        self.solver = solver
        self.topk = topk
        self.loss = loss
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_train(probs, temp=x[:self.topk], bias=x[self.topk:])
        if self.loss == "ce":
            #loss = log_loss(y_true=true, y_pred=scaled_probs)
            loss = -np.sum(true*np.log(scaled_probs))/probs.shape[0]
        elif self.loss == "mse":
            loss = mean_squared_error(y_true=true, y_pred=scaled_probs)    

        return loss
    
    def ascending_constraints(self, params, topk):
        # Extract temperature and bias from params
        t = params[:topk]
        b = params[topk:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for temperature
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        # maybe consider normalizing the logits before applying the scaling?
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        sorted_logits = sorted_logits[:, -self.topk:]
        true_sorted = true_sorted[:, -self.topk:]

        # only consider topk classes check if the labels are in the tok k classes
        indicator = true_sorted.sum(axis=1)
        index_of_match = np.where(indicator == 1)[0]
        sorted_logits = sorted_logits[index_of_match, :]
        true_sorted = true_sorted[index_of_match, :]

        temp_init = np.linspace(0, 1, num=self.topk)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.topk)
        # bias_init = np.linspace(0, 1, num=self.topk)

        bounds_weight = [(0, 1) for i in range(self.topk)]
        bounds_bias = [(-1, 1) for i in range(self.topk)]
        bounds = bounds_weight + bounds_bias

        initial_guess = np.concatenate((temp_init, bias_init), axis=0)

        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.topk,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter, 'disp':True}, method=self.solver, constraints=constraints, bounds=bounds)
        
        self.temp = opt.x[:self.topk]  # Update temperature
        self.bias = opt.x[self.topk:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities
        
        Params:
            logits: logits values of data (output from neural network) for each class (shape [samples, classes])
            temp: if not set, uses temperature found by the model or previously set.
            bias: if not set, uses bias found by the model or previously set.
            
        Returns:
            calibrated probabilities (nd.array with shape [samples, classes])
        """
        
        # Scale logits by temperature and add bias
        scaled_logits = (logits * temp) + bias
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index

        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)

        logits_sorted_k = logits_sorted[:, -self.topk:]
        # logit_sort_index_k = logit_sort_index[:, -self.topk:]

        scaled_logits = (logits_sorted_k * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        min_temp = self.temp[0]
        min_bias = self.bias[0]
        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                # check if the index is in the topk classes
                # keep putting zeros untile reach topk
                if i < self.nclass - self.topk:
                    # keep zero or put the original value?
                    scaled_logits_original[a][logit_sort_index[a][i]] = logits_sorted[a][i] * min_temp + min_bias
                else:
                    scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][self.topk - (self.nclass - i)]

        return softmax(scaled_logits_original)



class VectorScaling_mono_topk_3():
    
    def __init__(self, maxiter=50, solver="SLSQP", topk=10, loss = "ce"):
        """
        Initialize class
        
        Params:
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
            solver (str): optimization algorithm used by scipy.optimize.minimize
            topk (int): number of classes to be considered for monotonic vector scaling
        """
        self.maxiter = maxiter
        self.solver = solver
        self.topk = topk
        self.loss = loss
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_train(probs, temp=x[:self.topk], bias=x[self.topk:])
        if self.loss == "ce":
            # loss = log_loss(y_true=true, y_pred=scaled_probs)
            loss = -np.sum(true*np.log(scaled_probs))/probs.shape[0]
        elif self.loss == "mse":
            loss = mean_squared_error(y_true=true, y_pred=scaled_probs)    

        return loss
    
    def ascending_constraints(self, params, topk):
        # Extract temperature and bias from params
        t = params[:topk]
        b = params[topk:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for weight
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        # maybe consider normalizing the logits before applying the scaling?
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)
    

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        # only consider topk classes check if the labels are in the tok k classes
        # do we need to take them out? The ones that the labels are not in topk classes? Should not make any difference should it?
        # for now we do not take them out justkeep them might be a problem for imagenet

        # sorted_logits = sorted_logits[:, -self.topk:]
        # true_sorted = true_sorted[:, -self.topk:]


        temp_init = np.linspace(0, 1, num=self.topk)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.topk)
        # bias_init = np.linspace(0, 1, num=self.topk)
        
        initial_guess = np.concatenate((temp_init, bias_init), axis=0)


        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.topk,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter, 'disp':True}, method=self.solver, constraints=constraints, tol=1e-12)
        
        self.temp = opt.x[:self.topk]  # Update temperature
        self.bias = opt.x[self.topk:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities.
        """
        # Pre-allocate array for scaled logits
        scaled_logits = np.empty_like(logits)

        # Scale top-k logits
        scaled_logits[:, -self.topk:] = (logits[:, -self.topk:] * temp) + bias

        # Scale remaining logits using min_temp and min_bias
        min_temp = temp[0]
        min_bias = bias[0]
        scaled_logits[:, :-self.topk] = (logits[:, :-self.topk] * min_temp) + min_bias

        # Apply softmax
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index

        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)

        logits_sorted_k = logits_sorted[:, -self.topk:]
        # logit_sort_index_k = logit_sort_index[:, -self.topk:]

        scaled_logits = (logits_sorted_k * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        min_temp = self.temp[0]
        min_bias = self.bias[0]
        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                # check if the index is in the topk classes
                # keep putting zeros untile reach topk
                if i < self.nclass - self.topk:
                    # keep zero or put the original value?
                    scaled_logits_original[a][logit_sort_index[a][i]] = logits_sorted[a][i] * min_temp + min_bias
                else:
                    scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][self.topk - (self.nclass - i)]

        return softmax(scaled_logits_original)
    


class VectorScaling_mono_topk_3_bnd():
    
    def __init__(self, maxiter=50, solver="SLSQP", topk=10, loss = "ce"):
        """
        Initialize class
        
        Params:
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
            solver (str): optimization algorithm used by scipy.optimize.minimize
            topk (int): number of classes to be considered for monotonic vector scaling
        """
        self.maxiter = maxiter
        self.solver = solver
        self.topk = topk
        self.loss = loss
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        print(probs.shape)
        scaled_probs = self.predict_train(probs, temp=x[:self.topk], bias=x[self.topk:])
        if self.loss == "ce":
            # loss = log_loss(y_true=true, y_pred=scaled_probs)
            loss = -np.sum(true*np.log(scaled_probs))/probs.shape[0]
        elif self.loss == "mse":
            loss = mean_squared_error(y_true=true, y_pred=scaled_probs)    

        return loss
    
    def ascending_constraints(self, params, topk):
        # Extract temperature and bias from params
        t = params[:topk]
        b = params[topk:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for weight
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        # maybe consider normalizing the logits before applying the scaling?
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)
    

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        # only consider topk classes check if the labels are in the tok k classes
        # do we need to take them out? The ones that the labels are not in topk classes? Should not make any difference should it?
        # for now we do not take them out justkeep them might be a problem for imagenet

        # sorted_logits = sorted_logits[:, -self.topk:]
        # true_sorted = true_sorted[:, -self.topk:]


        temp_init = np.linspace(0, 1, num=self.topk)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.topk)
        # bias_init = np.linspace(0, 1, num=self.topk)
        
        initial_guess = np.concatenate((temp_init, bias_init), axis=0)

        bounds_weight = [(0, 1) for i in range(self.topk)]
        bounds_bias = [(-1, 1) for i in range(self.topk)]
        bounds = bounds_weight + bounds_bias

        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.topk,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter, 'disp':True}, method=self.solver, constraints=constraints, bounds=bounds)
        
        self.temp = opt.x[:self.topk]  # Update temperature
        self.bias = opt.x[self.topk:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities.
        """
        # Pre-allocate array for scaled logits
        scaled_logits = np.empty_like(logits)

        # Scale top-k logits
        scaled_logits[:, -self.topk:] = (logits[:, -self.topk:] * temp) + bias

        # Scale remaining logits using min_temp and min_bias
        min_temp = temp[0]
        min_bias = bias[0]
        scaled_logits[:, :-self.topk] = (logits[:, :-self.topk] * min_temp) + min_bias

        # Apply softmax
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index

        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)

        logits_sorted_k = logits_sorted[:, -self.topk:]
        # logit_sort_index_k = logit_sort_index[:, -self.topk:]

        scaled_logits = (logits_sorted_k * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        min_temp = self.temp[0]
        min_bias = self.bias[0]
        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                # check if the index is in the topk classes
                # keep putting zeros untile reach topk
                if i < self.nclass - self.topk:
                    # keep zero or put the original value?
                    scaled_logits_original[a][logit_sort_index[a][i]] = logits_sorted[a][i] * min_temp + min_bias
                else:
                    scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][self.topk - (self.nclass - i)]

        return softmax(scaled_logits_original)
    


class VectorScaling_mono_topk_3_filter():
    
    def __init__(self, maxiter=50, solver="SLSQP", topk=10, loss = "ce"):
        """
        Initialize class
        
        Params:
            maxiter (int): maximum iterations done by optimizer, however 8 iterations have been maximum.
            solver (str): optimization algorithm used by scipy.optimize.minimize
            topk (int): number of classes to be considered for monotonic vector scaling
        """
        self.maxiter = maxiter
        self.solver = solver
        self.topk = topk
        self.loss = loss
    
    def _loss_fun(self, x, probs, true):
        # Calculates the loss using log-loss (cross-entropy loss)
        scaled_probs = self.predict_train(probs, temp=x[:self.topk], bias=x[self.topk:])
        if self.loss == "ce":
            # loss = log_loss(y_true=true, y_pred=scaled_probs)
            loss = -np.sum(true*np.log(scaled_probs))/probs.shape[0]
        elif self.loss == "mse":
            loss = mean_squared_error(y_true=true, y_pred=scaled_probs)    

        return loss
    
    def ascending_constraints(self, params, topk):
        # Extract temperature and bias from params
        t = params[:topk]
        b = params[topk:]
        
        # Constraints: t[i] > t[i-1] and b[i] > b[i-1]
        temp_constraints = np.diff(t)  # Differences for weight
        bias_constraints = np.diff(b)  # Differences for bias
        
        # Return the concatenated array of differences, which should be all positive
        return np.concatenate((temp_constraints, bias_constraints))
    
    # Find the temperature and bias
    def fit(self, logits, true):
        """
        Trains the model and finds optimal temperature and bias
        
        Params:
            logits: the output from neural network for each class (shape [samples, classes])
            true: true labels in one hot encoding need to sort this too.
            
        Returns:
            the results of optimizer after minimizing is finished.
        """

        # maybe consider normalizing the logits before applying the scaling?
        self.nclass = logits.shape[1]
        sorted_logits = np.sort(logits, axis=1)
    

        logit_sort_index = np.argsort(logits, axis=1)
        true_sorted = true.copy()
        for a in range(logit_sort_index.shape[0]):
            true_sorted[a] = true_sorted[a][logit_sort_index[a]]

        # only consider topk classes check if the labels are in the tok k classes
        # do we need to take them out? The ones that the labels are not in topk classes? Should not make any difference should it?
        # for now we do not take them out justkeep them might be a problem for imagenet
        true_sorted_k = true_sorted[:, -self.topk:]
        
        indicator = true_sorted_k.sum(axis=1)
        self.indicator = indicator
        index_of_match = np.where(indicator == 1)[0]
        sorted_logits = sorted_logits[index_of_match, :]
        true_sorted = true_sorted[index_of_match, :]

        temp_init = np.linspace(0, 1, num=self.topk)
        # temp_init[-1] = 1
        bias_init = np.zeros(self.topk)
        # bias_init = np.linspace(0, 1, num=self.topk)
        
        initial_guess = np.concatenate((temp_init, bias_init), axis=0)


        constraints = {'type': 'ineq', 'fun': self.ascending_constraints, 'args': (self.topk,)}

        opt = minimize(self._loss_fun, x0=initial_guess, args=(sorted_logits, true_sorted), 
                       options={'maxiter':self.maxiter, 'disp':True}, method=self.solver, constraints=constraints)
        
        self.temp = opt.x[:self.topk]  # Update temperature
        self.bias = opt.x[self.topk:]  # Update bias
        
        return opt
        
    def predict_train(self, logits, temp, bias):
        """
        Scales logits based on the temperature and adds bias, returning calibrated probabilities.
        """
        # Pre-allocate array for scaled logits
        scaled_logits = np.empty_like(logits)

        # Scale top-k logits
        scaled_logits[:, -self.topk:] = (logits[:, -self.topk:] * temp) + bias

        # Scale remaining logits using min_temp and min_bias
        min_temp = temp[0]
        min_bias = bias[0]
        scaled_logits[:, :-self.topk] = (logits[:, :-self.topk] * min_temp) + min_bias

        # Apply softmax
        return softmax(scaled_logits)
    
    def predict(self, logits):
        # sort the logits and keep the index

        logits_sorted = np.sort(logits, axis=1)
        logit_sort_index = np.argsort(logits, axis=1)

        logits_sorted_k = logits_sorted[:, -self.topk:]
        # logit_sort_index_k = logit_sort_index[:, -self.topk:]

        scaled_logits = (logits_sorted_k * self.temp) + self.bias
        # return the logits to the original order
        scaled_logits_original = np.zeros_like(logits)

        min_temp = self.temp[0]
        min_bias = self.bias[0]
        for a in range(logit_sort_index.shape[0]):
            for i in range(logit_sort_index.shape[1]):
                # check if the index is in the topk classes
                # keep putting zeros untile reach topk
                if i < self.nclass - self.topk:
                    # keep zero or put the original value?
                    scaled_logits_original[a][logit_sort_index[a][i]] = logits_sorted[a][i] * min_temp + min_bias
                else:
                    scaled_logits_original[a][logit_sort_index[a][i]] = scaled_logits[a][self.topk - (self.nclass - i)]

        return softmax(scaled_logits_original)