"""
Inference and Prediction Functions for Bayesian Neural Networks

This module implements Monte Carlo sampling for:
- Bayesian prediction with uncertainty quantification
- Mutual Information (MI) computation
- Deep ensemble predictions
- Per-class epistemic uncertainty analysis
"""

import numpy as np
import tensorflow as tf
import gc
from typing import Tuple, List
from scipy.special import xlogy


def compute_total_uncertainty(mc_samples: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
    """
    Compute predictive entropy H[y|x,D] from MC samples.
    """
    mean_pred = mc_samples.mean(axis=0)
    total = -(xlogy(mean_pred, mean_pred + epsilon) +
              xlogy(1 - mean_pred, 1 - mean_pred + epsilon))
    return total


def compute_aleatoric_uncertainty(mc_samples: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
    """
    Compute expected entropy E_ω[H[y|x,ω]] from MC samples.
    """
    H_samples = -(xlogy(mc_samples, mc_samples + epsilon) +
                  xlogy(1 - mc_samples, 1 - mc_samples + epsilon))
    aleatoric = H_samples.mean(axis=0)
    return aleatoric


def _is_bayesian_model(model: tf.keras.Model) -> bool:
    for layer in model.layers:
        if hasattr(layer, "w_mu") or hasattr(layer, "A_mu") or hasattr(layer, "B_mu") or hasattr(layer, "s_mu") or hasattr(layer, "r_mu"):
            return True
    return False


def deterministic_predictions(model: tf.keras.Model,
                              X: np.ndarray,
                              batch_size: int = 128) -> np.ndarray:
    X_values = X.values if hasattr(X, "values") else X
    preds = model.predict(X_values, batch_size=batch_size, verbose=0).squeeze()
    return preds


def mc_predictions(model: tf.keras.Model,
                   X: np.ndarray,
                   n_samples: int = 512) -> Tuple[np.ndarray, np.ndarray]:
    """
    Perform Monte Carlo sampling to get predictions and uncertainties.

    For Bayesian models:
    - Sample T=n_samples weight configurations from posterior
    - Get prediction for each sample
    - Compute mean (final prediction) and std (uncertainty)

    For deterministic models:
    - Single forward pass with zero epistemic uncertainty

    Parameters:
    -----------
    model : trained deterministic or Bayesian neural network
    X : input features (N, feature_dim)
    n_samples : number of MC samples (default 512, same as paper)

    Returns:
    --------
    predictions : (N,) - predictive mean p̂(y=1|x) for each sample
    uncertainties : (N,) - predictive uncertainty σ(y|x) for each sample
    """
    if not _is_bayesian_model(model):
        predictions = deterministic_predictions(model, X)
        uncertainties = np.zeros_like(predictions)
        return predictions, uncertainties

    print(f"  Computing {n_samples} MC samples...")
    tf.random.set_seed(42)
    np.random.seed(42)
    # Storage for all samples: (n_samples, N_patients, 1)
    all_outputs = np.zeros((n_samples, len(X), 1))

    # Sample weights T times and get predictions
    for i in range(n_samples):
        # Forward pass with sampling (training=True ensures weight sampling)
        outputs = model(X, training=True).numpy()
        all_outputs[i] = outputs

    # Compute predictive mean: p̂(y=1|x) = (1/T) Σ σ(f_ωt(x))
    predictions = all_outputs.mean(axis=0).squeeze()  # (N,)

    # Compute predictive uncertainty: std across samples
    # This is the epistemic uncertainty from weight distribution
    uncertainties = all_outputs.std(axis=0).squeeze()  # (N,)

    return predictions, uncertainties


def mc_predictions_with_mi(model: tf.keras.Model,
                           X: np.ndarray,
                           n_samples: int = 512,
                           epsilon: float = 1e-10,
                           seed: int = 42):
    """
    Perform Monte Carlo sampling and compute:
    - Predictive mean (μ)
    - Predictive variance (Var[p])
    - Mutual Information (MI) as epistemic uncertainty

    Formula:
        MI ≈ Var[p] / (2 * μ * (1 - μ))

    Parameters:
    -----------
    model : trained deterministic or Bayesian neural network
    X : input features (N, feature_dim)
    n_samples : number of MC samples (default 512)
    epsilon : small constant to avoid division by zero
    seed : random seed for reproducibility

    Returns:
    --------
    predictions : (N,) - predictive mean μ(x)
    variance : (N,) - predictive variance Var[p(x)]
    mi : (N,) - mutual information (epistemic uncertainty)
    std : (N,) - standard deviation (old metric, for comparison)
    """
    if not _is_bayesian_model(model):
        predictions = deterministic_predictions(model, X)
        variance = np.zeros_like(predictions)
        mi = np.zeros_like(predictions)
        std = np.zeros_like(predictions)
        return predictions, variance, mi, std

    print(f"  Computing {n_samples} MC samples...")
    tf.keras.backend.clear_session()
    gc.collect()
    # Reset seeds
    tf.random.set_seed(seed)
    np.random.seed(seed)
    # Storage for all samples: (n_samples, N_patients, 1)
    all_outputs = np.zeros((n_samples, len(X), 1))

    # Sample weights T times and get predictions
    for i in range(n_samples):
        outputs = model(X, training=True).numpy()
        all_outputs[i] = outputs

    #  Compute Statistics 

    # 1. Predictive mean: μ(x) = (1/S) Σ p^{(s)}(x)
    predictions = all_outputs.mean(axis=0).squeeze()  # (N,)

    # 2. Predictive variance: Var[p] = (1/S) Σ (p^{(s)} - μ)²
    variance = all_outputs.var(axis=0).squeeze()  # (N,)

    # 3. Standard deviation (old metric, for comparison)
    std = all_outputs.std(axis=0).squeeze()  # (N,)

    # 4. Mutual Information (epistemic uncertainty)
    # MI ≈ Var[p] / (2 * μ * (1 - μ))
    # Add epsilon to avoid division by zero when μ ≈ 0 or 1
    denominator = 2 * predictions * (1 - predictions) + epsilon
    mi = variance / denominator  # (N,)

    # Handle Edge Cases 
    # When μ is very close to 0 or 1, MI can explode
    # Clip to reasonable range (optional, but helps with stability)
    mi = np.clip(mi, 0, 100)  # Adjust upper bound as needed

    return predictions, variance, mi, std


def get_uncertainties(model, X, n_samples=512):
    """
    Get predictive uncertainty via MC sampling.

    Parameters:
    -----------
    model : trained Bayesian neural network
    X : input features (N, feature_dim)
    n_samples : number of MC samples (default 512)

    Returns:
    --------
    uncertainties : (N,) - predictive uncertainty σ(y|x) for each sample
    """
    print(f"  Computing {n_samples} MC samples...")
    all_outputs = np.zeros((n_samples, len(X), 1))
    tf.random.set_seed(42)
    np.random.seed(42)
    for i in range(n_samples):
        all_outputs[i] = model(X, training=True).numpy()
    uncertainties = all_outputs.std(axis=0).squeeze()
    return uncertainties


def ensemble_predictions(models_list, X):
    """
    Given a list of trained models and an input DataFrame `X`, compute
    the ensemble predictive mean and standard deviation across members.
    The mean is used as the final predictive probability, and the
    standard deviation provides an epistemic uncertainty estimate.

    Parameters:
    -----------
    models_list : list of trained keras models
    X : input features (pandas DataFrame or numpy array)

    Returns:
    --------
    mean_pred : (N,) - ensemble mean prediction
    std_pred : (N,) - ensemble standard deviation (epistemic uncertainty)
    """
    member_preds = []
    for model in models_list:
        # Handle both DataFrame and numpy array inputs
        X_values = X.values if hasattr(X, 'values') else X
        preds = model.predict(X_values, batch_size=128).squeeze()
        member_preds.append(preds)
    member_preds = np.array(member_preds)  # shape: (n_members, N)
    mean_pred = member_preds.mean(axis=0)
    std_pred = member_preds.std(axis=0)
    return mean_pred, std_pred


def per_class_epistemic_uncertainty(model, X, n_samples=512, eps=1e-12):
    """
    For a model outputting (N,) probabilities (Dense(1, sigmoid)), compute per-class epistemic uncertainty.

    This addresses variance suppression near extreme probabilities by computing
    class-specific uncertainty measures.

    Parameters:
    -----------
    model : trained Bayesian neural network
    X : input features (N, feature_dim)
    n_samples : number of MC samples (default 512)
    eps : small constant to avoid division by zero

    Returns:
    --------
    C : (N, 2) - per-class epistemic uncertainty [class_0, class_1]
    mu : (N,) - mean prediction for class 1
    var : (N,) - variance of predictions for class 1
    """
    mc_ps = []
    for _ in range(n_samples):
        preds = model(X, training=True).numpy().squeeze()  # shape (N,)
        mc_ps.append(preds)
    mc_ps = np.stack(mc_ps, axis=0)  # (S, N)
    mu = mc_ps.mean(axis=0)          # (N,)
    var = mc_ps.var(axis=0, ddof=1)  # (N,)
    # Per-class uncertainty formulation (generalized):
    C0 = 0.5 * var / (1 - mu + eps)  # for class 0 (negative)
    C1 = 0.5 * var / (mu + eps)      # for class 1 (positive)
    C = np.stack([C0, C1], axis=1)   # (N,2)
    return C, mu, var


def compute_mutual_information(mc_samples: np.ndarray, epsilon: float = 1e-10) -> np.ndarray:
    """
    Compute proper information-theoretic mutual information for binary classification.

    MI = H[y|x,D] - E_ω[H[y|x,ω]]

    where:
    - H[y|x,D] is the entropy of the predictive distribution (averaged predictions)
    - E_ω[H[y|x,ω]] is the expected entropy across individual weight samples

    Parameters:
    -----------
    mc_samples : (n_samples, N) - Monte Carlo predictions for each sample
    epsilon : small constant to avoid log(0)

    Returns:
    --------
    mi : (N,) - mutual information (epistemic uncertainty) for each input
    """
    # mc_samples shape: (S, N) where S = number of MC samples, N = number of data points

    # 1. Compute mean prediction: μ(x) = (1/S) Σ p^{(s)}(x)
    mean_pred = mc_samples.mean(axis=0)  # (N,)

    # 2. Compute H[y|x,D] - entropy of the predictive distribution
    # H[y|x,D] = -μ log(μ) - (1-μ) log(1-μ)
    # Using scipy.special.xlogy which handles x*log(y) safely for x=0
    H_pred = -(xlogy(mean_pred, mean_pred + epsilon) +
               xlogy(1 - mean_pred, 1 - mean_pred + epsilon))  # (N,)

    # 3. Compute E_ω[H[y|x,ω]] - expected entropy across weight samples
    # For each MC sample s: H[y|x,ω_s] = -p^{(s)} log(p^{(s)}) - (1-p^{(s)}) log(1-p^{(s)})
    H_samples = -(xlogy(mc_samples, mc_samples + epsilon) +
                  xlogy(1 - mc_samples, 1 - mc_samples + epsilon))  # (S, N)

    # Average across samples
    E_H = H_samples.mean(axis=0)  # (N,)

    # 4. Mutual Information: MI = H[y|x,D] - E_ω[H[y|x,ω]]
    mi = H_pred - E_H  # (N,)

    # MI should be non-negative, but numerical errors can make it slightly negative
    mi = np.maximum(mi, 0)

    return mi


def mc_predictions_with_mi_v2(model,
                               X: np.ndarray,
                               n_samples: int = 512,
                               seed: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Perform Monte Carlo sampling with proper information-theoretic MI computation.

    Supports deterministic models, Bayesian models, and Deep Ensembles.

    Parameters:
    -----------
    model : deterministic/Bayesian model OR list of models (Deep Ensemble)
    X : input features (N, feature_dim)
    n_samples : number of MC samples (default 512)
    seed : random seed for reproducibility

    Returns:
    --------
    predictions : (N,) - predictive mean μ(x)
    std_unc : (N,) - standard deviation uncertainty
    mi_unc : (N,) - mutual information (epistemic uncertainty)
    """
    #print(f"  Computing {n_samples} MC samples with MI...")
    tf.keras.backend.clear_session()
    gc.collect()

    # Reset seeds
    tf.random.set_seed(seed)
    np.random.seed(seed)

    # Check if model is a Deep Ensemble (list of models)
    if isinstance(model, list):
        # Deep Ensemble: each member gives one sample
        mc_samples = []
        for member in model:
            preds = member.predict(X, batch_size=128, verbose=0).squeeze()
            mc_samples.append(preds)
        mc_samples = np.array(mc_samples)  # (n_members, N)
    elif not _is_bayesian_model(model):
        predictions = deterministic_predictions(model, X)
        std_unc = np.zeros_like(predictions)
        mi_unc = np.zeros_like(predictions)
        return predictions, std_unc, mi_unc
    else:
        # Single Bayesian model: sample weights n_samples times
        mc_samples = np.zeros((n_samples, len(X)))
        for i in range(n_samples):
            preds = model(X, training=True).numpy().squeeze()
            mc_samples[i] = preds

    # Compute statistics
    predictions = mc_samples.mean(axis=0)  # (N,)
    std_unc = mc_samples.std(axis=0)  # (N,)
    mi_unc = compute_mutual_information(mc_samples)  # (N,)

    return predictions, std_unc, mi_unc
