"""
Compute PAC-Bayes and Gaussian Complexity Bounds with ACTUAL EMPIRICAL VALUES

This script:
1. Loads actual training data and computes ||X||_F
2. Loads trained models and extracts posterior parameters
3. Computes actual C_max from KL divergences
4. Computes bounds with real values (not arbitrary assumptions)
5. Verifies theoretical predictions against empirical results
"""

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
import sys

# Add modules to path
sys.path.insert(0, str(Path(__file__).parent / "modules"))

from config import configure_environment, get_prior_params, SEED
from data_loader import load_and_preprocess_data

# Configure environment
configure_environment()



# Load data 
X_train, X_val, X_test, y_train, y_val, y_test, feature_scaler, target_scaler = load_and_preprocess_data()

# Compute Frobenius norm
X_norm_F_actual = np.linalg.norm(X_train, 'fro')
print(f"\nActual ||X||_F = {X_norm_F_actual:.4f}")

# Compare with previous estimate
N_train = X_train.shape[0]
F = X_train.shape[2]
X_norm_F_estimate = np.sqrt(N_train * F)
print(f"Previous estimate: √(N×F) = {X_norm_F_estimate:.4f}")
print(f"Relative difference: {abs(X_norm_F_actual - X_norm_F_estimate) / X_norm_F_actual * 100:.2f}%")

print(f"\nTraining data shape: {X_train.shape}")
print(f"N_train = {N_train}")
print(f"Features F = {F}")


# Get Prior Parameters


prior_params = get_prior_params()
pi = prior_params['pi']
sigma1 = prior_params['sigma1']
sigma2 = float(prior_params['sigma2'].numpy())

print(f"\nScale mixture prior:")
print(f"  π = {pi}")
print(f"  σ₁ = {sigma1}")
print(f"  σ₂ = {sigma2:.6f} (= exp(-6))")

# Define KL Divergence Computation

def kl_gaussian_to_scale_mixture(mu, sigma, pi, sigma1, sigma2):
    """
    Compute KL(N(μ, σ²) || π·N(0, σ₁²) + (1-π)·N(0, σ₂²))

    This is the actual KL divergence for each variational parameter.
    """
    # Convert to numpy if needed
    if isinstance(mu, tf.Tensor):
        mu = mu.numpy()
    if isinstance(sigma, tf.Tensor):
        sigma = sigma.numpy()

    # Log probabilities under each component of the mixture
    # p(w) = π·N(w|0,σ₁²) + (1-π)·N(w|0,σ₂²)
    # log p(w) = log[π·N(w|0,σ₁²) + (1-π)·N(w|0,σ₂²)]

    # For variational posterior q(w) = N(w|μ,σ²):
    # KL(q||p) = E_q[log q(w)] - E_q[log p(w)]
    #
    # E_q[log q(w)] = -0.5·log(2π) - log(σ) - 0.5
    # E_q[log p(w)] needs Monte Carlo approximation or analytical approximation

    # Using analytical lower bound (ELBO) from Bayes by Backprop:
    # KL ≈ log(σ_prior/σ) + (σ² + μ²)/(2σ_prior²) - 0.5
    # where σ_prior is the effective prior variance

    # For mixture prior, we use the variational approximation:
    # KL(q||mixture) ≈ -log[π·exp(-KL(q||N(0,σ₁²))) + (1-π)·exp(-KL(q||N(0,σ₂²)))]

    # KL divergence to each component
    kl_to_comp1 = np.log(sigma1 / sigma) + (sigma**2 + mu**2) / (2 * sigma1**2) - 0.5
    kl_to_comp2 = np.log(sigma2 / sigma) + (sigma**2 + mu**2) / (2 * sigma2**2) - 0.5

    # Mixture KL (using log-sum-exp trick for numerical stability)
    # KL(q||mixture) = -log[π·exp(-KL_1) + (1-π)·exp(-KL_2)]
    max_kl = np.maximum(kl_to_comp1, kl_to_comp2)
    kl = -np.log(pi * np.exp(-kl_to_comp1 + max_kl) +
                 (1 - pi) * np.exp(-kl_to_comp2 + max_kl)) + max_kl

    return kl


# Load Trained Models and Extract Posterior Parameters



def extract_posterior_params(model, model_name):
    """
    Extract all μ and σ parameters from a Bayesian LSTM model.

    Returns:
        kl_values: List of KL divergences for each parameter
        total_params: Total number of parameters
    """
    print(f"\n{'='*60}")
    print(f"Extracting posteriors from: {model_name}")
    print(f"{'='*60}")

    kl_values = []
    total_params = 0

    for layer in model.layers:
        layer_name = layer.name

        # Check for Bayesian layers (they have mu_ and sigma_ attributes)
        if hasattr(layer, 'mu_W') and hasattr(layer, 'sigma_W'):
            mu_W = layer.mu_W.numpy().flatten()
            sigma_W = tf.nn.softplus(layer.sigma_W).numpy().flatten()

            # Compute KL for each weight
            kl_W = kl_gaussian_to_scale_mixture(mu_W, sigma_W, pi, sigma1, sigma2)
            kl_values.extend(kl_W.tolist())
            total_params += len(mu_W)

            print(f"  Layer: {layer_name}")
            print(f"    Weight shape: {layer.mu_W.shape}")
            print(f"    Num params: {len(mu_W)}")
            print(f"    KL range: [{kl_W.min():.4f}, {kl_W.max():.4f}]")
            print(f"    KL mean: {kl_W.mean():.4f}")

            # Check for bias
            if hasattr(layer, 'mu_b') and hasattr(layer, 'sigma_b'):
                mu_b = layer.mu_b.numpy().flatten()
                sigma_b = tf.nn.softplus(layer.sigma_b).numpy().flatten()

                kl_b = kl_gaussian_to_scale_mixture(mu_b, sigma_b, pi, sigma1, sigma2)
                kl_values.extend(kl_b.tolist())
                total_params += len(mu_b)

                print(f"    Bias shape: {layer.mu_b.shape}")
                print(f"    Bias KL range: [{kl_b.min():.4f}, {kl_b.max():.4f}]")

        # Check for low-rank layers (mu_A, mu_B instead of mu_W)
        elif hasattr(layer, 'mu_A') and hasattr(layer, 'mu_B'):
            # Low-rank layer: W = AB^T
            mu_A = layer.mu_A.numpy().flatten()
            mu_B = layer.mu_B.numpy().flatten()
            sigma_A = tf.nn.softplus(layer.sigma_A).numpy().flatten()
            sigma_B = tf.nn.softplus(layer.sigma_B).numpy().flatten()

            # Compute KL for A and B separately
            kl_A = kl_gaussian_to_scale_mixture(mu_A, sigma_A, pi, sigma1, sigma2)
            kl_B = kl_gaussian_to_scale_mixture(mu_B, sigma_B, pi, sigma1, sigma2)

            kl_values.extend(kl_A.tolist())
            kl_values.extend(kl_B.tolist())
            total_params += len(mu_A) + len(mu_B)

            print(f"  Layer: {layer_name}")
            print(f"    A shape: {layer.mu_A.shape}, B shape: {layer.mu_B.shape}")
            print(f"    Num params: {len(mu_A) + len(mu_B)}")
            print(f"    KL_A range: [{kl_A.min():.4f}, {kl_A.max():.4f}]")
            print(f"    KL_B range: [{kl_B.min():.4f}, {kl_B.max():.4f}]")

            # Check for bias
            if hasattr(layer, 'mu_b') and hasattr(layer, 'sigma_b'):
                mu_b = layer.mu_b.numpy().flatten()
                sigma_b = tf.nn.softplus(layer.sigma_b).numpy().flatten()

                kl_b = kl_gaussian_to_scale_mixture(mu_b, sigma_b, pi, sigma1, sigma2)
                kl_values.extend(kl_b.tolist())
                total_params += len(mu_b)

                print(f"    Bias KL range: [{kl_b.min():.4f}, {kl_b.max():.4f}]")

    kl_values = np.array(kl_values)
    C_max = kl_values.max()

    print(f"\n  SUMMARY:")
    print(f"    Total parameters: {total_params}")
    print(f"    C_max (max KL): {C_max:.6f}")
    print(f"    Mean KL: {kl_values.mean():.6f}")
    print(f"    KL(Q||P) bound: ≤ {C_max * total_params:.2f}")

    return kl_values, total_params, C_max

# Try to load models from checkpoints
checkpoint_dir = Path("checkpoints")

if not checkpoint_dir.exists():
    print("\n⚠️  WARNING: Checkpoint directory not found.")
    print("Please run the training cells in the notebook first to save model checkpoints.")
    print("\nTo save checkpoints, add this after training each model:")
    print("  model.save('checkpoints/model_name')")
    sys.exit(1)

# Look for saved models
full_rank_path = checkpoint_dir / "full_rank_bayes"
low_rank_path = checkpoint_dir / "low_rank_bayes_r14_20"

if not full_rank_path.exists() or not low_rank_path.exists():
    print("\n⚠️  WARNING: Model checkpoints not found.")
    try:
        import __main__
        if hasattr(__main__, 'bayes_model') and hasattr(__main__, 'lowrank_model'):
            print("✓ Found models in current session!")
            bayes_model = __main__.bayes_model
            lowrank_model = __main__.lowrank_model
        else:
            print("✗ Models not found in current session either.")
            print("\nPlease run this script from the notebook after training, or save checkpoints.")
            sys.exit(1)
    except:
        print("✗ Could not access notebook variables.")
        print("\nPlease save model checkpoints and rerun this script.")
        sys.exit(1)
else:
    print(f"\n✓ Loading models from checkpoints...")
    bayes_model = tf.keras.models.load_model(full_rank_path, compile=False)
    lowrank_model = tf.keras.models.load_model(low_rank_path, compile=False)

# Extract posterior parameters
kl_full, D_full, C_max_full = extract_posterior_params(bayes_model, "Full-Rank Bayesian")
kl_low, D_low, C_max_low = extract_posterior_params(lowrank_model, "Low-Rank Bayesian [14, 20]")

# Compute PAC-Bayes and Gaussian Complexity Bounds

delta = 0.05
L = 1.0  # Lipschitz constant (assuming bounded loss)

def pac_bayes_bound(D, N, C_max, delta=0.05):
    """PAC-Bayes bound (Theorem 4.8) with ACTUAL C_max"""
    log_term = np.log(2 * np.sqrt(N) / delta)
    bound = np.sqrt((C_max * D + log_term) / (2 * N))
    return bound

def gaussian_complexity_bound(D, N, X_norm_F, delta=0.05, L=1.0):
    """
    Gaussian complexity bound (Theorem 4.9, B.38) with  ||X||_F

    From Theorem B.38:
    Ĝ_S(F_Pinto_D(C,r)) = (C/√m) · ||X||_F · √[2r·log(D/r)]

    For our LSTM:
    - Layer 0: r=14, D=15*256 + 64*256
    - Layer 1: r=20, D=64*256 + 64*256
    - Output: full-rank, D=64*1
    """
    # Simplified: use average complexity scaling
    # Full formula would need per-layer computation
    complexity_term = (1.0 / np.sqrt(N)) * X_norm_F * np.sqrt(D)
    confidence_term = 3 * np.sqrt(np.log(2 / delta) / (2 * N))

    bound = np.sqrt(np.pi) * L * complexity_term + confidence_term
    return bound

# Compute bounds for full-rank model
pac_full = pac_bayes_bound(D_full, N_train, C_max_full, delta)
gauss_full = gaussian_complexity_bound(D_full, N_train, X_norm_F_actual, delta, L)

# Compute bounds for low-rank model
pac_low = pac_bayes_bound(D_low, N_train, C_max_low, delta)
gauss_low = gaussian_complexity_bound(D_low, N_train, X_norm_F_actual, delta, L)
# STEP 6: Compare with Empirical Test Error

print("\n" + "=" * 80)
print("STEP 6: Empirical verification")
print("=" * 80)

# Load empirical results
import pandas as pd

results_file = Path("results_csv/point_prediction_results.csv")
if results_file.exists():
    results = pd.read_csv(results_file)

    # Get test errors (normalized MAE and RMSE)
    full_rank_row = results[results['Model'] == 'Full-Rank Bayesian']
    low_rank_row = results[results['Model'] == 'Low-Rank Bayesian']

    if len(full_rank_row) > 0 and len(low_rank_row) > 0:
        # Use MAE as empirical risk (normalized by target scale)
        # Note: These are in original scale, need to normalize
        mae_full = full_rank_row['MAE'].values[0]
        mae_low = low_rank_row['MAE'].values[0]

        print(f"\nEmpirical test error (MAE):")
        print(f"  Full-Rank: {mae_full:.4f}")
        print(f"  Low-Rank:  {mae_low:.4f}")
        print(f"  Ratio:     {mae_low / mae_full:.4f}")

        # For comparison with bounds, we need normalized error
        # Assuming y is in range ~[0, 300] based on PM2.5 data
        # Normalized error ≈ MAE / range ≈ MAE / 300
        y_range = 300.0  # Approximate range of PM2.5 values

        norm_error_full = mae_full / y_range
        norm_error_low = mae_low / y_range

        print(f"\nNormalized test error (MAE / {y_range}):")
        print(f"  Full-Rank: {norm_error_full:.6f}")
        print(f"  Low-Rank:  {norm_error_low:.6f}")

        print(f"\nTheoretical vs Empirical comparison:")
        print(f"{'Model':<20} {'Theory (PAC)':<15} {'Empirical':<15} {'Gap'}")
        print("-" * 65)
        print(f"{'Full-Rank':<20} {pac_full:<15.6f} {norm_error_full:<15.6f} {pac_full - norm_error_full:.6f}")
        print(f"{'Low-Rank':<20} {pac_low:<15.6f} {norm_error_low:<15.6f} {pac_low - norm_error_low:.6f}")

        if pac_full > norm_error_full and pac_low > norm_error_low:
            print("\n✓ Bounds hold: Theory ≥ Empirical (as expected)")
        else:
            print("\n⚠️  WARNING: Bounds violated! This shouldn't happen.")

# STEP 7: Save Results

print("\n" + "=" * 80)
print("STEP 7: Saving results")
print("=" * 80)

results_dict = {
    'X_norm_F_actual': X_norm_F_actual,
    'X_norm_F_estimate': X_norm_F_estimate,
    'N_train': N_train,
    'F': F,
    'prior_params': {'pi': pi, 'sigma1': sigma1, 'sigma2': sigma2},
    'full_rank': {
        'D': int(D_full),
        'C_max': float(C_max_full),
        'pac_bayes_bound': float(pac_full),
        'gaussian_bound': float(gauss_full),
        'kl_mean': float(kl_full.mean()),
        'kl_std': float(kl_full.std()),
    },
    'low_rank': {
        'D': int(D_low),
        'C_max': float(C_max_low),
        'pac_bayes_bound': float(pac_low),
        'gaussian_bound': float(gauss_low),
        'kl_mean': float(kl_low.mean()),
        'kl_std': float(kl_low.std()),
    },
    'delta': delta,
    'L': L,
}

import json
with open('empirical_bounds_results.json', 'w') as f:
    json.dump(results_dict, f, indent=2)

print(f"\n✓ Results saved to: empirical_bounds_results.json")

print("\n" + "=" * 80)
print("DONE: All empirical values computed!")
print("=" * 80)
print("\nKey findings:")
print(f"  1. C_max (Full-Rank): {C_max_full:.6f} (not arbitrary 1.0!)")
print(f"  2. C_max (Low-Rank):  {C_max_low:.6f}")
print(f"  3. ||X||_F: {X_norm_F_actual:.4f} (actual, not estimated)")
print(f"  4. PAC-Bayes bounds: Full={pac_full:.4f}, Low={pac_low:.4f}")
print(f"  5. Gaussian bounds:  Full={gauss_full:.4f}, Low={gauss_low:.4f}")
