# taken from model_transformer package
import numpy as np
from scipy import optimize
from scipy.interpolate import LSQUnivariateSpline
from scipy.special import softmax
from sklearn.isotonic import IsotonicRegression
from tqdm import tqdm
from functools import partial

from multiprocessing import Pool, cpu_count

import torch
import torch.nn as nn
import torch.optim as optim


def one_hot_encode(labels, num_classes):
    """
    Convert integer labels to one-hot encoded labels.
    
    Args:
        labels (np.ndarray): Integer labels.
        num_classes (int): Number of classes.
    
    Returns:
        np.ndarray: One-hot encoded labels.
    """
    return np.eye(num_classes)[labels]


class BaseCalibration:
    """
    Base class for calibration methods.

    This class provides a template for implementing calibration methods
    with a scikit-learn-like API. It defines the basic structure and
    common methods that all calibration classes should implement.

    Attributes:
        is_fitted (bool): Indicates whether the calibrator has been fitted.
    """

    def __init__(self):
        """Initialize the base calibration class."""
        self.is_fitted = False

    def fit(self, logits, labels):
        """
        Fit the calibrator.

        This method should be implemented by all subclasses.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        raise NotImplementedError("Subclasses should implement this!")

    def transform(self, logits):
        """
        Apply calibration to new data.

        This method should be implemented by all subclasses.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities or logits.
        """
        if not self.is_fitted:
            raise ValueError("This instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.")
        raise NotImplementedError("Subclasses should implement this!")

    def fit_transform(self, logits, labels):
        """
        Fit the calibrator and apply calibration to the input data.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            np.ndarray: Calibrated probabilities or logits.
        """
        return self.fit(logits, labels).transform(logits)

class TemperatureScaling(BaseCalibration):
    """
    Guo et al. - On Calibration of Modern Neural Networks (https://arxiv.org/abs/1706.04599)

    Temperature Scaling calibration method.

    This class implements the Temperature Scaling method for calibrating
    neural network outputs.

    Attributes:
        loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        temperature (float): The temperature parameter for scaling.
    """

    def __init__(self, loss='mse'):
        """
        Initialize the Temperature Scaling calibrator.

        Args:
            loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        """
        super().__init__()
        self.loss = loss
        self.temperature = None

    def fit(self, logits, labels):
        """
        Fit the Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        # Check if labels are one-hot encoded, if not, convert them
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            num_classes = logits.shape[1]
            labels = one_hot_encode(labels, num_classes)

        def objective(t):
            scaled_logits = logits / t
            probs = softmax(scaled_logits, axis=1)
            if self.loss == 'ce':
                return -np.sum(labels * np.log(probs + 1e-12)) / probs.shape[0]
            elif self.loss == 'mse':
                return np.mean((probs - labels) ** 2)
            elif self.loss == 'brier':
                return np.mean(np.sum((probs - labels) ** 2, axis=1))

        self.temperature = optimize.minimize_scalar(
            objective,
            bounds=(0.05, 10.0),
            method='bounded'
        ).x
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated logits.
        """
        return logits / self.temperature

class EnsembleTemperatureScaling(BaseCalibration):
    """
    Zhang et al. Mix-n-Match: Ensemble and Compositional Methods for Uncertainty Calibration in Deep Learning (http://arxiv.org/abs/2003.07329)

    Ensemble Temperature Scaling calibration method.

    This class implements the Ensemble Temperature Scaling method,
    which combines Temperature Scaling with other calibration techniques.

    Attributes:
        loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        temperature (float): The temperature parameter for scaling.
        weights (np.ndarray): Weights for combining different calibration components.
    """

    def __init__(self, loss='mse'):
        """
        Initialize the Ensemble Temperature Scaling calibrator.

        Args:
            loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        """
        super().__init__()
        self.loss = loss
        self.temperature = None
        self.weights = None

    def mse_w(self, w, p0, p1, p2, label):
        ## find optimal weight coefficients with MSE loss function

        p = w[0] * p0 + w[1] * p1 + w[2] * p2
        p = p / np.sum(p, 1)[:, None]
        mse = np.mean((p - label) ** 2)
        return mse


    def ll_w(self, w, p0, p1, p2, label):
        ## find optimal weight coefficients with Cros-Entropy loss function

        p = w[0] * p0 + w[1] * p1 + w[2] * p2
        p = p / np.sum(p, 1)[:, None]
        N = p.shape[0]
        ce = -np.sum(label * np.log(p)) / N
        return ce

    def fit(self, logits, labels):
        """
        Fit the Ensemble Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        # Check if labels are one-hot encoded, if not, convert them
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            num_classes = logits.shape[1]
            labels = one_hot_encode(labels, num_classes)

        n_class = logits.shape[1]
        
        # First, fit temperature scaling
        ts = TemperatureScaling(self.loss)
        self.temperature = ts.fit(logits, labels).temperature

        p0 = softmax(logits / self.temperature, axis=1)
        p1 = softmax(logits, axis=1)
        p2 = np.ones_like(p0) / n_class

        def objective(w):
            p = w[0] * p0 + w[1] * p1 + w[2] * p2
            p = p / np.sum(p, 1)[:, None]
            if self.loss == 'ce':
                return self.ll_w(w, p0, p1, p2, labels)
            elif self.loss == 'mse':
                return self.mse_w(w, p0, p1, p2, labels)

        constraints = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1})
        bounds = [(0, 1)] * 3
        self.weights = optimize.minimize(objective, (1.0, 0.0, 0.0), method='SLSQP', constraints=constraints, bounds=bounds, tol=1e-12).x
        
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Ensemble Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        p0 = softmax(logits / self.temperature, axis=1)
        p1 = softmax(logits, axis=1)
        p2 = np.ones_like(p0) / logits.shape[1]
        return self.weights[0] * p0 + self.weights[1] * p1 + self.weights[2] * p2

class IsotonicRegressionCalibration(BaseCalibration):
    """

    Zhang et al. Mix-n-Match: Ensemble and Compositional Methods for Uncertainty Calibration in Deep Learning (http://arxiv.org/abs/2003.07329)

    Isotonic Regression calibration method.

    This class implements calibration using Isotonic Regression.

    Attributes:
        ir_model: The fitted Isotonic Regression model.
    """

    def __init__(self):
        """Initialize the Isotonic Regression calibrator."""
        super().__init__()
        self.ir_model = None

    def fit(self, logits, labels):
        """
        Fit the Isotonic Regression calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        # Check if labels are one-hot encoded, if not, convert them
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            num_classes = logits.shape[1]
            labels = one_hot_encode(labels, num_classes)

        from sklearn.isotonic import IsotonicRegression
        probs = softmax(logits, axis=1)
        self.ir_model = IsotonicRegression(out_of_bounds='clip')
        self.ir_model.fit(probs.flatten(), labels.flatten())
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Isotonic Regression calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        probs = softmax(logits, axis=1)
        return self.ir_model.predict(probs.flatten()).reshape(logits.shape)


class ParameterizedTemperatureScaling(BaseCalibration):
    """
    Parameterized Temperature Scaling (PTS) calibration method.

    This class implements the PTS method for calibrating neural network outputs
    as described in the paper.

    Attributes:
        epochs (int): Number of epochs for PTS model tuning.
        lr (float): Learning rate of PTS model.
        batch_size (int): Batch size for tuning.
        n_layers (int): Number of layers in the PTS model.
        n_nodes (int): Number of nodes in each hidden layer.
        top_k_logits (int): Top k logits used for tuning.
        model (nn.Module): The PTS neural network model.
    """

    def __init__(self, lr=0.00005, batch_size=1000, n_layers=2, n_nodes=5, top_k_logits=10):
        """
        Initialize the Parameterized Temperature Scaling calibrator.

        Args:
            epochs (int): Number of epochs for PTS model tuning. Default is 100.
            lr (float): Learning rate of PTS model. Default is 0.00005.
            batch_size (int): Batch size for tuning. Default is 1000.
            n_layers (int): Number of layers in the PTS model. Default is 2.
            n_nodes (int): Number of nodes in each hidden layer. Default is 5.
            top_k_logits (int): Top k logits used for tuning. Default is 10.
        """
        super().__init__()
        self.epochs = 100 # will be adjusted in fit
        self.lr = lr
        self.batch_size = batch_size
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.top_k_logits = top_k_logits
        self.model = None

                # Determine the device to use
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")



    def build_model(self):
        """Build the PTS neural network model."""
        layers = []
        layers.append(nn.Linear(self.top_k_logits, self.n_nodes))
        layers.append(nn.ReLU())
        
        for _ in range(self.n_layers - 1):
            layers.append(nn.Linear(self.n_nodes, self.n_nodes))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(self.n_nodes, 1))
        layers.append(nn.Softplus())  # Ensure positive temperature
        return nn.Sequential(*layers)

    def fit(self, logits, labels, verbose=False):
        """
        Fit the Parameterized Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        # Check if labels are one-hot encoded, if not, convert them
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            num_classes = logits.shape[1]
            labels = one_hot_encode(labels, num_classes)
            
        labels = torch.tensor(labels).float()

        # Calculate epochs based on stepsize of 100,000 and given batch size
        total_samples = logits.shape[0]
        steps_per_epoch = total_samples // self.batch_size
        total_steps = 100000
        if steps_per_epoch == 0:
            steps_per_epoch = 1
        self.epochs = max(1, total_steps // steps_per_epoch)

        self.model = self.build_model()

        # Move the model to the selected device
        self.model = self.model.to(self.device)
        logits = logits.to(self.device)
        labels = labels.to(self.device)

        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss instead of MSELoss
        
        if verbose:
            epoch_iterator = tqdm(range(self.epochs), desc="Epochs")
        else:
            epoch_iterator = range(self.epochs)

        for epoch in epoch_iterator:
            for i in range(0, logits.shape[0], self.batch_size):
                batch_logits = logits[i:i+self.batch_size]
                batch_labels = labels[i:i+self.batch_size]

                # Sort and select top k logits
                top_k_logits, _ = torch.sort(batch_logits, dim=1, descending=True)
                top_k_logits = top_k_logits[:, :self.top_k_logits]

                temperature = self.model(top_k_logits)
                calibrated_logits = batch_logits / temperature  # Keep as logits, not probabilities

                loss = criterion(calibrated_logits, batch_labels)  # CrossEntropyLoss expects logits
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Parameterized Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        if not self.is_fitted:
            raise ValueError("This ParameterizedTemperatureScaling instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.")

        logits = logits.to(self.device)

        top_k_logits, _ = torch.sort(logits, dim=1, descending=True)
        top_k_logits = top_k_logits[:, :self.top_k_logits]

        with torch.no_grad():
            temperature = self.model(top_k_logits)
            calibrated_probs = torch.softmax(logits / temperature, dim=1)

        return calibrated_probs.cpu()

    def save(self, path):
        """Save PTS model parameters."""
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        """Load PTS model parameters."""
        self.model.load_state_dict(torch.load(path))
        self.is_fitted = True

