"""
Plot PAC-Bayes and Gaussian Complexity Bounds vs Rank
Using ACTUAL EMPIRICAL VALUES from trained models

This script:
1. Loads empirical C_max and ||X||_F from empirical_bounds_results.json
2. Computes bounds for varying rank r
3. Creates publication-ready figures
"""

import numpy as np
import matplotlib.pyplot as plt
import json
from pathlib import Path

# ==============================================================================
# Load Empirical Results
# ==============================================================================

results_file = Path('empirical_bounds_results.json')
if not results_file.exists():
    print("ERROR: empirical_bounds_results.json not found!")
    print("Please run compute_empirical_bounds_notebook.py first.")
    exit(1)

with open(results_file, 'r') as f:
    empirical_results = json.load(f)

# Extract empirical values
N_train = empirical_results['metadata']['N_train']
F = empirical_results['metadata']['F']
delta = empirical_results['metadata']['delta']
L = empirical_results['metadata']['L']

# Get R (RMS input radius) - with backward compatibility
if 'R' in empirical_results['metadata']:
    R = empirical_results['metadata']['R']
else:
    # Compute R from X_norm_F for old results files
    X_norm_F = empirical_results['metadata']['X_norm_F_actual']
    R = X_norm_F / np.sqrt(N_train)

# Empirical C_max values
C_max_full = empirical_results['full_rank']['C_max']
C_max_low = empirical_results['low_rank']['C_max']

# Parameter counts
D_full = empirical_results['full_rank']['D']
D_low_empirical = empirical_results['low_rank']['D']

print("="*80)
print("Loaded Empirical Values:")
print("="*80)
print(f"  R (RMS input radius) = {R:.4f}")
print(f"  N_train = {N_train}")
print(f"  C_max (Full-Rank) = {C_max_full:.6f}")
print(f"  C_max (Low-Rank [14,20]) = {C_max_low:.6f}")
print(f"  D_full = {D_full}")
print(f"  D_low [14,20] = {D_low_empirical}")

# ==============================================================================
# LSTM Architecture Parameters
# ==============================================================================

# From the notebook
LSTM_HIDDEN = 64  # H
NUM_LAYERS = 2
INPUT_SIZE = 15  # F_in
OUTPUT_DIM = 1

# Layer dimensions
# Layer 0: x_to_gates (15, 4*64), h_to_gates (64, 4*64)
# Layer 1: x_to_gates (64, 4*64), h_to_gates (64, 4*64)
# Output: (64, 1)

layer_dims = [
    # (input_dim, output_dim, has_bias) for each layer
    (INPUT_SIZE, 4 * LSTM_HIDDEN, True),    # Layer 0 x_to_gates
    (LSTM_HIDDEN, 4 * LSTM_HIDDEN, False),  # Layer 0 h_to_gates
    (LSTM_HIDDEN, 4 * LSTM_HIDDEN, True),   # Layer 1 x_to_gates
    (LSTM_HIDDEN, 4 * LSTM_HIDDEN, False),  # Layer 1 h_to_gates
    (LSTM_HIDDEN, OUTPUT_DIM, True),        # Output layer
]

def compute_parameter_count(ranks):
    """
    Compute total parameters for low-rank model with given ranks.

    Args:
        ranks: List of ranks for each layer, or scalar for uniform rank

    For low-rank layers:
        - Weights: r * (m + n) instead of m * n
        - Bias: unchanged (n)

    For the notebook's architecture with ranks=[14, 20]:
        - Layer 0 uses r=14
        - Layer 1 uses r=20
        - Output layer is always full-rank
    """
    # Handle different input types for ranks
    if isinstance(ranks, (int, float, np.integer)):
        # Uniform rank (scalar value)
        ranks = [int(ranks)] * (NUM_LAYERS * 2)  # 2 layers, 2 weight matrices each
    elif hasattr(ranks, '__len__') and len(ranks) == NUM_LAYERS:
        # Per-LSTM-layer ranks (each LSTM layer has 2 weight matrices)
        ranks = [ranks[0], ranks[0], ranks[1], ranks[1]]

    total_params = 0

    for layer_idx, (m, n, has_bias) in enumerate(layer_dims):
        if layer_idx < len(layer_dims) - 1:  # LSTM layers (low-rank)
            r = ranks[layer_idx] if layer_idx < len(ranks) else min(m, n)
            # Low-rank: W = AB^T with A (m, r), B (n, r)
            params = r * (m + n)
        else:  # Output layer (full-rank)
            params = m * n

        if has_bias:
            params += n

        total_params += params

    return total_params

# ==============================================================================
# Estimate C_max for varying ranks
# ==============================================================================

def estimate_C_max(rank, C_max_full, C_max_low_14_20, D_full, D_low_14_20):
    """
    Estimate C_max for a given rank based on empirical measurements.

    Strategy:
    - C_max_full is known for full-rank
    - C_max_low is known for ranks=[14, 20]
    - Interpolate/extrapolate based on effective rank and regularization

    Hypothesis: C_max increases with rank because higher rank allows
    posterior to move further from prior.

    Linear interpolation between known points:
    """
    # For ranks=[14, 20], average rank is 17
    avg_rank_empirical = (14 + 20) / 2

    # Full rank is max(m, n) for each layer
    # Approximate full rank for LSTM layers
    full_rank_approx = max(INPUT_SIZE, 4*LSTM_HIDDEN)  # ~256

    # Linear interpolation
    if isinstance(rank, (list, np.ndarray)):
        avg_rank = np.mean(rank)
    else:
        avg_rank = rank

    # Interpolate between low-rank and full-rank
    if avg_rank <= avg_rank_empirical:
        # Between rank-1 and empirical [14,20]
        # Assume C_max at rank-1 is lower (more constrained)
        # Rough estimate: C_max_rank1 ≈ 0.5 * C_max_low
        C_max_rank1 = 0.5 * C_max_low_14_20
        alpha = avg_rank / avg_rank_empirical
        C_max = C_max_rank1 + alpha * (C_max_low_14_20 - C_max_rank1)
    else:
        # Between empirical [14,20] and full-rank
        alpha = (avg_rank - avg_rank_empirical) / (full_rank_approx - avg_rank_empirical)
        alpha = np.clip(alpha, 0, 1)
        C_max = C_max_low_14_20 + alpha * (C_max_full - C_max_low_14_20)

    return C_max

# ==============================================================================
# Compute Bounds vs Rank
# ==============================================================================

# Rank range: from rank-1 to intersection point where D_lr = D_full
# Find intersection
r_max = 100  # Upper bound for search
D_values = [compute_parameter_count(r) for r in range(1, r_max)]
intersection_idx = np.where(np.array(D_values) >= D_full)[0]
if len(intersection_idx) > 0:
    r_intersection = intersection_idx[0] + 1
else:
    r_intersection = r_max

print(f"\nRank intersection point: r = {r_intersection} (where D_lr ≥ D_full)")

# Compute for range [1, r_intersection]
ranks = np.arange(1, r_intersection + 1)
D_lr = np.array([compute_parameter_count(r) for r in ranks])

# Estimate C_max for each rank
C_max_values = np.array([
    estimate_C_max(r, C_max_full, C_max_low, D_full, D_low_empirical)
    for r in ranks
])

# ==============================================================================
# Bound Functions - EXACT (with all constants)
# ==============================================================================

def pac_bayes_bound(D, C_max):
    """PAC-Bayes bound - exact with all constants"""
    log_term = np.log(2 * np.sqrt(N_train) / delta)
    return np.sqrt((C_max * D + log_term) / (2 * N_train))

def gaussian_complexity_bound(D):
    """Gaussian complexity bound - exact with all constants"""
    complexity_term = R * np.sqrt(D / N_train)
    confidence_term = 3 * np.sqrt(np.log(2 / delta) / (2 * N_train))
    return np.sqrt(np.pi) * L * complexity_term + confidence_term

# ==============================================================================
# Bound Functions - BIG-O (dominant terms only)
# ==============================================================================

def pac_bayes_bound_bigO(D, C_max):
    """PAC-Bayes bound - Big-O asymptotic form"""
    return np.sqrt(C_max * D / N_train)

def gaussian_complexity_bound_bigO(D):
    """Gaussian complexity bound - Big-O asymptotic form"""
    return R * np.sqrt(D / N_train)

# ==============================================================================
# Compute Bounds vs Rank
# ==============================================================================

# Exact bounds
pac_bounds = np.array([pac_bayes_bound(D, C) for D, C in zip(D_lr, C_max_values)])
gaussian_bounds = np.array([gaussian_complexity_bound(D) for D in D_lr])

# Big-O bounds
pac_bounds_bigO = np.array([pac_bayes_bound_bigO(D, C) for D, C in zip(D_lr, C_max_values)])
gaussian_bounds_bigO = np.array([gaussian_complexity_bound_bigO(D) for D in D_lr])

# Full-rank bounds (constant horizontal lines)
pac_full_line = pac_bayes_bound(D_full, C_max_full)
gaussian_full_line = gaussian_complexity_bound(D_full)

pac_full_line_bigO = pac_bayes_bound_bigO(D_full, C_max_full)
gaussian_full_line_bigO = gaussian_complexity_bound_bigO(D_full)

# Empirical point for [14, 20]
r_empirical = 17  # Average of [14, 20]
empirical_idx = np.argmin(np.abs(ranks - r_empirical))

# ==============================================================================
# Create Figure 1: PAC-Bayes Bounds
# ==============================================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Plot low-rank bound vs rank
ax.plot(ranks, pac_bounds, 'b-', linewidth=2, label='Low-Rank Bayesian')

# Full-rank bound (horizontal line)
ax.axhline(pac_full_line, color='orange', linestyle='--', linewidth=2,
           label='Full-Rank Bayesian')

# Empirical point
ax.plot(ranks[empirical_idx], pac_bounds[empirical_idx], 'ro', markersize=10,
        label=f'Empirical (r=[14,20], C_max={C_max_low:.4f})', zorder=5)

# Intersection point
ax.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)
ax.text(r_intersection, ax.get_ylim()[1]*0.95, f'r={r_intersection}\n(D_lr=D_full)',
        ha='center', fontsize=9, color='gray')

# Vacuous line
ax.axhline(1.0, color='red', linestyle=':', linewidth=1, alpha=0.5, label='Vacuous (bound=1)')

ax.set_xlabel('Rank r', fontsize=12)
ax.set_ylabel('PAC-Bayes Bound', fontsize=12)
ax.set_title(f'PAC-Bayes Bounds vs Rank (Empirical C_max, R={R:.2f})', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('figures/pac_bayes_bounds_empirical.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/pac_bayes_bounds_empirical.png', dpi=300, bbox_inches='tight')
print(f"\n✓ Saved: pac_bayes_bounds_empirical.pdf")

# ==============================================================================
# Create Figure 2: Gaussian Complexity Bounds
# ==============================================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Plot low-rank bound vs rank
ax.plot(ranks, gaussian_bounds, 'g-', linewidth=2, label='Low-Rank Bayesian')

# Full-rank bound
ax.axhline(gaussian_full_line, color='orange', linestyle='--', linewidth=2,
           label='Full-Rank Bayesian')

# Empirical point
ax.plot(ranks[empirical_idx], gaussian_bounds[empirical_idx], 'ro', markersize=10,
        label=f'Empirical (r=[14,20])', zorder=5)

# Intersection point
ax.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)
ax.text(r_intersection, ax.get_ylim()[1]*0.95, f'r={r_intersection}\n(D_lr=D_full)',
        ha='center', fontsize=9, color='gray')

ax.set_xlabel('Rank r', fontsize=12)
ax.set_ylabel('Gaussian Complexity Bound', fontsize=12)
ax.set_title(f'Gaussian Complexity Bounds vs Rank (Empirical R={R:.2f})', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('figures/gaussian_complexity_bounds_empirical.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/gaussian_complexity_bounds_empirical.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: gaussian_complexity_bounds_empirical.pdf")

# ==============================================================================
# Create Figure 3: PAC-Bayes Bounds (Big-O Asymptotic)
# ==============================================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Plot low-rank bound vs rank
ax.plot(ranks, pac_bounds_bigO, 'b-', linewidth=2, label='Low-Rank Bayesian (Big-O)')

# Full-rank bound
ax.axhline(pac_full_line_bigO, color='orange', linestyle='--', linewidth=2,
           label='Full-Rank Bayesian (Big-O)')

# Empirical point
ax.plot(ranks[empirical_idx], pac_bounds_bigO[empirical_idx], 'ro', markersize=10,
        label=f'Empirical (r=[14,20])', zorder=5)

# Intersection point
ax.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)
ax.text(r_intersection, ax.get_ylim()[1]*0.95, f'r={r_intersection}\n(D_lr=D_full)',
        ha='center', fontsize=9, color='gray')

ax.set_xlabel('Rank r', fontsize=12)
ax.set_ylabel('PAC-Bayes Bound (Big-O)', fontsize=12)
ax.set_title(f'PAC-Bayes Bounds vs Rank - Big-O Form ∝ √(C_max·D/N)', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('figures/pac_bayes_bounds_bigO.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/pac_bayes_bounds_bigO.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: pac_bayes_bounds_bigO.pdf")

# ==============================================================================
# Create Figure 4: Gaussian Complexity Bounds (Big-O Asymptotic)
# ==============================================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Plot low-rank bound vs rank
ax.plot(ranks, gaussian_bounds_bigO, 'g-', linewidth=2, label='Low-Rank Bayesian (Big-O)')

# Full-rank bound
ax.axhline(gaussian_full_line_bigO, color='orange', linestyle='--', linewidth=2,
           label='Full-Rank Bayesian (Big-O)')

# Empirical point
ax.plot(ranks[empirical_idx], gaussian_bounds_bigO[empirical_idx], 'ro', markersize=10,
        label=f'Empirical (r=[14,20])', zorder=5)

# Intersection point
ax.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)
ax.text(r_intersection, ax.get_ylim()[1]*0.95, f'r={r_intersection}\n(D_lr=D_full)',
        ha='center', fontsize=9, color='gray')

ax.set_xlabel('Rank r', fontsize=12)
ax.set_ylabel('Gaussian Complexity Bound (Big-O)', fontsize=12)
ax.set_title(f'Gaussian Complexity Bounds vs Rank - Big-O Form ∝ R·√(D/N)', fontsize=13)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('gaussian_complexity_bounds_bigO.pdf', dpi=300, bbox_inches='tight')
plt.savefig('gaussian_complexity_bounds_bigO.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: gaussian_complexity_bounds_bigO.pdf")

# ==============================================================================
# Create Figure 5: Comparison - Exact vs Big-O Bounds
# ==============================================================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Left: PAC-Bayes
ax1.plot(ranks, pac_bounds, 'b-', linewidth=2, label='Exact (with constants)')
ax1.plot(ranks, pac_bounds_bigO, 'b--', linewidth=2, alpha=0.7, label='Big-O (dominant term)')
ax1.axhline(pac_full_line, color='orange', linestyle='-', linewidth=2, alpha=0.5)
ax1.axhline(pac_full_line_bigO, color='orange', linestyle='--', linewidth=2, alpha=0.5)
ax1.plot(ranks[empirical_idx], pac_bounds[empirical_idx], 'ro', markersize=10, label='Empirical', zorder=5)
ax1.axhline(1.0, color='red', linestyle=':', linewidth=1, alpha=0.5, label='Vacuous')
ax1.set_xlabel('Rank r', fontsize=12)
ax1.set_ylabel('PAC-Bayes Bound', fontsize=12)
ax1.set_title('PAC-Bayes: Exact vs Big-O', fontsize=13)
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, r_intersection + 5)

# Right: Gaussian Complexity
ax2.plot(ranks, gaussian_bounds, 'g-', linewidth=2, label='Exact (with constants)')
ax2.plot(ranks, gaussian_bounds_bigO, 'g--', linewidth=2, alpha=0.7, label='Big-O (dominant term)')
ax2.axhline(gaussian_full_line, color='orange', linestyle='-', linewidth=2, alpha=0.5)
ax2.axhline(gaussian_full_line_bigO, color='orange', linestyle='--', linewidth=2, alpha=0.5)
ax2.plot(ranks[empirical_idx], gaussian_bounds[empirical_idx], 'ro', markersize=10, label='Empirical', zorder=5)
ax2.set_xlabel('Rank r', fontsize=12)
ax2.set_ylabel('Gaussian Complexity Bound', fontsize=12)
ax2.set_title('Gaussian Complexity: Exact vs Big-O', fontsize=13)
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('figures/bounds_comparison_exact_vs_bigO.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/bounds_comparison_exact_vs_bigO.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: bounds_comparison_exact_vs_bigO.pdf")

# ==============================================================================
# Create Figure 3: Combined (for paper - 2 panels side-by-side)
# ==============================================================================

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Panel A: PAC-Bayes
ax1.plot(ranks, pac_bounds, 'b-', linewidth=2.5, label='Low-Rank')
ax1.axhline(pac_full_line, color='orange', linestyle='--', linewidth=2.5, label='Full-Rank')
ax1.plot(ranks[empirical_idx], pac_bounds[empirical_idx], 'ro', markersize=12, zorder=5)
ax1.axhline(1.0, color='red', linestyle=':', linewidth=1, alpha=0.5)
ax1.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)

ax1.set_xlabel('Rank r', fontsize=13, fontweight='bold')
ax1.set_ylabel('PAC-Bayes Bound', fontsize=13, fontweight='bold')
ax1.set_title('(A) PAC-Bayes Bounds', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11, loc='best')
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0, r_intersection + 5)

# Panel B: Gaussian Complexity
ax2.plot(ranks, gaussian_bounds, 'g-', linewidth=2.5, label='Low-Rank')
ax2.axhline(gaussian_full_line, color='orange', linestyle='--', linewidth=2.5, label='Full-Rank')
ax2.plot(ranks[empirical_idx], gaussian_bounds[empirical_idx], 'ro', markersize=12, zorder=5)
ax2.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)

ax2.set_xlabel('Rank r', fontsize=13, fontweight='bold')
ax2.set_ylabel('Gaussian Complexity Bound', fontsize=13, fontweight='bold')
ax2.set_title('(B) Gaussian Complexity Bounds', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11, loc='best')
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0, r_intersection + 5)

plt.tight_layout()
plt.savefig('figures/combined_bounds_empirical.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/combined_bounds_empirical.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: combined_bounds_empirical.pdf (Publication figure)")

# ==============================================================================
# Create Figure 4: Complexity Ratio (Main Result for Paper)
# ==============================================================================

fig, ax = plt.subplots(figsize=(10, 6))

# Complexity reduction ratio
pac_ratio = pac_bounds / pac_full_line
gaussian_ratio = gaussian_bounds / gaussian_full_line
param_ratio = D_lr / D_full

ax.plot(ranks, param_ratio, 'k-', linewidth=2.5, label='Parameter Ratio (D_lr/D_full)', alpha=0.7)
ax.plot(ranks, pac_ratio, 'b-', linewidth=2.5, label='PAC-Bayes Bound Ratio')
ax.plot(ranks, gaussian_ratio, 'g-', linewidth=2.5, label='Gaussian Bound Ratio')

# Empirical point
ax.plot(ranks[empirical_idx], pac_ratio[empirical_idx], 'ro', markersize=12, zorder=5,
        label=f'Empirical (r=[14,20])')

# Reference line
ax.axhline(1.0, color='red', linestyle='--', linewidth=1, alpha=0.5)
ax.axvline(r_intersection, color='gray', linestyle=':', alpha=0.5)

ax.set_xlabel('Rank r', fontsize=13, fontweight='bold')
ax.set_ylabel('Ratio (Low-Rank / Full-Rank)', fontsize=13, fontweight='bold')
ax.set_title('Complexity Reduction: Low-Rank vs Full-Rank', fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='best')
ax.grid(True, alpha=0.3)
ax.set_xlim(0, r_intersection + 5)
ax.set_ylim(0, 1.1)

# Add annotations
y_pos = 0.5
ax.text(r_empirical + 5, pac_ratio[empirical_idx] - 0.05,
        f'PAC: {pac_ratio[empirical_idx]:.2f}×',
        fontsize=10, color='blue', fontweight='bold')

plt.tight_layout()
plt.savefig('figures/complexity_ratio_empirical.pdf', dpi=300, bbox_inches='tight')
plt.savefig('figures/complexity_ratio_empirical.png', dpi=300, bbox_inches='tight')
print(f"✓ Saved: complexity_ratio_empirical.pdf (Main result)")

# ==============================================================================
# Summary Statistics
# ==============================================================================

print("\n" + "="*80)
print("SUMMARY: Bounds vs Rank (Using Empirical Values)")
print("="*80)

print(f"\nEmpirical point (r=[14,20], avg r={r_empirical}):")
print(f"  Parameters:    {D_lr[empirical_idx]} ({param_ratio[empirical_idx]:.2%} of full-rank)")
print(f"  PAC-Bayes:     {pac_bounds[empirical_idx]:.6f} ({pac_ratio[empirical_idx]:.2%} of full-rank)")
print(f"  Gaussian:      {gaussian_bounds[empirical_idx]:.6f} ({gaussian_ratio[empirical_idx]:.2%} of full-rank)")
print(f"  C_max:         {C_max_values[empirical_idx]:.6f}")

print(f"\nFull-rank:")
print(f"  Parameters:    {D_full}")
print(f"  PAC-Bayes:     {pac_full_line:.6f}")
print(f"  Gaussian:      {gaussian_full_line:.6f}")
print(f"  C_max:         {C_max_full:.6f}")

print(f"\nKey findings:")
print(f"  1. Low-rank achieves {param_ratio[empirical_idx]:.1%} parameter reduction")
print(f"  2. PAC-Bayes bound reduced to {pac_ratio[empirical_idx]:.1%}")
print(f"  3. Gaussian bound reduced to {gaussian_ratio[empirical_idx]:.1%}")
print(f"  4. Intersection at rank r={r_intersection}")

# Check which bounds are vacuous
if pac_full_line > 1.0:
    print(f"\n  ⚠️  Full-rank PAC-Bayes bound is VACUOUS ({pac_full_line:.4f} > 1)")
else:
    print(f"\n  ✓ Full-rank PAC-Bayes bound is non-vacuous ({pac_full_line:.4f} < 1)")

if pac_bounds[empirical_idx] > 1.0:
    print(f"  ⚠️  Low-rank PAC-Bayes bound is VACUOUS ({pac_bounds[empirical_idx]:.4f} > 1)")
else:
    print(f"  ✓ Low-rank PAC-Bayes bound is non-vacuous ({pac_bounds[empirical_idx]:.4f} < 1)")

print("\n" + "="*80)
print("All figures saved!")
print("="*80)
print("\nGenerated files:")
print("  EXACT BOUNDS (with all constants):")
print("    1. pac_bayes_bounds_empirical.pdf")
print("    2. gaussian_complexity_bounds_empirical.pdf")
print("  BIG-O BOUNDS (dominant terms):")
print("    3. pac_bayes_bounds_bigO.pdf")
print("    4. gaussian_complexity_bounds_bigO.pdf")
print("  COMPARISON:")
print("    5. bounds_comparison_exact_vs_bigO.pdf")
print("  OTHER:")
print("    6. combined_bounds_empirical.pdf (2-panel figure)")
print("    7. complexity_ratio_empirical.pdf (main result)")
