"""
Training Utilities for Bayesian LSTM

This module contains all training-related functions and callbacks:
- Cache clearing utilities
- KL scaling functions
- Training functions for Bayesian and deterministic models
- Deep Ensemble training
"""

import tensorflow as tf
import numpy as np
import gc

from modules.model_builders import build_lstm_baseline


# ==============================================================================
# Seed Setting Function
# ==============================================================================

def set_seed(seed):
    """
    Set all random seeds for reproducibility.
    """
    import os
    import random
    os.environ["PYTHONHASHSEED"] = str(seed)
    tf.keras.backend.clear_session()
    gc.collect()
    np.random.seed(seed)
    tf.random.set_seed(seed)
    tf.keras.utils.set_random_seed(seed)


# ==============================================================================
# Cache Clearing Utilities
# ==============================================================================

def clear_model_cache(variational_layers):
    """
    Clear cached weights for all variational layers.
    Call this at the start of each batch to ensure fresh weight sampling.
    Args:
        variational_layers: List of DenseVariational layers
    """
    for layer in variational_layers:
        layer.clear_cache()


class CacheClearingCallback(tf.keras.callbacks.Callback):
    """
    Keras callback to clear cached weights at the start of each batch.
    Ensures one weight sample per batch for Bayesian RNNs.
    """
    def __init__(self, variational_layers):
        super().__init__()
        self.variational_layers = variational_layers

    def on_train_batch_begin(self, batch, logs=None):
        """Clear cache before each training batch"""
        clear_model_cache(self.variational_layers)

    def on_test_batch_begin(self, batch, logs=None):
        """Clear cache before each validation batch"""
        clear_model_cache(self.variational_layers)

    def on_predict_batch_begin(self, batch, logs=None):
        """Clear cache before each prediction batch"""
        clear_model_cache(self.variational_layers)


# ==============================================================================
# KL Scaling Utilities
# ==============================================================================

def compute_kl_scale(num_train_samples, batch_size, num_truncations=1):
    """
    Compute KL divergence scaling factor per DeepMind paper.
    For RNNs: scale = 1 / (num_batches * num_truncations)
    For non-truncated sequences: num_truncations = 1
    Args:
        num_train_samples: Total number of training samples (N)
        batch_size: Batch size (B)
        num_truncations: Number of sequence truncations (C=1 for full sequences)
    Returns:
        kl_scale: Scaling factor for KL divergence
    """
    num_batches = np.ceil(num_train_samples / batch_size)
    kl_scale = 0.1 / num_train_samples
    return kl_scale


def set_kl_scale(model, value, verbose=False):
    """
    Update kl_scale attribute on all variational layers.
    Args:
        model: Keras model containing variational layers
        value: New kl_scale value
        verbose: Print update info
    """
    updated = 0
    for layer in model.layers:
        # Recursively handle nested models
        if isinstance(layer, tf.keras.Model):
            set_kl_scale(layer, value, verbose)
        elif hasattr(layer, "kl_scale"):
            layer.kl_scale.assign(float(value))
            updated += 1
    if verbose and updated:
        print(f"[set_kl_scale] Set kl_scale={value:.8f} on {updated} layers")


# ==============================================================================
# Model Compilation
# ==============================================================================

def compile_regression(model, learning_rate=1e-3):
    """
    Compile model for regression with MSE loss.
    Args:
        model: Keras model to compile
        learning_rate: Learning rate for Adam optimizer
    """
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="mse",
        metrics=[
            tf.keras.metrics.MeanSquaredError(name="mse"),
            tf.keras.metrics.MeanAbsoluteError(name="mae")
        ]
    )


# ==============================================================================
# Bayesian LSTM Training
# ==============================================================================

def train_bayesian_lstm(
    model,
    variational_layers,
    X_train, y_train,
    X_val, y_val,
    batch_size=64,
    epochs=50,
    learning_rate=1e-3,
    kl_warmup_epochs=0,
    verbose=1,
    kl_scale=0.09,
    seed=42,
):
    """
    Train Bayesian LSTM model with proper cache clearing and KL scaling.
    Args:
        model: Bayesian LSTM model
        variational_layers: List of DenseVariational layers
        X_train, y_train: Training data (N, T, F) and (N, 1)
        X_val, y_val: Validation data
        batch_size: Training batch size
        epochs: Number of training epochs
        learning_rate: Learning rate for Adam
        kl_warmup_epochs: Number of epochs for KL warmup (0 = no warmup)
        verbose: Verbosity level
        kl_scale: KL scale factor
        seed: Random seed
    Returns:
        history: Training history
    """
    set_seed(seed)
    # Compute KL scaling factor (DeepMind paper: 1/(B*C))
    kl_scale_final = kl_scale / len(X_train)
    print(f"Training Bayesian LSTM:")
    print(f"  Train samples: {len(X_train):,}")
    print(f"  Val samples: {len(X_val):,}")
    print(f"  Batch size: {batch_size}")
    print(f"  Batches per epoch: {int(np.ceil(len(X_train) / batch_size))}")
    print(f"  KL scale (final): {kl_scale_final:.8f}")
    print(f"  KL warmup epochs: {kl_warmup_epochs}")
    # Set initial KL scale
    if kl_warmup_epochs > 0:
        initial_kl_scale = 0.0
        set_kl_scale(model, initial_kl_scale, verbose=True)
    else:
        set_kl_scale(model, kl_scale_final, verbose=True)
    # Compile model
    compile_regression(model, learning_rate=learning_rate)
    callbacks = [
        # Cache clearing for Bayesian RNN
        CacheClearingCallback(variational_layers),
        # Early stopping
        tf.keras.callbacks.EarlyStopping(
            monitor='val_mae',
            patience=30,
            restore_best_weights=True,
            verbose=1
        ),
        # Learning rate reduction
        tf.keras.callbacks.ReduceLROnPlateau(
            monitor='val_mae',
            factor=0.5,
            patience=5,
            min_lr=1e-6,
            verbose=1
        ),
    ]
    # KL warmup callback (if enabled)
    if kl_warmup_epochs > 0:
        class KLWarmupCallback(tf.keras.callbacks.Callback):
            def __init__(self, final_scale, zero_epochs, ramp_epochs, verbose=True):
                super().__init__()
                self.final_scale = float(final_scale)
                self.zero_epochs = int(zero_epochs)
                self.ramp_epochs = int(ramp_epochs)
                self.verbose = verbose

            def on_epoch_begin(self, epoch, logs=None):
                if epoch < self.zero_epochs:
                    current = 0.0
                elif epoch < self.zero_epochs + self.ramp_epochs:
                    frac = (epoch - self.zero_epochs + 1) / float(self.ramp_epochs)
                    current = self.final_scale * min(frac, 1.0)
                else:
                    current = self.final_scale
                # assign to all variational layers
                set_kl_scale(self.model, current, verbose=False)
                if self.verbose and (epoch % 5 == 0 or epoch == self.zero_epochs + self.ramp_epochs - 1):
                    print(f"[KLWarmup] epoch {epoch:02d} -> kl_scale = {current:.8f}")

        callbacks.append(KLWarmupCallback(final_scale=kl_scale_final,
                                          zero_epochs=10, ramp_epochs=20))
    # Train
    print("\nStarting training...")
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=batch_size,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )
    return history


# ==============================================================================
# Rank-1 Bayesian LSTM Training (wrapper)
# ==============================================================================

def train_rank1_bayesian_lstm(
    input_size,
    sequence_length,
    lstm_hidden_size,
    num_lstm_layers,
    X_train, y_train,
    X_val, y_val,
    batch_size=64,
    epochs=100,
    learning_rate=1e-3,
    kl_warmup_epochs=0,
    verbose=1,
    seed=42,
):
    """
    Build and train a Rank-1 Bayesian LSTM model.
    """
    from model_builders import build_bayesian_lstm_rank1

    model, _, variational_layers = build_bayesian_lstm_rank1(
        input_size=input_size,
        sequence_length=sequence_length,
        lstm_hidden_size=lstm_hidden_size,
        num_lstm_layers=num_lstm_layers,
        output_dim=1,
        output_mode="last",
    )
    history = train_bayesian_lstm(
        model=model,
        variational_layers=variational_layers,
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        batch_size=batch_size,
        epochs=epochs,
        learning_rate=learning_rate,
        kl_warmup_epochs=kl_warmup_epochs,
        verbose=verbose,
        seed=seed,
    )
    return model, variational_layers, history


# ==============================================================================
# Deep Ensemble Training
# ==============================================================================

def train_deep_ensemble(
    M,  # Number of ensemble members
    input_size,
    sequence_length,
    lstm_hidden_size,
    num_lstm_layers,
    X_train, y_train,
    X_val, y_val,
    batch_size=64,
    epochs=150,
    learning_rate=1e-3,
    verbose=1,
    base_seed=42,
):
    """
    Train M independent deterministic LSTM models (Deep Ensemble).
    Args:
        M: Number of ensemble members
        (other args same as deterministic builder)
    Returns:
        models: List of M trained models
        histories: List of M training histories
    """
    models = []
    histories = []
    print(f"\nTraining Deep Ensemble with {M} members...")
    print(f"Each member: {lstm_hidden_size}H x {num_lstm_layers}L")
    for i in range(M):
        # Different seed for each member ensures diverse initializations
        member_seed = base_seed + i
        set_seed(member_seed)
        print(f"\n{'='*80}")
        print(f"ENSEMBLE MEMBER {i+1}/{M} (seed={member_seed})")
        print(f"{'='*80}")
        # Build deterministic LSTM (different random init due to different seed)
        model = build_lstm_baseline(
            input_size=input_size,
            sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size,
            num_lstm_layers=num_lstm_layers,
            output_dim=1,
            output_mode="last",
            forget_bias_init=1.0,
        )
        # Compile
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss="mse",
            metrics=[
                tf.keras.metrics.MeanSquaredError(name="mse"),
                tf.keras.metrics.MeanAbsoluteError(name="mae")
            ]
        )
        # Callbacks
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=10,
                restore_best_weights=True,
                verbose=0
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=5,
                min_lr=1e-6,
                verbose=0
            ),
        ]
        # Train
        history = model.fit(
            X_train, y_train,
            validation_data=(X_val, y_val),
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks,
            verbose=verbose
        )
        models.append(model)
        histories.append(history)
        # Report best val loss for this member
        best_val_loss = min(history.history['val_loss'])
        print(f"  Member {i+1} best val loss: {best_val_loss:.6f}")
    print(f"\n{'='*80}")
    print(f"  ALL {M} ENSEMBLE MEMBERS TRAINED")
    print(f"{'='*80}")
    return models, histories


def ensemble_predict(models, X, return_std=False):
    """
    Make predictions with Deep Ensemble.
    Args:
        models: List of ensemble members
        X: Input data
        return_std: If True, return (mean, std), else just mean
    Returns:
        predictions: Mean predictions
        std (optional): Predictive standard deviation
    """
    predictions = []
    for model in models:
        pred = model.predict(X, verbose=0)
        predictions.append(pred)
    predictions = np.array(predictions)  # Shape: (M, N, 1)
    mean_pred = np.mean(predictions, axis=0)  # Shape: (N, 1)
    if return_std:
        std_pred = np.std(predictions, axis=0)  # Shape: (N, 1)
        return mean_pred, std_pred
    else:
        return mean_pred




#------------------------------------------------
### multi seed tracker classe
from pathlib import Path
from datetime import datetime
import json
import pickle
import numpy as np 
import pandas as pd
SEEDS = [42, 123, 456, 2026]  # 4 seeds total
CHECKPOINT_DIR = "checkpoints/multi_seed"  # Dedicated folder 
RESULTS_DIR = "multi_seed_results"
EXISTING_MODELS_DIR = "checkpoints"  

# Model names for tracking

MODEL_NAMES = [
    "Deterministic",
    "Full-Rank Bayesian",
    "Low-Rank Bayesian",
    "Low-Rank (SVD Init)",
    "Rank-1 Bayesian",
    "Deep Ensemble"
]

class ExperimentTracker:
    """Tracks experiment progress and manages checkpoints."""
    
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.progress_file = self.checkpoint_dir / "progress.json"
        self.results_file = self.checkpoint_dir / "results.pkl"
        self.load_progress()
    
    def load_progress(self):
        """Load existing progress or initialize new tracking."""
        if self.progress_file.exists():
            with open(self.progress_file, 'r') as f:
                self.progress = json.load(f)
            print(f"\nResuming: {len(self.progress['completed'])} experiments already completed")
        else:
            self.progress = {
                'completed': [],
                'start_time': datetime.now().isoformat(),
                'timings': {},
                'reused_seed_42': []
            }
        
        # Load results
        if self.results_file.exists():
            with open(self.results_file, 'rb') as f:
                self.results = pickle.load(f)
        else:
            self.results = {}
    
    def is_completed(self, seed, model_name):
        """Check if experiment is already completed."""
        key = f"{seed}_{model_name}"
        return key in self.progress['completed']
    
    def mark_reused(self, seed, model_name):
        """Mark that existing model was reused."""
        key = f"{seed}_{model_name}"
        if key not in self.progress['reused_seed_42']:
            self.progress['reused_seed_42'].append(key)
    
    def save_result(self, seed, model_name, metrics, training_time, reused=False):
        """Save experiment result and update checkpoint."""
        key = f"{seed}_{model_name}"
        
        # Store result
        self.results[key] = {
            'seed': seed,
            'model': model_name,
            'metrics': metrics,
            'training_time': training_time,
            'reused': reused,
            'timestamp': datetime.now().isoformat()
        }
        
        # Update progress
        if key not in self.progress['completed']:
            self.progress['completed'].append(key)
        self.progress['timings'][key] = training_time
        
        # Save checkpoint
        with open(self.progress_file, 'w') as f:
            json.dump(self.progress, f, indent=2)
        
        with open(self.results_file, 'wb') as f:
            pickle.dump(self.results, f)
        
        status = "Reused" if reused else "Trained"
        print(f"{status}: {key} ({training_time:.1f}s)")
    
    def get_summary(self):
        """Get summary of completed experiments."""
        total = len(SEEDS) * len(MODEL_NAMES)
        completed = len(self.progress['completed'])
        return {
            'total': total,
            'completed': completed,
            'remaining': total - completed,
            'progress_pct': (completed / total) * 100,
            'reused': len(self.progress.get('reused_seed_42', []))
        }




import os
import json

def save_trained_models(models, save_dir="checkpoints"):
    """
    Save model weights only (works with custom layers, no custom_objects needed).
    
    Args:
        models: dict of model_name -> model (or list of models for ensemble)
        save_dir: directory to save weights
    """
    os.makedirs(save_dir, exist_ok=True)
    
    print(f"\n{'='*80}")
    print(f"SAVING MODEL WEIGHTS TO: {save_dir}")
    print(f"{'='*80}")
    
    key_mapping = {}
    
    for idx, (name, model) in enumerate(models.items()):
        safe_name = f"model_{idx}"
        
        if isinstance(model, list):
            # Ensemble - save each member
            model_dir = os.path.join(save_dir, safe_name)
            os.makedirs(model_dir, exist_ok=True)
            for i, member in enumerate(model):
                member.save_weights(os.path.join(model_dir, f"member_{i}.weights.h5"))
            key_mapping[name] = {"path": safe_name, "type": "ensemble", "n_members": len(model)}
            print(f"  Saved ensemble: {name} ({len(model)} members)")
        else:
            # Single model
            weight_path = os.path.join(save_dir, f"{safe_name}.weights.h5")
            model.save_weights(weight_path)
            key_mapping[name] = {"path": f"{safe_name}.weights.h5", "type": "single"}
            print(f"  Saved: {name}")
    
    # Save key mapping for loading
    with open(os.path.join(save_dir, "key_mapping.json"), 'w') as f:
        json.dump(key_mapping, f, indent=2)
    
    print(f"\nKey mapping saved to: {os.path.join(save_dir, 'key_mapping.json')}")
    print(f"{'='*80}\n")


def load_trained_models(model_builders, feature_dim, sequence_length, save_dir="checkpoints",
                        lstm_hidden_size=64, num_lstm_layers=2, rank=16, 
                        ensemble_builder=None, n_ensemble_members=5):
    """
    Load model weights into freshly built models.
    
    Args:
        model_builders: dict of model_name -> builder_function or 'ensemble'
        feature_dim: input feature dimension (F)
        sequence_length: sequence length (T)
        save_dir: where weights are saved
        lstm_hidden_size: LSTM hidden size
        num_lstm_layers: number of LSTM layers
        rank: rank for low-rank models
        ensemble_builder: function to build single ensemble member
        n_ensemble_members: number of ensemble members
        
    Returns:
        models: dict of model_name -> loaded model
        layers: dict of model_name -> variational layers (if applicable)
    """
    print(f"\n{'='*80}")
    print(f"LOADING MODEL WEIGHTS FROM: {save_dir}")
    print(f"{'='*80}")
    
    # Load key mapping
    with open(os.path.join(save_dir, "key_mapping.json"), 'r') as f:
        key_mapping = json.load(f)
    
    models = {}
    layers = {}
    
    for original_key, info in key_mapping.items():
        item_path = os.path.join(save_dir, info["path"])
        
        if info["type"] == "ensemble":
            # Rebuild and load ensemble
            member_files = sorted([f for f in os.listdir(item_path) if f.endswith('.h5')])
            ensemble = []
            for f in member_files:
                model = ensemble_builder(feature_dim, sequence_length, lstm_hidden_size, num_lstm_layers)
                model.load_weights(os.path.join(item_path, f))
                ensemble.append(model)
            models[original_key] = ensemble
            layers[original_key] = None
            print(f"  Loaded ensemble: {original_key} ({len(ensemble)} members)")
        else:
            # Rebuild single model using the builder
            builder_info = model_builders.get(original_key)
            if builder_info is None:
                print(f"  WARNING: No builder found for {original_key}, skipping...")
                continue
                
            model, _, var_layers = builder_info(
                input_size=feature_dim, 
                sequence_length=sequence_length,
                lstm_hidden_size=lstm_hidden_size, 
                num_lstm_layers=num_lstm_layers
            )
            model.load_weights(item_path)
            models[original_key] = model
            layers[original_key] = var_layers
            print(f"  Loaded: {original_key}")

    print(f"{'='*80}\n")
    return models, layers


# ==============================================================================
# Model Loading and Training Utilities for Multi-Seed Experiments
# ==============================================================================

def load_existing_model(model_name,
                        models_dir,
                        input_size,
                        sequence_length,
                        lstm_hidden_size,
                        num_lstm_layers,
                        ranks=None,
                        ensemble_size=5,
                        seed=42,
                        reuse_seed_42=True):
    """
    Load existing model from checkpoints folder.

    Args:
        model_name: Name of the model to load
        models_dir: Directory containing saved models and key_mapping.json
        input_size: Number of input features (F)
        sequence_length: Sequence length (T)
        lstm_hidden_size: LSTM hidden size
        num_lstm_layers: Number of LSTM layers
        ranks: List of ranks for low-rank models (default: [14, 20])
        ensemble_size: Number of ensemble members (default: 5)
        seed: Random seed (only seed=42 models are loaded)
        reuse_seed_42: Whether to reuse seed 42 models

    Returns:
        tuple: (model, variational_layers) or (None, None) if not found
    """
    import json
    from pathlib import Path
    from modules.model_builders import (
        build_lstm_baseline,
        build_bayesian_lstm_fullrank,
        build_bayesian_lstm_lowrank,
        build_bayesian_lstm_rank1
    )

    if ranks is None:
        ranks = [14, 20]

    if seed != 42 or not reuse_seed_42:
        return None, None

    # Check if key_mapping.json exists
    models_dir = Path(models_dir)
    mapping_path = models_dir / "key_mapping.json"
    if not mapping_path.exists():
        print(f"  No key_mapping.json found in {models_dir}")
        return None, None

    with open(mapping_path, 'r') as f:
        key_mapping = json.load(f)

    # Find the model in key mapping
    model_key = None
    for key in key_mapping.keys():
        if model_name in key or key in model_name:
            model_key = key
            break

    if model_key is None:
        print(f"  Model '{model_name}' not found in key_mapping")
        return None, None

    info = key_mapping[model_key]

    try:
        if info["type"] == "ensemble":
            # Load ensemble
            model_dir = models_dir / info["path"]
            if not model_dir.exists():
                print(f"  Ensemble directory not found: {model_dir}")
                return None, None

            # Build and load each ensemble member
            ensemble_models = []
            n_members = info.get("n_members", ensemble_size)
            for i in range(n_members):
                member_path = model_dir / f"member_{i}.weights.h5"
                if member_path.exists():
                    model = build_lstm_baseline(
                        input_size=input_size, sequence_length=sequence_length,
                        lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
                    )
                    model.load_weights(str(member_path))
                    ensemble_models.append(model)
                else:
                    print(f"  Ensemble member {i} not found")
                    return None, None
            print(f"  Loaded {len(ensemble_models)} ensemble members from {model_dir}")
            return ensemble_models, None
        else:
            # Single model - rebuild and load weights
            weight_path = models_dir / info["path"]
            if not weight_path.exists():
                print(f"  Weights not found: {weight_path}")
                return None, None

            # Build the correct model architecture
            if "Deterministic" in model_name:
                model = build_lstm_baseline(
                    input_size=input_size, sequence_length=sequence_length,
                    lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
                )
                var_layers = None
            elif "Full-Rank" in model_name:
                model, _, var_layers = build_bayesian_lstm_fullrank(
                    input_size=input_size, sequence_length=sequence_length,
                    lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
                )
            elif "SVD" in model_name or "svd" in model_name.lower():
                model, _, var_layers = build_bayesian_lstm_lowrank(
                    input_size=input_size, sequence_length=sequence_length,
                    lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
                    ranks=ranks
                )
            elif "Low-Rank" in model_name:
                model, _, var_layers = build_bayesian_lstm_lowrank(
                    input_size=input_size, sequence_length=sequence_length,
                    lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
                    ranks=ranks
                )
            elif "Rank-1" in model_name:
                model, _, var_layers = build_bayesian_lstm_rank1(
                    input_size=input_size, sequence_length=sequence_length,
                    lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
                )
            else:
                print(f"  Unknown model type: {model_name}")
                return None, None

            model.load_weights(str(weight_path))
            print(f"  Loaded existing model from {weight_path}")
            return model, var_layers

    except Exception as e:
        print(f"  Failed to load model: {e}")
        import traceback
        traceback.print_exc()
        return None, None


def train_single_model(model_name,
                       X_train, y_train,
                       X_val, y_val,
                       input_size,
                       sequence_length,
                       lstm_hidden_size,
                       num_lstm_layers,
                       batch_size=64,
                       epochs=100,
                       learning_rate=1e-3,
                       kl_scale=0.25,
                       ranks=None,
                       ensemble_size=5,
                       ensemble_epochs=50,
                       seed=42,
                       verbose=1):
    """
    Train a single model from scratch.

    Args:
        model_name: Name of the model to train
        X_train, y_train: Training data
        X_val, y_val: Validation data
        input_size: Number of input features (F)
        sequence_length: Sequence length (T)
        lstm_hidden_size: LSTM hidden size
        num_lstm_layers: Number of LSTM layers
        batch_size: Batch size
        epochs: Number of training epochs
        learning_rate: Learning rate
        kl_scale: KL divergence scale for Bayesian models
        ranks: List of ranks for low-rank models (default: [14, 20])
        ensemble_size: Number of ensemble members
        ensemble_epochs: Epochs for ensemble training
        seed: Random seed
        verbose: Verbosity level

    Returns:
        tuple: (model, variational_layers, training_time)
    """
    import time
    from modules.model_builders import (
        build_lstm_baseline,
        build_bayesian_lstm_fullrank,
        build_bayesian_lstm_lowrank,
        build_bayesian_lstm_rank1
    )

    if ranks is None:
        ranks = [14, 20]

    # Clear memory and set seed
    aggressive_memory_cleanup()
    set_seed(seed)

    start_time = time.time()

    if model_name == "Deterministic":
        model = build_lstm_baseline(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
            output_dim=1, output_mode="last", forget_bias_init=1.0
        )
        model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss="mse",
            metrics=[tf.keras.metrics.MeanAbsoluteError(name="mae")]
        )
        model.fit(
            X_train, y_train, validation_data=(X_val, y_val),
            batch_size=batch_size, epochs=epochs, verbose=verbose,
            callbacks=[
                tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True),
                tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
            ]
        )
        var_layers = None

    elif model_name == "Full-Rank Bayesian":
        model, _, var_layers = build_bayesian_lstm_fullrank(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
        )
        train_bayesian_lstm(
            model=model, variational_layers=var_layers,
            X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
            batch_size=batch_size, epochs=epochs, learning_rate=learning_rate,
            kl_scale=kl_scale, verbose=verbose, seed=seed
        )

    elif model_name == "Low-Rank Bayesian":
        model, _, var_layers = build_bayesian_lstm_lowrank(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
            ranks=ranks
        )
        train_bayesian_lstm(
            model=model, variational_layers=var_layers,
            X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
            batch_size=batch_size, epochs=epochs, learning_rate=2e-3,
            kl_scale=0.1, verbose=verbose, seed=seed
        )

    elif model_name == "Low-Rank (SVD Init)":
        # First train deterministic baseline for SVD initialization
        print("  Training baseline for SVD initialization...")
        det_for_svd = build_lstm_baseline(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
            output_dim=1, output_mode="last", forget_bias_init=1.0
        )
        det_for_svd.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
            loss="mse",
            metrics=[tf.keras.metrics.MeanAbsoluteError(name="mae")]
        )
        det_for_svd.fit(
            X_train, y_train, validation_data=(X_val, y_val),
            batch_size=batch_size, epochs=epochs, verbose=0,
            callbacks=[
                tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True),
                tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
            ]
        )

        # Build low-rank with SVD init from baseline
        model, _, var_layers = build_bayesian_lstm_lowrank(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
            ranks=ranks,
            init_from_deterministic=det_for_svd
        )
        train_bayesian_lstm(
            model=model, variational_layers=var_layers,
            X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
            batch_size=batch_size, epochs=epochs, learning_rate=learning_rate,
            kl_scale=kl_scale, verbose=verbose, seed=seed
        )

        # Clean up
        del det_for_svd
        gc.collect()

    elif model_name == "Rank-1 Bayesian":
        model, _, var_layers = build_bayesian_lstm_rank1(
            input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers
        )
        train_bayesian_lstm(
            model=model, variational_layers=var_layers,
            X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
            batch_size=batch_size, epochs=150, learning_rate=learning_rate,
            kl_warmup_epochs=10, kl_scale=0.25, verbose=verbose, seed=seed
        )

    elif model_name == "Deep Ensemble":
        model, _ = train_deep_ensemble(
            M=ensemble_size, input_size=input_size, sequence_length=sequence_length,
            lstm_hidden_size=lstm_hidden_size, num_lstm_layers=num_lstm_layers,
            X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
            batch_size=batch_size, epochs=ensemble_epochs, learning_rate=learning_rate,
            verbose=0, base_seed=seed
        )
        var_layers = None
    else:
        raise ValueError(f"Unknown model: {model_name}")

    training_time = time.time() - start_time
    return model, var_layers, training_time

