"""
Baseline Calibration Methods for Comparison with HPC

This module implements standard post-hoc calibration methods used as baselines
in the HPC paper:
- Temperature Scaling (TS)
- Vector Scaling 
- Matrix Scaling
- Dirichlet Calibration
- Histogram Binning
- Isotonic Regression
- Ensemble methods
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple, Dict, List
from scipy.optimize import minimize
from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import calibration_curve
import warnings


class TemperatureScaling(nn.Module):
    """
    Temperature Scaling calibration method.
    
    Applies a single temperature parameter T to all logits:
    p' = softmax(z/T)
    
    This is the most common baseline method mentioned in the paper.
    """
    
    def __init__(self, temperature: float = 1.0):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1) * temperature)
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Apply temperature scaling to logits.
        
        Args:
            logits: Raw model logits (batch_size, num_classes)
            
        Returns:
            Temperature-scaled probabilities
        """
        return F.softmax(logits / self.temperature, dim=1)
    
    def fit(
        self, 
        logits: torch.Tensor, 
        targets: torch.Tensor,
        lr: float = 0.01,
        max_iter: int = 50
    ) -> float:
        """
        Fit temperature parameter using validation data.
        
        Args:
            logits: Validation logits (N, K)
            targets: True labels (N,)
            lr: Learning rate for optimization
            max_iter: Maximum optimization iterations
            
        Returns:
            Final NLL loss
        """
        optimizer = torch.optim.LBFGS([self.temperature], lr=lr, max_iter=max_iter)
        
        def eval_loss():
            optimizer.zero_grad()
            loss = F.cross_entropy(logits / self.temperature, targets)
            loss.backward()
            return loss
        
        optimizer.step(eval_loss)
        
        # Return final loss
        with torch.no_grad():
            final_loss = F.cross_entropy(logits / self.temperature, targets)
            return final_loss.item()


class VectorScaling(nn.Module):
    """
    Vector Scaling (also called Platt Scaling) - per-class temperature scaling.
    
    Applies different temperature to each class:
    p'_k = exp(z_k/T_k) / Σ_j exp(z_j/T_j)
    """
    
    def __init__(self, num_classes: int):
        super().__init__()
        self.num_classes = num_classes
        self.temperatures = nn.Parameter(torch.ones(num_classes))
        self.bias = nn.Parameter(torch.zeros(num_classes))
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """Apply vector scaling to logits."""
        scaled_logits = logits / self.temperatures + self.bias
        return F.softmax(scaled_logits, dim=1)
    
    def fit(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        lr: float = 0.01,
        max_iter: int = 100
    ) -> float:
        """Fit per-class temperature and bias parameters."""
        optimizer = torch.optim.LBFGS(
            [self.temperatures, self.bias], 
            lr=lr, 
            max_iter=max_iter
        )
        
        def eval_loss():
            optimizer.zero_grad()
            scaled_logits = logits / self.temperatures + self.bias
            loss = F.cross_entropy(scaled_logits, targets)
            loss.backward()
            return loss
        
        optimizer.step(eval_loss)
        
        with torch.no_grad():
            scaled_logits = logits / self.temperatures + self.bias
            final_loss = F.cross_entropy(scaled_logits, targets)
            return final_loss.item()


class MatrixScaling(nn.Module):
    """
    Matrix Scaling - full affine transformation of logits.
    
    Applies: z' = W @ z + b
    where W is a learned matrix and b is a learned bias.
    """
    
    def __init__(self, num_classes: int):
        super().__init__()
        self.num_classes = num_classes
        self.weight = nn.Parameter(torch.eye(num_classes))
        self.bias = nn.Parameter(torch.zeros(num_classes))
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """Apply matrix scaling to logits."""
        scaled_logits = torch.matmul(logits, self.weight.T) + self.bias
        return F.softmax(scaled_logits, dim=1)
    
    def fit(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor,
        lr: float = 0.01,
        max_iter: int = 100
    ) -> float:
        """Fit weight matrix and bias parameters."""
        optimizer = torch.optim.LBFGS(
            [self.weight, self.bias],
            lr=lr,
            max_iter=max_iter
        )
        
        def eval_loss():
            optimizer.zero_grad()
            scaled_logits = torch.matmul(logits, self.weight.T) + self.bias
            loss = F.cross_entropy(scaled_logits, targets)
            loss.backward()
            return loss
        
        optimizer.step(eval_loss)
        
        with torch.no_grad():
            scaled_logits = torch.matmul(logits, self.weight.T) + self.bias
            final_loss = F.cross_entropy(scaled_logits, targets)
            return final_loss.item()


class DirichletCalibration:
    """
    Dirichlet Calibration using off-manifold regularization.
    
    Fits a Dirichlet distribution to model outputs and applies
    regularization to improve calibration.
    """
    
    def __init__(self, reg_lambda: float = 1e-3):
        self.reg_lambda = reg_lambda
        self.weights = None
        self.bias = None
    
    def fit(
        self,
        logits: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        Fit Dirichlet calibration parameters.
        
        Args:
            logits: Validation logits (N, K)
            targets: True labels (N,)
            
        Returns:
            Final NLL loss
        """
        logits_np = logits.detach().cpu().numpy()
        targets_np = targets.cpu().numpy()
        num_classes = logits.shape[1]
        
        # Initialize parameters
        initial_params = np.concatenate([
            np.eye(num_classes).flatten(),  # Weight matrix
            np.zeros(num_classes)  # Bias
        ])
        
        def objective(params):
            # Reshape parameters
            weights = params[:num_classes*num_classes].reshape(num_classes, num_classes)
            bias = params[num_classes*num_classes:]
            
            # Apply transformation
            transformed_logits = np.dot(logits_np, weights.T) + bias
            
            # Softmax
            exp_logits = np.exp(transformed_logits - np.max(transformed_logits, axis=1, keepdims=True))
            probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
            
            # Cross-entropy loss
            nll = -np.mean(np.log(probs[np.arange(len(targets_np)), targets_np] + 1e-8))
            
            # L2 regularization
            reg = self.reg_lambda * (np.sum(weights**2) + np.sum(bias**2))
            
            return nll + reg
        
        # Optimize
        result = minimize(objective, initial_params, method='BFGS')
        
        # Store fitted parameters
        num_classes = logits.shape[1]
        self.weights = result.x[:num_classes*num_classes].reshape(num_classes, num_classes)
        self.bias = result.x[num_classes*num_classes:]
        
        return result.fun
    
    def predict_proba(self, logits: torch.Tensor) -> torch.Tensor:
        """Apply fitted calibration to new logits."""
        if self.weights is None:
            raise ValueError("Model must be fitted before prediction")
        
        logits_np = logits.detach().cpu().numpy()
        
        # Apply transformation
        transformed_logits = np.dot(logits_np, self.weights.T) + self.bias
        
        # Softmax
        exp_logits = np.exp(transformed_logits - np.max(transformed_logits, axis=1, keepdims=True))
        probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
        
        return torch.from_numpy(probs).float()


class HistogramBinning:
    """
    Histogram Binning calibration method.
    
    Bins predictions by confidence and calibrates within each bin.
    """
    
    def __init__(self, num_bins: int = 15):
        self.num_bins = num_bins
        self.bin_boundaries = None
        self.bin_calibrators = {}
    
    def fit(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        Fit histogram binning calibrator.
        
        Args:
            probabilities: Model probabilities (N, K)
            targets: True labels (N,)
            
        Returns:
            Fitted calibration loss
        """
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        
        # Create bin boundaries based on confidence quantiles
        self.bin_boundaries = torch.quantile(
            confidences, 
            torch.linspace(0, 1, self.num_bins + 1)
        )
        
        # Fit calibrator for each bin
        total_loss = 0.0
        total_samples = 0
        
        for i in range(self.num_bins):
            bin_lower = self.bin_boundaries[i]
            bin_upper = self.bin_boundaries[i + 1]
            
            # Find samples in this bin
            in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
            if i == self.num_bins - 1:  # Include upper boundary for last bin
                in_bin = (confidences >= bin_lower) & (confidences <= bin_upper)
            
            if in_bin.sum() > 0:
                bin_predictions = predictions[in_bin]
                bin_targets = targets[in_bin]
                
                # Compute bin accuracy
                bin_accuracy = (bin_predictions == bin_targets).float().mean()
                self.bin_calibrators[i] = bin_accuracy.item()
                
                # Compute bin loss
                bin_loss = F.binary_cross_entropy(
                    torch.full_like(bin_predictions.float(), bin_accuracy),
                    (bin_predictions == bin_targets).float(),
                    reduction='sum'
                )
                total_loss += bin_loss
                total_samples += in_bin.sum()
        
        return total_loss / total_samples if total_samples > 0 else 0.0
    
    def predict_proba(self, probabilities: torch.Tensor) -> torch.Tensor:
        """Apply histogram binning calibration."""
        if self.bin_boundaries is None:
            raise ValueError("Model must be fitted before prediction")
        
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        calibrated_probs = probabilities.clone()
        
        for i in range(self.num_bins):
            if i not in self.bin_calibrators:
                continue
                
            bin_lower = self.bin_boundaries[i]
            bin_upper = self.bin_boundaries[i + 1]
            
            # Find samples in this bin
            in_bin = (confidences >= bin_lower) & (confidences < bin_upper)
            if i == self.num_bins - 1:
                in_bin = (confidences >= bin_lower) & (confidences <= bin_upper)
            
            if in_bin.sum() > 0:
                bin_accuracy = self.bin_calibrators[i]
                
                # Adjust probabilities in this bin
                for j in torch.where(in_bin)[0]:
                    pred_class = predictions[j].item()
                    calibrated_probs[j] = probabilities[j] * 0.1  # Reduce non-predicted classes
                    calibrated_probs[j, pred_class] = bin_accuracy
                    
                    # Renormalize
                    calibrated_probs[j] = calibrated_probs[j] / calibrated_probs[j].sum()
        
        return calibrated_probs


class IsotonicRegressionCalibration:
    """
    Isotonic Regression calibration for binary and multiclass problems.
    
    Fits isotonic regression to map confidences to calibrated probabilities.
    """
    
    def __init__(self):
        self.calibrators = {}  # One per class for multiclass
        self.is_binary = False
    
    def fit(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        Fit isotonic regression calibrators.
        
        Args:
            probabilities: Model probabilities (N, K)
            targets: True labels (N,)
            
        Returns:
            Fitted calibration loss
        """
        num_classes = probabilities.shape[1]
        self.is_binary = (num_classes == 2)
        
        if self.is_binary:
            # Binary classification - single calibrator
            confidences = probabilities[:, 1].cpu().numpy()
            binary_targets = (targets == 1).float().cpu().numpy()
            
            calibrator = IsotonicRegression(out_of_bounds='clip')
            calibrator.fit(confidences, binary_targets)
            self.calibrators[0] = calibrator
            
        else:
            # Multiclass - one-vs-rest approach
            for class_idx in range(num_classes):
                class_probs = probabilities[:, class_idx].cpu().numpy()
                binary_targets = (targets == class_idx).float().cpu().numpy()
                
                calibrator = IsotonicRegression(out_of_bounds='clip')
                try:
                    calibrator.fit(class_probs, binary_targets)
                    self.calibrators[class_idx] = calibrator
                except ValueError:
                    # Handle case where all targets are the same
                    self.calibrators[class_idx] = lambda x: np.mean(binary_targets)
        
        # Compute calibration loss
        calibrated_probs = self.predict_proba(probabilities)
        loss = F.cross_entropy(torch.log(calibrated_probs + 1e-8), targets)
        return loss.item()
    
    def predict_proba(self, probabilities: torch.Tensor) -> torch.Tensor:
        """Apply isotonic regression calibration."""
        if not self.calibrators:
            raise ValueError("Model must be fitted before prediction")
        
        if self.is_binary:
            # Binary case
            confidences = probabilities[:, 1].cpu().numpy()
            calibrated_pos = self.calibrators[0].predict(confidences)
            calibrated_probs = torch.zeros_like(probabilities)
            calibrated_probs[:, 0] = torch.from_numpy(1 - calibrated_pos)
            calibrated_probs[:, 1] = torch.from_numpy(calibrated_pos)
            
        else:
            # Multiclass case
            num_classes = probabilities.shape[1]
            calibrated_probs = torch.zeros_like(probabilities)
            
            for class_idx in range(num_classes):
                if class_idx in self.calibrators:
                    class_probs = probabilities[:, class_idx].cpu().numpy()
                    if callable(self.calibrators[class_idx]):
                        calibrated_class_probs = self.calibrators[class_idx](class_probs)
                        if np.isscalar(calibrated_class_probs):
                            calibrated_class_probs = np.full_like(class_probs, calibrated_class_probs)
                    else:
                        calibrated_class_probs = self.calibrators[class_idx].predict(class_probs)
                    calibrated_probs[:, class_idx] = torch.from_numpy(calibrated_class_probs)
                else:
                    calibrated_probs[:, class_idx] = probabilities[:, class_idx]
            
            # Renormalize to ensure probabilities sum to 1
            calibrated_probs = F.normalize(calibrated_probs, p=1, dim=1)
        
        return calibrated_probs


def create_ensemble_predictions(
    predictions_list: List[torch.Tensor],
    method: str = 'average'
) -> torch.Tensor:
    """
    Create ensemble predictions from multiple models.
    
    Args:
        predictions_list: List of probability tensors from different models
        method: Ensemble method ('average', 'weighted_average')
        
    Returns:
        Ensemble predictions
    """
    if method == 'average':
        return torch.stack(predictions_list).mean(dim=0)
    elif method == 'weighted_average':
        # Simple equal weighting for now
        return torch.stack(predictions_list).mean(dim=0)
    else:
        raise ValueError(f"Unknown ensemble method: {method}")


# Convenience function for fitting all baseline methods
def fit_all_baselines(
    logits: torch.Tensor,
    probabilities: torch.Tensor,
    targets: torch.Tensor,
    methods: Optional[List[str]] = None
) -> Dict[str, object]:
    """
    Fit all baseline calibration methods on validation data.
    
    Args:
        logits: Validation logits (N, K)
        probabilities: Validation probabilities (N, K) 
        targets: True labels (N,)
        methods: List of methods to fit (None = all)
        
    Returns:
        Dictionary mapping method names to fitted calibrators
    """
    if methods is None:
        methods = [
            'temperature_scaling', 'vector_scaling', 'matrix_scaling',
            'dirichlet', 'histogram_binning', 'isotonic_regression'
        ]
    
    num_classes = logits.shape[1]
    calibrators = {}
    
    for method in methods:
        print(f"Fitting {method}...")
        
        try:
            if method == 'temperature_scaling':
                calibrator = TemperatureScaling()
                loss = calibrator.fit(logits, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}, Temperature: {calibrator.temperature.item():.4f}")
                
            elif method == 'vector_scaling':
                calibrator = VectorScaling(num_classes)
                loss = calibrator.fit(logits, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}")
                
            elif method == 'matrix_scaling':
                calibrator = MatrixScaling(num_classes)
                loss = calibrator.fit(logits, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}")
                
            elif method == 'dirichlet':
                calibrator = DirichletCalibration()
                loss = calibrator.fit(logits, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}")
                
            elif method == 'histogram_binning':
                calibrator = HistogramBinning()
                loss = calibrator.fit(probabilities, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}")
                
            elif method == 'isotonic_regression':
                calibrator = IsotonicRegressionCalibration()
                loss = calibrator.fit(probabilities, targets)
                calibrators[method] = calibrator
                print(f"  Final loss: {loss:.4f}")
                
        except Exception as e:
            print(f"  Failed to fit {method}: {e}")
            continue
    
    return calibrators


# Example usage and testing
if __name__ == "__main__":
    # Create synthetic validation data
    torch.manual_seed(42)
    n_samples = 1000
    n_classes = 10
    
    # Generate overconfident logits
    logits = torch.randn(n_samples, n_classes) * 3
    probabilities = F.softmax(logits, dim=1)
    targets = torch.randint(0, n_classes, (n_samples,))
    
    print("Testing baseline calibration methods...")
    print(f"Data: {n_samples} samples, {n_classes} classes")
    
    # Compute initial metrics
    initial_accuracy = (torch.argmax(probabilities, dim=1) == targets).float().mean()
    initial_nll = F.cross_entropy(logits, targets)
    
    print(f"Initial accuracy: {initial_accuracy:.4f}")
    print(f"Initial NLL: {initial_nll:.4f}")
    
    # Fit all baseline methods
    calibrators = fit_all_baselines(logits, probabilities, targets)
    
    # Test each fitted calibrator
    print("\nTesting fitted calibrators:")
    
    for method_name, calibrator in calibrators.items():
        try:
            if hasattr(calibrator, 'forward'):
                # PyTorch modules
                with torch.no_grad():
                    calibrated_probs = calibrator(logits)
            else:
                # Scikit-learn style calibrators
                calibrated_probs = calibrator.predict_proba(probabilities)
            
            # Compute metrics
            cal_accuracy = (torch.argmax(calibrated_probs, dim=1) == targets).float().mean()
            cal_nll = F.cross_entropy(torch.log(calibrated_probs + 1e-8), targets)
            
            print(f"  {method_name}: Acc={cal_accuracy:.4f}, NLL={cal_nll:.4f}")
            
        except Exception as e:
            print(f"  {method_name}: Error during evaluation: {e}")
    
    print("\nBaseline calibration methods test completed.")
