import numpy as np
from scipy import optimize
from scipy.special import softmax
from model_transformer.utils import one_hot_encode
from sklearn.isotonic import IsotonicRegression
from model_transformer.spline import get_spline_calib_func, spline_calibrate
from tqdm import tqdm
from functools import partial

from multiprocessing import Pool, cpu_count

import torch
import torch.nn as nn
import torch.optim as optim

class BaseCalibration:
    """
    Base class for calibration methods.

    This class provides a template for implementing calibration methods
    with a scikit-learn-like API. It defines the basic structure and
    common methods that all calibration classes should implement.

    Attributes:
        is_fitted (bool): Indicates whether the calibrator has been fitted.
    """

    def __init__(self):
        """Initialize the base calibration class."""
        self.is_fitted = False

    def fit(self, logits, labels):
        """
        Fit the calibrator.

        This method should be implemented by all subclasses.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        raise NotImplementedError("Subclasses should implement this!")

    def transform(self, logits):
        """
        Apply calibration to new data.

        This method should be implemented by all subclasses.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities or logits.
        """
        if not self.is_fitted:
            raise ValueError("This instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.")
        raise NotImplementedError("Subclasses should implement this!")

    def fit_transform(self, logits, labels):
        """
        Fit the calibrator and apply calibration to the input data.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            np.ndarray: Calibrated probabilities or logits.
        """
        return self.fit(logits, labels).transform(logits)

class TemperatureScaling(BaseCalibration):
    """
    Guo et al. - On Calibration of Modern Neural Networks (https://arxiv.org/abs/1706.04599)

    Temperature Scaling calibration method.

    This class implements the Temperature Scaling method for calibrating
    neural network outputs.

    Attributes:
        loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        temperature (float): The temperature parameter for scaling.
    """

    def __init__(self, loss='mse'):
        """
        Initialize the Temperature Scaling calibrator.

        Args:
            loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        """
        super().__init__()
        self.loss = loss
        self.temperature = None

    def fit(self, logits, labels):
        """
        Fit the Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        # Check if labels are one-hot encoded, if not, convert them
        if len(labels.shape) == 1 or labels.shape[1] == 1:
            num_classes = logits.shape[1]
            labels = one_hot_encode(labels, num_classes)

        def objective(t):
            scaled_logits = logits / t
            probs = softmax(scaled_logits, axis=1)
            if self.loss == 'ce':
                return -np.sum(labels * np.log(probs)) / probs.shape[0]
            elif self.loss == 'mse':
                return np.mean((probs - labels) ** 2)
            elif self.loss == 'brier':
                return np.mean(np.sum((probs - labels) ** 2, axis=1))

        self.temperature = optimize.minimize_scalar(
            objective,
            bounds=(0.05, 5.0),
            method='bounded'
        ).x
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated logits.
        """
        return logits / self.temperature

class EnsembleTemperatureScaling(BaseCalibration):
    """
    Zhang et al. Mix-n-Match: Ensemble and Compositional Methods for Uncertainty Calibration in Deep Learning (http://arxiv.org/abs/2003.07329)

    Ensemble Temperature Scaling calibration method.

    This class implements the Ensemble Temperature Scaling method,
    which combines Temperature Scaling with other calibration techniques.

    Attributes:
        loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        temperature (float): The temperature parameter for scaling.
        weights (np.ndarray): Weights for combining different calibration components.
    """

    def __init__(self, loss='mse'):
        """
        Initialize the Ensemble Temperature Scaling calibrator.

        Args:
            loss (str): Loss function to use for optimization ('ce', 'mse', or 'brier').
        """
        super().__init__()
        self.loss = loss
        self.temperature = None
        self.weights = None

    def mse_w(self, w, p0, p1, p2, label):
        ## find optimal weight coefficients with MSE loss function

        p = w[0] * p0 + w[1] * p1 + w[2] * p2
        p = p / np.sum(p, 1)[:, None]
        mse = np.mean((p - label) ** 2)
        return mse


    def ll_w(self, w, p0, p1, p2, label):
        ## find optimal weight coefficients with Cros-Entropy loss function

        p = w[0] * p0 + w[1] * p1 + w[2] * p2
        p = p / np.sum(p, 1)[:, None]
        N = p.shape[0]
        ce = -np.sum(label * np.log(p)) / N
        return ce

    def fit(self, logits, labels):
        """
        Fit the Ensemble Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        n_class = logits.shape[1]
        
        # First, fit temperature scaling
        ts = TemperatureScaling(self.loss)
        self.temperature = ts.fit(logits, labels).temperature

        p0 = softmax(logits / self.temperature, axis=1)
        p1 = softmax(logits, axis=1)
        p2 = np.ones_like(p0) / n_class

        def objective(w):
            p = w[0] * p0 + w[1] * p1 + w[2] * p2
            p = p / np.sum(p, 1)[:, None]
            if self.loss == 'ce':
                return self.ll_w(w, p0, p1, p2, labels)
            elif self.loss == 'mse':
                return self.mse_w(w, p0, p1, p2, labels)

        constraints = ({'type': 'eq', 'fun': lambda x: np.sum(x) - 1})
        bounds = [(0, 1)] * 3
        self.weights = optimize.minimize(objective, (1.0, 0.0, 0.0), method='SLSQP', constraints=constraints, bounds=bounds, tol=1e-12).x
        
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Ensemble Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        p0 = softmax(logits / self.temperature, axis=1)
        p1 = softmax(logits, axis=1)
        p2 = np.ones_like(p0) / logits.shape[1]
        return self.weights[0] * p0 + self.weights[1] * p1 + self.weights[2] * p2

class IsotonicRegressionCalibration(BaseCalibration):
    """

    Zhang et al. Mix-n-Match: Ensemble and Compositional Methods for Uncertainty Calibration in Deep Learning (http://arxiv.org/abs/2003.07329)

    Isotonic Regression calibration method.

    This class implements calibration using Isotonic Regression.

    Attributes:
        ir_model: The fitted Isotonic Regression model.
    """

    def __init__(self):
        """Initialize the Isotonic Regression calibrator."""
        super().__init__()
        self.ir_model = None

    def fit(self, logits, labels):
        """
        Fit the Isotonic Regression calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        from sklearn.isotonic import IsotonicRegression
        probs = softmax(logits, axis=1)
        self.ir_model = IsotonicRegression(out_of_bounds='clip')
        self.ir_model.fit(probs.flatten(), labels.flatten())
        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Isotonic Regression calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        probs = softmax(logits, axis=1)
        return self.ir_model.predict(probs.flatten()).reshape(logits.shape)

class SplineCalibration(BaseCalibration):
    """
    Gupta et al. Spline Calibration (https://arxiv.org/abs/2006.12800)

    Spline calibration method.

    This class implements the Spline calibration method as described in the
    spline_tomani.py file, adapted to fit the BaseCalibration interface.

    Attributes:
        spline_method (str): The spline method to use ('natural', 'parabolic', or 'cubic').
        splines (int): The number of spline knots to use.
        frecal (callable): The recalibration function.
    """

    def __init__(self, spline_method='natural', splines=6):
        """
        Initialize the Spline calibrator.

        Args:
            spline_method (str): The spline method to use ('natural', 'parabolic', or 'cubic').
            splines (int): The number of spline knots to use.
        """
        super().__init__()
        self.spline_method = spline_method
        self.splines = splines
        self.frecal = None

    def fit(self, logits, labels):
        """
        Fit the Spline calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        self.val_logits = logits
        self.val_labels = labels
        self.is_fitted = True
        return self


    
    def transform(self, logits, labels):
        """
        Apply Spline calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        if not self.is_fitted:
            raise ValueError("This SplineCalibration instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.")
        
        
        SPL_frecal, _, _ = get_spline_calib_func(self.val_logits, self.val_labels)
        scores, labels, _ = spline_calibrate(SPL_frecal, logits, labels, n=-1, inclusive=False)

        return np.array(scores), np.array(labels)

class ParameterizedTemperatureScaling(BaseCalibration):
    """
    Parameterized Temperature Scaling (PTS) calibration method.

    This class implements the PTS method for calibrating neural network outputs
    as described in the paper.

    Attributes:
        epochs (int): Number of epochs for PTS model tuning.
        lr (float): Learning rate of PTS model.
        batch_size (int): Batch size for tuning.
        n_layers (int): Number of layers in the PTS model.
        n_nodes (int): Number of nodes in each hidden layer.
        top_k_logits (int): Top k logits used for tuning.
        model (nn.Module): The PTS neural network model.
    """

    def __init__(self, lr=0.00005, batch_size=1000, n_layers=2, n_nodes=5, top_k_logits=10):
        """
        Initialize the Parameterized Temperature Scaling calibrator.

        Args:
            epochs (int): Number of epochs for PTS model tuning. Default is 100.
            lr (float): Learning rate of PTS model. Default is 0.00005.
            batch_size (int): Batch size for tuning. Default is 1000.
            n_layers (int): Number of layers in the PTS model. Default is 2.
            n_nodes (int): Number of nodes in each hidden layer. Default is 5.
            top_k_logits (int): Top k logits used for tuning. Default is 10.
        """
        super().__init__()
        self.epochs = 100 # will be adjusted in fit
        self.lr = lr
        self.batch_size = batch_size
        self.n_layers = n_layers
        self.n_nodes = n_nodes
        self.top_k_logits = top_k_logits
        self.model = None

                # Determine the device to use
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")



    def build_model(self):
        """Build the PTS neural network model."""
        layers = []
        layers.append(nn.Linear(self.top_k_logits, self.n_nodes))
        layers.append(nn.ReLU())
        
        for _ in range(self.n_layers - 1):
            layers.append(nn.Linear(self.n_nodes, self.n_nodes))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(self.n_nodes, 1))
        layers.append(nn.Softplus())  # Ensure positive temperature
        return nn.Sequential(*layers)

    def fit(self, logits, labels, verbose=False):
        """
        Fit the Parameterized Temperature Scaling calibrator.

        Args:
            logits (np.ndarray): Uncalibrated logits.
            labels (np.ndarray): True labels.

        Returns:
            self
        """
        labels = torch.tensor(labels).float()

        # Calculate epochs based on stepsize of 100,000 and given batch size
        total_samples = logits.shape[0]
        steps_per_epoch = total_samples // self.batch_size
        total_steps = 100000
        if steps_per_epoch == 0:
            steps_per_epoch = 1
        self.epochs = max(1, total_steps // steps_per_epoch)

        self.model = self.build_model()

        # Move the model to the selected device
        self.model = self.model.to(self.device)
        logits = logits.to(self.device)
        labels = labels.to(self.device)

        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        criterion = nn.CrossEntropyLoss()  # Use CrossEntropyLoss instead of MSELoss
        
        if verbose:
            epoch_iterator = tqdm(range(self.epochs), desc="Epochs")
        else:
            epoch_iterator = range(self.epochs)

        for epoch in epoch_iterator:
            for i in range(0, logits.shape[0], self.batch_size):
                batch_logits = logits[i:i+self.batch_size]
                batch_labels = labels[i:i+self.batch_size]

                # Sort and select top k logits
                top_k_logits, _ = torch.sort(batch_logits, dim=1, descending=True)
                top_k_logits = top_k_logits[:, :self.top_k_logits]

                temperature = self.model(top_k_logits)
                calibrated_logits = batch_logits / temperature  # Keep as logits, not probabilities

                loss = criterion(calibrated_logits, batch_labels)  # CrossEntropyLoss expects logits
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        self.is_fitted = True
        return self

    def transform(self, logits):
        """
        Apply Parameterized Temperature Scaling calibration to new data.

        Args:
            logits (np.ndarray): Uncalibrated logits to be transformed.

        Returns:
            np.ndarray: Calibrated probabilities.
        """
        if not self.is_fitted:
            raise ValueError("This ParameterizedTemperatureScaling instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.")

        logits = logits.to(self.device)

        top_k_logits, _ = torch.sort(logits, dim=1, descending=True)
        top_k_logits = top_k_logits[:, :self.top_k_logits]

        with torch.no_grad():
            temperature = self.model(top_k_logits)
            calibrated_probs = torch.softmax(logits / temperature, dim=1)

        return calibrated_probs.cpu().numpy()

    def save(self, path):
        """Save PTS model parameters."""
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        """Load PTS model parameters."""
        self.model.load_state_dict(torch.load(path))
        self.is_fitted = True


