"""
model_builders.py - All 6 model builder functions for Bayesian Transformers.

This module contains builder functions for:
1. Baseline Transformer (deterministic)
2. Full-Rank Transformer BBB
3. Low-Rank Transformer BBB (random initialization)
4. Low-Rank Transformer BBB (SVD initialization from baseline)
5. Rank-1 Transformer BBB
6. Deep Ensemble (wrapper for multiple baseline models)
"""

import gc
import os
import json
import time
import random
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from modules.bayesian_layers import (
    DenseVariational,
    EmbeddingVariational,
    LowRankDenseVariational,
    LowRankEmbeddingVariational,
    Rank1DenseVariational,
    Rank1EmbeddingVariational,
    IndependentDropout,
    set_dropout_active,
)

tfd = tfp.distributions


# =============================================================================
# 1. BASELINE TRANSFORMER (DETERMINISTIC)
# =============================================================================

def build_tiny_transformer_baseline(
    vocab_size,
    max_length,
    d_model=256,
    n_layers=4,
    n_heads=4,
    d_ff=512,
    dropout_rate=0.0,
    num_classes=2,
):
    """
    Deterministic Tiny Transformer encoder (pre-norm) used as a baseline.

    Same architecture as the BBB models:
        - Token + positional embeddings
        - Stack of L pre-norm Transformer blocks
          (self-attention + FFN with residual connections)
        - CLS pooling + classifier head

    Differences versus BBB versions:
        - All linear maps are standard tf.keras.layers.Dense (no weight sampling)
        - No KL terms, no variational parameters

    Parameters
    ----------
    vocab_size : int
        Size of vocabulary
    max_length : int
        Maximum sequence length
    d_model : int
        Model dimension (default: 256)
    n_layers : int
        Number of transformer layers (default: 4)
    n_heads : int
        Number of attention heads (default: 4)
    d_ff : int
        Feed-forward dimension (default: 512)
    dropout_rate : float
        Dropout rate (default: 0.0)
    num_classes : int
        Number of output classes (default: 2)

    Returns
    -------
    tf.keras.Model
        The baseline transformer model
    """

    # -------------------------------------------------------------------------
    # 1) Inputs
    # -------------------------------------------------------------------------
    input_ids = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="input_ids"
    )
    attention_mask = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="attention_mask"
    )

    # -------------------------------------------------------------------------
    # 2) Deterministic embeddings
    # -------------------------------------------------------------------------
    # Token embedding: map token ids -> d_model vectors
    tok_emb_layer = tf.keras.layers.Embedding(
        input_dim=vocab_size,
        output_dim=d_model,
        name="tok_emb",
    )
    token_emb = tok_emb_layer(input_ids)  # (B, L, d_model)

    # Positional embedding: one embedding per position 0..max_length-1
    pos_indices = tf.range(max_length)  # (L,)
    pos_emb_layer = tf.keras.layers.Embedding(
        input_dim=max_length,
        output_dim=d_model,
        name="pos_emb",
    )
    # (L, d_model) -> add batch dimension -> (1, L, d_model)
    pos_emb = pos_emb_layer(pos_indices)[tf.newaxis, :, :]

    # Initial hidden representation H^(0) = token + positional
    x = token_emb + pos_emb  # (B, L, d_model)

    if dropout_rate > 0.0:
        x = IndependentDropout(dropout_rate, name="emb_dropout")(x)

    # -------------------------------------------------------------------------
    # 3) Attention mask
    # -------------------------------------------------------------------------
    # attention_mask: 1 for real tokens, 0 for padding
    mask_raw = tf.cast(attention_mask, tf.float32)           # (B, L)
    mask_broadcast = mask_raw[:, tf.newaxis, tf.newaxis, :]  # (B, 1, 1, L)

    # Per-head dimension (assumes d_model divisible by n_heads)
    d_k = d_model // n_heads

    # -------------------------------------------------------------------------
    # 4) Transformer encoder stack (pre-norm)
    # -------------------------------------------------------------------------
    for layer_idx in range(n_layers):
        # ===================== Attention sub-layer (pre-norm) =====================
        # LayerNorm BEFORE attention (pre-norm)
        ln_attn = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_attn",
        )
        attn_input = ln_attn(x)  # (B, L, d_model)

        # Standard Dense projections for Q, K, V: each output d_model
        q_proj = tf.keras.layers.Dense(
            units=d_model,
            activation=None,
            name=f"layer{layer_idx}_Q",
        )
        k_proj = tf.keras.layers.Dense(
            units=d_model,
            activation=None,
            name=f"layer{layer_idx}_K",
        )
        v_proj = tf.keras.layers.Dense(
            units=d_model,
            activation=None,
            name=f"layer{layer_idx}_V",
        )

        Q = q_proj(attn_input)  # (B, L, d_model)
        K = k_proj(attn_input)  # (B, L, d_model)
        V = v_proj(attn_input)  # (B, L, d_model)

        def split_heads(t, n_heads=n_heads, d_k=d_k):
            """
            Reshape from (B, L, d_model) -> (B, n_heads, L, d_k)
            so that attention operates per-head.
            """
            batch_size = tf.shape(t)[0]
            seq_len = tf.shape(t)[1]
            # (B, L, n_heads, d_k)
            t = tf.reshape(t, [batch_size, seq_len, n_heads, d_k])
            # (B, n_heads, L, d_k)
            return tf.transpose(t, perm=[0, 2, 1, 3])

        # Project into multiple heads
        Qh = split_heads(Q)  # (B, n_heads, L, d_k)
        Kh = split_heads(K)  # (B, n_heads, L, d_k)
        Vh = split_heads(V)  # (B, n_heads, L, d_k)

        # Scaled dot-product attention scores: S = Q K^T / sqrt(d_k)
        scores = tf.matmul(Qh, Kh, transpose_b=True)  # (B, n_heads, L, L)
        # Cast scaling factor to match scores dtype (for mixed precision)
        scale = tf.cast(tf.math.sqrt(tf.cast(d_k, tf.float32)), scores.dtype)
        scores = scores / scale

        # Apply mask: set scores for padding positions to a large negative value
        # Cast mask values to match scores dtype (for mixed precision)
        mask_value = tf.cast((1.0 - mask_broadcast) * (-1e9), scores.dtype)
        scores += mask_value

        # Normalized attention weights along key dimension
        attn_weights = tf.nn.softmax(scores, axis=-1)  # (B, n_heads, L, L)

        # Context vectors = sum_j a_ij * V_j
        context = tf.matmul(attn_weights, Vh)  # (B, n_heads, L, d_k)

        # ----------------- Explicit concatenation of heads -----------------
        # 1. Move sequence dimension back: (B, L, n_heads, d_k)
        context = tf.transpose(context, perm=[0, 2, 1, 3])
        # 2. Merge heads and d_k into d_model: (B, L, d_model)
        batch_size = tf.shape(context)[0]
        seq_len = tf.shape(context)[1]
        context = tf.reshape(context, [batch_size, seq_len, n_heads * d_k])

        # Output projection W_O: (B, L, d_model) -> (B, L, d_model)
        o_proj = tf.keras.layers.Dense(
            units=d_model,
            activation=None,
            name=f"layer{layer_idx}_O",
        )
        attn_output = o_proj(context)  # (B, L, d_model)

        if dropout_rate > 0.0:
            attn_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_attn_dropout",
            )(attn_output)

        # Residual connection after attention
        x = x + attn_output  # (B, L, d_model)

        # ========================= FFN sub-layer (pre-norm) ========================
        ln_ff = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_ff",
        )
        ffn_input = ln_ff(x)  # (B, L, d_model)

        # First FFN layer: expand to d_ff, with GELU activation
        ff1 = tf.keras.layers.Dense(
            units=d_ff,
            activation=tf.keras.activations.gelu,
            name=f"layer{layer_idx}_ff1",
        )
        ff_hidden = ff1(ffn_input)  # (B, L, d_ff)

        # Second FFN layer: project back to d_model
        ff2 = tf.keras.layers.Dense(
            units=d_model,
            activation=None,
            name=f"layer{layer_idx}_ff2",
        )
        ff_output = ff2(ff_hidden)  # (B, L, d_model)

        if dropout_rate > 0.0:
            ff_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_ff_dropout",
            )(ff_output)

        # Residual connection after FFN
        x = x + ff_output  # (B, L, d_model)

    # -------------------------------------------------------------------------
    # 5) CLS pooling + deterministic classifier
    # -------------------------------------------------------------------------
    # Take representation of the [CLS] token at position 0
    cls_token = x[:, 0, :]  # (B, d_model)

    # Final classifier head to num_classes logits
    # Keep in float32 for numerical stability with mixed precision
    logits = tf.keras.layers.Dense(
        units=num_classes,
        activation=None,
        dtype='float32',
        name="cls_head",
    )(cls_token)  # (B, num_classes)

    probs = tf.nn.softmax(logits, axis=-1, name="probs")

    # -------------------------------------------------------------------------
    # 6) Build model
    # -------------------------------------------------------------------------
    model = tf.keras.Model(
        inputs={"input_ids": input_ids, "attention_mask": attention_mask},
        outputs=probs,
        name="tiny_transformer_baseline",
    )
    return model


# =============================================================================
# 2. FULL-RANK TRANSFORMER BBB
# =============================================================================

def build_tiny_transformer_fullrank_bbb(
    vocab_size,
    max_length,
    d_model=256,
    n_layers=4,
    n_heads=4,
    d_ff=512,
    dropout_rate=0.0,
    num_classes=2,
    kl_scale=1.0,
    pretrained_embedding_weights=None,
):
    """
    Build a tiny *pre-norm* Transformer encoder with full-rank Bayes-by-Backprop.

    Parameters
    ----------
    vocab_size : int
        Size of vocabulary
    max_length : int
        Maximum sequence length
    d_model : int
        Model dimension (default: 256)
    n_layers : int
        Number of transformer layers (default: 4)
    n_heads : int
        Number of attention heads (default: 4)
    d_ff : int
        Feed-forward dimension (default: 512)
    dropout_rate : float
        Dropout rate (default: 0.0)
    num_classes : int
        Number of output classes (default: 2)
    kl_scale : float
        KL divergence scaling factor (default: 1.0)
    pretrained_embedding_weights : list, optional
        Pretrained embedding weights to initialize the embedding layer

    Returns
    -------
    tf.keras.Model
        The full-rank BBB transformer model
    """

    # -------------------------------------------------------------------------
    # 1) Inputs
    # -------------------------------------------------------------------------
    input_ids = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="input_ids"
    )
    attention_mask = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="attention_mask"
    )

    # -------------------------------------------------------------------------
    # 2) Variational embeddings (token + position)
    # -------------------------------------------------------------------------
    # Token embedding (variational)
    token_emb_layer = EmbeddingVariational(
        input_dim=vocab_size,
        output_dim=d_model,
        kl_scale=kl_scale,
        name="tok_emb_variational",
    )
    token_emb = token_emb_layer(input_ids)  # (B, L, d_model)

    # Positional embedding (variational)
    pos_emb_layer = EmbeddingVariational(
        input_dim=max_length,
        output_dim=d_model,
        kl_scale=kl_scale,
        name="pos_emb_variational"
    )
    pos_indices = tf.range(max_length)
    pos_emb = pos_emb_layer(pos_indices)[tf.newaxis, :, :]  # (1, L, d_model)

    # Initial hidden representation H^(0)
    x = token_emb + pos_emb

    if dropout_rate > 0.0:
        x = IndependentDropout(dropout_rate, name="emb_dropout")(x)

    # -------------------------------------------------------------------------
    # 3) Attention mask preparation
    # -------------------------------------------------------------------------
    mask_raw = tf.cast(attention_mask, tf.float32)
    mask_broadcast = mask_raw[:, tf.newaxis, tf.newaxis, :]  # (B, 1, 1, L)
    d_k = d_model // n_heads

    # -------------------------------------------------------------------------
    # 4) Transformer encoder stack (pre-norm)
    # -------------------------------------------------------------------------
    for layer_idx in range(n_layers):
        # LayerNorm (Pre-Norm)
        ln_attn = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_attn"
        )
        attn_input = ln_attn(x)

        # Variational Projections
        q_proj = DenseVariational(d_model, kl_scale=kl_scale, name=f"layer{layer_idx}_Q")
        k_proj = DenseVariational(d_model, kl_scale=kl_scale, name=f"layer{layer_idx}_K")
        v_proj = DenseVariational(d_model, kl_scale=kl_scale, name=f"layer{layer_idx}_V")

        Q = q_proj(attn_input)
        K = k_proj(attn_input)
        V = v_proj(attn_input)

        # Split Heads
        def split_heads(t, n_heads=n_heads, d_k=d_k):
            batch_size = tf.shape(t)[0]
            seq_len = tf.shape(t)[1]
            t = tf.reshape(t, [batch_size, seq_len, n_heads, d_k])
            return tf.transpose(t, perm=[0, 2, 1, 3])

        Qh = split_heads(Q)
        Kh = split_heads(K)
        Vh = split_heads(V)

        # Attention Scores
        scores = tf.matmul(Qh, Kh, transpose_b=True)
        # Cast scaling factor to match scores dtype (for mixed precision)
        scale = tf.cast(tf.math.sqrt(tf.cast(d_k, tf.float32)), scores.dtype)
        scores = scores / scale
        # Cast mask values to match scores dtype (for mixed precision)
        mask_value = tf.cast((1.0 - mask_broadcast) * (-1e9), scores.dtype)
        scores += mask_value

        # Attention Weights & Context
        attn_weights = tf.nn.softmax(scores, axis=-1)
        context = tf.matmul(attn_weights, Vh)

        # Merge Heads
        context = tf.transpose(context, perm=[0, 2, 1, 3])
        batch_size = tf.shape(context)[0]
        seq_len = tf.shape(context)[1]
        context = tf.reshape(context, [batch_size, seq_len, n_heads * d_k])

        # Output Projection (Variational)
        o_proj = DenseVariational(d_model, kl_scale=kl_scale, name=f"layer{layer_idx}_O")
        attn_output = o_proj(context)

        if dropout_rate > 0.0:
            attn_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_attn_dropout",
            )(attn_output)

        # Residual 1
        x = x + attn_output

        # FFN (Variational)
        ln_ff = tf.keras.layers.LayerNormalization(
            epsilon=1e-6, name=f"layer{layer_idx}_ln_ff"
        )
        ffn_input = ln_ff(x)

        ff1 = DenseVariational(
            d_ff, kl_scale=kl_scale, activation="gelu", name=f"layer{layer_idx}_ff1"
        )
        ff_hidden = ff1(ffn_input)

        ff2 = DenseVariational(
            d_model, kl_scale=kl_scale, activation=None, name=f"layer{layer_idx}_ff2"
        )
        ff_output = ff2(ff_hidden)

        if dropout_rate > 0.0:
            ff_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_ff_dropout",
            )(ff_output)

        # Residual 2
        x = x + ff_output

    # -------------------------------------------------------------------------
    # 5) Classifier Head
    # -------------------------------------------------------------------------
    cls_token = x[:, 0, :]
    cls_head = DenseVariational(
        num_classes, kl_scale=kl_scale, activation=None, name="cls_head"
    )
    logits = cls_head(cls_token)
    # Cast to float32 for numerical stability in mixed precision training
    logits = tf.cast(logits, tf.float32)
    probs = tf.nn.softmax(logits, axis=-1, name="probs")

    # -------------------------------------------------------------------------
    # 6) Build Model
    # -------------------------------------------------------------------------
    model = tf.keras.Model(
        inputs={"input_ids": input_ids, "attention_mask": attention_mask},
        outputs=probs,
        name="tiny_transformer_fullrank_bbb",
    )
    return model


# =============================================================================
# 3 & 4. LOW-RANK TRANSFORMER BBB (Random and SVD Initialization)
# =============================================================================

def build_tiny_transformer_lowrank_bbb(
    vocab_size,
    max_length,
    d_model=256,
    n_layers=4,
    n_heads=4,
    d_ff=512,
    rank=16,
    rank_emb=10,
    dropout_rate=0.2,
    num_classes=2,
    kl_scale=1.0,
    pretrained_embedding_weights=None,
    init_from_deterministic=None,
):
    """
    Tiny *pre-norm* Transformer encoder with Low-Rank Bayes-by-Backprop.

    Parameters
    ----------
    vocab_size : int
        Size of vocabulary
    max_length : int
        Maximum sequence length
    d_model : int
        Model dimension (default: 256)
    n_layers : int
        Number of transformer layers (default: 4)
    n_heads : int
        Number of attention heads (default: 4)
    d_ff : int
        Feed-forward dimension (default: 512)
    rank : int
        Rank for low-rank factorization (default: 10)
    dropout_rate : float
        Dropout rate (default: 0.2)
    num_classes : int
        Number of output classes (default: 2)
    kl_scale : float
        KL divergence scaling factor (default: 1.0)
    pretrained_embedding_weights : list, optional
        Pretrained embedding weights
    init_from_deterministic : tf.keras.Model, optional
        Trained deterministic model to initialize low-rank layers via SVD

    Returns
    -------
    tf.keras.Model
        The low-rank BBB transformer model
    """

    # -------------------------------------------------------------------------
    # 1) Inputs
    # -------------------------------------------------------------------------
    input_ids = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="input_ids"
    )
    attention_mask = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="attention_mask"
    )

    # -------------------------------------------------------------------------
    # 2) Variational embeddings
    # -------------------------------------------------------------------------
    # Token embeddings
    tok_emb_layer = LowRankEmbeddingVariational(
        input_dim=vocab_size,
        output_dim=d_model,
        rank=rank,
        kl_scale=kl_scale,
        name="tok_emb_lowrank_variational",
    )
    token_emb = tok_emb_layer(input_ids)  # (B, max_length, d_model)

    # Positional embedding (low-rank variational)
    pos_emb_layer = LowRankEmbeddingVariational(
        input_dim=max_length,
        output_dim=d_model,
        rank=rank_emb,
        kl_scale=kl_scale,
        name="pos_emb_lowrank_variational",
    )
    pos_indices = tf.range(max_length, dtype=tf.int32)
    pos_emb = pos_emb_layer(pos_indices)[tf.newaxis, :, :]

    # Initial hidden representation
    x = token_emb + pos_emb

    if dropout_rate > 0.0:
        x = IndependentDropout(dropout_rate, name="emb_dropout")(x)

    # -------------------------------------------------------------------------
    # 3) Attention mask
    # -------------------------------------------------------------------------
    mask_raw = tf.cast(attention_mask, tf.float32)
    mask_broadcast = mask_raw[:, tf.newaxis, tf.newaxis, :]
    d_k = d_model // n_heads

    # -------------------------------------------------------------------------
    # 4) Transformer encoder stack (pre-norm)
    # -------------------------------------------------------------------------
    for layer_idx in range(n_layers):
        # LayerNorm
        ln_attn = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_attn"
        )
        attn_input = ln_attn(x)

        # Low-rank BBB projections
        q_proj = LowRankDenseVariational(
            d_model, rank=rank, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_Q"
        )
        k_proj = LowRankDenseVariational(
            d_model, rank=rank, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_K"
        )
        v_proj = LowRankDenseVariational(
            d_model, rank=rank, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_V"
        )

        Q = q_proj(attn_input)
        K = k_proj(attn_input)
        V = v_proj(attn_input)

        def split_heads(t, n_heads=n_heads, d_k=d_k):
            batch_size = tf.shape(t)[0]
            seq_len = tf.shape(t)[1]
            t = tf.reshape(t, [batch_size, seq_len, n_heads, d_k])
            return tf.transpose(t, perm=[0, 2, 1, 3])

        Qh = split_heads(Q)
        Kh = split_heads(K)
        Vh = split_heads(V)

        # Attention Scores
        scores = tf.matmul(Qh, Kh, transpose_b=True)
        # Cast scaling factor to match scores dtype (for mixed precision)
        scale = tf.cast(tf.math.sqrt(tf.cast(d_k, tf.float32)), scores.dtype)
        scores = scores / scale
        # Cast mask values to match scores dtype (for mixed precision)
        mask_value = tf.cast((1.0 - mask_broadcast) * (-1e9), scores.dtype)
        scores += mask_value

        # Attention Weights & Context
        attn_weights = tf.nn.softmax(scores, axis=-1)
        context = tf.matmul(attn_weights, Vh)

        # Merge Heads
        context = tf.transpose(context, perm=[0, 2, 1, 3])
        batch_size = tf.shape(context)[0]
        seq_len = tf.shape(context)[1]
        context = tf.reshape(context, [batch_size, seq_len, n_heads * d_k])

        # Low-rank output projection
        o_proj = LowRankDenseVariational(
            d_model, rank=rank, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_O"
        )
        attn_output = o_proj(context)

        if dropout_rate > 0.0:
            attn_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_attn_dropout",
            )(attn_output)

        # Residual 1
        x = x + attn_output

        # FFN
        ln_ff = tf.keras.layers.LayerNormalization(
            epsilon=1e-6, name=f"layer{layer_idx}_ln_ff"
        )
        ffn_input = ln_ff(x)

        ff1 = LowRankDenseVariational(
            d_ff, rank=rank, kl_scale=kl_scale, activation="gelu",
            name=f"layer{layer_idx}_ff1"
        )
        ff_hidden = ff1(ffn_input)

        ff2 = LowRankDenseVariational(
            d_model, rank=rank, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_ff2"
        )
        ff_output = ff2(ff_hidden)

        if dropout_rate > 0.0:
            ff_output = IndependentDropout(
                dropout_rate,
                name=f"layer{layer_idx}_ff_dropout",
            )(ff_output)

        # Residual 2
        x = x + ff_output

    # -------------------------------------------------------------------------
    # 5) CLS pooling + low-rank BBB classifier
    # -------------------------------------------------------------------------
    cls_token = x[:, 0, :]
    cls_head = LowRankDenseVariational(
        num_classes, rank=rank, kl_scale=kl_scale, activation=None,
        name="cls_head"
    )
    logits = cls_head(cls_token)
    # Cast to float32 for numerical stability in mixed precision training
    logits = tf.cast(logits, tf.float32)
    probs = tf.nn.softmax(logits, axis=-1, name="probs")

    # -------------------------------------------------------------------------
    # 6) Build model
    # -------------------------------------------------------------------------
    model = tf.keras.Model(
        inputs={"input_ids": input_ids, "attention_mask": attention_mask},
        outputs=probs,
        name=f"tiny_transformer_lowrank_bbb_r{rank}",
    )

    # -------------------------------------------------------------------------
    # 7) Optional SVD initialization from deterministic model
    # -------------------------------------------------------------------------
    if init_from_deterministic is not None:
        # Build Bayesian model with dummy input to initialize weights
        dummy_input = {
            "input_ids": tf.zeros((1, max_length), dtype=tf.int32),
            "attention_mask": tf.ones((1, max_length), dtype=tf.int32)
        }
        _ = model(dummy_input, training=False)

        # Initialize low-rank layers from deterministic model via SVD
        initialized_count = 0
        for layer in model.layers:
            if isinstance(layer, LowRankDenseVariational) and hasattr(layer, "init_from_full_matrix"):
                try:
                    det_layer = init_from_deterministic.get_layer(layer.name)
                except ValueError:
                    print(f"  [SVD Init] Layer '{layer.name}' not found in deterministic model, skipping...")
                    continue

                det_weights = det_layer.get_weights()
                if det_weights:
                    # det_weights[0] is kernel, det_weights[1] is bias (if present)
                    W_full = det_weights[0]
                    b_full = det_weights[1] if len(det_weights) > 1 else None
                    layer.init_from_full_matrix(W_full, b_full)
                    initialized_count += 1

        print(f"  [SVD Init] Initialized {initialized_count} LowRankDenseVariational layers from deterministic model")

    return model


# =============================================================================
# 5. RANK-1 TRANSFORMER BBB
# =============================================================================

def build_tiny_transformer_rank1_bbb(
    vocab_size,
    max_length,
    d_model=256,
    n_layers=4,
    n_heads=4,
    d_ff=512,
    dropout_rate=0.2,
    num_classes=2,
    kl_scale=1.0,
    pretrained_embedding_weights=None,
):
    """
    Tiny *pre-norm* Transformer encoder with Rank-1 Bayes-by-Backprop.

    Uses Rank1DenseVariational for all dense layers and Rank1EmbeddingVariational
    for embeddings. Rank-1 factorization (Dusenberry et al., 2020) is the most
    parameter-efficient Bayesian approach.

    Parameters
    ----------
    vocab_size : int
        Size of vocabulary
    max_length : int
        Maximum sequence length
    d_model : int
        Model dimension (default: 256)
    n_layers : int
        Number of transformer layers (default: 4)
    n_heads : int
        Number of attention heads (default: 4)
    d_ff : int
        Feed-forward dimension (default: 512)
    dropout_rate : float
        Dropout rate (default: 0.2)
    num_classes : int
        Number of output classes (default: 2)
    kl_scale : float
        KL divergence scaling factor (default: 1.0)
    pretrained_embedding_weights : list, optional
        Pretrained embedding weights

    Returns
    -------
    tf.keras.Model
        The Rank-1 BBB transformer model
    """

    # -------------------------------------------------------------------------
    # 1) Inputs
    # -------------------------------------------------------------------------
    input_ids = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="input_ids"
    )
    attention_mask = tf.keras.Input(
        shape=(max_length,), dtype="int32", name="attention_mask"
    )

    # -------------------------------------------------------------------------
    # 2) Rank-1 variational embeddings
    # -------------------------------------------------------------------------
    # Token embeddings
    tok_emb_layer = Rank1EmbeddingVariational(
        input_dim=vocab_size,
        output_dim=d_model,
        kl_scale=kl_scale,
        name="tok_emb_rank1_variational",
    )
    token_emb = tok_emb_layer(input_ids)  # (B, max_length, d_model)

    # Positional embedding (rank-1 variational)
    pos_emb_layer = Rank1EmbeddingVariational(
        input_dim=max_length,
        output_dim=d_model,
        kl_scale=kl_scale,
        name="pos_emb_rank1_variational",
    )
    pos_indices = tf.range(max_length, dtype=tf.int32)
    pos_emb = pos_emb_layer(pos_indices)[tf.newaxis, :, :]

    # Initial hidden representation
    x = token_emb + pos_emb

    if dropout_rate > 0.0:
        x = IndependentDropout(dropout_rate, name="emb_dropout")(x)

    # -------------------------------------------------------------------------
    # 3) Attention mask
    # -------------------------------------------------------------------------
    mask_raw = tf.cast(attention_mask, tf.float32)
    mask_broadcast = mask_raw[:, tf.newaxis, tf.newaxis, :]
    d_k = d_model // n_heads

    # -------------------------------------------------------------------------
    # 4) Transformer encoder stack (pre-norm)
    # -------------------------------------------------------------------------
    for layer_idx in range(n_layers):
        # LayerNorm
        ln_attn = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_attn"
        )
        attn_input = ln_attn(x)

        # Rank-1 BBB projections for Q, K, V
        q_proj = Rank1DenseVariational(
            d_model, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_Q"
        )
        k_proj = Rank1DenseVariational(
            d_model, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_K"
        )
        v_proj = Rank1DenseVariational(
            d_model, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_V"
        )
        out_proj = Rank1DenseVariational(
            d_model, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_O"
        )

        Q = q_proj(attn_input)  # (B, L, d_model)
        K = k_proj(attn_input)
        V = v_proj(attn_input)

        # Reshape for multi-head
        def reshape_for_heads(t, max_length=max_length, n_heads=n_heads, d_k=d_k):
            bs = tf.shape(t)[0]
            t = tf.reshape(t, [bs, max_length, n_heads, d_k])
            return tf.transpose(t, [0, 2, 1, 3])

        Q = reshape_for_heads(Q)
        K = reshape_for_heads(K)
        V = reshape_for_heads(V)

        # Scaled dot-product attention
        scale = tf.math.sqrt(tf.cast(d_k, Q.dtype))
        scores = tf.matmul(Q, K, transpose_b=True) / scale
        # Cast mask values to match scores dtype (for mixed precision)
        mask_value = tf.cast((1.0 - mask_broadcast) * (-1e9), scores.dtype)
        scores += mask_value
        attn_weights = tf.nn.softmax(scores, axis=-1)

        if dropout_rate > 0.0:
            attn_weights = IndependentDropout(
                dropout_rate, name=f"layer{layer_idx}_attn_dropout"
            )(attn_weights)

        context = tf.matmul(attn_weights, V)
        context = tf.transpose(context, [0, 2, 1, 3])
        context = tf.reshape(context, [tf.shape(context)[0], max_length, d_model])

        attn_out = out_proj(context)
        if dropout_rate > 0.0:
            attn_out = IndependentDropout(
                dropout_rate, name=f"layer{layer_idx}_attn_out_dropout"
            )(attn_out)

        x = x + attn_out

        # FFN block
        ln_ffn = tf.keras.layers.LayerNormalization(
            epsilon=1e-6,
            name=f"layer{layer_idx}_ln_ffn"
        )
        ffn_input = ln_ffn(x)

        ff1 = Rank1DenseVariational(
            d_ff, kl_scale=kl_scale, activation="gelu",
            name=f"layer{layer_idx}_ff1"
        )
        ff2 = Rank1DenseVariational(
            d_model, kl_scale=kl_scale, activation=None,
            name=f"layer{layer_idx}_ff2"
        )

        ffn_out = ff1(ffn_input)
        if dropout_rate > 0.0:
            ffn_out = IndependentDropout(
                dropout_rate, name=f"layer{layer_idx}_ffn_dropout"
            )(ffn_out)
        ffn_out = ff2(ffn_out)
        if dropout_rate > 0.0:
            ffn_out = IndependentDropout(
                dropout_rate, name=f"layer{layer_idx}_ffn_out_dropout"
            )(ffn_out)

        x = x + ffn_out

    # -------------------------------------------------------------------------
    # 5) Final layer norm + pooling + classification
    # -------------------------------------------------------------------------
    final_ln = tf.keras.layers.LayerNormalization(epsilon=1e-6, name="final_ln")
    x = final_ln(x)

    # Global average pooling
    x = tf.reduce_mean(x, axis=1)  # (B, d_model)

    # Classification head with Rank-1 BBB
    classifier = Rank1DenseVariational(
        num_classes, kl_scale=kl_scale, activation="softmax",
        name="classifier"
    )
    logits = classifier(x)
    # Cast to float32 for numerical stability in mixed precision training
    logits = tf.cast(logits, tf.float32)

    model = tf.keras.Model(
        inputs=[input_ids, attention_mask],
        outputs=logits,
        name="TinyTransformer_Rank1BBB"
    )
    return model


# =============================================================================
# 6. DEEP ENSEMBLE
# =============================================================================

class DeepEnsemble:
    """
    Deep Ensemble wrapper for uncertainty quantification.

    This class manages multiple independently trained neural network models
    and provides ensemble predictions with uncertainty estimates.

    Parameters
    ----------
    builder_fn : callable
        Function that returns a new Keras model when called
    n_members : int
        Number of ensemble members to train (default: 5)
    """

    def __init__(self, builder_fn, n_members=5):
        self.n_members = n_members
        self.builder_fn = builder_fn
        self.members = []
        self.histories = []

    def train(self, train_dataset, val_dataset, epochs=10, verbose=1):
        """
        Train all ensemble members with different random seeds.

        Each member is trained independently with a different random
        initialization, which is the key to capturing epistemic uncertainty.

        Parameters
        ----------
        train_dataset : tf.data.Dataset
            Training dataset
        val_dataset : tf.data.Dataset
            Validation dataset
        epochs : int
            Number of training epochs per member
        verbose : int
            Verbosity level (0, 1, or 2)

        Returns
        -------
        list
            List of training histories for each member
        """
        self.members = []
        self.histories = []

        # Calculate training schedule parameters
        steps_per_epoch = len(list(train_dataset))
        total_steps = steps_per_epoch * epochs

        for i in range(self.n_members):
            print(f"\n{'='*60}")
            print(f"Training Ensemble Member {i+1}/{self.n_members}")
            print(f"{'='*60}")

            # Clear session and set unique seed for each member
            tf.keras.backend.clear_session()
            gc.collect()

            # Set unique seed for this ensemble member
            member_seed = 42 + i * 1000  # Different seed for each member
            np.random.seed(member_seed)
            tf.random.set_seed(member_seed)
            random.seed(member_seed)

            # Build fresh model
            model = self.builder_fn()

            # Learning rate schedule
            lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
                initial_learning_rate=2e-4,
                decay_steps=total_steps,
                alpha=0.01
            )

            # Compile model
            optimizer = tf.keras.optimizers.AdamW(
                learning_rate=lr_schedule,
                weight_decay=0.01
            )

            model.compile(
                optimizer=optimizer,
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='acc')],
                jit_compile=False
            )

            # Callbacks
            early_stopping = tf.keras.callbacks.EarlyStopping(
                monitor='val_acc',
                patience=5,
                restore_best_weights=True,
                verbose=1
            )
            reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_acc',
                factor=0.5,
                patience=3,
                verbose=1,
                min_lr=1e-6
            )

            # Train
            history = model.fit(
                train_dataset,
                validation_data=val_dataset,
                epochs=epochs,
                verbose=verbose,
                callbacks=[early_stopping, reduce_lr]
            )

            self.members.append(model)
            self.histories.append(history.history)

            print(f"Member {i+1} Final Val Accuracy: {history.history['val_acc'][-1]:.4f}")

        print(f"\n{'='*60}")
        print(f"Deep Ensemble Training Complete!")
        print(f"{'='*60}")

        return self.histories

    def predict(self, X, return_individual=False):
        """
        Make ensemble predictions.

        Parameters
        ----------
        X : dict or np.ndarray
            Input data (dictionary with 'input_ids' and 'attention_mask')
        return_individual : bool
            If True, also return individual member predictions

        Returns
        -------
        mean_pred : np.ndarray
            Mean prediction across ensemble members (probability of class 1)
        std_pred : np.ndarray
            Standard deviation across ensemble members (uncertainty)
        individual_preds : np.ndarray (optional)
            Individual predictions from each member, shape (n_members, N)
        """
        all_preds = []

        for model in self.members:
            # Get raw predictions (N, 2) - probabilities for both classes
            probs = model.predict(X, verbose=0)
            # Extract probability of class 1
            p_class_1 = probs[:, 1]
            all_preds.append(p_class_1)

        # Stack predictions: shape (n_members, N)
        all_preds = np.stack(all_preds, axis=0)

        # Compute ensemble statistics
        #mean_pred = all_preds.mean(axis=0)  # (N,)
        #std_pred = all_preds.std(axis=0)    # (N,)
        mean_pred = np.nanmean(all_preds, axis=0)  # ← Ignores NaN values
        std_pred = np.nanstd(all_preds, axis=0)     # ← Ignores NaN values
        if return_individual:
            return mean_pred, std_pred, all_preds
        return mean_pred, std_pred

    def __call__(self, X, training=False):
        """
        Callable interface for compatibility with existing evaluation code.

        When training=True, returns a single member prediction (for MC-style sampling)
        When training=False, returns ensemble mean

        Parameters
        ----------
        X : dict or np.ndarray
            Input data
        training : bool
            If True, return random member prediction; if False, return ensemble mean

        Returns
        -------
        np.ndarray
            Predictions with shape (N, 2) for compatibility
        """
        if training:
            # During "training" mode (used for MC sampling),
            # return a random member's prediction
            idx = np.random.randint(0, self.n_members)
            return self.members[idx](X, training=False)
        else:
            # During inference, return ensemble mean as (N, 2) tensor
            mean_pred, _ = self.predict(X)
            # Convert to (N, 2) format: [P(class=0), P(class=1)]
            probs = np.stack([1 - mean_pred, mean_pred], axis=1)
            return tf.constant(probs, dtype=tf.float32)

    def evaluate(self, dataset, verbose=1):
        """
        Evaluate ensemble on a dataset.

        Computes loss and accuracy using ensemble predictions.

        Parameters
        ----------
        dataset : tf.data.Dataset
            Dataset to evaluate on
        verbose : int
            Verbosity level

        Returns
        -------
        tuple
            (loss, accuracy)
        """
        all_y_true = []
        all_y_pred = []

        for batch in dataset:
            X_batch, y_batch = batch
            mean_pred, _ = self.predict(X_batch)
            all_y_true.append(y_batch.numpy())
            all_y_pred.append(mean_pred)

        y_true = np.concatenate(all_y_true)
        y_pred = np.concatenate(all_y_pred)

        # Compute accuracy
        y_pred_binary = (y_pred >= 0.5).astype(int)
        accuracy = (y_pred_binary == y_true).mean()

        # Compute loss (cross-entropy)
        epsilon = 1e-12
        y_pred_clipped = np.clip(y_pred, epsilon, 1 - epsilon)
        loss = -np.mean(y_true * np.log(y_pred_clipped) +
                       (1 - y_true) * np.log(1 - y_pred_clipped))

        if verbose:
            print(f"Ensemble - Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")

        return loss, accuracy


def build_deep_ensemble(builder_fn, n_members=5):
    """
    Factory function to create a Deep Ensemble.

    Parameters
    ----------
    builder_fn : callable
        Function that returns a new baseline transformer model
    n_members : int
        Number of ensemble members

    Returns
    -------
    DeepEnsemble
        Initialized (but not trained) ensemble
    """
    return DeepEnsemble(builder_fn, n_members=n_members)


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def clean_summary(model):
    """
    Print a clean summary of model layers.

    Parameters
    ----------
    model : tf.keras.Model
        The model to summarize
    """
    print(f"{'Layer (type)':<30} {'Output Shape':<25} {'Param #':<15}")
    print("=" * 70)

    total_params = 0
    trainable_params = 0

    # Use string names instead of class types to avoid AttributeError
    ignored_layer_names = ["TFOpLambda", "SlicingOpLambda", "InputLayer"]

    for layer in model.layers:
        # 1. robustly check if we should skip this layer
        if layer.__class__.__name__ in ignored_layer_names:
            continue

        # 2. Get output shape safely
        try:
            output_shape = layer.output_shape
        except AttributeError:
            output_shape = "multiple"
        except RuntimeError:
            # Sometimes happens with subclassed models not yet built
            output_shape = "?"

        # 3. Count parameters
        layer_params = layer.count_params()

        # 4. Print row
        name_str = f"{layer.name} ({layer.__class__.__name__})"
        print(f"{name_str:<30} {str(output_shape):<25} {layer_params:<15,}")

        total_params += layer_params
        if layer.trainable:
            trainable_params += layer_params

    print("=" * 70)
    print(f"Total params: {total_params:,}")
    print(f"Trainable params: {trainable_params:,}")
    print(f"Non-trainable params: {total_params - trainable_params:,}")
    print("-" * 70)


# =============================================================================
# MODEL SAVE/LOAD FUNCTIONS
# =============================================================================

def save_trained_models(models, save_dir="checkpoints"):
    """
    Save trained model weights (works with custom Bayesian layers).

    This function saves model weights only (not full models) to avoid issues
    with custom layers requiring custom_objects during loading. A key mapping
    JSON file is saved to preserve the original model names.

    Parameters
    ----------
    models : dict
        Dictionary of trained models. Keys are model names, values are either:
        - tf.keras.Model: Single model
        - list: Ensemble of models (e.g., DeepEnsemble members)
        - DeepEnsemble: DeepEnsemble instance (will extract .members)
    save_dir : str
        Directory to save weights (default: "checkpoints")

    Returns
    -------
    str
        Path to the save directory

    Examples
    --------
    >>> models = {
    ...     "baseline": baseline_model,
    ...     "fullrank_bbb": fullrank_model,
    ...     "deep_ensemble": ensemble.members  # or ensemble object
    ... }
    >>> save_trained_models(models, save_dir="my_checkpoints")
    """
    os.makedirs(save_dir, exist_ok=True)

    print(f"\n{'='*80}")
    print(f"SAVING MODEL WEIGHTS")
    print(f"Save directory: {save_dir}")
    print(f"{'='*80}")

    key_mapping = {}

    for idx, (name, model) in enumerate(models.items()):
        safe_name = f"model_{idx}"

        # Handle DeepEnsemble objects by extracting members
        if hasattr(model, 'members'):
            model = model.members

        if isinstance(model, list):
            # Ensemble: save each member separately
            model_dir = os.path.join(save_dir, safe_name)
            os.makedirs(model_dir, exist_ok=True)
            for i, member in enumerate(model):
                member.save_weights(os.path.join(model_dir, f"member_{i}.h5"))
            key_mapping[name] = {
                "path": safe_name,
                "type": "ensemble",
                "n_members": len(model)
            }
            print(f"  Saved ensemble: {name} ({len(model)} members)")
        else:
            # Single model
            weight_path = os.path.join(save_dir, f"{safe_name}.h5")
            model.save_weights(weight_path)
            key_mapping[name] = {
                "path": f"{safe_name}.h5",
                "type": "single"
            }
            print(f"  Saved: {name}")

    # Save key mapping for loading
    mapping_path = os.path.join(save_dir, "key_mapping.json")
    with open(mapping_path, 'w') as f:
        json.dump(key_mapping, f, indent=2)

    print(f"\nKey mapping saved to: {mapping_path}")
    print(f"{'='*80}\n")

    return save_dir


def load_trained_models(model_configs, save_dir="checkpoints", max_length=64):
    """
    Load trained model weights into freshly built models.

    This function rebuilds models from configs and loads saved weights.
    It preserves the exact dictionary keys used during saving.

    IMPORTANT: Models are built with a dummy forward pass before loading weights
    to ensure all weight tensors are properly initialized.

    Parameters
    ----------
    model_configs : dict
        Dictionary mapping model names to their configuration. Each config should have:
        - "builder": callable that builds the model (takes no args or uses defaults)
        - For ensembles: "n_members" (optional, will use saved value if not provided)

        Example:
        {
            "baseline": {"builder": lambda: build_tiny_transformer_baseline(...)},
            "fullrank_bbb": {"builder": lambda: build_tiny_transformer_fullrank_bbb(...)},
            "deep_ensemble": {
                "builder": lambda: build_tiny_transformer_baseline(...),
                "n_members": 5
            }
        }
    save_dir : str
        Directory where weights are saved (default: "checkpoints")
    max_length : int
        Maximum sequence length for building the model (default: 64)

    Returns
    -------
    dict
        Dictionary of loaded models with original keys

    Examples
    --------
    >>> # Define builders for each model
    >>> configs = {
    ...     "baseline": {
    ...         "builder": lambda: build_tiny_transformer_baseline(
    ...             vocab_size=30522, max_length=128
    ...         )
    ...     },
    ...     "fullrank_bbb": {
    ...         "builder": lambda: build_tiny_transformer_fullrank_bbb(
    ...             vocab_size=30522, max_length=128
    ...         )
    ...     },
    ...     "deep_ensemble": {
    ...         "builder": lambda: build_tiny_transformer_baseline(
    ...             vocab_size=30522, max_length=128
    ...         ),
    ...         "n_members": 5
    ...     }
    ... }
    >>> models = load_trained_models(configs, save_dir="my_checkpoints")
    """
    print(f"\n{'='*80}")
    print(f"LOADING MODEL WEIGHTS")
    print(f"Load directory: {save_dir}")
    print(f"{'='*80}")

    # Load key mapping
    mapping_path = os.path.join(save_dir, "key_mapping.json")
    if not os.path.exists(mapping_path):
        raise FileNotFoundError(
            f"Key mapping file not found at {mapping_path}. "
            f"Make sure the save_dir is correct and models were saved with save_trained_models()."
        )

    with open(mapping_path, 'r') as f:
        key_mapping = json.load(f)

    models = {}

    # Create dummy input for building models before loading weights
    dummy_input = {
        "input_ids": tf.zeros((1, max_length), dtype=tf.int32),
        "attention_mask": tf.ones((1, max_length), dtype=tf.int32)
    }

    for original_key, info in key_mapping.items():
        if original_key not in model_configs:
            print(f"  Warning: No config provided for '{original_key}', skipping...")
            continue

        config = model_configs[original_key]
        item_path = os.path.join(save_dir, info["path"])

        if info["type"] == "ensemble":
            # Rebuild and load ensemble
            n_members = config.get("n_members", info.get("n_members", 5))
            member_files = sorted([f for f in os.listdir(item_path) if f.endswith('.h5')])

            if len(member_files) != n_members:
                print(f"  Warning: Expected {n_members} members but found {len(member_files)} files")
                n_members = len(member_files)

            ensemble = []
            for f in member_files:
                # Build fresh model
                model = config["builder"]()
                # Build model with dummy forward pass before loading weights
                _ = model(dummy_input, training=False)

                # Now load weights into properly built model
                try:
                    model.load_weights(os.path.join(item_path, f), by_name=True, skip_mismatch=False)
                except Exception as e:
                    print(f"    ERROR loading ensemble member {f}: {e}")
                    raise

                ensemble.append(model)

            models[original_key] = ensemble
            print(f"  Loaded ensemble: {original_key} ({len(ensemble)} members)")
        else:
            # Rebuild single model from config
            model = config["builder"]()

            # Build model with dummy forward pass before loading weights
            # This ensures all layer weights (including variational parameters like
            # w_mu, w_rho, A_mu, A_rho, B_mu, B_rho, etc.) are properly initialized
            _ = model(dummy_input, training=False)

            # Get weight count and sample values before loading
            num_weights_before = len(model.weights)
            weight_sample_before = model.weights[0].numpy().flatten()[:5].copy() if model.weights else None

            # Now load weights into properly built model
            # IMPORTANT: Use by_name=True to match weights by layer/variable name, not position
            try:
                # Load weights by name for robust loading across sessions
                model.load_weights(item_path, by_name=True, skip_mismatch=False)
            except Exception as e:
                print(f"    ERROR loading weights for {original_key}: {e}")
                print(f"    Model has {len(model.weights)} weights")
                print(f"    First few weight names: {[w.name for w in model.weights[:5]]}")
                raise

            # Verify weights actually changed after loading
            num_weights_after = len(model.weights)
            weight_sample_after = model.weights[0].numpy().flatten()[:5].copy() if model.weights else None

            if num_weights_before != num_weights_after:
                print(f"    WARNING: Weight count changed! Before: {num_weights_before}, After: {num_weights_after}")

            if weight_sample_before is not None and weight_sample_after is not None:
                if np.allclose(weight_sample_before, weight_sample_after):
                    print(f"    ⚠️  WARNING: First weight did not change after loading for {original_key}!")
                    print(f"       This suggests weights were not loaded correctly.")

            models[original_key] = model
            print(f"  Loaded: {original_key} ({len(model.weights)} weights)")

    # Check for any missing models
    missing = set(model_configs.keys()) - set(models.keys())
    if missing:
        print(f"\n  Warning: The following models were in configs but not found in saved files: {missing}")

    print(f"{'='*80}\n")

    return models


# =============================================================================
# MODULE INFO
# =============================================================================

if __name__ == "__main__":
    print("Model builders module loaded successfully!")
    print("\nAvailable model builders:")
    print("  1. build_tiny_transformer_baseline - Deterministic baseline")
    print("  2. build_tiny_transformer_fullrank_bbb - Full-rank BBB")
    print("  3. build_tiny_transformer_lowrank_bbb - Low-rank BBB (random init)")
    print("  4. build_tiny_transformer_lowrank_bbb - Low-rank BBB (SVD init, via init_from_deterministic)")
    print("  5. build_tiny_transformer_rank1_bbb - Rank-1 BBB")
    print("  6. DeepEnsemble / build_deep_ensemble - Deep Ensemble wrapper")
    print("\nSave/Load utilities:")
    print("  - save_trained_models(models, save_dir) - Save all model weights")
    print("  - load_trained_models(model_configs, save_dir) - Load model weights")
