"""
Evaluation Metrics for Human-Prior Correction

This module implements all evaluation metrics used in the HPC paper:
- Expected Calibration Error (ECE) 
- Negative Log-Likelihood for true labels (NLL_true)
- Negative Log-Likelihood for human distributions (NLL_human)
- Brier Score
- Reliability diagrams 
- Human-centric calibration analysis
- Area Under Risk-Coverage curve (AURC)
"""

import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List, Optional, Dict, Union
import matplotlib.pyplot as plt
from sklearn.metrics import brier_score_loss
import warnings


class CalibrationMetrics:
    """
    Comprehensive calibration metrics for evaluating HPC and baseline methods.
    
    Implements metrics from the paper with focus on both traditional calibration 
    (ECE, reliability) and human alignment (NLL_human).
    """
    
    def __init__(self, num_bins: int = 15, bin_strategy: str = 'uniform'):
        """
        Initialize calibration metrics calculator.
        
        Args:
            num_bins: Number of bins for ECE calculation
            bin_strategy: 'uniform' or 'adaptive' binning strategy
        """
        self.num_bins = num_bins
        self.bin_strategy = bin_strategy
    
    def expected_calibration_error(
        self, 
        probabilities: torch.Tensor, 
        targets: torch.Tensor,
        return_components: bool = False
    ) -> Union[float, Tuple[float, Dict]]:
        """
        Compute Expected Calibration Error (ECE).
        
        ECE = Σ_m (n_m/N) |acc_m - conf_m|
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            return_components: Whether to return per-bin breakdown
            
        Returns:
            ECE value, optionally with per-bin information
        """
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        accuracies = predictions.eq(targets)
        
        if self.bin_strategy == 'uniform':
            bin_boundaries = torch.linspace(0, 1, self.num_bins + 1)
        else:
            # Adaptive binning based on quantiles
            bin_boundaries = torch.quantile(confidences, torch.linspace(0, 1, self.num_bins + 1))
        
        ece = 0.0
        bin_info = {}
        
        for i in range(self.num_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            # Find samples in this bin
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.float().mean()
            
            if prop_in_bin > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                
                bin_error = torch.abs(avg_confidence_in_bin - accuracy_in_bin)
                ece += bin_error * prop_in_bin
                
                if return_components:
                    bin_info[f'bin_{i}'] = {
                        'accuracy': accuracy_in_bin.item(),
                        'confidence': avg_confidence_in_bin.item(),
                        'proportion': prop_in_bin.item(),
                        'error': bin_error.item()
                    }
        
        if return_components:
            return ece.item(), bin_info
        return ece.item()
    
    def negative_log_likelihood_true(
        self, 
        probabilities: torch.Tensor, 
        targets: torch.Tensor
    ) -> float:
        """
        Compute NLL with respect to true (one-hot) labels.
        
        NLL_true = -Σ_n log p(y_n | x_n)
        
        Args:
            probabilities: Predicted probabilities (N, K)  
            targets: True labels (N,)
            
        Returns:
            Average negative log-likelihood
        """
        # Add small epsilon to avoid log(0)
        eps = 1e-8
        log_probs = torch.log(probabilities + eps)
        nll = F.nll_loss(log_probs, targets, reduction='mean')
        return nll.item()
    
    def negative_log_likelihood_human(
        self,
        probabilities: torch.Tensor,
        human_distributions: torch.Tensor
    ) -> float:
        """
        Compute NLL with respect to human label distributions.
        
        NLL_human = -Σ_n Σ_k h_k^(n) log p_k^(n)
        
        Args:
            probabilities: Model probabilities (N, K)
            human_distributions: Human label distributions (N, K)
            
        Returns:
            Average negative log-likelihood to human distributions
        """
        eps = 1e-8
        log_probs = torch.log(probabilities + eps)
        
        # Compute cross-entropy between human distributions and model predictions
        nll = -(human_distributions * log_probs).sum(dim=1).mean()
        return nll.item()
    
    def brier_score(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        Compute Brier Score.
        
        BS = (1/N) Σ_n ||p^(n) - y^(n)||_2^2
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            
        Returns:
            Brier score
        """
        # Convert targets to one-hot
        num_classes = probabilities.shape[1]
        targets_onehot = F.one_hot(targets, num_classes).float()
        
        # Compute squared differences
        brier = torch.mean((probabilities - targets_onehot) ** 2)
        return brier.item()
    
    def reliability_diagram(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor,
        save_path: Optional[str] = None,
        title: str = "Reliability Diagram"
    ) -> Dict:
        """
        Generate reliability diagram and compute calibration statistics.
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            save_path: Optional path to save figure
            title: Plot title
            
        Returns:
            Dictionary with calibration statistics
        """
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        accuracies = predictions.eq(targets)
        
        # Compute binned statistics
        bin_boundaries = torch.linspace(0, 1, self.num_bins + 1)
        bin_acc = []
        bin_conf = []
        bin_counts = []
        
        for i in range(self.num_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            bin_count = in_bin.sum().item()
            
            if bin_count > 0:
                bin_accuracy = accuracies[in_bin].float().mean().item()
                bin_confidence = confidences[in_bin].mean().item()
            else:
                bin_accuracy = 0.0
                bin_confidence = (bin_lower + bin_upper) / 2
            
            bin_acc.append(bin_accuracy)
            bin_conf.append(bin_confidence)
            bin_counts.append(bin_count)
        
        # Create plot
        plt.figure(figsize=(8, 6))
        
        # Plot reliability bars
        plt.bar(range(len(bin_conf)), bin_acc, alpha=0.7, 
               width=1.0, edgecolor='black', label='Accuracy')
        
        # Plot confidence line
        plt.plot(range(len(bin_conf)), bin_conf, 'r-', marker='o', 
                linewidth=2, markersize=6, label='Confidence')
        
        # Plot perfect calibration line
        plt.plot([0, len(bin_conf)-1], [0, 1], 'k--', alpha=0.5, 
                label='Perfect Calibration')
        
        plt.xlabel('Confidence Bin')
        plt.ylabel('Accuracy / Confidence')
        plt.title(title)
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Add bin labels
        bin_labels = [f'{bin_boundaries[i]:.1f}-{bin_boundaries[i+1]:.1f}' 
                     for i in range(len(bin_conf))]
        plt.xticks(range(len(bin_conf)), bin_labels, rotation=45)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        # Return statistics
        stats = {
            'bin_accuracies': bin_acc,
            'bin_confidences': bin_conf,
            'bin_counts': bin_counts,
            'ece': self.expected_calibration_error(probabilities, targets)
        }
        
        return stats
    
    def human_reliability_diagram(
        self,
        probabilities: torch.Tensor,
        human_distributions: torch.Tensor,
        save_path: Optional[str] = None,
        title: str = "Human-Targeted Reliability Diagram"
    ) -> Dict:
        """
        Generate reliability diagram against human probability distributions.
        
        This is a key innovation from the paper - evaluating calibration against
        human uncertainty rather than binary ground truth.
        
        Args:
            probabilities: Model probabilities (N, K)
            human_distributions: Human label distributions (N, K) 
            save_path: Optional path to save figure
            title: Plot title
            
        Returns:
            Dictionary with human calibration statistics
        """
        # Get model confidences
        confidences = torch.max(probabilities, dim=1)[0]
        
        # Get "human accuracy" - probability mass on the predicted class according to humans
        predictions = torch.argmax(probabilities, dim=1)
        human_accuracies = human_distributions[torch.arange(len(predictions)), predictions]
        
        # Compute binned statistics
        bin_boundaries = torch.linspace(0, 1, self.num_bins + 1)
        bin_human_acc = []
        bin_conf = []
        bin_counts = []
        
        for i in range(self.num_bins):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            bin_count = in_bin.sum().item()
            
            if bin_count > 0:
                bin_human_accuracy = human_accuracies[in_bin].float().mean().item()
                bin_confidence = confidences[in_bin].mean().item()
            else:
                bin_human_accuracy = 0.0
                bin_confidence = (bin_lower + bin_upper) / 2
            
            bin_human_acc.append(bin_human_accuracy)
            bin_conf.append(bin_confidence)
            bin_counts.append(bin_count)
        
        # Create plot
        plt.figure(figsize=(8, 6))
        
        # Plot human accuracy bars
        plt.bar(range(len(bin_conf)), bin_human_acc, alpha=0.7,
               width=1.0, edgecolor='black', label='Human "Accuracy"')
        
        # Plot model confidence line  
        plt.plot(range(len(bin_conf)), bin_conf, 'r-', marker='o',
                linewidth=2, markersize=6, label='Model Confidence')
        
        # Plot ideal line
        plt.plot([0, len(bin_conf)-1], [0, 1], 'k--', alpha=0.5,
                label='Perfect Human Calibration')
        
        plt.xlabel('Model Confidence Bin')
        plt.ylabel('Human Agreement / Model Confidence')
        plt.title(title)
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Add bin labels
        bin_labels = [f'{bin_boundaries[i]:.1f}-{bin_boundaries[i+1]:.1f}'
                     for i in range(len(bin_conf))]
        plt.xticks(range(len(bin_conf)), bin_labels, rotation=45)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        # Compute human ECE
        human_ece = 0.0
        total_samples = len(confidences)
        
        for i in range(len(bin_conf)):
            bin_lower = bin_boundaries[i]
            bin_upper = bin_boundaries[i + 1]
            in_bin = (confidences > bin_lower) & (confidences <= bin_upper)
            prop_in_bin = in_bin.sum().float() / total_samples
            
            if prop_in_bin > 0:
                error = abs(bin_conf[i] - bin_human_acc[i])
                human_ece += error * prop_in_bin
        
        stats = {
            'bin_human_accuracies': bin_human_acc,
            'bin_confidences': bin_conf,
            'bin_counts': bin_counts,
            'human_ece': human_ece.item()
        }
        
        return stats


class DecisionUtilityMetrics:
    """
    Metrics for evaluating decision-making utility of calibrated predictions.
    
    Includes AURC (Area Under Risk-Coverage curve) and related metrics
    mentioned in the paper for practical deployment scenarios.
    """
    
    def __init__(self):
        pass
    
    def area_under_risk_coverage_curve(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor
    ) -> float:
        """
        Compute Area Under Risk-Coverage curve (AURC).
        
        AURC measures selective prediction quality - lower is better.
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            
        Returns:
            AURC value
        """
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        correct = predictions.eq(targets).float()
        
        # Sort by confidence (descending)
        sorted_indices = torch.argsort(confidences, descending=True)
        sorted_correct = correct[sorted_indices]
        
        # Compute risk and coverage at each threshold
        n = len(sorted_correct)
        coverages = torch.arange(1, n + 1).float() / n
        risks = torch.cumsum(1 - sorted_correct, dim=0) / torch.arange(1, n + 1).float()
        
        # Compute AURC using trapezoidal rule
        aurc = torch.trapz(risks, coverages)
        return aurc.item()
    
    def coverage_at_risk_threshold(
        self,
        probabilities: torch.Tensor,
        targets: torch.Tensor,
        risk_threshold: float = 0.05
    ) -> float:
        """
        Compute coverage achieved at a given risk threshold.
        
        Args:
            probabilities: Predicted probabilities (N, K)
            targets: True labels (N,)
            risk_threshold: Maximum acceptable risk level
            
        Returns:
            Coverage at the risk threshold
        """
        confidences = torch.max(probabilities, dim=1)[0]
        predictions = torch.argmax(probabilities, dim=1)
        correct = predictions.eq(targets).float()
        
        # Sort by confidence (descending)
        sorted_indices = torch.argsort(confidences, descending=True)
        sorted_correct = correct[sorted_indices]
        
        # Find coverage at risk threshold
        n = len(sorted_correct)
        for i in range(1, n + 1):
            current_risk = (1 - sorted_correct[:i]).mean()
            if current_risk > risk_threshold:
                return (i - 1) / n if i > 1 else 0.0
        
        return 1.0  # All samples satisfy risk threshold


class RobustnessMetrics:
    """
    Metrics for evaluating robustness under distribution shift.
    
    Used for CIFAR-10-C corruption analysis in the paper.
    """
    
    def __init__(self):
        pass
    
    def corruption_robustness_analysis(
        self,
        baseline_metrics: Dict[str, float],
        corrupted_metrics: Dict[str, Dict[str, float]],  # corruption_type -> metrics
    ) -> Dict[str, float]:
        """
        Analyze robustness across corruption types and severities.
        
        Args:
            baseline_metrics: Metrics on clean test set
            corrupted_metrics: Metrics on corrupted test sets
            
        Returns:
            Robustness statistics
        """
        results = {}
        
        # Collect all metric values
        all_improvements = []
        corruption_improvements = {}
        
        for corruption_type, metrics in corrupted_metrics.items():
            for metric_name, value in metrics.items():
                if metric_name in baseline_metrics:
                    # Compute relative improvement (negative for metrics we want to minimize)
                    baseline_val = baseline_metrics[metric_name]
                    if 'nll' in metric_name.lower() or 'ece' in metric_name.lower():
                        improvement = (baseline_val - value) / baseline_val * 100  # % reduction
                    else:
                        improvement = (value - baseline_val) / baseline_val * 100  # % increase
                    
                    all_improvements.append(improvement)
                    
                    if corruption_type not in corruption_improvements:
                        corruption_improvements[corruption_type] = []
                    corruption_improvements[corruption_type].append(improvement)
        
        # Compute statistics
        results['mean_improvement'] = np.mean(all_improvements)
        results['std_improvement'] = np.std(all_improvements)
        results['min_improvement'] = np.min(all_improvements)
        results['max_improvement'] = np.max(all_improvements)
        
        # Per-corruption statistics
        for corruption_type, improvements in corruption_improvements.items():
            results[f'{corruption_type}_improvement'] = np.mean(improvements)
        
        return results


# Convenience function for comprehensive evaluation
def evaluate_calibration_comprehensive(
    probabilities: torch.Tensor,
    targets: torch.Tensor,
    human_distributions: Optional[torch.Tensor] = None,
    method_name: str = "Method",
    save_plots: bool = False,
    save_dir: str = "./plots/"
) -> Dict[str, float]:
    """
    Comprehensive calibration evaluation using all metrics.
    
    Args:
        probabilities: Model probabilities (N, K)
        targets: True labels (N,)
        human_distributions: Optional human label distributions (N, K)
        method_name: Name for the method being evaluated
        save_plots: Whether to save reliability diagrams
        save_dir: Directory to save plots
        
    Returns:
        Dictionary with all calibration metrics
    """
    metrics = CalibrationMetrics()
    decision_metrics = DecisionUtilityMetrics()
    
    results = {}
    
    # Basic calibration metrics
    results['accuracy'] = (torch.argmax(probabilities, dim=1) == targets).float().mean().item()
    results['ece'] = metrics.expected_calibration_error(probabilities, targets)
    results['nll_true'] = metrics.negative_log_likelihood_true(probabilities, targets)
    results['brier_score'] = metrics.brier_score(probabilities, targets)
    
    # Decision utility metrics
    results['aurc'] = decision_metrics.area_under_risk_coverage_curve(probabilities, targets)
    results['coverage_at_5_percent_risk'] = decision_metrics.coverage_at_risk_threshold(
        probabilities, targets, risk_threshold=0.05
    )
    
    # Human alignment metrics (if human distributions provided)
    if human_distributions is not None:
        results['nll_human'] = metrics.negative_log_likelihood_human(probabilities, human_distributions)
        
        # Human reliability diagram
        if save_plots:
            human_stats = metrics.human_reliability_diagram(
                probabilities, human_distributions,
                save_path=f"{save_dir}/{method_name}_human_reliability.png",
                title=f"{method_name} - Human Reliability"
            )
            results['human_ece'] = human_stats['human_ece']
    
    # Standard reliability diagram
    if save_plots:
        reliability_stats = metrics.reliability_diagram(
            probabilities, targets,
            save_path=f"{save_dir}/{method_name}_reliability.png",
            title=f"{method_name} - Reliability Diagram"
        )
    
    return results


# Example usage and testing
if __name__ == "__main__":
    # Create synthetic test data
    n_samples = 1000
    n_classes = 10
    
    # Generate synthetic probabilities (poorly calibrated)
    logits = torch.randn(n_samples, n_classes) * 3  # High confidence
    probabilities = F.softmax(logits, dim=1)
    targets = torch.randint(0, n_classes, (n_samples,))
    
    # Generate synthetic human distributions
    human_distributions = F.softmax(torch.randn(n_samples, n_classes), dim=1)
    
    print("Testing calibration metrics...")
    
    # Comprehensive evaluation
    results = evaluate_calibration_comprehensive(
        probabilities, targets, human_distributions,
        method_name="Test_Method", save_plots=True
    )
    
    print("\nCalibration Results:")
    for metric, value in results.items():
        print(f"  {metric}: {value:.4f}")
    
    # Test individual metrics
    metrics = CalibrationMetrics()
    
    print(f"\nDetailed ECE breakdown:")
    ece, bin_info = metrics.expected_calibration_error(probabilities, targets, return_components=True)
    print(f"Overall ECE: {ece:.4f}")
    for bin_name, info in list(bin_info.items())[:3]:  # Show first 3 bins
        print(f"  {bin_name}: acc={info['accuracy']:.3f}, conf={info['confidence']:.3f}, "
              f"prop={info['proportion']:.3f}, error={info['error']:.3f}")
    
    # Test robustness metrics
    print(f"\nTesting robustness metrics...")
    robustness = RobustnessMetrics()
    
    baseline_metrics = {'nll_human': 0.8, 'ece': 0.1}
    corrupted_metrics = {
        'gaussian_noise': {'nll_human': 0.7, 'ece': 0.08},
        'motion_blur': {'nll_human': 0.75, 'ece': 0.09}
    }
    
    robustness_stats = robustness.corruption_robustness_analysis(baseline_metrics, corrupted_metrics)
    print("Robustness analysis:")
    for stat, value in robustness_stats.items():
        print(f"  {stat}: {value:.2f}")
