"""
Compute PAC-Bayes and Gaussian Complexity Bounds with ACTUAL EMPIRICAL VALUES

This module provides functions to:
1. Compute actual ||X||_F from training data
2. Extract posterior parameters from trained Bayesian models
3. Compute actual C_max from KL divergences
4. Calculate bounds with real values
5. Verify theoretical predictions against empirical results
"""

import numpy as np
import tensorflow as tf
from pathlib import Path
import json


# ==============================================================================
# Prior Parameters (from config.py)
# ==============================================================================

def get_prior_params():
    """Get prior parameters for scale mixture prior."""
    return {
        'pi': 0.5,
        'sigma1': 1.0,
        'sigma2': np.exp(-6.0),  # ≈ 0.00247875
    }


# ==============================================================================
# Parameter Count Functions
# ==============================================================================

# Default architecture (from config.py)
DEFAULT_LSTM_HIDDEN = 64
DEFAULT_INPUT_SIZE = 15
DEFAULT_OUTPUT_DIM = 1
DEFAULT_NUM_LAYERS = 2


def get_layer_dims(h=DEFAULT_LSTM_HIDDEN, d_in=DEFAULT_INPUT_SIZE,
                   output_dim=DEFAULT_OUTPUT_DIM):
    """
    Get layer dimensions for 2-layer LSTM + output layer.

    Returns list of (m, n, has_bias) for each weight matrix W ∈ R^{m×n}:
    - Layer 0 x_to_gates: (d_in, 4h), bias=4h
    - Layer 0 h_to_gates: (h, 4h), no bias (shared with x_to_gates)
    - Layer 1 x_to_gates: (h, 4h), bias=4h
    - Layer 1 h_to_gates: (h, 4h), no bias
    - Output layer: (h, output_dim), bias=output_dim
    """
    return [
        (d_in, 4 * h, True),      # L0 x_to_gates: 15×256, bias=256
        (h, 4 * h, False),        # L0 h_to_gates: 64×256, no bias
        (h, 4 * h, True),         # L1 x_to_gates: 64×256, bias=256
        (h, 4 * h, False),        # L1 h_to_gates: 64×256, no bias
        (h, output_dim, True),    # Output: 64×1, bias=1
    ]


def compute_D_full(h=DEFAULT_LSTM_HIDDEN, d_in=DEFAULT_INPUT_SIZE,
                   output_dim=DEFAULT_OUTPUT_DIM, verbose=False):
    """
    Compute total parameter count for full-rank 2-layer LSTM.

    Formula (without bias):
        D_full = 4h*d_in + 12h² + h

    With biases (actual implementation):
        D_full = 4h*d_in + 256 + 4h*h + 4h*h + 256 + 4h*h + h*output_dim + output_dim

    Args:
        h: LSTM hidden size (default: 64)
        d_in: Input feature dimension (default: 15)
        output_dim: Output dimension (default: 1)
        verbose: Print breakdown if True

    Returns:
        Total parameter count D_full
    """
    layer_dims = get_layer_dims(h, d_in, output_dim)

    params_breakdown = []
    total = 0

    layer_names = ['L0_x_to_gates', 'L0_h_to_gates', 'L1_x_to_gates',
                   'L1_h_to_gates', 'Output']

    for i, (m, n, has_bias) in enumerate(layer_dims):
        weight_params = m * n
        bias_params = n if has_bias else 0
        layer_total = weight_params + bias_params
        total += layer_total
        params_breakdown.append({
            'name': layer_names[i],
            'shape': (m, n),
            'weight_params': weight_params,
            'bias_params': bias_params,
            'total': layer_total
        })

    if verbose:
        print(f"\nFull-Rank Parameter Count (h={h}, d_in={d_in}):")
        print("-" * 60)
        for p in params_breakdown:
            print(f"  {p['name']:15s}: {p['shape']} = {p['weight_params']:6d} + {p['bias_params']:3d} = {p['total']:6d}")
        print("-" * 60)
        print(f"  {'TOTAL':15s}: {total:6d}")

    return total


def compute_D_lr(r, h=DEFAULT_LSTM_HIDDEN, d_in=DEFAULT_INPUT_SIZE,
                 output_dim=DEFAULT_OUTPUT_DIM, output_full_rank=True, verbose=False):
    """
    Compute total parameter count for low-rank 2-layer LSTM with rank r.

    For low-rank factorization W = A @ B.T:
        - A ∈ R^{m × r}, B ∈ R^{n × r}
        - params = r * (m + n)

    Formula (uniform rank r, output full-rank, with biases):
        D_lr(r) = r*(d_in + 4h) + 4h           # L0 x_to_gates + bias
                + r*(h + 4h)                    # L0 h_to_gates
                + r*(h + 4h) + 4h               # L1 x_to_gates + bias
                + r*(h + 4h)                    # L1 h_to_gates
                + h*output_dim + output_dim     # Output (full-rank)

        D_lr(r) = r*(d_in + 4h + 3*(h + 4h)) + 8h + h + 1
                = r*(d_in + 4h + 3*5h) + 9h + 1
                = r*(d_in + 19h) + 9h + 1

    Simplified (h=64, d_in=15, output_dim=1):
        D_lr(r) = 1231*r + 577

    Args:
        r: Rank for low-rank factorization (int or list for per-layer ranks)
        h: LSTM hidden size (default: 64)
        d_in: Input feature dimension (default: 15)
        output_dim: Output dimension (default: 1)
        output_full_rank: Keep output layer full-rank (default: True)
        verbose: Print breakdown if True

    Returns:
        Total parameter count D_lr
    """
    layer_dims = get_layer_dims(h, d_in, output_dim)

    # Handle per-layer ranks
    if isinstance(r, (int, float, np.integer)):
        ranks = [int(r)] * 4  # 4 LSTM weight matrices
    elif hasattr(r, '__len__') and len(r) == 2:
        # Per-LSTM-layer ranks [r0, r1] -> [r0, r0, r1, r1]
        ranks = [r[0], r[0], r[1], r[1]]
    else:
        ranks = list(r)[:4]

    params_breakdown = []
    total = 0

    layer_names = ['L0_x_to_gates', 'L0_h_to_gates', 'L1_x_to_gates',
                   'L1_h_to_gates', 'Output']

    for i, (m, n, has_bias) in enumerate(layer_dims):
        if i < 4:  # LSTM layers (low-rank)
            rank_i = ranks[i]
            weight_params = rank_i * (m + n)
        else:  # Output layer
            if output_full_rank:
                weight_params = m * n
            else:
                rank_i = ranks[i] if i < len(ranks) else r if isinstance(r, int) else r[0]
                weight_params = rank_i * (m + n)

        bias_params = n if has_bias else 0
        layer_total = weight_params + bias_params
        total += layer_total

        params_breakdown.append({
            'name': layer_names[i],
            'shape': (m, n),
            'rank': ranks[i] if i < 4 else ('full' if output_full_rank else rank_i),
            'weight_params': weight_params,
            'bias_params': bias_params,
            'total': layer_total
        })

    if verbose:
        r_str = r if isinstance(r, int) else list(r)
        print(f"\nLow-Rank Parameter Count (r={r_str}, h={h}, d_in={d_in}):")
        print("-" * 70)
        for p in params_breakdown:
            print(f"  {p['name']:15s}: {p['shape']} r={p['rank']:4} = {p['weight_params']:6d} + {p['bias_params']:3d} = {p['total']:6d}")
        print("-" * 70)
        print(f"  {'TOTAL':15s}: {total:6d}")

    return total


def validate_parameter_counts(model, h=DEFAULT_LSTM_HIDDEN, d_in=DEFAULT_INPUT_SIZE,
                              output_dim=DEFAULT_OUTPUT_DIM, verbose=True):
    """
    Validate computed parameter counts against actual model.

    Args:
        model: Trained model
        h, d_in, output_dim: Architecture parameters
        verbose: Print comparison if True

    Returns:
        dict with computed, actual, and difference
    """
    # Count actual model parameters
    actual_params = sum(
        np.prod(w.shape) for w in model.trainable_weights
    )

    # Determine if model is low-rank by checking for A_mu/B_mu attributes
    is_low_rank = False
    ranks = []
    for layer in model.layers:
        if hasattr(layer, 'A_mu'):
            is_low_rank = True
            # Get rank from A_mu shape
            ranks.append(layer.A_mu.shape[-1])

    if is_low_rank:
        # For 2-layer LSTM, we expect 4 weight matrices + output
        # Ranks might be [r0, r0, r1, r1] or similar
        if len(ranks) >= 4:
            computed_params = compute_D_lr(ranks[:4], h, d_in, output_dim, verbose=False)
        else:
            computed_params = compute_D_lr(ranks[0] if ranks else 16, h, d_in, output_dim, verbose=False)
        model_type = f"Low-Rank (ranks={ranks})"
    else:
        computed_params = compute_D_full(h, d_in, output_dim, verbose=False)
        model_type = "Full-Rank"

    diff = actual_params - computed_params
    rel_diff = abs(diff) / actual_params if actual_params > 0 else 0

    result = {
        'model_type': model_type,
        'computed': computed_params,
        'actual': actual_params,
        'difference': diff,
        'relative_difference': rel_diff,
        'match': abs(diff) <= 1  # Allow off-by-one for rounding
    }

    if verbose:
        print(f"\nParameter Count Validation ({model_type}):")
        print("-" * 50)
        print(f"  Computed:   {computed_params:,}")
        print(f"  Actual:     {actual_params:,}")
        print(f"  Difference: {diff:,} ({rel_diff:.4%})")
        print(f"  Match:      {'✓' if result['match'] else '✗'}")

    return result


# ==============================================================================
# KL Divergence Computation
# ==============================================================================

def kl_gaussian_to_scale_mixture(mu, sigma, pi, sigma1, sigma2):
    """
    Compute KL(N(μ, σ²) || π·N(0, σ₁²) + (1-π)·N(0, σ₂²))

    Args:
        mu: Mean of variational posterior
        sigma: Std deviation of variational posterior
        pi: Mixture weight
        sigma1: Std of first Gaussian component
        sigma2: Std of second Gaussian component

    Returns:
        KL divergence for each parameter
    """
    # Convert to numpy if needed
    if isinstance(mu, tf.Tensor):
        mu = mu.numpy()
    if isinstance(sigma, tf.Tensor):
        sigma = sigma.numpy()

    # Clip sigma to prevent numerical overflow
    # Lower bound: small positive to avoid log(0)
    # Upper bound: prevent overflow in sigma^2 and log(sigma1/sigma) = -Inf
    sigma = np.clip(sigma, 1e-10, 1e10)

    # KL to each Gaussian component
    kl_to_comp1 = np.log(sigma1 / sigma) + (sigma**2 + mu**2) / (2 * sigma1**2) - 0.5
    kl_to_comp2 = np.log(sigma2 / sigma) + (sigma**2 + mu**2) / (2 * sigma2**2) - 0.5

    # Mixture KL using log-sum-exp trick for numerical stability
    max_kl = np.maximum(kl_to_comp1, kl_to_comp2)
    kl = -np.log(pi * np.exp(-kl_to_comp1 + max_kl) +
                 (1 - pi) * np.exp(-kl_to_comp2 + max_kl)) + max_kl

    return kl


# ==============================================================================
# Extract Posterior Parameters
# ==============================================================================

def extract_posterior_params(model, model_name, prior_params, verbose=True):
    """
    Extract all μ and σ parameters from a Bayesian LSTM model.

    Args:
        model: Trained Bayesian LSTM model
        model_name: Name for display
        prior_params: Prior parameters {pi, sigma1, sigma2}
        verbose: Whether to print detailed info

    Returns:
        kl_values: Array of KL divergences for each parameter
        total_params: Total number of parameters
        C_max: Maximum KL divergence
        layer_info: List of per-layer information
    """
    pi = prior_params['pi']
    sigma1 = prior_params['sigma1']
    sigma2 = prior_params['sigma2']

    if verbose:
        print(f"\n{'='*60}")
        print(f"Model: {model_name}")
        print(f"{'='*60}")

    kl_values = []
    total_params = 0
    layer_info = []

    for layer in model.layers:
        layer_name = layer.name
        layer_kls = []

        # Full-rank Bayesian layers (mu_W, rho_W)
        # Note: sigma = softplus(rho)
        if hasattr(layer, 'mu_W') and hasattr(layer, 'rho_W'):
            mu_W = layer.mu_W.numpy().flatten()
            sigma_W = tf.nn.softplus(layer.rho_W).numpy().flatten()

            kl_W = kl_gaussian_to_scale_mixture(mu_W, sigma_W, pi, sigma1, sigma2)
            layer_kls.extend(kl_W.tolist())
            total_params += len(mu_W)

            info = {
                'name': layer_name,
                'type': 'Full-Rank',
                'weight_shape': tuple(layer.mu_W.shape),
                'num_params': int(len(mu_W)),
                'kl_min': float(kl_W.min()),
                'kl_max': float(kl_W.max()),
                'kl_mean': float(kl_W.mean()),
            }
            layer_info.append(info)

            if verbose:
                print(f"  {layer_name}:")
                print(f"    Shape: {info['weight_shape']}")
                print(f"    Params: {info['num_params']}")
                print(f"    KL: [{info['kl_min']:.4f}, {info['kl_max']:.4f}], mean={info['kl_mean']:.4f}")

            # Check for bias (mu_b, rho_b)
            if hasattr(layer, 'mu_b') and hasattr(layer, 'rho_b'):
                mu_b = layer.mu_b.numpy().flatten()
                sigma_b = tf.nn.softplus(layer.rho_b).numpy().flatten()
                kl_b = kl_gaussian_to_scale_mixture(mu_b, sigma_b, pi, sigma1, sigma2)
                layer_kls.extend(kl_b.tolist())
                total_params += len(mu_b)

        # Low-rank Bayesian layers (A_mu, A_rho, B_mu, B_rho)
        # Note: sigma = softplus(rho)
        elif hasattr(layer, 'A_mu') and hasattr(layer, 'B_mu'):
            mu_A = layer.A_mu.numpy().flatten()
            mu_B = layer.B_mu.numpy().flatten()
            sigma_A = tf.nn.softplus(layer.A_rho).numpy().flatten()
            sigma_B = tf.nn.softplus(layer.B_rho).numpy().flatten()

            kl_A = kl_gaussian_to_scale_mixture(mu_A, sigma_A, pi, sigma1, sigma2)
            kl_B = kl_gaussian_to_scale_mixture(mu_B, sigma_B, pi, sigma1, sigma2)

            layer_kls.extend(kl_A.tolist())
            layer_kls.extend(kl_B.tolist())
            total_params += len(mu_A) + len(mu_B)

            info = {
                'name': layer_name,
                'type': 'Low-Rank',
                'A_shape': tuple(layer.A_mu.shape),
                'B_shape': tuple(layer.B_mu.shape),
                'num_params': int(len(mu_A) + len(mu_B)),
                'kl_A_range': [float(kl_A.min()), float(kl_A.max())],
                'kl_B_range': [float(kl_B.min()), float(kl_B.max())],
                'kl_mean': float(np.concatenate([kl_A, kl_B]).mean()),
            }
            layer_info.append(info)

            if verbose:
                print(f"  {layer_name}:")
                print(f"    A: {info['A_shape']}, B: {info['B_shape']}")
                print(f"    Params: {info['num_params']}")
                print(f"    KL_A: [{info['kl_A_range'][0]:.4f}, {info['kl_A_range'][1]:.4f}]")
                print(f"    KL_B: [{info['kl_B_range'][0]:.4f}, {info['kl_B_range'][1]:.4f}]")
                print(f"    KL mean: {info['kl_mean']:.4f}")

            # Check for bias (b_mu, b_rho)
            if hasattr(layer, 'b_mu') and hasattr(layer, 'b_rho'):
                mu_b = layer.b_mu.numpy().flatten()
                sigma_b = tf.nn.softplus(layer.b_rho).numpy().flatten()
                kl_b = kl_gaussian_to_scale_mixture(mu_b, sigma_b, pi, sigma1, sigma2)
                layer_kls.extend(kl_b.tolist())
                total_params += len(mu_b)

        if layer_kls:
            kl_values.extend(layer_kls)

    # Handle case where no Bayesian layers were found
    if len(kl_values) == 0:
        raise ValueError(
            f"No Bayesian layers found in model '{model_name}'!\n"
            f"Expected layers with attributes: mu_W/rho_W (full-rank) or A_mu/B_mu (low-rank).\n"
            f"Model layers: {[layer.name for layer in model.layers]}"
        )

    kl_values = np.array(kl_values)
    C_max = kl_values.max()

    if verbose:
        print(f"\n  {'TOTAL SUMMARY':-^60}")
        print(f"    Total parameters D: {total_params}")
        print(f"    C_max (max KL):     {C_max:.6f}")
        print(f"    Mean KL:            {kl_values.mean():.6f}")
        print(f"    Median KL:          {np.median(kl_values):.6f}")
        print(f"    Std KL:             {kl_values.std():.6f}")
        print(f"    KL(Q||P) ≤ C_max·D = {C_max:.6f} × {total_params} = {C_max * total_params:.2f}")

    return kl_values, total_params, C_max, layer_info


# ==============================================================================
# Bounds Computation
# ==============================================================================

def pac_bayes_bound(D, N, C_max, delta=0.05):
    """
    PAC-Bayes bound (McAllester 1998, Theorem 4.8)

    L(Q) ≤ L̂(Q) + √[(KL(Q||P) + log(2√N/δ)) / (2N)]

    With KL(Q||P) ≤ C_max · D
    """
    log_term = np.log(2 * np.sqrt(N) / delta)
    bound = np.sqrt((C_max * D + log_term) / (2 * N))
    return bound


def gaussian_complexity_bound(D, N, R, delta=0.05, L=1.0):
    """
    Gaussian complexity bound (Pinto et al. 2025, Theorem 4.9)

    E[ℓ(f)] ≤ (1/m)Σ ℓ(f(xⱼ)) + √π·L·Ĝ_S(F) + 3√[log(2/δ)/(2m)]

    Using R (RMS input radius per sequence):
        complexity_term = R · √(D/N)

    Note: R = ||X||_F / √N, so this is equivalent to ||X||_F · √D / N

    Args:
        D: Total parameter count
        N: Number of training samples
        R: RMS input radius (R = √(mean ||x_i||²))
        delta: Confidence parameter (default: 0.05)
        L: Lipschitz constant (default: 1.0)

    Returns:
        Gaussian complexity generalization bound
    """
    complexity_term = R * np.sqrt(D / N)
    confidence_term = 3 * np.sqrt(np.log(2 / delta) / (2 * N))
    bound = np.sqrt(np.pi) * L * complexity_term + confidence_term
    return bound


# ==============================================================================
# Big-O Asymptotic Bounds (without constants)
# ==============================================================================

def pac_bayes_bound_bigO(D, N, C_max):
    """
    PAC-Bayes bound - Big-O asymptotic form (dominant term only)

    Gap_PAC ∝ √(C_max · D / N)

    This shows the scaling behavior without confidence terms and constants.
    Use this to compare theoretical predictions with empirical scaling.

    Args:
        D: Total parameter count
        N: Number of training samples
        C_max: Maximum per-parameter KL divergence

    Returns:
        Big-O asymptotic complexity term
    """
    return np.sqrt(C_max * D / N)


def gaussian_complexity_bound_bigO(D, N, R):
    """
    Gaussian complexity bound - Big-O asymptotic form (dominant term only)

    Gap_GC ∝ R · √(D/N)

    This shows the rank-dependent scaling without Lipschitz constants,
    confidence terms, or π factors.

    Args:
        D: Total parameter count
        N: Number of training samples
        R: RMS input radius

    Returns:
        Big-O asymptotic complexity term
    """
    return R * np.sqrt(D / N)


# ==============================================================================
# Main Function
# ==============================================================================

def compute_empirical_bounds(X_train, y_train, bayes_model, lowrank_model,
                            delta=0.05, L=1.0, save_results=True, verbose=True):
    """
    Compute PAC-Bayes and Gaussian complexity bounds with empirical values.

    Args:
        X_train: Training data (N, T, F)
        y_train: Training labels (N, 1)
        bayes_model: Trained full-rank Bayesian LSTM
        lowrank_model: Trained low-rank Bayesian LSTM
        delta: Confidence parameter (default: 0.05)
        L: Lipschitz constant (default: 1.0)
        save_results: Whether to save results to JSON (default: True)
        verbose: Whether to print detailed output (default: True)

    Returns:
        results_dict: Dictionary with all computed values
    """
    if verbose:
        print("=" * 80)
        print("Computing PAC-Bayes Bounds with ACTUAL EMPIRICAL VALUES")
        print("=" * 80)

    # =========================================================================
    # STEP 1: Compute R (RMS input radius) from actual training data
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("STEP 1: Computing R (RMS input radius) from training data")
        print("=" * 80)

    # X_train is 3D: (N, T, F)
    N_train = X_train.shape[0]
    T = X_train.shape[1]
    F = X_train.shape[2]

    # Flatten each sequence to compute per-sequence norms
    X_flat = X_train.reshape(N_train, -1)  # (N, T*F)

    # Per-sequence squared norms
    seq_sq_norms = np.sum(X_flat**2, axis=1)  # (N,)

    # RMS radius: R = sqrt(mean of ||x_i||^2)
    R = np.sqrt(np.mean(seq_sq_norms))

    # Also compute ||X||_F for reference (||X||_F = sqrt(N) * R)
    X_norm_F_actual = np.sqrt(np.sum(seq_sq_norms))  # = sqrt(N) * R

    # For normalized data (mean=0, std=1), expected values:
    # E[||x_i||^2] = T*F (since each element has variance 1)
    # R_expected = sqrt(T*F)
    R_expected = np.sqrt(T * F)

    if verbose:
        print(f"\nRMS input radius R = {R:.4f}")
        print(f"Expected R (normalized data): √(T×F) = {R_expected:.4f}")
        print(f"Ratio R/R_expected: {R/R_expected:.4f}")
        print(f"\nDerived ||X||_F = √N × R = {X_norm_F_actual:.4f}")
        print(f"\nTraining data:")
        print(f"  Shape: {X_train.shape} (N, T, F)")
        print(f"  N_train = {N_train}")
        print(f"  Sequence length T = {T}")
        print(f"  Features F = {F}")
        print(f"  Total elements: {N_train * T * F:,}")

    # =========================================================================
    # STEP 2: Prior parameters
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("STEP 2: Prior parameters")
        print("=" * 80)

    prior_params = get_prior_params()
    pi = prior_params['pi']
    sigma1 = prior_params['sigma1']
    sigma2 = prior_params['sigma2']

    if verbose:
        print(f"\nScale mixture prior:")
        print(f"  π = {pi}")
        print(f"  σ₁ = {sigma1}")
        print(f"  σ₂ = {sigma2:.6f} (= exp(-6))")

    # =========================================================================
    # STEP 3: Extract posterior parameters and compute C_max
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("STEP 3: Extracting posteriors and computing C_max")
        print("=" * 80)

    kl_full, D_full, C_max_full, layers_full = extract_posterior_params(
        bayes_model, "Full-Rank Bayesian", prior_params, verbose=verbose
    )

    if verbose:
        print("\n" + "="*80 + "\n")

    kl_low, D_low, C_max_low, layers_low = extract_posterior_params(
        lowrank_model, "Low-Rank Bayesian", prior_params, verbose=verbose
    )

    # =========================================================================
    # STEP 4: Compute bounds
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("STEP 4: Computing bounds with EMPIRICAL values")
        print("=" * 80)

    # Exact bounds (with all constants)
    pac_full = pac_bayes_bound(D_full, N_train, C_max_full, delta)
    pac_low = pac_bayes_bound(D_low, N_train, C_max_low, delta)

    gauss_full = gaussian_complexity_bound(D_full, N_train, R, delta, L)
    gauss_low = gaussian_complexity_bound(D_low, N_train, R, delta, L)

    # Big-O asymptotic bounds (dominant terms only)
    pac_full_bigO = pac_bayes_bound_bigO(D_full, N_train, C_max_full)
    pac_low_bigO = pac_bayes_bound_bigO(D_low, N_train, C_max_low)

    gauss_full_bigO = gaussian_complexity_bound_bigO(D_full, N_train, R)
    gauss_low_bigO = gaussian_complexity_bound_bigO(D_low, N_train, R)

    if verbose:
        print(f"\n{'EXACT BOUNDS (with all constants)':^90}")
        print(f"{'Model':<30} {'D':<10} {'C_max':<15} {'PAC-Bayes':<15} {'Gaussian'}")
        print("=" * 90)
        print(f"{'Full-Rank Bayesian':<30} {D_full:<10} {C_max_full:<15.6f} {pac_full:<15.6f} {gauss_full:.6f}")
        print(f"{'Low-Rank Bayesian':<30} {D_low:<10} {C_max_low:<15.6f} {pac_low:<15.6f} {gauss_low:.6f}")

        print(f"\n{'BIG-O ASYMPTOTIC BOUNDS (dominant terms)':^90}")
        print(f"{'Model':<30} {'D':<10} {'C_max':<15} {'PAC-Bayes':<15} {'Gaussian'}")
        print("=" * 90)
        print(f"{'Full-Rank Bayesian':<30} {D_full:<10} {C_max_full:<15.6f} {pac_full_bigO:<15.6f} {gauss_full_bigO:.6f}")
        print(f"{'Low-Rank Bayesian':<30} {D_low:<10} {C_max_low:<15.6f} {pac_low_bigO:<15.6f} {gauss_low_bigO:.6f}")

        print(f"\n{'Complexity Reduction Ratios:':-^90}")
        print(f"  Parameter reduction:  D_low/D_full = {D_low}/{D_full} = {D_low/D_full:.4f}")
        print(f"  C_max ratio:          C_low/C_full = {C_max_low:.4f}/{C_max_full:.4f} = {C_max_low/C_max_full:.4f}")
        print(f"  PAC-Bayes bound ratio:              {pac_low:.4f}/{pac_full:.4f} = {pac_low/pac_full:.4f}")
        print(f"  Gaussian bound ratio:               {gauss_low:.4f}/{gauss_full:.4f} = {gauss_low/gauss_full:.4f}")

        print(f"\n{'Vacuousness Check (bound > 1.0 = vacuous):':-^90}")
        for name, bound in [
            ("Full-Rank PAC-Bayes", pac_full),
            ("Full-Rank Gaussian", gauss_full),
            ("Low-Rank PAC-Bayes", pac_low),
            ("Low-Rank Gaussian", gauss_low),
        ]:
            status = "VACUOUS ❌" if bound > 1.0 else "Non-vacuous ✓"
            print(f"  {name:<25} {bound:>10.6f}  {status}")

    # =========================================================================
    # STEP 5: Empirical verification
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("STEP 5: Empirical Verification")
        print("=" * 80)

    # Try to load empirical results
    import pandas as pd
    empirical_comparison = None

    results_file = Path("results_csv/point_prediction_results.csv")
    if results_file.exists():
        results = pd.read_csv(results_file)
        full_rank_row = results[results['Model'] == 'Full-Rank Bayesian']
        low_rank_row = results[results['Model'] == 'Low-Rank Bayesian']

        if len(full_rank_row) > 0 and len(low_rank_row) > 0:
            mae_full = full_rank_row['MAE'].values[0]
            mae_low = low_rank_row['MAE'].values[0]

            if verbose:
                print(f"\nEmpirical test error (MAE):")
                print(f"  Full-Rank: {mae_full:.4f}")
                print(f"  Low-Rank:  {mae_low:.4f}")
                print(f"  Ratio:     {mae_low/mae_full:.4f}")

            # Normalize by range
            y_range = float(y_train.max() - y_train.min())
            norm_error_full = mae_full / y_range
            norm_error_low = mae_low / y_range

            if verbose:
                print(f"\nNormalized test error (for comparison with bounds):")
                print(f"  y_range = {y_range:.2f}")
                print(f"  Full-Rank: {norm_error_full:.6f}")
                print(f"  Low-Rank:  {norm_error_low:.6f}")

                print(f"\n{'Theoretical vs Empirical:':-^90}")
                print(f"{'Model':<20} {'PAC-Bayes':<15} {'Empirical':<15} {'Gap':<15} {'Valid?'}")
                print("=" * 90)
                gap_full = pac_full - norm_error_full
                gap_low = pac_low - norm_error_low
                print(f"{'Full-Rank':<20} {pac_full:<15.6f} {norm_error_full:<15.6f} {gap_full:<15.6f} {'✓' if gap_full > 0 else '❌'}")
                print(f"{'Low-Rank':<20} {pac_low:<15.6f} {norm_error_low:<15.6f} {gap_low:<15.6f} {'✓' if gap_low > 0 else '❌'}")

                if gap_full > 0 and gap_low > 0:
                    print("\n✓ Bounds HOLD: Theory ≥ Empirical (as required by PAC-Bayes theorem)")
                else:
                    print("\n❌ WARNING: Bounds VIOLATED! This indicates a problem.")

            empirical_comparison = {
                'mae_full': float(mae_full),
                'mae_low': float(mae_low),
                'y_range': float(y_range),
                'norm_error_full': float(norm_error_full),
                'norm_error_low': float(norm_error_low),
                'gap_full': float(gap_full),
                'gap_low': float(gap_low),
                'bounds_hold': bool(gap_full > 0 and gap_low > 0),
            }

    # =========================================================================
    # STEP 6: Prepare results dictionary
    # =========================================================================
    results_dict = {
        'metadata': {
            'R': float(R),
            'R_expected': float(R_expected),
            'X_norm_F_actual': float(X_norm_F_actual),  # = sqrt(N) * R
            'N_train': int(N_train),
            'T': int(T),
            'F': int(F),
            'delta': float(delta),
            'L': float(L),
        },
        'prior_params': {
            'pi': float(pi),
            'sigma1': float(sigma1),
            'sigma2': float(sigma2),
        },
        'full_rank': {
            'D': int(D_full),
            'C_max': float(C_max_full),
            'pac_bayes_bound': float(pac_full),
            'gaussian_bound': float(gauss_full),
            'pac_bayes_bound_bigO': float(pac_full_bigO),
            'gaussian_bound_bigO': float(gauss_full_bigO),
            'kl_mean': float(kl_full.mean()),
            'kl_median': float(np.median(kl_full)),
            'kl_std': float(kl_full.std()),
            'kl_min': float(kl_full.min()),
            'kl_max': float(kl_full.max()),
        },
        'low_rank': {
            'D': int(D_low),
            'C_max': float(C_max_low),
            'pac_bayes_bound': float(pac_low),
            'gaussian_bound': float(gauss_low),
            'pac_bayes_bound_bigO': float(pac_low_bigO),
            'gaussian_bound_bigO': float(gauss_low_bigO),
            'kl_mean': float(kl_low.mean()),
            'kl_median': float(np.median(kl_low)),
            'kl_std': float(kl_low.std()),
            'kl_min': float(kl_low.min()),
            'kl_max': float(kl_low.max()),
        },
        'complexity_ratios': {
            'parameter_ratio': float(D_low / D_full),
            'C_max_ratio': float(C_max_low / C_max_full),
            'pac_bayes_ratio': float(pac_low / pac_full),
            'gaussian_ratio': float(gauss_low / gauss_full),
        },
    }

    if empirical_comparison is not None:
        results_dict['empirical_comparison'] = empirical_comparison

    # =========================================================================
    # STEP 7: Save results
    # =========================================================================
    if save_results:
        output_file = Path('empirical_bounds_results.json')
        with open(output_file, 'w') as f:
            json.dump(results_dict, f, indent=2)

        if verbose:
            print("\n" + "=" * 80)
            print("STEP 6: Results saved")
            print("=" * 80)
            print(f"\n✓ Results saved to: {output_file}")

    # =========================================================================
    # Summary
    # =========================================================================
    if verbose:
        print("\n" + "=" * 80)
        print("SUMMARY: Empirical Bounds Computation Complete!")
        print("=" * 80)

        print(f"\n{'KEY FINDINGS:':-^80}")
        print(f"\n1. ACTUAL C_max values (not arbitrary!):")
        print(f"   Full-Rank: {C_max_full:.6f}")
        print(f"   Low-Rank:  {C_max_low:.6f}")
        print(f"   Ratio:     {C_max_low/C_max_full:.4f}")

        print(f"\n2. RMS input radius R:")
        print(f"   Measured:  {R:.4f}")
        print(f"   Expected:  {R_expected:.4f} (for normalized data)")
        print(f"   ||X||_F = √N × R = {X_norm_F_actual:.4f}")

        print(f"\n3. PAC-Bayes bounds (with REAL values):")
        print(f"   Full-Rank: {pac_full:.6f} ({'VACUOUS' if pac_full > 1 else 'Non-vacuous'})")
        print(f"   Low-Rank:  {pac_low:.6f} ({'VACUOUS' if pac_low > 1 else 'Non-vacuous'})")

        print(f"\n4. Complexity reduction:")
        print(f"   Parameters: {D_low/D_full:.1%} of full-rank")
        print(f"   PAC-Bayes:  {pac_low/pac_full:.1%} of full-rank")

        print("\n" + "=" * 80)

    return results_dict
