"""
bayesian_layers.py - All Bayesian / variational layer classes.

This module contains:
- log_mixture_prior: Log-density of Gaussian mixture prior
- DenseVariational: Full-rank Bayes by Backprop dense layer
- EmbeddingVariational: Full-rank variational embedding layer
- LowRankDenseVariational: Low-rank BBB dense layer
- LowRankEmbeddingVariational: Low-rank variational embedding layer
- Rank1DenseVariational: Rank-1 BBB dense layer (Dusenberry et al., 2020)
- Rank1EmbeddingVariational: Rank-1 variational embedding layer
- IndependentDropout: Dropout layer with controllable activation
- Helper functions: set_kl_scale, set_dropout_active
"""

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions


# =============================================================================
# LOG MIXTURE PRIOR
# =============================================================================

def log_mixture_prior(w, pi=0.5, sigma1=1.0, sigma2=tf.exp(-6.0)):
    """
    Log-density of a zero-mean Gaussian mixture prior:
        p(w) = pi * N(0, sigma1^2) + (1-pi) * N(0, sigma2^2)

    Returns log p(w) elementwise, with broadcasting over w.

    Parameters
    ----------
    w : tf.Tensor
        Weight tensor to compute prior log-density for
    pi : float
        Mixing coefficient (default: 0.5)
    sigma1 : float
        Standard deviation of first Gaussian (default: 1.0)
    sigma2 : float
        Standard deviation of second Gaussian (default: exp(-6))

    Returns
    -------
    tf.Tensor
        Log-density of the mixture prior for each element of w
    """
    n1 = tfd.Normal(loc=0.0, scale=sigma1)
    n2 = tfd.Normal(loc=0.0, scale=sigma2)
    lp1 = n1.log_prob(w)
    lp2 = n2.log_prob(w)
    a = tf.math.log(pi) + lp1
    b = tf.math.log(1.0 - pi) + lp2
    return tf.reduce_logsumexp(tf.stack([a, b], axis=0), axis=0)


# =============================================================================
# FULL-RANK DENSE VARIATIONAL LAYER
# =============================================================================

class DenseVariational(tf.keras.layers.Layer):
    """
    Full-rank Bayes by Backprop dense layer:
        y = x W + b
    with variational Gaussian posterior on W and b, and a log-mixture prior.

    Parameters
    ----------
    units : int
        Number of output units
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    activation : str or callable
        Activation function (default: None)
    """

    def __init__(self, units, kl_scale=1.0, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = int(units)
        self.kl_scale = float(kl_scale)
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        din = int(input_shape[-1])

        self.w_mu = self.add_weight(
            "w_mu", shape=[din, self.units],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.w_rho = self.add_weight(
            "w_rho", shape=[din, self.units],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        self.b_mu = self.add_weight(
            "b_mu", shape=[self.units],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.b_rho = self.add_weight(
            "b_rho", shape=[self.units],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        super().build(input_shape)

    def _sample(self, mu, rho):
        sigma = tf.nn.softplus(rho) + 1e-5
        eps = tf.random.normal(tf.shape(mu))
        return mu + sigma * eps, sigma

    def call(self, x, training=None):
        # Sample weights during training; use posterior mean at eval
        if training:
            w, w_sigma = self._sample(self.w_mu, self.w_rho)
            b, b_sigma = self._sample(self.b_mu, self.b_rho)
        else:
            w, b = self.w_mu, self.b_mu
            w_sigma = tf.nn.softplus(self.w_rho)
            b_sigma = tf.nn.softplus(self.b_rho) + 1e-5

        y = tf.linalg.matmul(x, w) + b
        if self.activation is not None:
            y = self.activation(y)

        # KL contribution (log q - log p) using mixture prior
        qw = tfd.Normal(self.w_mu, w_sigma)
        qb = tfd.Normal(self.b_mu, b_sigma)
        log_q = (
            tf.reduce_sum(qw.log_prob(w if training else self.w_mu))
            + tf.reduce_sum(qb.log_prob(b if training else self.b_mu))
        )
        log_p = (
            tf.reduce_sum(log_mixture_prior(w if training else self.w_mu))
            + tf.reduce_sum(log_mixture_prior(b if training else self.b_mu))
        )

        self.add_loss(self.kl_scale * (log_q - log_p))
        return y


# =============================================================================
# FULL-RANK EMBEDDING VARIATIONAL LAYER
# =============================================================================

class EmbeddingVariational(tf.keras.layers.Layer):
    """
    Full-rank variational embedding layer with Gaussian posterior.

    Only samples the rows corresponding to tokens in the current batch,
    making it efficient for large vocabularies.

    Parameters
    ----------
    input_dim : int
        Vocabulary size
    output_dim : int
        Embedding dimension
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    """

    def __init__(self, input_dim, output_dim, kl_scale=1.0, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = int(input_dim)
        self.output_dim = int(output_dim)
        self.kl_scale = float(kl_scale)
        self.eps = 1e-5

    def build(self, input_shape):
        self.emb_mu = self.add_weight(
            "emb_mu", shape=[self.input_dim, self.output_dim],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.emb_rho = self.add_weight(
            "emb_rho", shape=[self.input_dim, self.output_dim],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        super().build(input_shape)

    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()
        training = tf.cast(training, tf.bool)

        flat = tf.reshape(inputs, [-1])                         # (B*L,)
        unique_ids, inv = tf.unique(flat)                       # unique token ids + inverse map
        mu_u = tf.gather(self.emb_mu, unique_ids)               # (U, d)
        rho_u = tf.gather(self.emb_rho, unique_ids)             # (U, d)
        sigma_u = tf.nn.softplus(rho_u) + self.eps              # (U, d)

        # Sample only the used rows (or mean at eval)
        eps = tf.random.normal(tf.shape(mu_u), dtype=mu_u.dtype)
        w_u = tf.where(training, mu_u + sigma_u * eps, mu_u)    # (U, d)

        # Reconstruct per-token embeddings (B*L,d) -> (B,L,d)
        out_flat = tf.gather(w_u, inv)                          # (B*L, d)
        out = tf.reshape(out_flat, tf.concat([tf.shape(inputs), [self.output_dim]], axis=0))

        # KL only on used rows (U,d)
        q_u = tfd.Normal(loc=mu_u, scale=sigma_u)
        log_q = tf.reduce_sum(q_u.log_prob(w_u))
        log_p = tf.reduce_sum(log_mixture_prior(w_u))
        kl = log_q - log_p

        self.add_loss(self.kl_scale * 0.01 * kl)
        return out


# =============================================================================
# LOW-RANK DENSE VARIATIONAL LAYER
# =============================================================================

class LowRankDenseVariational(tf.keras.layers.Layer):
    """
    Low-rank Bayes by Backprop dense layer:
        W = A B^T, bias b
    with variational Gaussian posteriors on A, B, b, and a log-mixture prior.

    Parameters
    ----------
    units : int
        Number of output units
    rank : int
        Rank of the low-rank factorization
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    activation : str or callable
        Activation function (default: None)
    """

    def __init__(self, units, rank, kl_scale=1.0, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = int(units)
        self.rank = int(rank)
        self.kl_scale = float(kl_scale)
        self.activation = tf.keras.activations.get(activation)

    def _is_relu_family(self, activation):
        """Check if activation belongs to ReLU family for He initialization."""
        if activation is None:
            return False

        # Handle both string and function object
        if isinstance(activation, str):
            name = activation
        else:
            name = getattr(activation, '__name__', '')

        #  Normalize by removing underscores for matching
        name_normalized = name.lower().replace('_', '')

        # ReLU family activations
        relu_family = {'relu', 'elu', 'selu', 'leakyrelu', 'relu6', 'gelu'}

        return name_normalized in relu_family

    def build(self, input_shape):
        din = int(input_shape[-1])   # Input dimension
        dout = self.units             # Output dimension
        r = self.rank                 # Rank of factorization

        # ========================================================================
        # ADAPTIVE INITIALIZATION - Variance-preserving & rank-aware
        # ========================================================================

        # Step 1: Determine target variance based on activation
        if self._is_relu_family(self.activation):
            sigma_w_sq = 2.0 / din              # He initialization for ReLU family
        else:
            sigma_w_sq = 2.0 / (din + dout)     # Glorot initialization

        # Step 2: Apply rank-dependent damping
        # TUNED for Transformer architecture with rank=16
        if r <= 5:
            damping = 0.46    # Conservative but sufficient for low-rank
        else:
            damping = 0.85    # Higher to maintain uncertainty across deep network

        # Step 3: Calculate uniform bound for mean parameters
        # Formula: a = damping * sqrt(3) * (σ²_W / r)^(1/4)
        a = damping * np.sqrt(3.0) * np.power(sigma_w_sq / r, 0.25)

        # Step 4: Calculate rho initialization for variance parameters
        # Original used U[-4.5, -3.5], we adapt the center to layer dims
        eta = 0.15  # Scale factor for rho center
        sigma_init = eta * np.power(sigma_w_sq / r, 0.25)

        # Calculate center of rho range (adaptive)
        if sigma_init > 1e-4:
            rho_center = np.log(np.expm1(sigma_init))
        else:
            rho_center = np.log(sigma_init)

        # Create random range around center (width=1.0 like original [-4.5, -3.5])
        rho_min = rho_center - 0.5
        rho_max = rho_center + 0.5

        # ========================================================================
        # WEIGHT INITIALIZATION
        # ========================================================================

        # Factor A mean: U[-a, a]
        self.A_mu = self.add_weight(
            name="A_mu",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor B mean: U[-a, a]
        self.B_mu = self.add_weight(
            name="B_mu",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor A log-variance: RANDOM (for diversity!)
        self.A_rho = self.add_weight(
            name="A_rho",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        # Factor B log-variance: RANDOM (for diversity!)
        self.B_rho = self.add_weight(
            name="B_rho",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        # ========================================================================
        # BIAS INITIALIZATION
        # ========================================================================

        # Bias mean: ZERO initialization
        self.b_mu = self.add_weight(
            name="b_mu",
            shape=[dout],
            initializer=tf.keras.initializers.Zeros(),
            trainable=True
        )

        # Bias log-variance: Fixed at -5.0 (small variance)
        self.b_rho = self.add_weight(
            name="b_rho",
            shape=[dout],
            initializer=tf.keras.initializers.Constant(-5.0),
            trainable=True
        )

        super().build(input_shape)

    def _sample(self, mu, rho):
        sigma = tf.nn.softplus(rho) + 1e-5
        eps = tf.random.normal(tf.shape(mu))
        return mu + sigma * eps, sigma

    def init_from_full_matrix(self, W_full, b_full=None):
        """
        Initialize A_mu and B_mu from a full-rank matrix using truncated SVD, so
        that A_mu @ B_mu^T equals the rank-r approximation of W_full.

        Parameters
        ----------
        W_full : np.ndarray
            Full-rank weight matrix of shape (din, units)
        b_full : np.ndarray, optional
            Bias vector of shape (units,)
        """
        W_full = np.asarray(W_full)
        U, s, Vt = np.linalg.svd(W_full, full_matrices=False)
        r_eff = min(self.rank, len(s))

        # A_mu = U_r * sqrt(s_r), B_mu = V_r * sqrt(s_r)
        A = U[:, :r_eff] * np.sqrt(s[:r_eff])
        B = Vt[:r_eff, :].T * np.sqrt(s[:r_eff])

        # Zero-pad extra columns if r_eff < self.rank
        if r_eff < self.rank:
            A_pad = np.zeros((A.shape[0], self.rank), dtype=A.dtype)
            B_pad = np.zeros((B.shape[0], self.rank), dtype=B.dtype)
            A_pad[:, :r_eff] = A
            B_pad[:, :r_eff] = B
            A, B = A_pad, B_pad

        # Assign to variational means
        self.A_mu.assign(A.astype(np.float32))
        self.B_mu.assign(B.astype(np.float32))

        # Optionally initialize bias
        if b_full is not None:
            self.b_mu.assign(np.asarray(b_full).astype(np.float32))

    def call(self, x, training=None):
        if training:
            A, A_sigma = self._sample(self.A_mu, self.A_rho)
            B, B_sigma = self._sample(self.B_mu, self.B_rho)
        else:
            A, B = self.A_mu, self.B_mu
            A_sigma = tf.nn.softplus(self.A_rho) + 1e-5
            B_sigma = tf.nn.softplus(self.B_rho) + 1e-5

        b_sigma = tf.nn.softplus(self.b_rho) + 1e-5
        b_dist = tfd.Normal(self.b_mu, b_sigma)

        if training:
            b = b_dist.sample()
        else:
            b = self.b_mu

        # Forward pass: y = x A B^T + b
        y = tf.linalg.matmul(tf.linalg.matmul(x, A), tf.transpose(B)) + b
        if self.activation is not None:
            y = self.activation(y)

        # KL terms for A, B, and bias
        qA = tfd.Normal(self.A_mu, A_sigma)
        qB = tfd.Normal(self.B_mu, B_sigma)

        log_q = (
            tf.reduce_sum(qA.log_prob(A if training else self.A_mu))
            + tf.reduce_sum(qB.log_prob(B if training else self.B_mu))
            + tf.reduce_sum(b_dist.log_prob(b if training else self.b_mu))
        )
        log_p = (
            tf.reduce_sum(log_mixture_prior(A if training else self.A_mu))
            + tf.reduce_sum(log_mixture_prior(B if training else self.B_mu))
            + tf.reduce_sum(log_mixture_prior(b if training else self.b_mu))
        )

        self.add_loss(self.kl_scale * (log_q - log_p))
        return y


# =============================================================================
# LOW-RANK EMBEDDING VARIATIONAL LAYER
# =============================================================================

class LowRankEmbeddingVariational(tf.keras.layers.Layer):
    """
    Low-rank variational embedding layer:
        E = A B^T where A in R^{V x r} and B in R^{d x r}

    Exploits batch sparsity by sampling only the rows of A corresponding to tokens
    in the current batch while sampling the full B matrix, reducing cost from
    O(Vd) to O(|U|r + dr) where |U| is the number of unique tokens per batch.

    Parameters
    ----------
    input_dim : int
        Vocabulary size (V)
    output_dim : int
        Embedding dimension (d)
    rank : int
        Rank of the low-rank factorization (r)
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    """

    def __init__(self, input_dim, output_dim, rank, kl_scale=1.0, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = int(input_dim)     # V
        self.output_dim = int(output_dim)   # d
        self.rank = int(rank)               # r
        self.kl_scale = float(kl_scale)
        self.eps = 1e-5

    def _is_relu_family(self, activation):
        """Check if activation belongs to ReLU family for He initialization."""
        if activation is None:
            return False

        # Handle both string and function object
        if isinstance(activation, str):
            name = activation
        else:
            name = getattr(activation, '__name__', '')

        #Normalize by removing underscores for matching
        name_normalized = name.lower().replace('_', '')

        # ReLU family activations
        relu_family = {'relu', 'elu', 'selu', 'leakyrelu', 'relu6', 'gelu'}

        return name_normalized in relu_family

    def build(self, input_shape):
        din = self.input_dim      # Vocabulary size (V)
        dout = self.output_dim    # Embedding dimension (d)
        r = self.rank             # Rank of factorization

        # ========================================================================
        # ADAPTIVE INITIALIZATION - Variance-preserving & rank-aware
        # ========================================================================

        # Step 1: Determine target variance (embeddings use Glorot)
        # No activation function for embeddings, so use Glorot initialization
        sigma_w_sq = 2.0 / (din + dout)     # Glorot initialization

        # Step 2: Apply rank-dependent damping
        # TUNED for Transformer architecture
        if r <= 5:
            damping = 0.46    # Conservative but sufficient for low-rank
        else:
            damping = 0.85    # Higher to maintain uncertainty

        # Step 3: Calculate uniform bound for mean parameters
        a = damping * np.sqrt(3.0) * np.power(sigma_w_sq / r, 0.25)

        # Step 4: Calculate rho initialization for variance parameters
        # Use RANDOM rho for uncertainty diversity
        eta = 0.15
        sigma_init = eta * np.power(sigma_w_sq / r, 0.25)

        # Calculate center of rho range
        if sigma_init > 1e-4:
            rho_center = np.log(np.expm1(sigma_init))
        else:
            rho_center = np.log(sigma_init)

        # Random range around center (width=1.0)
        rho_min = rho_center - 0.5
        rho_max = rho_center + 0.5

        # ========================================================================
        # WEIGHT INITIALIZATION
        # ========================================================================

        # Factor A mean: U[-a, a]
        self.A_mu = self.add_weight(
            name="A_mu",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor B mean: U[-a, a]
        self.B_mu = self.add_weight(
            name="B_mu",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor A log-variance: RANDOM (for diversity!)
        self.A_rho = self.add_weight(
            name="A_rho",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        # Factor B log-variance: RANDOM (for diversity!)
        self.B_rho = self.add_weight(
            name="B_rho",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        super().build(input_shape)

    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()
        training = tf.cast(training, tf.bool)

        # ---- A: only used rows ----
        flat = tf.reshape(inputs, [-1])
        unique_ids, inv = tf.unique(flat)                    # (U,), (B*L,)
        A_mu_u = tf.gather(self.A_mu, unique_ids)            # (U, r)
        A_rho_u = tf.gather(self.A_rho, unique_ids)          # (U, r)
        A_sigma_u = tf.nn.softplus(A_rho_u) + self.eps

        epsA = tf.random.normal(tf.shape(A_mu_u), dtype=A_mu_u.dtype)
        A_u = tf.where(training, A_mu_u + A_sigma_u * epsA, A_mu_u)   # (U, r)

        # Map back to per-token A rows: (B*L,r) -> (B,L,r)
        A_rows_flat = tf.gather(A_u, inv)                   # (B*L, r)
        A_rows = tf.reshape(A_rows_flat, tf.concat([tf.shape(inputs), [self.rank]], axis=0))

        # ---- B: full matrix (small) ----
        B_sigma = tf.nn.softplus(self.B_rho) + self.eps     # (d, r)
        epsB = tf.random.normal(tf.shape(self.B_mu), dtype=self.B_mu.dtype)
        B = tf.where(training, self.B_mu + B_sigma * epsB, self.B_mu)  # (d, r)

        # Forward: out = A_rows @ B^T  => (B,L,r) x (r,d) = (B,L,d)
        out = tf.linalg.matmul(A_rows, B, transpose_b=True)

        # KL: A on used rows, B on full
        qA = tfd.Normal(loc=A_mu_u, scale=A_sigma_u)
        log_qA = tf.reduce_sum(qA.log_prob(A_u))
        log_pA = tf.reduce_sum(log_mixture_prior(A_u))
        klA = log_qA - log_pA

        qB = tfd.Normal(loc=self.B_mu, scale=B_sigma)
        log_qB = tf.reduce_sum(qB.log_prob(B))
        log_pB = tf.reduce_sum(log_mixture_prior(B))
        klB = log_qB - log_pB

        kl = klA + klB
        self.add_loss(self.kl_scale * 0.01 * kl)
        return out


# =============================================================================
# RANK-1 VARIATIONAL DENSE LAYER (Dusenberry et al., 2020)
# =============================================================================

class Rank1DenseVariational(tf.keras.layers.Layer):
    """
    Efficient version of W_eff = W0 * (1 + s)(1 + r)^T:
        y = (x * (1+s)) @ W0 ; y = (y * (1+r)) + b

    Only r (units) and s (din) are stochastic; W0, b deterministic.
    This is the most parameter-efficient approach with only d_in + d_out
    stochastic parameters.

    Based on Dusenberry et al. (2020) "Efficient and Scalable Bayesian Neural
    Nets with Rank-1 Factors"

    Parameters
    ----------
    units : int
        Number of output units
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    activation : str or callable
        Activation function (default: None)
    """

    def __init__(self, units, kl_scale=1.0, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.units = int(units)
        self.kl_scale = float(kl_scale)
        self.activation = tf.keras.activations.get(activation)

    def build(self, input_shape):
        din = int(input_shape[-1])
        self.W0 = self.add_weight(
            "W0", [din, self.units],
            initializer="glorot_uniform"
        )
        self.b = self.add_weight(
            "bias", [self.units],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.r_mu = self.add_weight(
            "r_mu", [self.units],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.r_rho = self.add_weight(
            "r_rho", [self.units],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        self.s_mu = self.add_weight(
            "s_mu", [din],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.s_rho = self.add_weight(
            "s_rho", [din],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        super().build(input_shape)

    def _sample(self, mu, rho):
        sigma = tf.nn.softplus(rho) + 1e-5
        eps = tf.random.normal(tf.shape(mu))
        return mu + sigma * eps, sigma

    def call(self, x, training=None):
        # Sample factors during training; use posterior mean at eval
        if training:
            r, r_sigma = self._sample(self.r_mu, self.r_rho)
            s, s_sigma = self._sample(self.s_mu, self.s_rho)
        else:
            r, s = self.r_mu, self.s_mu
            r_sigma = tf.nn.softplus(self.r_rho) + 1e-5
            s_sigma = tf.nn.softplus(self.s_rho) + 1e-5

        # Forward pass: y = (x * (1+s)) @ W0 * (1+r) + b
        x_scaled = x * (1.0 + s)
        y = tf.linalg.matmul(x_scaled, self.W0)
        y = y * (1.0 + r) + self.b
        if self.activation is not None:
            y = self.activation(y)

        # KL contribution (log q - log p) using mixture prior
        qr = tfd.Normal(self.r_mu, r_sigma)
        qs = tfd.Normal(self.s_mu, s_sigma)
        log_q = (
            tf.reduce_sum(qr.log_prob(r if training else self.r_mu))
            + tf.reduce_sum(qs.log_prob(s if training else self.s_mu))
        )
        log_p = (
            tf.reduce_sum(log_mixture_prior(r if training else self.r_mu))
            + tf.reduce_sum(log_mixture_prior(s if training else self.s_mu))
        )

        self.add_loss(self.kl_scale * (log_q - log_p))
        return y


# =============================================================================
# RANK-1 VARIATIONAL EMBEDDING LAYER
# =============================================================================

class Rank1EmbeddingVariational(tf.keras.layers.Layer):
    """
    Rank-1 variational embedding layer applying the same factorization as
    Rank1DenseVariational.

    For token embedding E (V x d), we use:
        E_eff = E0 * (1 + r)(1 + s)^T
    where:
        - E0 is the deterministic base embedding matrix (V x d)
        - r is a stochastic vector of size vocab_size (V)
        - s is a stochastic vector of size embedding_dim (d)

    Only the rows corresponding to tokens in the current batch are used for
    KL computation, making this efficient for large vocabularies.

    Parameters
    ----------
    input_dim : int
        Vocabulary size (V)
    output_dim : int
        Embedding dimension (d)
    kl_scale : float
        Scaling factor for KL divergence term (default: 1.0)
    """

    def __init__(self, input_dim, output_dim, kl_scale=1.0, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = int(input_dim)     # V (vocab size)
        self.output_dim = int(output_dim)   # d (embedding dim)
        self.kl_scale = float(kl_scale)
        self.eps = 1e-5

    def build(self, input_shape):
        # Base embedding matrix (deterministic)
        self.E0 = self.add_weight(
            "E0", shape=[self.input_dim, self.output_dim],
            initializer="glorot_uniform"
        )

        # r vector (one per vocab token) - stochastic
        self.r_mu = self.add_weight(
            "r_mu", shape=[self.input_dim],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.r_rho = self.add_weight(
            "r_rho", shape=[self.input_dim],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )

        # s vector (one per embedding dimension) - stochastic
        self.s_mu = self.add_weight(
            "s_mu", shape=[self.output_dim],
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2)
        )
        self.s_rho = self.add_weight(
            "s_rho", shape=[self.output_dim],
            initializer=tf.keras.initializers.RandomUniform(-5.0, -4.0)
        )
        super().build(input_shape)

    def call(self, inputs, training=None):
        if training is None:
            training = tf.keras.backend.learning_phase()
        training = tf.cast(training, tf.bool)

        # ---- s: full embedding dimension vector ----
        s_sigma = tf.nn.softplus(self.s_rho) + self.eps
        eps_s = tf.random.normal(tf.shape(self.s_mu), dtype=self.s_mu.dtype)
        s = tf.where(training, self.s_mu + s_sigma * eps_s, self.s_mu)  # (d,)

        # ---- r: only sample used rows (unique tokens in batch) ----
        flat = tf.reshape(inputs, [-1])
        unique_ids, inv = tf.unique(flat)                    # (U,), (B*L,)

        r_mu_u = tf.gather(self.r_mu, unique_ids)            # (U,)
        r_rho_u = tf.gather(self.r_rho, unique_ids)          # (U,)
        r_sigma_u = tf.nn.softplus(r_rho_u) + self.eps

        eps_r = tf.random.normal(tf.shape(r_mu_u), dtype=r_mu_u.dtype)
        r_u = tf.where(training, r_mu_u + r_sigma_u * eps_r, r_mu_u)   # (U,)

        # ---- Forward pass: E_eff[i] = E0[i] * (1 + r[i]) * (1 + s) ----
        # Get base embeddings for unique tokens
        E0_u = tf.gather(self.E0, unique_ids)                # (U, d)

        # Apply rank-1 scaling: E_eff = E0 * (1 + r)[:, None] * (1 + s)[None, :]
        r_scale = (1.0 + r_u)[:, tf.newaxis]                 # (U, 1)
        s_scale = (1.0 + s)[tf.newaxis, :]                   # (1, d)
        E_eff_u = E0_u * r_scale * s_scale                   # (U, d)

        # Map back to per-token embeddings: (B*L, d) -> (B, L, d)
        out_flat = tf.gather(E_eff_u, inv)                   # (B*L, d)
        out = tf.reshape(out_flat, tf.concat([tf.shape(inputs), [self.output_dim]], axis=0))

        # ---- KL divergence ----
        # KL for r (only used rows)
        qr = tfd.Normal(loc=r_mu_u, scale=r_sigma_u)
        log_qr = tf.reduce_sum(qr.log_prob(r_u))
        log_pr = tf.reduce_sum(log_mixture_prior(r_u))
        kl_r = log_qr - log_pr

        # KL for s (full dimension)
        qs = tfd.Normal(loc=self.s_mu, scale=s_sigma)
        log_qs = tf.reduce_sum(qs.log_prob(s))
        log_ps = tf.reduce_sum(log_mixture_prior(s))
        kl_s = log_qs - log_ps

        kl = kl_r + kl_s
        self.add_loss(self.kl_scale * 0.01* kl)
        return out


# =============================================================================
# INDEPENDENT DROPOUT LAYER
# =============================================================================

class IndependentDropout(tf.keras.layers.Layer):
    """
    Dropout layer with controllable activation via a TensorFlow Variable.

    This allows enabling/disabling dropout independently of the training flag,
    which is useful for MC dropout evaluation.

    Parameters
    ----------
    rate : float
        Dropout rate (between 0 and 1)
    """

    def __init__(self, rate, **kwargs):
        super().__init__(**kwargs)
        self.rate = float(rate)
        # Graph-safe mutable flag
        self.use_dropout = tf.Variable(False, trainable=False, dtype=tf.bool)

    def call(self, inputs):
        if not (0.0 < self.rate < 1.0):
            return inputs

        def dropped():
            # Use TF's dropout (handles scaling correctly)
            return tf.nn.dropout(inputs, rate=self.rate)

        return tf.cond(self.use_dropout, dropped, lambda: inputs)

    def set_dropout(self, use_dropout: bool):
        """Enable or disable dropout."""
        self.use_dropout.assign(bool(use_dropout))


# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def set_dropout_active(model, active: bool):
    """
    Set dropout state for all IndependentDropout layers in a model.

    Handles both regular Keras models and DeepEnsemble objects.

    Parameters
    ----------
    model : tf.keras.Model or DeepEnsemble
        The model containing IndependentDropout layers
    active : bool
        Whether to enable (True) or disable (False) dropout
    """
    hit = 0

    # For regular Keras models, iterate through layers
    # Use layers instead of submodules (which is deprecated)
    if hasattr(model, 'layers'):
        for layer in model.layers:
            if isinstance(layer, IndependentDropout):
                layer.set_dropout(active)
                hit += 1
            # Recursively check nested models
            elif hasattr(layer, 'layers'):
                for sublayer in layer.layers:
                    if isinstance(sublayer, IndependentDropout):
                        sublayer.set_dropout(active)
                        hit += 1

    print(f"IndependentDropout layers updated: {hit}")


def set_kl_scale(model, value):
    """
    Set kl_scale on all variational layers in the model.

    Parameters
    ----------
    model : tf.keras.Model
        The model containing variational layers
    value : float
        The KL scale value to set
    """
    for layer in model.layers:
        if hasattr(layer, 'kl_scale'):
            layer.kl_scale = float(value)


def compile_binary(model):
    """
    Compile model with a binary classification objective.

    Parameters
    ----------
    model : tf.keras.Model
        The model to compile

    Returns
    -------
    tf.keras.Model
        The compiled model
    """
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=[
            tf.keras.metrics.Accuracy(name='accuracy'),
            tf.keras.metrics.AUC(name="auc"),
            tf.keras.metrics.AUC(curve="PR", name="auprc")
        ]
    )
    return model


# =============================================================================
# MODULE INFO
# =============================================================================

if __name__ == "__main__":
    print("Bayesian layers module loaded successfully!")
    print("\nAvailable layers:")
    print("  - DenseVariational: Full-rank BBB dense layer")
    print("  - EmbeddingVariational: Full-rank variational embedding")
    print("  - LowRankDenseVariational: Low-rank BBB dense layer")
    print("  - LowRankEmbeddingVariational: Low-rank variational embedding")
    print("  - Rank1DenseVariational: Rank-1 BBB dense layer")
    print("  - Rank1EmbeddingVariational: Rank-1 variational embedding")
    print("  - IndependentDropout: Controllable dropout layer")
    print("\nHelper functions:")
    print("  - log_mixture_prior: Gaussian mixture prior log-density")
    print("  - set_kl_scale: Set KL scale on all variational layers")
    print("  - set_dropout_active: Control IndependentDropout layers")
