#!/usr/bin/env python3
"""
Continuous Ranked Probability Score (CRPS) calculation utilities.

This module provides functions to calculate CRPS for uncertainty quantification
in Physics-Informed Neural Networks.
"""

import torch
import numpy as np
from typing import Union, List, Tuple, Optional
import warnings


def crps_gaussian(
    observations: torch.Tensor,
    mean: torch.Tensor,
    variance: torch.Tensor,
    eps: float = 1e-8
) -> torch.Tensor:
    """
    Calculate CRPS for Gaussian distribution.
    
    Args:
        observations: True values [batch_size, output_dim]
        mean: Predicted mean [batch_size, output_dim]
        variance: Predicted variance [batch_size, output_dim]
        eps: Small value to avoid division by zero
        
    Returns:
        CRPS values [batch_size, output_dim]
    """
    # Ensure variance is positive
    variance = torch.clamp(variance, min=eps)
    std = torch.sqrt(variance)
    
    # Standardize
    z = (observations - mean) / std
    
    # CRPS formula for Gaussian distribution
    crps = std * (z * torch.erf(z / np.sqrt(2)) + 
                  2 * torch.exp(-z**2 / 2) / np.sqrt(2 * np.pi) - 
                  1 / np.sqrt(np.pi))
    
    return crps


def crps_ensemble(
    observations: torch.Tensor,
    predictions: torch.Tensor,
    weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """
    Calculate CRPS for ensemble predictions.
    
    Args:
        observations: True values [batch_size, output_dim]
        predictions: Ensemble predictions [n_ensemble, batch_size, output_dim]
        weights: Optional weights for ensemble members [n_ensemble]
        
    Returns:
        CRPS values [batch_size, output_dim]
    """
    n_ensemble, batch_size, output_dim = predictions.shape
    
    if weights is None:
        weights = torch.ones(n_ensemble, device=predictions.device) / n_ensemble
    else:
        weights = weights / weights.sum()  # Normalize weights
    
    # Sort predictions for each ensemble member
    sorted_preds = torch.sort(predictions, dim=0)[0]  # [n_ensemble, batch_size, output_dim]
    
    # Calculate empirical CDF
    crps_values = torch.zeros(batch_size, output_dim, device=predictions.device)
    
    for i in range(n_ensemble):
        # Calculate the difference between observation and prediction
        diff = torch.abs(observations - sorted_preds[i])
        
        # Weight by ensemble member weight
        crps_values += weights[i] * diff
    
    # Additional term for ensemble CRPS
    for i in range(n_ensemble):
        for j in range(i + 1, n_ensemble):
            diff_ij = torch.abs(sorted_preds[i] - sorted_preds[j])
            crps_values += weights[i] * weights[j] * diff_ij
    
    return crps_values


def crps_from_samples(
    observations: torch.Tensor,
    samples: torch.Tensor,
    weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
    """
    Calculate CRPS from Monte Carlo samples.
    
    Args:
        observations: True values [batch_size, output_dim]
        samples: Monte Carlo samples [n_samples, batch_size, output_dim]
        weights: Optional weights for samples [n_samples]
        
    Returns:
        CRPS values [batch_size, output_dim]
    """
    n_samples, batch_size, output_dim = samples.shape
    
    if weights is None:
        weights = torch.ones(n_samples, device=samples.device) / n_samples
    else:
        weights = weights / weights.sum()  # Normalize weights
    
    # Sort samples
    sorted_samples = torch.sort(samples, dim=0)[0]
    
    # Calculate CRPS using empirical CDF
    crps_values = torch.zeros(batch_size, output_dim, device=samples.device)
    
    for i in range(n_samples):
        diff = torch.abs(observations - sorted_samples[i])
        crps_values += weights[i] * diff
    
    # Pairwise differences
    for i in range(n_samples):
        for j in range(i + 1, n_samples):
            diff_ij = torch.abs(sorted_samples[i] - sorted_samples[j])
            crps_values += weights[i] * weights[j] * diff_ij
    
    return crps_values


def calculate_crps_for_model(
    model,
    test_data: Tuple[torch.Tensor, torch.Tensor],
    method: str = "ensemble",
    n_samples: int = 100,
    device: str = "cpu"
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Calculate CRPS for a given model on test data.
    
    Args:
        model: Trained model (BayesianPINN, RPITPINN, or StandardPINN)
        test_data: Tuple of (inputs, targets)
        method: Method for uncertainty estimation ("ensemble", "gaussian", "samples")
        n_samples: Number of samples for Monte Carlo estimation
        device: Device to run computation on
        
    Returns:
        Tuple of (CRPS values, predictions)
    """
    model.eval()
    inputs, targets = test_data
    inputs = inputs.to(device)
    targets = targets.to(device)
    
    with torch.no_grad():
        if hasattr(model, 'forward_with_uncertainty') and method == "gaussian":
            # Use model's built-in uncertainty estimation
            mean, variance = model.forward_with_uncertainty(inputs)
            crps = crps_gaussian(targets, mean, variance)
            predictions = mean
            
        elif hasattr(model, 'ensemble') and method == "ensemble":
            # Use ensemble predictions
            predictions = []
            for network in model.ensemble:
                pred = network(inputs)
                predictions.append(pred)
            
            predictions = torch.stack(predictions, dim=0)  # [n_ensemble, batch, output]
            crps = crps_ensemble(targets, predictions)
            
        elif method == "samples":
            # Generate Monte Carlo samples
            samples = []
            for _ in range(n_samples):
                if hasattr(model, 'ensemble'):
                    # Sample from ensemble
                    idx = torch.randint(0, len(model.ensemble), (1,))
                    pred = model.ensemble[idx[0]](inputs)
                else:
                    # Single model with dropout
                    pred = model(inputs)
                samples.append(pred)
            
            samples = torch.stack(samples, dim=0)  # [n_samples, batch, output]
            crps = crps_from_samples(targets, samples)
            predictions = samples.mean(dim=0)
            
        else:
            # Fallback: use single prediction
            predictions = model(inputs)
            # For single prediction, CRPS reduces to MAE
            crps = torch.abs(targets - predictions)
    
    return crps, predictions


def calculate_crps_metrics(
    crps_values: torch.Tensor,
    predictions: torch.Tensor,
    targets: torch.Tensor
) -> dict:
    """
    Calculate comprehensive CRPS metrics.
    
    Args:
        crps_values: CRPS values [batch_size, output_dim]
        predictions: Model predictions [batch_size, output_dim]
        targets: True targets [batch_size, output_dim]
        
    Returns:
        Dictionary of CRPS metrics
    """
    # Mean CRPS
    mean_crps = crps_values.mean().item()
    
    # CRPS per output dimension
    crps_per_dim = crps_values.mean(dim=0)
    
    # Relative CRPS (normalized by target variance)
    target_var = torch.var(targets, dim=0)
    relative_crps = crps_per_dim / (torch.sqrt(target_var) + 1e-8)
    
    # CRPS skill score (compared to climatology)
    # For climatology, we use the mean of targets as prediction
    climatology_pred = targets.mean(dim=0, keepdim=True)
    climatology_crps = torch.abs(targets - climatology_pred).mean(dim=0)
    crps_skill = 1 - crps_per_dim / (climatology_crps + 1e-8)
    
    return {
        'mean_crps': mean_crps,
        'crps_per_dim': crps_per_dim.cpu().numpy(),
        'relative_crps': relative_crps.cpu().numpy(),
        'crps_skill': crps_skill.cpu().numpy(),
        'climatology_crps': climatology_crps.cpu().numpy()
    }


def compare_crps_methods(
    model,
    test_data: Tuple[torch.Tensor, torch.Tensor],
    methods: List[str] = ["ensemble", "gaussian", "samples"],
    n_samples: int = 100,
    device: str = "cpu"
) -> dict:
    """
    Compare CRPS across different uncertainty estimation methods.
    
    Args:
        model: Trained model
        test_data: Test data tuple
        methods: List of methods to compare
        n_samples: Number of samples for Monte Carlo
        device: Device to run on
        
    Returns:
        Dictionary with CRPS results for each method
    """
    results = {}
    
    for method in methods:
        try:
            crps, predictions = calculate_crps_for_model(
                model, test_data, method, n_samples, device
            )
            metrics = calculate_crps_metrics(crps, predictions, test_data[1])
            results[method] = {
                'crps_values': crps.cpu().numpy(),
                'predictions': predictions.cpu().numpy(),
                'metrics': metrics
            }
        except Exception as e:
            print(f"Error calculating CRPS for method {method}: {e}")
            results[method] = None
    
    return results
