#!/usr/bin/env python3
"""
Continuous Ranked Probability Score (CRPS) and Uncertainty Calibration Metrics

This module implements CRPS and related uncertainty quantification metrics
for evaluating the quality of probabilistic predictions in PINNs.
"""

import numpy as np
import torch
import torch.nn.functional as F
from typing import Union, Tuple, List, Dict
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns


class CRPSMetric:
    """
    Continuous Ranked Probability Score implementation for uncertainty quantification.
    
    CRPS measures the quality of probabilistic forecasts by comparing the predicted
    cumulative distribution function (CDF) with the observed value.
    """
    
    def __init__(self, device: str = 'cpu'):
        """
        Initialize CRPS metric calculator.
        
        Args:
            device: Device to use for computations ('cpu' or 'cuda')
        """
        self.device = device
    
    def crps_normal(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                   y_pred_var: torch.Tensor) -> torch.Tensor:
        """
        Calculate CRPS for normal distribution predictions.
        
        For a normal distribution N(μ, σ²), the CRPS is:
        CRPS = σ * (1/√π - 2*φ((y-μ)/σ) - (y-μ)/σ * (2*Φ((y-μ)/σ) - 1))
        
        where φ is the standard normal PDF and Φ is the standard normal CDF.
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            
        Returns:
            CRPS values (N,)
        """
        y_true = y_true.to(self.device)
        y_pred_mean = y_pred_mean.to(self.device)
        y_pred_var = y_pred_var.to(self.device)
        
        # Ensure variance is positive
        y_pred_var = torch.clamp(y_pred_var, min=1e-6)
        y_pred_std = torch.sqrt(y_pred_var)
        
        # Standardized error
        z = (y_true - y_pred_mean) / y_pred_std
        
        # Standard normal CDF and PDF
        phi_z = torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=self.device))
        Phi_z = 0.5 * (1 + torch.erf(z / torch.sqrt(torch.tensor(2.0, device=self.device))))
        
        # CRPS formula for normal distribution
        crps = y_pred_std * (1 / torch.sqrt(torch.tensor(torch.pi, device=self.device)) - 2 * phi_z - 
                           z * (2 * Phi_z - 1))
        
        return crps
    
    def crps_ensemble(self, y_true: torch.Tensor, y_pred_samples: torch.Tensor) -> torch.Tensor:
        """
        Calculate CRPS for ensemble predictions.
        
        Args:
            y_true: True values (N,)
            y_pred_samples: Ensemble predictions (N, M) where M is ensemble size
            
        Returns:
            CRPS values (N,)
        """
        y_true = y_true.to(self.device)
        y_pred_samples = y_pred_samples.to(self.device)
        
        N, M = y_pred_samples.shape
        
        # Sort predictions for each sample
        y_pred_sorted, _ = torch.sort(y_pred_samples, dim=1)
        
        # Calculate CRPS using empirical CDF
        crps_values = torch.zeros(N, device=self.device)
        
        for i in range(N):
            # Empirical CDF
            cdf_values = torch.arange(1, M + 1, device=self.device) / M
            
            # Find position of true value
            true_val = y_true[i]
            sorted_preds = y_pred_sorted[i]
            
            # Calculate CRPS integral
            crps = 0.0
            for j in range(M):
                if j == 0:
                    # First interval
                    if true_val <= sorted_preds[j]:
                        crps += cdf_values[j]**2 * (sorted_preds[j] - true_val)
                    else:
                        crps += cdf_values[j]**2 * (true_val - sorted_preds[j])
                else:
                    # Middle intervals
                    if sorted_preds[j-1] <= true_val <= sorted_preds[j]:
                        crps += (cdf_values[j] - cdf_values[j-1])**2 * (sorted_preds[j] - sorted_preds[j-1])
                    elif true_val < sorted_preds[j-1]:
                        crps += (cdf_values[j] - cdf_values[j-1])**2 * (sorted_preds[j] - sorted_preds[j-1])
                    else:
                        crps += (cdf_values[j] - cdf_values[j-1])**2 * (sorted_preds[j] - sorted_preds[j-1])
            
            crps_values[i] = crps
        
        return crps_values
    
    def mean_crps(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                  y_pred_var: torch.Tensor) -> float:
        """
        Calculate mean CRPS across all samples.
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            
        Returns:
            Mean CRPS value
        """
        crps_values = self.crps_normal(y_true, y_pred_mean, y_pred_var)
        return torch.mean(crps_values).item()
    
    def crps_skill_score(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                        y_pred_var: torch.Tensor, baseline_crps: float) -> float:
        """
        Calculate CRPS skill score relative to a baseline.
        
        Skill Score = 1 - CRPS_method / CRPS_baseline
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            baseline_crps: Baseline CRPS value
            
        Returns:
            CRPS skill score
        """
        method_crps = self.mean_crps(y_true, y_pred_mean, y_pred_var)
        return 1 - method_crps / baseline_crps


class CalibrationMetrics:
    """
    Uncertainty calibration metrics for evaluating prediction reliability.
    """
    
    def __init__(self, device: str = 'cpu'):
        """
        Initialize calibration metrics calculator.
        
        Args:
            device: Device to use for computations
        """
        self.device = device
    
    def calibration_error(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                         y_pred_var: torch.Tensor, n_bins: int = 10) -> Dict[str, float]:
        """
        Calculate calibration error using reliability diagrams.
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            n_bins: Number of bins for reliability diagram
            
        Returns:
            Dictionary containing calibration metrics
        """
        y_true = y_true.to(self.device)
        y_pred_mean = y_pred_mean.to(self.device)
        y_pred_var = y_pred_var.to(self.device)
        
        # Ensure variance is positive
        y_pred_var = torch.clamp(y_pred_var, min=1e-6)
        y_pred_std = torch.sqrt(y_pred_var)
        
        # Calculate prediction intervals
        intervals = torch.linspace(0.1, 0.9, n_bins, device=self.device)
        calibration_errors = []
        reliabilities = []
        confidences = []
        
        for interval in intervals:
            # Calculate prediction interval
            alpha = 1 - interval
            z_score = torch.erfinv(interval) * torch.sqrt(torch.tensor(2.0, device=self.device))
            
            lower_bound = y_pred_mean - z_score * y_pred_std
            upper_bound = y_pred_mean + z_score * y_pred_std
            
            # Calculate empirical coverage
            in_interval = (y_true >= lower_bound) & (y_true <= upper_bound)
            empirical_coverage = torch.mean(in_interval.float()).item()
            
            # Calibration error
            calibration_error = abs(empirical_coverage - interval.item())
            calibration_errors.append(calibration_error)
            reliabilities.append(empirical_coverage)
            confidences.append(interval.item())
        
        # Expected Calibration Error (ECE)
        ece = np.mean(calibration_errors)
        
        # Maximum Calibration Error (MCE)
        mce = np.max(calibration_errors)
        
        return {
            'ece': ece,
            'mce': mce,
            'reliabilities': reliabilities,
            'confidences': confidences,
            'calibration_errors': calibration_errors
        }
    
    def sharpness(self, y_pred_var: torch.Tensor) -> float:
        """
        Calculate prediction sharpness (inverse of average variance).
        
        Args:
            y_pred_var: Predicted variances (N,)
            
        Returns:
            Sharpness value
        """
        y_pred_var = y_pred_var.to(self.device)
        y_pred_var = torch.clamp(y_pred_var, min=1e-6)
        return torch.mean(torch.sqrt(y_pred_var)).item()
    
    def resolution(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                   y_pred_var: torch.Tensor) -> float:
        """
        Calculate prediction resolution (variance of predictions).
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            
        Returns:
            Resolution value
        """
        y_pred_mean = y_pred_mean.to(self.device)
        return torch.var(y_pred_mean).item()


class UncertaintyAnalyzer:
    """
    Comprehensive uncertainty analysis for PINN predictions.
    """
    
    def __init__(self, device: str = 'cpu'):
        """
        Initialize uncertainty analyzer.
        
        Args:
            device: Device to use for computations
        """
        self.device = device
        self.crps_metric = CRPSMetric(device)
        self.calibration_metrics = CalibrationMetrics(device)
    
    def analyze_uncertainty(self, y_true: torch.Tensor, y_pred_mean: torch.Tensor, 
                           y_pred_var: torch.Tensor, method_name: str = "Method") -> Dict:
        """
        Perform comprehensive uncertainty analysis.
        
        Args:
            y_true: True values (N,)
            y_pred_mean: Predicted means (N,)
            y_pred_var: Predicted variances (N,)
            method_name: Name of the method for reporting
            
        Returns:
            Dictionary containing all uncertainty metrics
        """
        results = {
            'method': method_name,
            'n_samples': len(y_true)
        }
        
        # CRPS analysis
        crps_values = self.crps_metric.crps_normal(y_true, y_pred_mean, y_pred_var)
        results['crps'] = {
            'mean': torch.mean(crps_values).item(),
            'std': torch.std(crps_values).item(),
            'median': torch.median(crps_values).item(),
            'values': crps_values.cpu().numpy()
        }
        
        # Calibration analysis
        calibration_results = self.calibration_metrics.calibration_error(
            y_true, y_pred_mean, y_pred_var
        )
        results['calibration'] = calibration_results
        
        # Sharpness and resolution
        results['sharpness'] = self.calibration_metrics.sharpness(y_pred_var)
        results['resolution'] = self.calibration_metrics.resolution(y_true, y_pred_mean, y_pred_var)
        
        # Prediction accuracy
        mse = torch.mean((y_true - y_pred_mean)**2).item()
        mae = torch.mean(torch.abs(y_true - y_pred_mean)).item()
        results['accuracy'] = {
            'mse': mse,
            'mae': mae,
            'rmse': np.sqrt(mse)
        }
        
        return results
    
    def compare_methods(self, results_dict: Dict[str, Dict]) -> Dict:
        """
        Compare uncertainty quality across multiple methods.
        
        Args:
            results_dict: Dictionary of method results from analyze_uncertainty
            
        Returns:
            Comparison summary
        """
        comparison = {
            'methods': list(results_dict.keys()),
            'crps_comparison': {},
            'calibration_comparison': {},
            'accuracy_comparison': {}
        }
        
        # CRPS comparison
        for method, results in results_dict.items():
            comparison['crps_comparison'][method] = results['crps']['mean']
        
        # Calibration comparison
        for method, results in results_dict.items():
            comparison['calibration_comparison'][method] = {
                'ece': results['calibration']['ece'],
                'mce': results['calibration']['mce']
            }
        
        # Accuracy comparison
        for method, results in results_dict.items():
            comparison['accuracy_comparison'][method] = {
                'mse': results['accuracy']['mse'],
                'mae': results['accuracy']['mae'],
                'rmse': results['accuracy']['rmse']
            }
        
        return comparison
    
    def plot_uncertainty_analysis(self, results_dict: Dict[str, Dict], 
                                 save_path: str = None) -> None:
        """
        Create comprehensive uncertainty analysis plots.
        
        Args:
            results_dict: Dictionary of method results
            save_path: Path to save the plot
        """
        n_methods = len(results_dict)
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        methods = list(results_dict.keys())
        colors = plt.cm.Set1(np.linspace(0, 1, n_methods))
        
        # CRPS comparison
        crps_values = [results_dict[method]['crps']['mean'] for method in methods]
        axes[0, 0].bar(methods, crps_values, color=colors)
        axes[0, 0].set_title('CRPS Comparison')
        axes[0, 0].set_ylabel('CRPS')
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # Calibration error comparison
        ece_values = [results_dict[method]['calibration']['ece'] for method in methods]
        mce_values = [results_dict[method]['calibration']['mce'] for method in methods]
        
        x = np.arange(len(methods))
        width = 0.35
        
        axes[0, 1].bar(x - width/2, ece_values, width, label='ECE', color=colors)
        axes[0, 1].bar(x + width/2, mce_values, width, label='MCE', color=colors, alpha=0.7)
        axes[0, 1].set_title('Calibration Error Comparison')
        axes[0, 1].set_ylabel('Calibration Error')
        axes[0, 1].set_xticks(x)
        axes[0, 1].set_xticklabels(methods, rotation=45)
        axes[0, 1].legend()
        
        # Sharpness comparison
        sharpness_values = [results_dict[method]['sharpness'] for method in methods]
        axes[1, 0].bar(methods, sharpness_values, color=colors)
        axes[1, 0].set_title('Sharpness Comparison')
        axes[1, 0].set_ylabel('Sharpness (√Var)')
        axes[1, 0].tick_params(axis='x', rotation=45)
        
        # Accuracy comparison
        rmse_values = [results_dict[method]['accuracy']['rmse'] for method in methods]
        axes[1, 1].bar(methods, rmse_values, color=colors)
        axes[1, 1].set_title('RMSE Comparison')
        axes[1, 1].set_ylabel('RMSE')
        axes[1, 1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Uncertainty analysis plot saved to: {save_path}")
        
        plt.show()
    
    def plot_reliability_diagram(self, results_dict: Dict[str, Dict], 
                                save_path: str = None) -> None:
        """
        Create reliability diagrams for calibration assessment.
        
        Args:
            results_dict: Dictionary of method results
            save_path: Path to save the plot
        """
        n_methods = len(results_dict)
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))
        
        methods = list(results_dict.keys())
        colors = plt.cm.Set1(np.linspace(0, 1, n_methods))
        
        # Perfect calibration line
        ax.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration', linewidth=2)
        
        for i, (method, results) in enumerate(results_dict.items()):
            confidences = results['calibration']['confidences']
            reliabilities = results['calibration']['reliabilities']
            
            ax.plot(confidences, reliabilities, 'o-', 
                   color=colors[i], label=method, linewidth=2, markersize=6)
        
        ax.set_xlabel('Confidence Level')
        ax.set_ylabel('Empirical Coverage')
        ax.set_title('Reliability Diagram')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Reliability diagram saved to: {save_path}")
        
        plt.show()


def test_crps_implementation():
    """
    Test CRPS implementation with synthetic data.
    """
    print("Testing CRPS implementation...")
    
    # Create synthetic data
    n_samples = 1000
    y_true = torch.randn(n_samples)
    y_pred_mean = y_true + 0.1 * torch.randn(n_samples)
    y_pred_var = 0.5 + 0.1 * torch.rand(n_samples)
    
    # Test CRPS
    crps_metric = CRPSMetric()
    crps_values = crps_metric.crps_normal(y_true, y_pred_mean, y_pred_var)
    mean_crps = crps_metric.mean_crps(y_true, y_pred_mean, y_pred_var)
    
    print(f"Mean CRPS: {mean_crps:.4f}")
    print(f"CRPS std: {torch.std(crps_values):.4f}")
    
    # Test calibration
    calibration_metrics = CalibrationMetrics()
    calibration_results = calibration_metrics.calibration_error(y_true, y_pred_mean, y_pred_var)
    
    print(f"Expected Calibration Error: {calibration_results['ece']:.4f}")
    print(f"Maximum Calibration Error: {calibration_results['mce']:.4f}")
    
    print("CRPS implementation test completed successfully!")


if __name__ == "__main__":
    test_crps_implementation()
