'''
Code to perform temperature scaling. Adapted from https://github.com/gpleiss/temperature_scaling
'''
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import time

from Metrics.metrics import ECELoss


class AdaptiveEpsilonOptimizer:
    
    def __init__(self, tolerance_factor=0.03, tolerance_factor_last=0.05, max_iterations=50):
        self.tolerance_factor = tolerance_factor
        self.tolerance_factor_last = tolerance_factor_last
        self.max_iterations = max_iterations
    
    def simple_minimize(self, func, x0, args=(), bounds=None, tol=1e-4):
        if bounds is None:
            bounds = (0.01, 2.0)
        
        best_x = x0
        best_val = func(x0, *args)
        

        for coarse_x in np.linspace(bounds[0], bounds[1], 20):
            val = func(coarse_x, *args)
            if val < best_val:
                best_val = val
                best_x = coarse_x
        

        search_range = min(0.1, (bounds[1] - bounds[0]) / 10)
        fine_bounds = (max(bounds[0], best_x - search_range), 
                      min(bounds[1], best_x + search_range))
        
        for fine_x in np.linspace(fine_bounds[0], fine_bounds[1], 20):
            val = func(fine_x, *args)
            if val < best_val:
                best_val = val
                best_x = fine_x
        
        return best_x, best_val
    
    def accuracy_objective(self, epsilon, target_accuracy, model_eval_func):

        actual_accuracy = model_eval_func(epsilon)
        diff = abs(target_accuracy - actual_accuracy)
        return diff
    
    def calculate_target_accuracies(self, max_accuracy, n_classes, n_levels):

        min_accuracy = 1.0 / n_classes 
        target_accuracies = []
        
        for level in range(n_levels):

            target_acc = max_accuracy - (max_accuracy - min_accuracy) * level / (n_levels - 1)
            target_accuracies.append(target_acc)
            
        return target_accuracies
    
    def optimize_epsilon_sequence(self, model_eval_func, max_accuracy, n_classes, 
                                n_levels, start_epsilon=0.05):


        
        start_time = time.time()

        target_accuracies = self.calculate_target_accuracies(max_accuracy, n_classes, n_levels)

        

        optimized_epsilons = []
        actual_accuracies = []
        
        for i, target_acc in enumerate(target_accuracies):

            
            if i == 0:

                epsilon = 0.0
                actual_acc = model_eval_func(epsilon)
            else:

                if i == 1:
                    x0 = start_epsilon
                else:
                    x0 = max(optimized_epsilons[i-1] + 0.05, start_epsilon)  
                

                epsilon, _ = self.simple_minimize(
                    self.accuracy_objective,
                    x0,
                    args=(target_acc, model_eval_func),
                    bounds=(0.01, 2.0)
                )
                

                actual_acc = model_eval_func(epsilon)
            
            optimized_epsilons.append(epsilon)
            actual_accuracies.append(actual_acc)
            

        elapsed_time = time.time() - start_time

        
        return optimized_epsilons, actual_accuracies


class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self, model, log=True):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.temperature = 1.0
        self.log = log
        self.architecture = self.model.architecture if hasattr(self.model, 'architecture') else None
        

        self.adaptive_optimizer = AdaptiveEpsilonOptimizer()

    def forward(self, input):
        if hasattr(self.model, 'architecture') and self.model.architecture == 'CNN':
            logits, feature = self.model(input)
            return self.temperature_scale(logits), feature
        else:
            logits = self.model(input)
            return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        return logits / self.temperature

    def apply_gaussian_refinement(self, logits, labels, epsilon_levels=[0.1, 0.2, 0.3]):
        """
        Apply Gaussian Data Refinement strategy for robust calibration
        
        Args:
            logits: Original logits [N, C]
            labels: True labels [N,]
            epsilon_levels: List of Gaussian noise strength levels
            
        Returns:
            augmented_logits: Augmented logits with Gaussian noise
            augmented_labels: Corresponding labels
        """
        augmented_logits = [logits]
        augmented_labels = [labels]
        
        # Add different levels of Gaussian noise to logits
        for epsilon in epsilon_levels:
            noise = torch.randn_like(logits) * epsilon
            noisy_logits = logits + noise
            augmented_logits.append(noisy_logits)
            augmented_labels.append(labels)
            
        return torch.cat(augmented_logits), torch.cat(augmented_labels)
    
    def evaluate_accuracy_with_gaussian_noise(self, logits, labels, epsilon):

        if epsilon == 0.0:
            noisy_logits = logits
        else:
            noise = torch.randn_like(logits) * epsilon
            noisy_logits = logits + noise
        
        # 计算准确率
        _, predicted = torch.max(noisy_logits, 1)
        correct = (predicted == labels).float()
        accuracy = correct.mean().item()
        
        return accuracy
    
    def adaptive_gaussian_refinement(self, logits, labels, n_classes, n_levels=6):

        original_accuracy = self.evaluate_accuracy_with_gaussian_noise(logits, labels, 0.0)
        if self.log:
            print(f"Acc original: {original_accuracy:.4f}")
        
        def model_eval_func(epsilon):
            return self.evaluate_accuracy_with_gaussian_noise(logits, labels, epsilon)
        
        optimized_epsilons, actual_accuracies = self.adaptive_optimizer.optimize_epsilon_sequence(
            model_eval_func=model_eval_func,
            max_accuracy=original_accuracy,
            n_classes=n_classes,
            n_levels=n_levels,
            start_epsilon=0.05
        )
        
        performance_info = {
            'original_accuracy': original_accuracy,
            'target_accuracies': self.adaptive_optimizer.calculate_target_accuracies(
                original_accuracy, n_classes, n_levels
            ),
            'actual_accuracies': actual_accuracies,
            'optimization_success': True,
            'n_levels': n_levels
        }
        
        if self.log:
            print(f"  epsilon: {[f'{eps:.6f}' for eps in optimized_epsilons]}")
            print(f"  Acc: {[f'{acc:.4f}' for acc in actual_accuracies]}")
        
        return optimized_epsilons, performance_info

    def set_temperature(self, valid_loader, cross_validate='ece', use_gaussian_refinement=False, 
                       gaussian_eps=[0.1, 0.2, 0.3], adaptive_gaussian=False, n_classes=None): 
        """
        Tune the tempearature of the model (using the validation set) with cross-validation on ECE or NLL
        
        Args:
            valid_loader: Validation data loader
            cross_validate: 'ece' or 'nll' for cross-validation metric
            use_gaussian_refinement: Whether to apply Gaussian Data Refinement
            gaussian_eps: List of Gaussian noise levels for refinement
            adaptive_gaussian: Whether to use adaptive gaussian refinement
            n_classes: Number of classes (required for adaptive refinement)
        """
        self.cuda()
        self.model.eval()
        nll_criterion = nn.CrossEntropyLoss().cuda()
        ece_criterion = ECELoss().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for batch_data in valid_loader:
                if len(batch_data) == 2:
                    x, label = batch_data
                elif len(batch_data) == 3:
                    x, label, _ = batch_data
                x = x.cuda()
                if hasattr(self.model, 'architecture') and self.model.architecture == 'CNN':
                    logits, _ = self.model(x)
                else:
                    logits = self.model(x)
                    
                logits_list.append(logits)
                labels_list.append(label)
            logits = torch.cat(logits_list).cuda()
            labels = torch.cat(labels_list).cuda()

        if use_gaussian_refinement:
            if adaptive_gaussian:
                if n_classes is None:
                    raise ValueError("n_classes must be specified for adaptive gaussian refinement")
            
                
                optimized_eps, perf_info = self.adaptive_gaussian_refinement(
                    logits, labels, n_classes, n_levels=len(gaussian_eps)
                )
                
                logits, labels = self.apply_gaussian_refinement(logits, labels, optimized_eps)
                
                if self.log:
                    print(f'use epsilon: {optimized_eps}')
                    print(f'sample num: {logits.shape[0]} ')
            else:
                if self.log:
                    print(f'epsilon: {gaussian_eps}')
                
                logits, labels = self.apply_gaussian_refinement(logits, labels, gaussian_eps)
                
                if self.log:
                    print(f'sample num: {logits.shape[0]}')

        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logits, labels).item()
        before_temperature_ece = ece_criterion(logits, labels).item()
      

        nll_val = 10 ** 7
        ece_val = 10 ** 7
        T_opt_nll = 1.0
        T_opt_ece = 1.0
        T = 0.1
        for i in range(500): 
            self.temperature = T
            self.cuda()
            after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
            after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
            if nll_val > after_temperature_nll:
                T_opt_nll = T
                nll_val = after_temperature_nll

            if ece_val > after_temperature_ece:
                T_opt_ece = T
                ece_val = after_temperature_ece
            T += 0.01

        if cross_validate == 'ece':
            self.temperature = T_opt_ece
        else:
            self.temperature = T_opt_nll
        self.cuda()

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logits), labels).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logits), labels).item()
        if self.log:
            print('best Temp: {:.3f}'.format(self.temperature))
            print('after TS - NLL: {:.3f}, ECE: {:.3f}'.format(after_temperature_nll, after_temperature_ece))

        return self

    def get_temperature(self):
        return self.temperature