"""
Evaluation Functions for Bayesian LSTM

This module contains all evaluation-related functions:
- Point prediction metrics (MAE, RMSE, R2)
- Uncertainty quantification metrics (NLL, CRPS, ECE)
- Calibration metrics (PICP, MPIW, calibration curves)
- Selective prediction analysis
- OOD detection metrics
"""

import numpy as np
import json
from pathlib import Path
from scipy.stats import norm
from scipy.special import erf
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    roc_auc_score,
    precision_recall_curve,
    auc,
    roc_curve
)


# ==============================================================================
# Multi-Seed Aggregation Utilities
# ==============================================================================

def load_aggregated_results(json_path):
    """
    Load aggregated results from a JSON file.

    Args:
        json_path: Path to the aggregated_results.json file

    Returns:
        dict: Aggregated results with structure:
              {model_name: {metric_name: {'mean': float, 'std': float,
                                          'min': float, 'max': float, 'values': list}}}

    Example:
        from modules.evaluation import load_aggregated_results
        aggregated_results = load_aggregated_results('multi_seed_results/aggregated_results.json')
    """
    json_path = Path(json_path)
    if not json_path.exists():
        raise FileNotFoundError(f"Aggregated results file not found: {json_path}")

    with open(json_path, 'r') as f:
        aggregated_results = json.load(f)

    print(f"Loaded results for {len(aggregated_results)} models from {json_path}")
    return aggregated_results


# ==============================================================================
# Data Transformation
# ==============================================================================

def inverse_transform_predictions(y_scaled, scaler):
    """
    Inverse transform predictions from scaled to original space.
    Args:
        y_scaled: Scaled predictions/labels (N, 1) or (N,)
        scaler: Fitted StandardScaler object
    Returns:
        y_original: Values in original scale (e.g., ug/m3)
    """
    if y_scaled.ndim == 1:
        y_scaled = y_scaled.reshape(-1, 1)
    y_original = scaler.inverse_transform(y_scaled)
    return y_original.flatten()


# ==============================================================================
# Point Prediction Metrics
# ==============================================================================

def calculate_point_metrics(y_true, y_pred):
    """
    Calculate point prediction metrics in ORIGINAL scale.
    Args:
        y_true: True values in original scale (N,)
        y_pred: Predicted values in original scale (N,)
    Returns:
        dict: Dictionary of metrics
    """
    # Basic metrics
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    r2 = r2_score(y_true, y_pred)
    # Avoid division by zero - only compute where y_true > 0
    mask = y_true > 0
    ape = np.abs((y_true[mask] - y_pred[mask]) / y_true[mask]) * 100
    # LOWER IS BETTER - measures worst 10% of errors (tail robustness)
    metrics = {
        'MAE': mae,
        'RMSE': rmse,
        'R2': r2,
    }
    return metrics


# ==============================================================================
# Uncertainty Quantification Metrics
# ==============================================================================

def compute_nll_gaussian(y_true, y_pred_mean, y_pred_std, epsilon=1e-6):
    """
    Compute Negative Log-Likelihood assuming Gaussian predictive distribution.

    NLL = -log p(y|x) = 0.5*log(2*pi) + 0.5*log(sigma^2) + (y - mu)^2/(2*sigma^2)
    Lower is better.

    Args:
        y_true: True values (N,)
        y_pred_mean: Predicted mean (N,)
        y_pred_std: Predicted standard deviation (N,)
        epsilon: Small value to avoid log(0)

    Returns:
        nll: Negative log-likelihood (scalar)
    """
    # Clip std to avoid numerical issues
    y_pred_std = np.maximum(y_pred_std, epsilon)

    # Compute NLL
    nll = 0.5 * np.log(2 * np.pi * y_pred_std**2) + \
          0.5 * ((y_true - y_pred_mean)**2) / (y_pred_std**2)

    return np.mean(nll)


def compute_crps_gaussian(y_true, y_pred_mean, y_pred_std):
    """
    Compute Continuous Ranked Probability Score for Gaussian predictive distribution.

    CRPS = sigma[z/sqrt(pi) - 2*phi(z) - 1/(2*sqrt(pi))]
    where z = (y - mu)/sigma, phi is standard normal PDF

    Lower is better. CRPS generalizes MAE to probabilistic predictions.

    Args:
        y_true: True values (N,)
        y_pred_mean: Predicted mean (N,)
        y_pred_std: Predicted standard deviation (N,)

    Returns:
        crps: Continuous ranked probability score (scalar)
    """
    # Standardized error
    z = (y_true - y_pred_mean) / y_pred_std

    # Standard normal PDF
    phi = (1 / np.sqrt(2 * np.pi)) * np.exp(-0.5 * z**2)

    # Standard normal CDF using error function
    Phi = 0.5 * (1 + erf(z / np.sqrt(2)))

    # CRPS formula
    crps = y_pred_std * (z * (2 * Phi - 1) + 2 * phi - 1 / np.sqrt(np.pi))

    return np.mean(crps)


def compute_ece_regression_fixed(y_true, y_pred_mean, y_pred_std, num_levels=10):
    """
    CORRECTED Expected Calibration Error for regression.

    Uses prediction interval coverage approach (appropriate for regression):
    1. For each confidence level p in [0.1, 0.2, ..., 0.9]:
       - Compute prediction interval bounds using Gaussian quantiles
       - Expected coverage = p
       - Observed coverage = fraction of points within PI
    2. ECE = mean(|observed_coverage - expected_coverage|)

    This is fundamentally different from classification ECE which uses
    probability bins. For regression, we measure if our predicted
    uncertainty (std) correctly captures the true coverage.

    Lower is better (0 = perfectly calibrated).
    """
    confidence_levels = np.linspace(0.1, 0.9, num_levels)
    expected_coverage = []
    observed_coverage = []

    for conf_level in confidence_levels:
        # Compute z-score for this confidence level
        alpha = 1 - conf_level
        z = norm.ppf(1 - alpha/2)

        # Prediction interval bounds
        lower = y_pred_mean - z * y_pred_std
        upper = y_pred_mean + z * y_pred_std

        # Observed coverage
        in_interval = np.logical_and(y_true >= lower, y_true <= upper)
        obs_cov = np.mean(in_interval)

        expected_coverage.append(conf_level)
        observed_coverage.append(obs_cov)

    expected_coverage = np.array(expected_coverage)
    observed_coverage = np.array(observed_coverage)

    # ECE = mean absolute calibration error
    ece = np.mean(np.abs(observed_coverage - expected_coverage))

    # Return calibration data for plotting
    calibration_data = [{'confidence': exp, 'coverage': obs, 'diff': abs(obs-exp)}
                        for exp, obs in zip(expected_coverage, observed_coverage)]

    return ece, calibration_data


def compute_calibration_error(y_true, y_pred_mean, y_pred_std, num_bins=10):
    """
    Compute Expected Calibration Error (ECE) for regression.

    Measures if predicted confidence matches actual accuracy.
    Lower is better (0 = perfectly calibrated).

    Args:
        y_true: True values (N,)
        y_pred_mean: Predicted mean (N,)
        y_pred_std: Predicted standard deviation (N,)
        num_bins: Number of confidence bins

    Returns:
        ece: Expected calibration error (scalar)
        calibration_data: List of (confidence, accuracy, bin_size) for plotting
    """
    N = len(y_true)

    # Compute absolute errors
    errors = np.abs(y_true - y_pred_mean)

    # Normalize errors by predicted uncertainty
    # confidence = 1 - normalized_error (higher std = lower confidence in being exact)
    normalized_errors = errors / (y_pred_std + 1e-6)

    # Create confidence bins based on predicted uncertainty
    # Lower std = higher confidence
    confidence_scores = 1.0 / (1.0 + y_pred_std)  # Higher confidence when std is low

    # Sort by confidence
    sorted_indices = np.argsort(confidence_scores)
    sorted_errors = errors[sorted_indices]
    sorted_confidence = confidence_scores[sorted_indices]

    # Create bins
    bin_size = N // num_bins
    ece = 0.0
    calibration_data = []

    for i in range(num_bins):
        start_idx = i * bin_size
        end_idx = (i + 1) * bin_size if i < num_bins - 1 else N

        bin_errors = sorted_errors[start_idx:end_idx]
        bin_confidence = sorted_confidence[start_idx:end_idx]

        if len(bin_errors) > 0:
            avg_confidence = np.mean(bin_confidence)
            avg_error = np.mean(bin_errors)
            bin_weight = len(bin_errors) / N

            ece += bin_weight * np.abs(avg_confidence - (1.0 / (1.0 + avg_error)))

            calibration_data.append({
                'confidence': avg_confidence,
                'error': avg_error,
                'count': len(bin_errors)
            })

    return ece, calibration_data


def compute_nll_ensemble_pooled_noise(y_true, ensemble_predictions_list,
                                      y_val, ensemble_val_predictions_list, epsilon=1e-6):
    """
    CORRECT NLL for Deep Ensemble using pooled noise variance.

    Key insight: Ensemble uncertainty (epistemic) is NOT the same as noise (aleatoric).
    We must estimate a single pooled noise variance from validation residuals.

    Steps:
    1. Compute residuals on validation set across all ensemble members
    2. Pool all residuals to estimate σ² (noise variance)
    3. Use this σ² in the NLL formula for test set

    Args:
        y_true: True test values (N,)
        ensemble_predictions_list: List of M predictions on test set, each (N,)
        y_val: True validation values (N_val,)
        ensemble_val_predictions_list: List of M predictions on validation set, each (N_val,)
        epsilon: Small value to avoid division by zero

    Returns:
        nll_ensemble: Negative log-likelihood (scalar)
        sigma_noise: Estimated noise std (scalar)
    """
    # Flatten arrays
    y_val_flat = np.array(y_val).flatten()
    y_test_flat = np.array(y_true).flatten()

    # Stack ensemble predictions: (M, N_val) and (M, N_test)
    val_preds = np.array([np.array(p).flatten() for p in ensemble_val_predictions_list])  # (M, N_val)

    # Compute all residuals on validation set
    all_val_residuals = y_val_flat[np.newaxis, :] - val_preds  # (M, N_val)

    # Pooled noise variance: mean of squared residuals across all members
    sigma_sq_noise = np.mean(all_val_residuals ** 2)
    sigma_sq_noise = max(sigma_sq_noise, epsilon)  # Floor to avoid log(0)

    # Compute ensemble mean on test set
    test_preds = np.array([np.array(p).flatten() for p in ensemble_predictions_list])  # (M, N_test)
    y_pred_mean = np.mean(test_preds, axis=0)  # (N_test,)

    # NLL using pooled noise variance
    # NLL = 0.5*log(2π) + 0.5*log(σ²) + (y - μ)²/(2σ²)
    nll = 0.5 * np.log(2 * np.pi * sigma_sq_noise) + \
          0.5 * ((y_test_flat - y_pred_mean)**2) / sigma_sq_noise

    return np.mean(nll), np.sqrt(sigma_sq_noise)


def compute_nll_ensemble_fixed(y_true, ensemble_predictions_list, epsilon=1e-6):
    """
    DEPRECATED: Incorrect NLL for Deep Ensemble (uses epistemic std as noise std).

    DO NOT USE THIS FUNCTION. Use compute_nll_ensemble_pooled_noise() instead.

    This function incorrectly uses the ensemble std (epistemic uncertainty)
    as the noise std in the NLL formula. The correct approach is to estimate
    a pooled noise variance from validation residuals.

    Kept for backwards compatibility only.
    """
    M = len(ensemble_predictions_list)

    # Stack predictions
    ensemble_array = np.array(ensemble_predictions_list)  # (M, N)
    ensemble_std = np.std(ensemble_array, axis=0)  # (N,)
    ensemble_std = np.maximum(ensemble_std, epsilon)

    # Compute NLL for each member
    nll_per_member = []
    for i, y_pred in enumerate(ensemble_predictions_list):
        y_pred = np.array(y_pred).flatten()
        y_true_flat = np.array(y_true).flatten()

        # NLL assuming Gaussian with ensemble std
        nll_i = 0.5 * np.log(2 * np.pi * ensemble_std**2) + \
                0.5 * ((y_true_flat - y_pred)**2) / (ensemble_std**2)
        nll_per_member.append(np.mean(nll_i))

    nll_ensemble = np.mean(nll_per_member)
    return nll_ensemble, np.array(nll_per_member)


def compute_calibration_auc_fixed(expected_coverage, observed_coverage):
    """
    Compute Calibration AUC - area between calibration curve and diagonal.

    Uses trapezoidal integration of |observed - expected|.
    Lower is better (0 = perfect calibration).
    """
    # Sort by expected coverage
    sort_idx = np.argsort(expected_coverage)
    expected_sorted = expected_coverage[sort_idx]
    observed_sorted = observed_coverage[sort_idx]

    # Integrate |observed - expected|
    cal_auc = np.trapz(np.abs(observed_sorted - expected_sorted), expected_sorted)

    return cal_auc


# ==============================================================================
# Prediction Intervals
# ==============================================================================

def compute_prediction_intervals(y_pred_samples, confidence_level=0.95):
    """
    Compute prediction intervals from MC samples.

    Args:
        y_pred_samples: MC samples (M, N) where M=num_samples, N=num_points
        confidence_level: Confidence level (e.g., 0.95 for 95% PI)

    Returns:
        lower_bound: Lower bound of PI (N,)
        upper_bound: Upper bound of PI (N,)
        pred_mean: Mean prediction (N,)
        pred_std: Standard deviation (N,)
    """
    # Compute mean and std
    pred_mean = np.mean(y_pred_samples, axis=0)
    pred_std = np.std(y_pred_samples, axis=0)

    # Compute percentiles for prediction interval
    alpha = 1 - confidence_level
    lower_percentile = (alpha / 2) * 100
    upper_percentile = (1 - alpha / 2) * 100

    lower_bound = np.percentile(y_pred_samples, lower_percentile, axis=0)
    upper_bound = np.percentile(y_pred_samples, upper_percentile, axis=0)

    return lower_bound, upper_bound, pred_mean, pred_std


def compute_picp_mpiw(y_true, lower_bound, upper_bound):
    """
    Compute Prediction Interval Coverage Probability (PICP) and
    Mean Prediction Interval Width (MPIW).

    Args:
        y_true: True values (N,)
        lower_bound: Lower bound of PI (N,)
        upper_bound: Upper bound of PI (N,)

    Returns:
        picp: Prediction interval coverage probability (0-1, higher is better)
        mpiw: Mean prediction interval width (lower is better)
    """
    # Coverage: fraction of true values inside the prediction interval
    coverage = np.logical_and(y_true >= lower_bound, y_true <= upper_bound)
    picp = np.mean(coverage)

    # Width: average width of prediction intervals
    mpiw = np.mean(upper_bound - lower_bound)

    return picp, mpiw


def compute_interval_score(y_true, lower_bound, upper_bound, confidence_level=0.95):
    """
    Compute Interval Score (Winkler Score) - the gold standard for evaluating
    prediction intervals. This is a proper scoring rule that balances width and coverage.

    Formula:
        IS = (upper - lower) + (2/α)(lower - y)𝟙(y < lower) + (2/α)(y - upper)𝟙(y > upper)

    where α = 1 - confidence_level (e.g., 0.05 for 95% intervals)

    The score consists of:
    - Width term: Rewards narrow intervals
    - Penalty terms: Heavily penalizes when true value falls outside the interval

    Args:
        y_true: True values (N,)
        lower_bound: Lower bound of PI (N,)
        upper_bound: Upper bound of PI (N,)
        confidence_level: Confidence level (e.g., 0.95 for 95% intervals)

    Returns:
        interval_score: Mean interval score (LOWER IS BETTER)

    Reference:
        Winkler, R. L. (1972). A decision-theoretic approach to interval estimation.
        Journal of the American Statistical Association, 67(337), 187-191.
    """
    alpha = 1 - confidence_level

    # Width term (always positive)
    width = upper_bound - lower_bound

    # Penalty for underprediction (y < lower)
    penalty_lower = (2 / alpha) * np.maximum(0, lower_bound - y_true)

    # Penalty for overprediction (y > upper)
    penalty_upper = (2 / alpha) * np.maximum(0, y_true - upper_bound)

    # Total score per point
    scores = width + penalty_lower + penalty_upper

    # Mean interval score
    interval_score = np.mean(scores)

    return interval_score


def compute_cwc(picp, mpiw, y_range, target_coverage=0.95, eta=50):
    """
    Compute Coverage Width-based Criterion (CWC).
    
    CWC combines PICP and NMPIW into a single score that penalizes undercoverage
    exponentially while rewarding narrow intervals when coverage is met.
    
    Formula:
        NMPIW = MPIW / R  (where R is the target range)
        CWC = NMPIW                              if PICP >= target_coverage
        CWC = NMPIW * (1 + exp(-η*(PICP-μ)))     if PICP < target_coverage
    
    where μ = target_coverage, η controls penalty steepness.
    
    Args:
        picp: Prediction Interval Coverage Probability (0-1)
        mpiw: Mean Prediction Interval Width (unnormalized)
        y_range: Range of target variable (y_max - y_min)
        target_coverage: Target coverage level (default 0.95)
        eta: Penalty steepness parameter (default 50)
    
    Returns:
        cwc: Coverage Width-based Criterion (LOWER IS BETTER)
    
    Reference:
        Khosravi, A., et al. (2011). IEEE Trans. Neural Networks.
    """
    # Normalize MPIW by target range
    nmpiw = mpiw / y_range
    
    if picp >= target_coverage:
        return nmpiw
    else:
        penalty = np.exp(-eta * (picp - target_coverage))
        return nmpiw * (1 + penalty)

# ==============================================================================
# Calibration Curves
# ==============================================================================

def compute_calibration_curve(y_true, y_pred_mean, y_pred_std, num_bins=10):
    """
    Compute calibration curve for regression.
    For each confidence level p, compute:
    - Expected coverage: p
    - Observed coverage: fraction of true values within p% prediction interval
    Args:
        y_true: True values (N,)
        y_pred_mean: Predicted mean (N,)
        y_pred_std: Predicted standard deviation (N,)
        num_bins: Number of confidence levels to evaluate
    Returns:
        expected_coverage: Expected coverage levels (num_bins,)
        observed_coverage: Observed coverage levels (num_bins,)
    """
    # Confidence levels to evaluate (e.g., 0.1, 0.2, ..., 0.9, 0.95, 0.99)
    confidence_levels = np.linspace(0.1, 0.99, num_bins)
    expected_coverage = []
    observed_coverage = []
    for conf_level in confidence_levels:
        # Compute z-score for this confidence level
        z = norm.ppf((1 + conf_level) / 2)
        # Compute prediction intervals
        lower = y_pred_mean - z * y_pred_std
        upper = y_pred_mean + z * y_pred_std
        # Check coverage
        coverage = np.logical_and(y_true >= lower, y_true <= upper)
        observed_cov = np.mean(coverage)
        expected_coverage.append(conf_level)
        observed_coverage.append(observed_cov)
    return np.array(expected_coverage), np.array(observed_coverage)


def compute_sharpness_curve(y_pred_std, num_bins=20):
    """
    Compute sharpness curve (distribution of predicted uncertainties).
    Args:
        y_pred_std: Predicted standard deviations (N,)
        num_bins: Number of histogram bins
    Returns:
        bin_edges: Histogram bin edges
        counts: Histogram counts
    """
    counts, bin_edges = np.histogram(y_pred_std, bins=num_bins)
    return bin_edges, counts


# ==============================================================================
# Selective Prediction
# ==============================================================================

def selective_prediction_analysis(y_true, y_pred_mean, y_pred_std, retention_rates):
    """
    Analyze model performance when rejecting most uncertain predictions.
    Args:
        y_true: True values (N,)
        y_pred_mean: Predicted mean (N,)
        y_pred_std: Predicted standard deviation (N,)
        retention_rates: List of retention rates (e.g., [1.0, 0.9, 0.8, 0.7])
    Returns:
        results: List of dicts with metrics at each retention rate
    """
    # Sort by uncertainty (ascending - keep most confident predictions)
    sorted_indices = np.argsort(y_pred_std)
    results = []
    for retention in retention_rates:
        # Keep top retention% of most confident predictions
        n_keep = int(len(y_true) * retention)
        keep_indices = sorted_indices[:n_keep]
        # Subset data
        y_true_subset = y_true[keep_indices]
        y_pred_subset = y_pred_mean[keep_indices]
        y_std_subset = y_pred_std[keep_indices]
        # Compute metrics
        mae = mean_absolute_error(y_true_subset, y_pred_subset)
        rmse = np.sqrt(mean_squared_error(y_true_subset, y_pred_subset))
        r2 = r2_score(y_true_subset, y_pred_subset)
        mask = y_true_subset > 0
        ape = np.abs((y_true_subset[mask] - y_pred_subset[mask]) / y_true_subset[mask]) * 100
        # Mean uncertainty of retained predictions
        mean_uncertainty = np.mean(y_std_subset)
        results.append({
            'retention': retention,
            'n_samples': n_keep,
            'MAE': mae,
            'RMSE': rmse,
            'R2': r2,
            'mean_std': mean_uncertainty,
        })
    return results


# ==============================================================================
# OOD Detection Metrics
# ==============================================================================

def compute_ood_detection_metrics(id_uncertainty, ood_uncertainty):
    """
    Compute OOD detection metrics using uncertainty as scores.

    Args:
        id_uncertainty: Uncertainty scores for in-distribution data (N_id,)
        ood_uncertainty: Uncertainty scores for out-of-distribution data (N_ood,)

    Returns:
        auroc: Area under ROC curve
        aupr: Area under precision-recall curve
        fpr_at_tpr95: False positive rate at 95% true positive rate
        tpr_at_fpr5: True positive rate at 5% false positive rate
    """
    # Flatten arrays if multi-dimensional
    id_unc_flat = id_uncertainty.flatten()
    ood_unc_flat = ood_uncertainty.flatten()
    y_true = np.concatenate([
        np.zeros(len(id_unc_flat), dtype=int),
        np.ones(len(ood_unc_flat), dtype=int),
    ])
    scores = np.concatenate([id_unc_flat, ood_unc_flat])
    auroc = roc_auc_score(y_true, scores)
    precision, recall, _ = precision_recall_curve(y_true, scores)
    aupr = auc(recall, precision)
    fpr, tpr, _ = roc_curve(y_true, scores)
    fpr_at_tpr95 = np.nan
    tpr_at_fpr5 = np.nan
    idx_tpr = np.where(tpr >= 0.95)[0]
    if len(idx_tpr) > 0:
        fpr_at_tpr95 = fpr[idx_tpr[0]]
    idx_fpr = np.where(fpr <= 0.05)[0]
    if len(idx_fpr) > 0:
        tpr_at_fpr5 = tpr[idx_fpr[-1]]
    return auroc, aupr, fpr_at_tpr95, tpr_at_fpr5


def compute_ood_crps_picp_mpiw(y_true, pred_mean, pred_std, confidence_level):
    """
    Compute CRPS, PICP, and MPIW for OOD predictions.
    """
    z = norm.ppf((1 + confidence_level) / 2)
    lower = pred_mean - z * pred_std
    upper = pred_mean + z * pred_std
    picp, mpiw = compute_picp_mpiw(y_true, lower, upper)
    crps = compute_crps_gaussian(y_true, pred_mean, pred_std)
    return crps, picp, mpiw


# ==============================================================================
# MC Sampling for Bayesian Models
# ==============================================================================

def mc_sample_bayesian(model, variational_layers, X, num_samples, scaler_y, seed=42):
    """
    Perform proper MC sampling for Bayesian models.
    Steps:
    1. Clear cache before each sample to ensure fresh weight draws
    2. Use model(X, training=True) to trigger weight sampling from variational posterior
    Args:
        model: Bayesian Keras model
        variational_layers: List of variational layers with clear_cache method
        X: Input data
        num_samples: Number of MC samples
        scaler_y: Scaler for inverse transform
    Returns:
        samples: Array of shape (num_samples, N)
    """
    from modules.training import set_seed, clear_model_cache

    samples = []
    set_seed(seed)
    for i in range(num_samples):
        # Clear cached weights to force fresh sampling
        clear_model_cache(variational_layers)
        # Use training=True to enable weight sampling
        y_sample_scaled = model(X, training=True).numpy()
        y_sample = inverse_transform_predictions(y_sample_scaled, scaler_y)
        samples.append(y_sample)
        if (i + 1) % 10 == 0:
            print(f"    MC sample {i+1}/{num_samples}", end='\r')
    print()  # New line after progress
    return np.array(samples)


# ==============================================================================
# Singular Value Analysis
# ==============================================================================

def analyze_gate_singular_values(model, target_energy=70.0):
    """
    Analyze singular value decay for x_to_gates and h_to_gates layers.
    Rule of thumb: pick the minimum rank where the overall architecture
    (average across LSTM layers) reaches target_energy (%).

    Returns:
        proposed_rank: Suggested rank for low-rank approximation
        layer_energy: Dictionary of cumulative energy per layer
    """
    import re

    gate_layers = []
    for layer in model.layers:
        name = getattr(layer, "name", "")
        if name.endswith("_x_to_gates") or name.endswith("_h_to_gates"):
            weights = layer.get_weights()
            if weights:
                W = weights[0]
                if W.ndim == 2:
                    gate_layers.append((name, W))
    if not gate_layers:
        print("No gate layers found.")
        return None, {}

    # Precompute cumulative energy per gate layer
    layer_energy = {}
    for name, W in gate_layers:
        s = np.linalg.svd(W, full_matrices=False, compute_uv=False)
        total_energy = np.sum(s ** 2)
        cumulative_energy = np.cumsum(s ** 2) / total_energy * 100
        layer_energy[name] = cumulative_energy

    # Group x/h gate layers by LSTM layer index
    groups = {}
    for name in layer_energy:
        m = re.match(r"layer(\d+)_", name)
        if m:
            layer_id = int(m.group(1))
            groups.setdefault(layer_id, []).append(name)

    max_rank = min(len(v) for v in layer_energy.values())
    proposed_rank = max_rank

    # Find minimal rank where overall average >= target_energy
    for r in range(1, max_rank + 1):
        layer_avgs = []
        for layer_id in sorted(groups):
            names = groups[layer_id]
            energies = [layer_energy[n][r - 1] for n in names]
            layer_avgs.append(float(np.mean(energies)))
        overall_avg = float(np.mean(layer_avgs)) if layer_avgs else 0.0
        if overall_avg >= target_energy:
            proposed_rank = r
            break

    return proposed_rank, layer_energy


def energy_at_rank(model, rank=15):
    """
    Report cumulative energy captured at a given rank for gate layers.
    - Per LSTM layer: average of x_to_gates and h_to_gates
    - Overall: average across LSTM layers
    """
    import re

    results = {}
    layer_groups = {}
    for layer in model.layers:
        name = getattr(layer, "name", "")
        if name.endswith("_x_to_gates") or name.endswith("_h_to_gates"):
            weights = layer.get_weights()
            if weights:
                W = weights[0]
                if W.ndim == 2:
                    s = np.linalg.svd(W, full_matrices=False, compute_uv=False)
                    total_energy = np.sum(s ** 2)
                    cumulative_energy = np.cumsum(s ** 2) / total_energy * 100
                    r_idx = min(rank, len(s)) - 1
                    energy = cumulative_energy[r_idx]
                    results[name] = energy
                    m = re.match(r"layer(\d+)_", name)
                    if m:
                        layer_id = int(m.group(1))
                        layer_groups.setdefault(layer_id, []).append(energy)
    if not results:
        print("No gate layers found.")
        return None
    print(f"Energy captured at rank r={rank}:")
    for name, energy in results.items():
        print(f"  {name}: {energy:.2f}%")
    if layer_groups:
        print("\nPer-layer averages:")
        layer_avgs = []
        for layer_id in sorted(layer_groups):
            avg = float(np.mean(layer_groups[layer_id]))
            layer_avgs.append(avg)
            print(f"  layer{layer_id}: {avg:.2f}%")
        overall_avg = float(np.mean(layer_avgs))
        print(f"\nOverall average across LSTM layers: {overall_avg:.2f}%")
    return results


# ==============================================================================
# Full Model Evaluation Pipeline
# ==============================================================================

def evaluate_model(model,
                   var_layers,
                   model_name,
                   X_test,
                   y_test_original,
                   X_val,
                   y_val,
                   X_ood,
                   scaler_y,
                   confidence_level=0.95,
                   num_mc_samples=100,
                   seed=42,
                   set_seed_fn=None,
                   clear_cache_fn=None,
                   verbose=True):
    """
    Evaluate a model and return all metrics.

    Args:
        model: Trained model (single model or list of ensemble members)
        var_layers: Variational layers for Bayesian models (None for deterministic/ensemble)
        model_name: Name of the model (used to detect ensemble)
        X_test: Test features
        y_test_original: Test targets in original scale
        X_val: Validation features (for ensemble NLL computation)
        y_val: Validation targets (scaled)
        X_ood: Out-of-distribution features
        scaler_y: Fitted scaler for y values
        confidence_level: Confidence level for prediction intervals (default: 0.95)
        num_mc_samples: Number of MC samples for Bayesian models (default: 100)
        seed: Random seed
        set_seed_fn: Function to set random seed (optional)
        clear_cache_fn: Function to clear variational layer cache (optional)
        verbose: Print progress (default: True)

    Returns:
        dict: All evaluation metrics including:
              - MAE, RMSE, R2 (point prediction)
              - NLL, CRPS, ECE, PICP, MPIW (uncertainty quantification)
              - AUROC_OOD, AUPR_OOD, FPR95_OOD (OOD detection)
    """
    if set_seed_fn:
        set_seed_fn(seed)

    metrics = {}
    is_ensemble = model_name == "Deep Ensemble"
    is_bayesian = var_layers is not None

    # Get predictions and samples
    if is_ensemble:
        # Deep Ensemble predictions
        ensemble_preds = [m.predict(X_test, verbose=0) for m in model]
        samples = np.array([inverse_transform_predictions(p, scaler_y) for p in ensemble_preds])
        y_pred = np.mean(samples, axis=0)
    elif is_bayesian:
        # Bayesian models - MC sampling
        samples = []
        for i in range(num_mc_samples):
            if clear_cache_fn:
                clear_cache_fn(var_layers)
            y_sample_scaled = model(X_test, training=True).numpy().flatten()
            y_sample = inverse_transform_predictions(y_sample_scaled, scaler_y)
            samples.append(y_sample)
        samples = np.array(samples)
        y_pred = np.mean(samples, axis=0)
    else:
        # Deterministic model
        y_pred_scaled = model.predict(X_test, verbose=0)
        y_pred = inverse_transform_predictions(y_pred_scaled, scaler_y)
        samples = None

    # Point prediction metrics
    point_metrics = calculate_point_metrics(y_test_original, y_pred)
    metrics['MAE'] = point_metrics['MAE']
    metrics['RMSE'] = point_metrics['RMSE']
    metrics['R2'] = point_metrics['R2']

    if set_seed_fn:
        set_seed_fn(seed)

    # Uncertainty quantification metrics (for probabilistic models)
    if samples is not None:
        lower, upper, mean_pred, std_pred = compute_prediction_intervals(samples, confidence_level)

        if is_ensemble:
            # Use pooled noise for ensemble NLL
            ensemble_val_preds = [inverse_transform_predictions(m.predict(X_val, verbose=0), scaler_y)
                                  for m in model]
            y_val_original = inverse_transform_predictions(y_val, scaler_y)
            nll, _ = compute_nll_ensemble_pooled_noise(
                y_test_original.flatten(),
                [s.flatten() for s in samples],
                y_val_original.flatten(),
                [s.flatten() for s in ensemble_val_preds]
            )
        else:
            nll = compute_nll_gaussian(y_test_original, mean_pred, std_pred)

        crps = compute_crps_gaussian(y_test_original, mean_pred, std_pred)
        ece, _ = compute_ece_regression_fixed(y_test_original, mean_pred, std_pred)
        picp, mpiw = compute_picp_mpiw(y_test_original, lower, upper)

        metrics['NLL'] = nll
        metrics['CRPS'] = crps
        metrics['ECE'] = ece
        metrics['PICP'] = picp
        metrics['MPIW'] = mpiw

        # OOD Detection metrics
        if set_seed_fn:
            set_seed_fn(seed)
        if verbose:
            print(f"  Computing OOD detection metrics...")

        if is_ensemble:
            id_samples = np.array([inverse_transform_predictions(m.predict(X_test, verbose=0), scaler_y)
                                   for m in model])
            ood_samples = np.array([inverse_transform_predictions(m.predict(X_ood, verbose=0), scaler_y)
                                    for m in model])
            _, _, _, id_std = compute_prediction_intervals(id_samples, confidence_level)
            _, _, _, ood_std = compute_prediction_intervals(ood_samples, confidence_level)
        else:
            # MC sampling on test set (ID)
            id_samples = []
            for i in range(num_mc_samples):
                if clear_cache_fn:
                    clear_cache_fn(var_layers)
                y_sample_scaled = model(X_test, training=True).numpy().flatten()
                y_sample = inverse_transform_predictions(y_sample_scaled, scaler_y)
                id_samples.append(y_sample)
            id_samples = np.array(id_samples)
            id_std = np.std(id_samples, axis=0).flatten()

            # MC sampling on OOD set
            ood_samples = []
            for i in range(num_mc_samples):
                if clear_cache_fn:
                    clear_cache_fn(var_layers)
                y_sample_scaled = model(X_ood, training=True).numpy().flatten()
                y_sample = inverse_transform_predictions(y_sample_scaled, scaler_y)
                ood_samples.append(y_sample)
            ood_samples = np.array(ood_samples)
            ood_std = np.std(ood_samples, axis=0).flatten()

        auroc, aupr, fpr_at_tpr95, tpr_at_fpr5 = compute_ood_detection_metrics(
            id_std.flatten(), ood_std.flatten()
        )
        metrics['AUROC_OOD'] = auroc
        metrics['AUPR_OOD'] = aupr
        metrics['FPR95_OOD'] = fpr_at_tpr95
    else:
        # Deterministic model - no UQ metrics
        metrics['NLL'] = np.nan
        metrics['CRPS'] = np.nan
        metrics['ECE'] = np.nan
        metrics['PICP'] = np.nan
        metrics['MPIW'] = np.nan
        metrics['AUROC_OOD'] = np.nan
        metrics['AUPR_OOD'] = np.nan
        metrics['FPR95_OOD'] = np.nan

    return metrics
