import torch
import numpy as np
from typing import List
import torch.nn.functional as F
from scipy.stats import norm


def compute_metrics(predictions_mean: torch.Tensor, predictions_var: torch.Tensor, 
                             targets: torch.Tensor, confidence_levels: List[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]):
    """
    Compute regression metrics including MSE, NLL, MAE, CRPS, Coverage, Width, and ACE.
    
    Parameters
    ----------
    predictions_mean : torch.Tensor
        Predicted means
    predictions_var : torch.Tensor  
        Predicted variances
    targets : torch.Tensor
        True target values
    confidence_levels : List[float]
        Confidence levels for coverage computation
    
    Returns
    -------
    Dict containing computed metrics
    """
    predictions_std = torch.sqrt(predictions_var)
    
    # MSE (Mean Squared Error)
    mse = F.mse_loss(predictions_mean, targets).item()
    
    # NLL (Negative Log-Likelihood)
    nll = -torch.mean(torch.distributions.Normal(
        predictions_mean, 
        predictions_std
    ).log_prob(targets)).item()
    
    # MAE (Mean Absolute Error)
    mae = torch.mean(torch.abs(predictions_mean - targets)).item()
    
    # CRPS (Continuous Ranked Probability Score) 
    # For normal distribution: CRPS = σ * (z * (2 * Φ(z) - 1) + 2 * φ(z) - 1/√π)
    # where z = (y - μ)/σ
    z = (targets - predictions_mean) / predictions_std
    z_cpu = z.cpu().numpy()
    std_cpu = predictions_std.cpu().numpy()
    
    # Handle numerical stability
    z_cpu = np.clip(z_cpu, -10, 10)
    
    phi_z = norm.cdf(z_cpu)  # CDF
    pdf_z = norm.pdf(z_cpu)  # PDF
    
    crps_normalized = z_cpu * (2 * phi_z - 1) + 2 * pdf_z - 1/np.sqrt(np.pi)
    crps = np.mean(std_cpu * crps_normalized)
    
    # Coverage and Width for different confidence levels
    coverage_results = {}
    width_results = {}
    empirical_coverage = []
    nominal_coverage = []
    
    for alpha in confidence_levels:
        # Calculate prediction intervals
        z_score = norm.ppf(1 - (1-alpha)/2)  # Two-sided interval
        lower_bound = predictions_mean - z_score * predictions_std
        upper_bound = predictions_mean + z_score * predictions_std
        
        # Coverage (PICP - Prediction Interval Coverage Probability)
        in_interval = (targets >= lower_bound) & (targets <= upper_bound)
        coverage = torch.mean(in_interval.float()).item()
        
        # Width
        width = torch.mean(upper_bound - lower_bound).item()
        
        coverage_results[f'coverage_{int(alpha*100)}'] = coverage
        width_results[f'width_{int(alpha*100)}'] = width
        
        empirical_coverage.append(coverage)
        nominal_coverage.append(alpha)
    
    # ACE (Average Coverage Error)
    empirical_coverage = np.array(empirical_coverage)
    nominal_coverage = np.array(nominal_coverage)
    ace = np.mean(np.abs(empirical_coverage - nominal_coverage))
    
    # Also compute 95% coverage and width as standard metrics
    z_95 = norm.ppf(0.975)  # 95% confidence interval
    lower_95 = predictions_mean - z_95 * predictions_std
    upper_95 = predictions_mean + z_95 * predictions_std
    in_interval_95 = (targets >= lower_95) & (targets <= upper_95)
    coverage_95 = torch.mean(in_interval_95.float()).item()
    width_95 = torch.mean(upper_95 - lower_95).item()
    
    return {
        'mse': mse,
        'nll': nll,
        'mae': mae,
        'crps': crps,
        'coverage_95': coverage_95,
        'width_95': width_95,
        'ace': ace,
        **coverage_results,
        **width_results
    }

