"""
Model Builder Functions for Bayesian LSTM

This module contains all model builder functions:
- build_lstm_baseline: Deterministic LSTM baseline
- build_bayesian_lstm_fullrank: Full-rank Bayesian LSTM (Bayes by Backprop)
- build_bayesian_lstm_lowrank: Low-rank Bayesian LSTM
- build_bayesian_lstm_rank1: Rank-1 Bayesian LSTM
"""

import tensorflow as tf
import numpy as np

from modules.bayeslstm import DenseVariational, LowRankDenseVariational, Rank1DenseVariational


# ==============================================================================
# Deterministic LSTM Baseline
# ==============================================================================

def build_lstm_baseline(
    input_size,
    sequence_length,
    lstm_hidden_size=128,
    num_lstm_layers=2,
    dropout_rate=0.0,
    output_dim=1,
    output_mode="last",        # "last" or "sequence"
    forget_bias_init=1.0,      # Standard practice: init forget gate bias to 1.0
    gate_order="ifco",         # Gate ordering: [input | forget | cell | output]
):
    """
    Deterministic LSTM built FROM SCRATCH with explicit time-stepping.
    Why build from scratch?
    1. Full control over weight sampling (critical for Bayesian version)
    2. Easy to swap Dense -> DenseVariational layer by layer
    3. Explicit time loop means we can sample weights ONCE per batch (DeepMind paper requirement)
    4. Educational: see exactly what LSTM does at each timestep
    LSTM Equations (per timestep t):
        z_t = x_t @ W_x + h_{t-1} @ W_h + b      (packed: z_t in R^(4H))
        [i_pre, f_pre, c_tilde_pre, o_pre] = split(z_t)
        i_t = sigmoid(i_pre)           # Input gate: how much new info to let in
        f_t = sigmoid(f_pre)           # Forget gate: how much old memory to keep
        c_tilde_t = tanh(c_tilde_pre)  # Candidate cell state: new information
        c_t = f_t * c_{t-1} + i_t * c_tilde_t    # Cell update: blend old + new
        o_t = sigmoid(o_pre)           # Output gate: how much cell to expose
        h_t = o_t * tanh(c_t)          # Hidden state: filtered cell state
    Args:
        input_size: Number of features per timestep (F=15 for Beijing PM2.5)
        sequence_length: Number of timesteps to unroll (T=24 for our data)
        lstm_hidden_size: Hidden state dimension (H)
        num_lstm_layers: Number of stacked LSTM layers
        dropout_rate: Dropout between LSTM layers (regularization)
        output_dim: Final output dimension (1 for PM2.5 prediction)
        output_mode: "last" = h_T only, "sequence" = all h_t
        forget_bias_init: Initialize forget gate bias (1.0 helps gradient flow early)
        gate_order: Must be "ifco" (input, forget, cell, output)
    """
    # ==========================================================================
    # Validation
    # ==========================================================================
    if output_mode not in {"last", "sequence"}:
        raise ValueError("output_mode must be 'last' or 'sequence'")
    if gate_order != "ifco":
        raise ValueError("gate_order must be 'ifco' (input, forget, cell, output)")
    if num_lstm_layers < 1:
        raise ValueError("num_lstm_layers must be >= 1")
    # ==========================================================================
    # #1 Input Layer
    # ==========================================================================
    x_in = tf.keras.Input(
        shape=(sequence_length, input_size),  # (T, F) = (24, 15)
        dtype=tf.float32,
        name="x_input",
    )
    x = x_in  # Shape: (B, T, F) where B=batch_size
    # ==========================================================================
    # #2 Stacked LSTM Layers
    # ==========================================================================
    for layer_idx in range(num_lstm_layers):
        H = lstm_hidden_size  # Hidden state dimension for this layer
        # ----------------------------------------------------------------------
        # #2.1 Create weight matrices for THIS layer
        #
        # Why two separate Dense layers?
        # - x_to_gates: Projects input x_t -> gate pre-activations (contains W_x and bias)
        # - h_to_gates: Projects hidden h_{t-1} -> gate pre-activations (contains W_h, no bias)
        #
        # Why "packed gates"?
        # - Compute all 4 gates (i, f, c_tilde, o) in one matrix multiply: 4H outputs
        # - More efficient than 4 separate Dense layers
        # - Easier to swap to DenseVariational later (swap 2 layers instead of 8)
        # ----------------------------------------------------------------------
        # Create forget gate bias: initialize to 1.0 for better gradient flow
        # Why? At t=0, if f~1 and i~0, cell state is preserved -> less vanishing gradient
        bias_vec = np.zeros((4 * H,), dtype=np.float32)  # 4H values (i, f, c_tilde, o)
        if forget_bias_init != 0.0:
            bias_vec[H:2*H] = float(forget_bias_init)  # Forget gate is 2nd chunk
        # Input projection: x_t -> 4H (with bias)
        x_to_gates = tf.keras.layers.Dense(
            units=4 * H,  # Output: [i_pre | f_pre | c_tilde_pre | o_pre] concatenated
            use_bias=True,  # Contains bias term b
            activation=None,  # No activation yet (applied after split)
            bias_initializer=tf.keras.initializers.Constant(bias_vec),  # Forget bias = 1.0
            kernel_initializer=tf.keras.initializers.GlorotUniform(),  # Xavier init
            name=f"layer{layer_idx}_x_to_gates",
        )
        # Recurrent projection: h_{t-1} -> 4H (no bias, already in x_to_gates)
        h_to_gates = tf.keras.layers.Dense(
            units=4 * H,
            use_bias=False,  # Convention: bias only in x_to_gates
            activation=None,
            kernel_initializer=tf.keras.initializers.Orthogonal(),  # Orthogonal for recurrent weights
            name=f"layer{layer_idx}_h_to_gates",
        )
        # ----------------------------------------------------------------------
        # #2.2 Initialize LSTM states
        # ----------------------------------------------------------------------
        batch_size = tf.shape(x)[0]  # Dynamic batch size
        h = tf.zeros((batch_size, H), dtype=tf.float32)  # h_0 = 0
        c = tf.zeros((batch_size, H), dtype=tf.float32)  # c_0 = 0
        outputs = []  # Store h_t for all timesteps
        # ----------------------------------------------------------------------
        # #2.3 Unroll time explicitly for Bayesian weight sampling
        #
        # Why explicit loop instead of tf.keras.layers.LSTM?
        # - For Bayesian LSTM, weights are sampled ONCE before this loop
        # - Weights stay fixed throughout all T timesteps (DeepMind paper requirement)
        # - With explicit loop, the exact timing of sampling is controlled
        # ----------------------------------------------------------------------
        for t in range(sequence_length):  # t = 0, 1, ..., T-1
            x_t = x[:, t, :]  # Extract timestep t: shape (B, D_in)
            # ==================================================================
            # STEP 0: Compute packed gate pre-activations
            # z_t = x_t @ W_x + b + h_{t-1} @ W_h
            # ==================================================================
            z = x_to_gates(x_t) + h_to_gates(h)  # Shape: (B, 4H)
            # Split into 4 gates: each is (B, H)
            i_pre, f_pre, ctilde_pre, o_pre = tf.split(z, num_or_size_splits=4, axis=1)
            # ==================================================================
            # STEP 1: Forget gate (controls what to forget from c_{t-1})
            # ==================================================================
            f_t = tf.nn.sigmoid(f_pre)  # Range: [0, 1]
            # f_t ~ 0 -> forget everything from c_{t-1}
            # f_t ~ 1 -> keep everything from c_{t-1}
            # ==================================================================
            # STEP 2: Input gate + Candidate cell state
            # ==================================================================
            i_t = tf.nn.sigmoid(i_pre)  # Range: [0, 1]
            # i_t ~ 0 -> ignore new information
            # i_t ~ 1 -> accept new information
            c_tilde = tf.nn.tanh(ctilde_pre)  # Range: [-1, 1]
            # Candidate: new information to potentially add to cell
            # ==================================================================
            # STEP 3: Cell state update (core memory mechanism)
            # ==================================================================
            c = f_t * c + i_t * c_tilde
            # Blend: (forget old memory) + (add new information)
            # ==================================================================
            # STEP 4: Output gate + Hidden state
            # ==================================================================
            o_t = tf.nn.sigmoid(o_pre)  # Range: [0, 1]
            # o_t ~ 0 -> hide cell state
            # o_t ~ 1 -> expose cell state
            h = o_t * tf.nn.tanh(c)  # Filtered version of cell state
            # tanh(c) puts cell in [-1, 1] range
            # o_t controls how much to output
            outputs.append(h)  # Store hidden state for this timestep
        # Reshape: list of (B, H) -> (B, T, H)
        x = tf.stack(outputs, axis=1)  # x becomes input to next layer
        # ----------------------------------------------------------------------
        # Dropout between stacked LSTM layers (NOT within a layer)
        # Only apply if not the last layer
        # ----------------------------------------------------------------------
        if dropout_rate > 0.0 and layer_idx < num_lstm_layers - 1:
            x = tf.keras.layers.Dropout(
                rate=dropout_rate,
                name=f"layer{layer_idx}_dropout_between"
            )(x)
    # ==========================================================================
    # #3 Choose output representation
    # ==========================================================================
    if output_mode == "last":
        # Use only final hidden state h_T from top LSTM layer
        encoded = x[:, -1, :]  # Shape: (B, H)
    else:  # "sequence"
        # Use all hidden states h_1, ..., h_T
        encoded = x  # Shape: (B, T, H)
    # ==========================================================================
    # #4 Final prediction head
    # ==========================================================================
    y = tf.keras.layers.Dense(
        units=output_dim,  # 1 for PM2.5 prediction
        activation=None,   # Linear output (regression)
        name="output_head"
    )(encoded)
    # If output_mode="last": y shape = (B, 1)
    # If output_mode="sequence": y shape = (B, T, 1)
    # ==========================================================================
    # #5 Build Keras Model
    # ==========================================================================
    model = tf.keras.Model(
        inputs=x_in,
        outputs=y,
        name="lstm_baseline_deterministic"
    )
    return model


# ==============================================================================
# Full-Rank Bayesian LSTM
# ==============================================================================

def build_bayesian_lstm_fullrank(
    input_size,
    sequence_length,
    lstm_hidden_size=128,
    num_lstm_layers=2,
    output_dim=1,
    output_mode="last",
    forget_bias_init=1.0,
    prior_params=None,
    gate_order="ifco",
):
    """
    Bayesian LSTM with Full-Rank Bayes by Backprop (DeepMind Algorithm 2).
    Key implementation:
    - Weights sampled ONCE per batch at t=0 (use_cached=False)
    - Same weights reused for t=1..T-1 (use_cached=True)
    - KL divergence computed once per layer
    Args:
        input_size: Number of input features (F=15 for Beijing PM2.5)
        sequence_length: Number of timesteps (T=24)
        lstm_hidden_size: Hidden state dimension (H)
        num_lstm_layers: Number of stacked LSTM layers
        output_dim: Final output dimension (1 for regression)
        output_mode: "last" = h_T only, "sequence" = all h_t
        forget_bias_init: Initial bias for forget gate (1.0 recommended)
        prior_params: Prior hyperparameters {pi, sigma1, sigma2}
        gate_order: Must be "ifco" (input, forget, cell, output)
    Returns:
        model: Keras Model
        kl_loss_fn: Function to compute total KL divergence
        variational_layers: List of all DenseVariational layers (for cache clearing)
    """
    # Validation
    if output_mode not in {"last", "sequence"}:
        raise ValueError("output_mode must be 'last' or 'sequence'")
    if gate_order != "ifco":
        raise ValueError("gate_order must be 'ifco'")
    if num_lstm_layers < 1:
        raise ValueError("num_lstm_layers must be >= 1")
    # Input layer
    x_in = tf.keras.Input(
        shape=(sequence_length, input_size),
        dtype=tf.float32,
        name="x_input",
    )
    x = x_in  # Shape: (B, T, F)
    # Store all variational layers
    variational_layers = []
    # ===========================================================================
    # Stacked Bayesian LSTM Layers
    # ===========================================================================
    for layer_idx in range(num_lstm_layers):
        H = lstm_hidden_size
        # -----------------------------------------------------------------------
        # Create forget gate bias vector
        # -----------------------------------------------------------------------
        bias_vec = np.zeros((4 * H,), dtype=np.float32)
        bias_vec[H:2*H] = float(forget_bias_init)  # Forget gate = 1.0
        # -----------------------------------------------------------------------
        # Create Bayesian Dense layers
        # -----------------------------------------------------------------------
        x_to_gates = DenseVariational(
            units=4 * H,
            use_bias=True,
            bias_initializer=tf.keras.initializers.Constant(bias_vec),
            prior_params=prior_params,
            name=f"layer{layer_idx}_x_to_gates",
        )
        h_to_gates = DenseVariational(
            units=4 * H,
            use_bias=False,  # Bias already in x_to_gates
            prior_params=prior_params,
            name=f"layer{layer_idx}_h_to_gates",
        )
        # Store for KL computation
        variational_layers.extend([x_to_gates, h_to_gates])
        # -----------------------------------------------------------------------
        # Initialize LSTM states
        # -----------------------------------------------------------------------
        batch_size = tf.shape(x)[0]
        h = tf.zeros((batch_size, H), dtype=tf.float32)
        c = tf.zeros((batch_size, H), dtype=tf.float32)
        outputs = []
        # -----------------------------------------------------------------------
        # Explicit time loop with weight caching
        # -----------------------------------------------------------------------
        for t in range(sequence_length):
            x_t = x[:, t, :]  # (B, F_in) where F_in = input_size or H
            # Determine if we should use cached weights
            use_cached = (t > 0)  # False at t=0 (sample), True at t>0 (reuse)
            # Compute gate pre-activations
            # At t=0: samples fresh weights
            # At t>0: reuses cached weights from t=0
            z = x_to_gates(x_t, training=True, use_cached=use_cached) + \
                h_to_gates(h, training=True, use_cached=use_cached)
            # Split into gates: [input | forget | cell | output]
            i_pre, f_pre, ctilde_pre, o_pre = tf.split(z, num_or_size_splits=4, axis=1)
            # LSTM computations
            f_t = tf.nn.sigmoid(f_pre)           # Forget gate
            i_t = tf.nn.sigmoid(i_pre)           # Input gate
            c_tilde = tf.nn.tanh(ctilde_pre)     # Candidate cell
            c = f_t * c + i_t * c_tilde          # Cell update
            o_t = tf.nn.sigmoid(o_pre)           # Output gate
            h = o_t * tf.nn.tanh(c)              # Hidden state
            outputs.append(h)
        # Stack outputs: (B, T, H)
        x = tf.stack(outputs, axis=1)
    # ===========================================================================
    # Choose output representation
    # ===========================================================================
    if output_mode == "last":
        encoded = x[:, -1, :]  # (B, H) - final hidden state
    else:
        encoded = x  # (B, T, H) - all hidden states
    # ===========================================================================
    # Final output head (also Bayesian)
    # ===========================================================================
    output_head = DenseVariational(
        units=output_dim,
        use_bias=True,
        prior_params=prior_params,
        name="output_head",
    )
    variational_layers.append(output_head)
    # Output head doesn't need caching (called once per forward pass)
    y = output_head(encoded, training=True, use_cached=False)
    # ===========================================================================
    # Build model
    # ===========================================================================
    model = tf.keras.Model(inputs=x_in, outputs=y, name="bayesian_lstm_fullrank")
    # ===========================================================================
    # KL divergence function
    # ===========================================================================
    def kl_loss_fn():
        """
        Compute total KL divergence across all variational layers.
        Should be scaled by 1/(B*C) in training loop (DeepMind paper).
        """
        total_kl = tf.constant(0.0, dtype=tf.float32)
        for layer in variational_layers:
            total_kl += layer.kl_divergence()
        return total_kl
    return model, kl_loss_fn, variational_layers


# ==============================================================================
# Low-Rank Bayesian LSTM
# ==============================================================================

def build_bayesian_lstm_lowrank(
    input_size,
    sequence_length,
    lstm_hidden_size=128,
    num_lstm_layers=2,
    ranks=16,
    output_dim=1,
    output_mode="last",
    forget_bias_init=1.0,
    prior_params=None,
    gate_order="ifco",
    init_from_deterministic=None,
):
    """
    Bayesian LSTM with Low-Rank Bayes by Backprop.
    Args:
        ranks: Rank for low-rank factorization. Can be:
               - int: Same rank for all LSTM layers (e.g., ranks=10)
               - list: Different rank per LSTM layer (e.g., ranks=[15, 10])
                       Must have length = num_lstm_layers
               Output head always uses rank=1.
    """
    # Validation
    if output_mode not in {"last", "sequence"}:
        raise ValueError("output_mode must be 'last' or 'sequence'")
    if gate_order != "ifco":
        raise ValueError("gate_order must be 'ifco'")
    if num_lstm_layers < 1:
        raise ValueError("num_lstm_layers must be >= 1")
    # Handle ranks parameter
    if isinstance(ranks, int):
        # Same rank for all LSTM layers
        rank_list = [ranks] * num_lstm_layers
    else:
        # User provided list
        if len(ranks) != num_lstm_layers:
            raise ValueError(
                f"ranks list must have length {num_lstm_layers} "
                f"(one rank per LSTM layer), got {len(ranks)}"
            )
        rank_list = list(ranks)
    print(f"Low-Rank LSTM configuration:")
    print(f"  LSTM layer ranks: {rank_list}")
    print(f"  Output head rank: 1")
    # Input layer
    x_in = tf.keras.Input(
        shape=(sequence_length, input_size),
        dtype=tf.float32,
        name="x_input",
    )
    x = x_in  # Shape: (B, T, F)
    # Store all variational layers
    variational_layers = []
    # ===========================================================================
    # Stacked Bayesian LSTM Layers
    # ===========================================================================
    for layer_idx in range(num_lstm_layers):
        H = lstm_hidden_size
        layer_rank = rank_list[layer_idx]  # Same rank for both x_to_gates and h_to_gates
        # -----------------------------------------------------------------------
        # Create forget gate bias vector
        # -----------------------------------------------------------------------
        bias_vec = np.zeros((4 * H,), dtype=np.float32)
        bias_vec[H:2*H] = float(forget_bias_init)  # Forget gate = 1.0
        # -----------------------------------------------------------------------
        # Create Bayesian Dense layers (LOW-RANK)
        # Both use same rank for this layer
        # -----------------------------------------------------------------------
        x_to_gates = LowRankDenseVariational(
            units=4 * H,
            rank=layer_rank,  # Same rank
            use_bias=True,
            bias_initializer=tf.keras.initializers.Constant(bias_vec),
            prior_params=prior_params,
            name=f"layer{layer_idx}_x_to_gates",
        )
        h_to_gates = LowRankDenseVariational(
            units=4 * H,
            rank=layer_rank,  # Same rank
            use_bias=False,  # Bias already in x_to_gates
            prior_params=prior_params,
            name=f"layer{layer_idx}_h_to_gates",
        )
        # Store for KL computation
        variational_layers.extend([x_to_gates, h_to_gates])
        # -----------------------------------------------------------------------
        # Initialize LSTM states
        # -----------------------------------------------------------------------
        batch_size = tf.shape(x)[0]
        h = tf.zeros((batch_size, H), dtype=tf.float32)
        c = tf.zeros((batch_size, H), dtype=tf.float32)
        outputs = []
        # -----------------------------------------------------------------------
        # Explicit time loop with weight caching
        # -----------------------------------------------------------------------
        for t in range(sequence_length):
            x_t = x[:, t, :]  # (B, F_in)
            use_cached = (t > 0)
            z = x_to_gates(x_t, training=True, use_cached=use_cached) + \
                h_to_gates(h, training=True, use_cached=use_cached)
            # Split into gates: [input | forget | cell | output]
            i_pre, f_pre, ctilde_pre, o_pre = tf.split(z, num_or_size_splits=4, axis=1)
            # LSTM computations
            f_t = tf.nn.sigmoid(f_pre)
            i_t = tf.nn.sigmoid(i_pre)
            c_tilde = tf.nn.tanh(ctilde_pre)
            c = f_t * c + i_t * c_tilde
            o_t = tf.nn.sigmoid(o_pre)
            h = o_t * tf.nn.tanh(c)
            outputs.append(h)
        # Stack outputs: (B, T, H)
        x = tf.stack(outputs, axis=1)
    # ===========================================================================
    # Choose output representation
    # ===========================================================================
    if output_mode == "last":
        encoded = x[:, -1, :]  # (B, H)
    else:
        encoded = x  # (B, T, H)
    # ===========================================================================
    # Final output head (rank=1 always)
    # ===========================================================================
    output_head = LowRankDenseVariational(
        units=output_dim,
        rank=1,  # Always rank 1 for output head
        use_bias=True,
        prior_params=prior_params,
        name="output_head",
    )
    variational_layers.append(output_head)
    y = output_head(encoded, training=True, use_cached=False)
    # ===========================================================================
    # Build model
    # ===========================================================================
    model = tf.keras.Model(inputs=x_in, outputs=y, name="bayesian_lstm_lowrank")
    # Optional: initialize low-rank means from deterministic weights
    if init_from_deterministic is not None:
        _ = model(tf.zeros((1, sequence_length, input_size), dtype=tf.float32), training=False)
        for layer in variational_layers:
            if hasattr(layer, "init_from_full_matrix"):
                try:
                    det_layer = init_from_deterministic.get_layer(layer.name)
                except ValueError:
                    continue
                det_weights = det_layer.get_weights()
                if det_weights:
                    layer.init_from_full_matrix(det_weights[0])
        # Clear caches after manual init
        for layer in variational_layers:
            if hasattr(layer, "clear_cache"):
                layer.clear_cache()
    # ===========================================================================
    # KL divergence function
    # ===========================================================================
    def kl_loss_fn():
        total_kl = tf.constant(0.0, dtype=tf.float32)
        for layer in variational_layers:
            total_kl += layer.kl_divergence()
        return total_kl
    return model, kl_loss_fn, variational_layers


# ==============================================================================
# Rank-1 Bayesian LSTM
# ==============================================================================

def build_bayesian_lstm_rank1(
    input_size,
    sequence_length,
    lstm_hidden_size=128,
    num_lstm_layers=2,
    output_dim=1,
    output_mode="last",
    forget_bias_init=1.0,
    gate_order="ifco",
):
    """
    Bayesian LSTM using rank-1 multiplicative factors (no stochastic full matrix).
    We sample rank-1 factors once per batch (t=0) and reuse them for all timesteps
    via use_cached=True, matching the full- and low-rank implementations.
    """
    if output_mode not in {"last", "sequence"}:
        raise ValueError("output_mode must be 'last' or 'sequence'")
    if gate_order != "ifco":
        raise ValueError("gate_order must be 'ifco'")
    if num_lstm_layers < 1:
        raise ValueError("num_lstm_layers must be >= 1")

    x_in = tf.keras.Input(
        shape=(sequence_length, input_size),
        dtype=tf.float32,
        name="x_input",
    )
    x = x_in

    variational_layers = []

    for layer_idx in range(num_lstm_layers):
        H = lstm_hidden_size
        bias_vec = np.zeros((4 * H,), dtype=np.float32)
        bias_vec[H:2*H] = float(forget_bias_init)

        x_to_gates = Rank1DenseVariational(
            units=4 * H,
            use_bias=True,
            bias_initializer=tf.keras.initializers.Constant(bias_vec),
            name=f"layer{layer_idx}_x_to_gates",
        )
        h_to_gates = Rank1DenseVariational(
            units=4 * H,
            use_bias=False,
            name=f"layer{layer_idx}_h_to_gates",
        )

        variational_layers.extend([x_to_gates, h_to_gates])

        batch_size = tf.shape(x)[0]
        h = tf.zeros((batch_size, H), dtype=tf.float32)
        c = tf.zeros((batch_size, H), dtype=tf.float32)
        outputs = []

        for t in range(sequence_length):
            x_t = x[:, t, :]
            use_cached = (t > 0)
            z = x_to_gates(x_t, training=True, use_cached=use_cached) + \
                h_to_gates(h, training=True, use_cached=use_cached)

            i_pre, f_pre, ctilde_pre, o_pre = tf.split(z, num_or_size_splits=4, axis=1)
            f_t = tf.nn.sigmoid(f_pre)
            i_t = tf.nn.sigmoid(i_pre)
            c_tilde = tf.nn.tanh(ctilde_pre)
            c = f_t * c + i_t * c_tilde
            o_t = tf.nn.sigmoid(o_pre)
            h = o_t * tf.nn.tanh(c)
            outputs.append(h)

        x = tf.stack(outputs, axis=1)

    if output_mode == "last":
        encoded = x[:, -1, :]
    else:
        encoded = x

    output_head = Rank1DenseVariational(
        units=output_dim,
        use_bias=True,
        name="output_head",
    )
    variational_layers.append(output_head)
    y = output_head(encoded, training=True, use_cached=False)

    model = tf.keras.Model(inputs=x_in, outputs=y, name="bayesian_lstm_rank1")

    def kl_loss_fn():
        total_kl = tf.constant(0.0, dtype=tf.float32)
        for layer in variational_layers:
            total_kl += layer.kl_divergence()
        return total_kl

    return model, kl_loss_fn, variational_layers
