# evaluation/metrics.py
"""
Evaluation metrics for BSNP.
"""

import torch
import numpy as np
from typing import Dict, Optional, Tuple, List
from torch.utils.data import DataLoader


def compute_mse(pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float:
    """
    Compute mean squared error.
    
    Args:
        pred: Predictions, shape (batch_size, ...)
        target: Targets, shape (batch_size, ...)
        mask: Optional mask for valid points, shape (batch_size, ...)
    
    Returns:
        MSE value
    """
    if mask is not None:
        diff = (pred - target) ** 2
        diff = diff * mask.unsqueeze(1)  # Broadcast mask
        return (diff.sum() / mask.sum()).item()
    else:
        return torch.mean((pred - target) ** 2).item()


def compute_rmse(pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float:
    """Compute root mean squared error."""
    return np.sqrt(compute_mse(pred, target, mask))


def compute_mae(pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float:
    """
    Compute mean absolute error.
    
    Args:
        pred: Predictions
        target: Targets
        mask: Optional mask for valid points
    
    Returns:
        MAE value
    """
    if mask is not None:
        diff = torch.abs(pred - target)
        diff = diff * mask.unsqueeze(1)
        return (diff.sum() / mask.sum()).item()
    else:
        return torch.mean(torch.abs(pred - target)).item()


def compute_relative_error(pred: torch.Tensor, target: torch.Tensor, 
                          mask: Optional[torch.Tensor] = None, eps: float = 1e-8) -> float:
    """
    Compute relative error.
    
    Args:
        pred: Predictions
        target: Targets
        mask: Optional mask for valid points
        eps: Small constant for numerical stability
    
    Returns:
        Relative error
    """
    if mask is not None:
        diff = torch.abs(pred - target)
        rel_diff = diff / (torch.abs(target) + eps)
        rel_diff = rel_diff * mask.unsqueeze(1)
        return (rel_diff.sum() / mask.sum()).item()
    else:
        return torch.mean(torch.abs(pred - target) / (torch.abs(target) + eps)).item()


def compute_mnse(pred: torch.Tensor, target: torch.Tensor, 
                 mask: Optional[torch.Tensor] = None, eps: float = 1e-10) -> float:
    """
    Compute Mean Normalized Squared Error (MNSE).
    
    MNSE = (1/N_t) * Σ ||μ_θ(x_j) - u(x_j)||²₂ / ||u(x_j)||²₂
    
    This measures the relative prediction error at each point, normalized by
    the magnitude of the target at that point.
    
    Args:
        pred: Predictions, shape (batch_size, output_dim, n_points)
        target: Targets, shape (batch_size, output_dim, n_points)
        mask: Optional mask for valid points, shape (batch_size, n_points)
        eps: Small constant to avoid division by zero
    
    Returns:
        MNSE value
    """
    # Compute squared error at each point: ||μ - u||²₂
    # Sum over output dimensions (dim=1)
    squared_error = torch.sum((pred - target) ** 2, dim=1)  # (batch_size, n_points)
    
    # Compute squared norm of target at each point: ||u||²₂
    target_norm_sq = torch.sum(target ** 2, dim=1)  # (batch_size, n_points)
    
    if mask is not None:
        # Apply mask
        squared_error = squared_error * mask
        target_norm_sq = target_norm_sq * mask
        
        # Compute normalized error (avoid division by zero)
        # Only compute for points where target_norm_sq > eps
        valid = (target_norm_sq > eps) & (mask > 0)
        
        if valid.sum() == 0:
            # No valid points
            return 0.0
        
        normalized_error = torch.zeros_like(squared_error)
        normalized_error[valid] = squared_error[valid] / target_norm_sq[valid]
        
        # Average over valid points
        mnse = normalized_error.sum() / valid.sum()
    else:
        # Compute valid points (where target is not too small)
        valid = target_norm_sq > eps
        
        if valid.sum() == 0:
            return 0.0
        
        normalized_error = torch.zeros_like(squared_error)
        normalized_error[valid] = squared_error[valid] / target_norm_sq[valid]
        
        # Average over valid points
        mnse = normalized_error.sum() / valid.sum()
    
    return mnse.item()


def compute_ecp(mean: torch.Tensor, sigma: torch.Tensor, target: torch.Tensor,
                mask: Optional[torch.Tensor] = None, confidence_level: float = 0.90) -> float:
    """
    Compute Empirical Coverage Probability (ECP) at specified confidence level.
    
    ECP measures the proportion of true values that fall within the predicted
    confidence interval. For a well-calibrated model:
    - ECP ≈ confidence_level indicates good calibration
    - ECP < confidence_level indicates over-confident predictions (intervals too narrow)
    - ECP > confidence_level indicates under-confident predictions (intervals too wide)
    
    Default confidence level is 90% as commonly used in uncertainty quantification.
    
    Args:
        mean: Predicted means, shape (batch_size, output_dim, n_points)
        sigma: Predicted std devs, shape (batch_size, output_dim, n_points)
        target: Target values, shape (batch_size, output_dim, n_points)
        mask: Optional mask for valid points, shape (batch_size, n_points)
        confidence_level: Confidence level (default: 0.90 for 90% CI)
    
    Returns:
        ECP value in [0, 1]
    """
    # Compute z-score for the confidence level
    # For 90% CI: z ≈ 1.645
    # For 95% CI: z ≈ 1.960
    from scipy import stats
    z_score = stats.norm.ppf((1 + confidence_level) / 2)
    
    # Compute confidence interval bounds
    lower_bound = mean - z_score * sigma
    upper_bound = mean + z_score * sigma
    
    # Check if target falls within interval
    within_interval = (target >= lower_bound) & (target <= upper_bound)
    
    if mask is not None:
        # Apply mask: only count valid points
        within_interval = within_interval * mask.unsqueeze(1)
        ecp = within_interval.float().sum() / mask.sum()
    else:
        ecp = within_interval.float().mean()
    
    return ecp.item()


def compute_nll(mean: torch.Tensor, sigma: torch.Tensor, target: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> float:
    """
    Compute negative log-likelihood under Gaussian assumption.
    
    Args:
        mean: Predicted means, shape (batch_size, output_dim, n_target)
        sigma: Predicted std devs, shape (batch_size, output_dim, n_target)
        target: Target values, shape (batch_size, output_dim, n_target)
        mask: Optional mask for valid points, shape (batch_size, n_target)
    
    Returns:
        NLL value
    """
    # Gaussian NLL: 0.5 * (log(2π) + log(σ²) + (y - μ)² / σ²)
    nll = 0.5 * (np.log(2 * np.pi) + 2 * torch.log(sigma) + 
                 ((target - mean) ** 2) / (sigma ** 2))
    
    if mask is not None:
        nll = nll * mask.unsqueeze(1)
        return (nll.sum() / mask.sum()).item()
    else:
        return nll.mean().item()


def compute_calibration(mean: torch.Tensor, sigma: torch.Tensor, target: torch.Tensor,
                       mask: Optional[torch.Tensor] = None,
                       num_bins: int = 10) -> Dict[str, np.ndarray]:
    """
    Compute calibration of predictive uncertainty.
    
    Measures whether p% confidence intervals contain p% of true values.
    
    Args:
        mean: Predicted means
        sigma: Predicted std devs
        target: Target values
        mask: Optional mask for valid points
        num_bins: Number of confidence levels to check
    
    Returns:
        Dictionary with calibration statistics
    """
    # Compute z-scores
    z_scores = torch.abs((target - mean) / sigma)
    
    if mask is not None:
        # Only consider valid points
        z_scores_flat = z_scores[mask.unsqueeze(1).expand_as(z_scores)]
    else:
        z_scores_flat = z_scores.flatten()
    
    # Confidence levels
    confidence_levels = np.linspace(0, 1, num_bins + 1)[1:]  # [0.1, 0.2, ..., 1.0]
    
    # For Gaussian, confidence level p corresponds to z-score
    from scipy import stats
    z_threshold = [stats.norm.ppf((1 + p) / 2) for p in confidence_levels]
    
    # Compute actual coverage
    actual_coverage = np.array([
        (z_scores_flat <= z).float().mean().item()
        for z in z_threshold
    ])
    
    # Calibration error: average absolute difference
    calibration_error = np.mean(np.abs(actual_coverage - confidence_levels))
    
    return {
        'expected_coverage': confidence_levels,
        'actual_coverage': actual_coverage,
        'calibration_error': calibration_error
    }


def compute_sharpness(sigma: torch.Tensor, mask: Optional[torch.Tensor] = None) -> float:
    """
    Compute sharpness (average uncertainty).
    
    Lower values indicate more confident predictions.
    Should be balanced with calibration (ECP).
    
    Args:
        sigma: Predicted standard deviations
        mask: Optional mask for valid points
    
    Returns:
        Average standard deviation
    """
    if mask is not None:
        sigma_masked = sigma * mask.unsqueeze(1)
        return (sigma_masked.sum() / mask.sum()).item()
    else:
        return sigma.mean().item()


def evaluate_model(model: torch.nn.Module, 
                  data_loader: DataLoader,
                  device: str = 'cpu',
                  compute_uncertainty: bool = True,
                  confidence_level: float = 0.90) -> Dict[str, float]:
    """
    Evaluate model on a dataset.
    
    Args:
        model: BSNP model
        data_loader: DataLoader for evaluation data
        device: Device to run on
        compute_uncertainty: Whether to compute uncertainty metrics
        confidence_level: Confidence level for ECP (default: 0.90)
    
    Returns:
        Dictionary of metrics including MNSE and ECP
    """
    model.eval()
    model = model.to(device)
    
    all_mean = []
    all_sigma = []
    all_target = []
    all_mask = []
    
    with torch.no_grad():
        for batch in data_loader:
            # Move batch to device
            x_context = batch['x_context'].to(device)
            y_context = batch['y_context'].to(device)
            x_target = batch['x_target'].to(device)
            y_target = batch['y_target'].to(device)
            target_mask = batch.get('target_mask', None)
            
            if target_mask is not None:
                target_mask = target_mask.to(device)
            
            # Get lambda params if available
            lambda_params = batch.get('lambda_params', None)
            if lambda_params is not None:
                lambda_params = lambda_params.to(device)
            
            # Forward pass
            if lambda_params is not None:
                mean, sigma = model(x_context, y_context, x_target, lambda_params)
            else:
                mean, sigma = model(x_context, y_context, x_target)
            
            # Store predictions - flatten only valid points
            if target_mask is not None:
                # Extract only valid points for each sample
                for i in range(mean.shape[0]):
                    valid_mask = target_mask[i]
                    n_valid = valid_mask.sum().item()
                    
                    all_mean.append(mean[i, :, :n_valid])
                    if compute_uncertainty:
                        all_sigma.append(sigma[i, :, :n_valid])
                    all_target.append(y_target[i, :n_valid].transpose(0, 1))
            else:
                all_mean.append(mean)
                if compute_uncertainty:
                    all_sigma.append(sigma)
                all_target.append(y_target.transpose(1, 2))
    
    # Concatenate all predictions - now they should have matching dimensions
    all_mean_cat = torch.cat(all_mean, dim=-1)
    all_target_cat = torch.cat(all_target, dim=-1)
    
    # Compute basic metrics
    metrics = {
        'mse': compute_mse(all_mean_cat, all_target_cat),
        'rmse': compute_rmse(all_mean_cat, all_target_cat),
        'mae': compute_mae(all_mean_cat, all_target_cat),
        'relative_error': compute_relative_error(all_mean_cat, all_target_cat),
        'mnse': compute_mnse(all_mean_cat, all_target_cat)  # Add MNSE
    }
    
    # Uncertainty metrics
    if compute_uncertainty:
        all_sigma_cat = torch.cat(all_sigma, dim=-1)
        
        # Add ECP at specified confidence level (default 90%)
        metrics['ecp'] = compute_ecp(all_mean_cat, all_sigma_cat, all_target_cat, 
                                     confidence_level=confidence_level)
        metrics['confidence_level'] = confidence_level
        
        # Other uncertainty metrics
        metrics['nll'] = compute_nll(all_mean_cat, all_sigma_cat, all_target_cat)
        metrics['sharpness'] = compute_sharpness(all_sigma_cat)
        
        calibration = compute_calibration(all_mean_cat, all_sigma_cat, all_target_cat)
        metrics['calibration_error'] = calibration['calibration_error']
    
    return metrics


def compute_physics_residual(model: torch.nn.Module,
                            x_test: torch.Tensor,
                            lambda_params: torch.Tensor,
                            pde_func: callable) -> float:
    """
    Compute physics residual on test points.
    
    Args:
        model: Trained model
        x_test: Test points, shape (batch_size, n_points, spatial_dim)
        lambda_params: PDE parameters, shape (batch_size, n_params)
        pde_func: Function that computes PDE residual
    
    Returns:
        Average residual magnitude
    """
    model.eval()
    
    with torch.no_grad():
        # Get predictions
        mean, _ = model(x_test, None, x_test, lambda_params)
        
        # Compute PDE residual
        residual = pde_func(x_test, mean, lambda_params)
        
        return torch.abs(residual).mean().item()


def print_metrics(metrics: Dict[str, float], title: str = "Evaluation Metrics"):
    """
    Pretty print evaluation metrics.
    
    Args:
        metrics: Dictionary of metrics
        title: Title for the printout
    """
    print("\n" + "="*80)
    print(f"{title}")
    print("="*80)
    
    # Basic error metrics
    if 'mse' in metrics:
        print(f"  MSE:              {metrics['mse']:.6f}")
    if 'rmse' in metrics:
        print(f"  RMSE:             {metrics['rmse']:.6f}")
    if 'mae' in metrics:
        print(f"  MAE:              {metrics['mae']:.6f}")
    if 'relative_error' in metrics:
        print(f"  Relative Error:   {metrics['relative_error']:.6f}")
    
    # Key metrics for paper
    print("\n  📊 Key Metrics:")
    if 'mnse' in metrics:
        print(f"  MNSE:             {metrics['mnse']:.6f}")
    if 'ecp' in metrics:
        conf_level = metrics.get('confidence_level', 0.90)
        print(f"  ECP ({conf_level*100:.0f}%):        {metrics['ecp']:.4f}")
    
    # Uncertainty metrics
    if 'nll' in metrics or 'sharpness' in metrics:
        print("\n  🎯 Uncertainty Metrics:")
    if 'nll' in metrics:
        print(f"  NLL:              {metrics['nll']:.6f}")
    if 'sharpness' in metrics:
        print(f"  Sharpness:        {metrics['sharpness']:.6f}")
    if 'calibration_error' in metrics:
        print(f"  Calibration Err:  {metrics['calibration_error']:.4f}")
    
    print("="*80 + "\n")


def compute_multiple_confidence_levels(mean: torch.Tensor, sigma: torch.Tensor, 
                                       target: torch.Tensor,
                                       mask: Optional[torch.Tensor] = None,
                                       levels: List[float] = [0.80, 0.90, 0.95, 0.99]) -> Dict[str, float]:
    """
    Compute ECP at multiple confidence levels for detailed calibration analysis.
    
    Args:
        mean: Predicted means
        sigma: Predicted std devs
        target: Target values
        mask: Optional mask for valid points
        levels: List of confidence levels to compute
    
    Returns:
        Dictionary mapping confidence level to ECP value
    """
    ecp_dict = {}
    for level in levels:
        ecp_dict[f'ecp_{int(level*100)}'] = compute_ecp(mean, sigma, target, mask, level)
    return ecp_dict