"""
Bayesian LSTM Variational Layers

This module contains all variational layer implementations for Bayesian LSTMs:
- DenseVariational: Full-rank Bayes by Backprop
- LowRankDenseVariational: Low-rank factorization Bayes by Backprop
- Rank1DenseVariational: Rank-1 multiplicative reparameterization

"""

import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np

tfd = tfp.distributions


# ==============================================================================
# Full-Rank Dense Variational Layer
# ==============================================================================

class DenseVariational(tf.keras.layers.Layer):
    """
    Full-rank Bayes by Backprop Dense layer with weight caching for LSTM.
    Key features:
    1. Sample weights once, then reuse across multiple calls (for LSTM timesteps)
    2. use_cached=False: Sample fresh weights
    3. use_cached=True: Reuse previously sampled weights
    Usage pattern for LSTM:
        # First timestep: sample fresh
        output_t0 = layer(input_t0, training=True, use_cached=False)
        # Remaining timesteps: reuse same weights
        for t in range(1, T):
            output_t = layer(input_t, training=True, use_cached=True)
    This ensures one sampling per batch as required by DeepMind Algorithm 2.
    """
    def __init__(
        self,
        units,
        use_bias=True,
        bias_initializer='zeros',
        prior_params=None,
        name=None,
        **kwargs
    ):
        super().__init__(name=name, **kwargs)
        self.units = units
        self.use_bias = use_bias
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        # Prior: mixture of two Gaussians
        if prior_params is None:
            prior_params = {
                'pi': 0.5,
                'sigma1': 1.0,
                'sigma2': tf.exp(-6.0),
            }
        self.prior_params = prior_params
        # Cached sampled weights (for reuse across timesteps)
        self.cached_W = None
        self.cached_b = None
        self.kl_scale = tf.Variable(1.0, trainable=False, dtype=tf.float32)
        # Store sampled weights for KL computation
        self.sampled_W = None
        self.sampled_b = None

    def build(self, input_shape):
        input_dim = int(input_shape[-1])
        # Weight mean
        self.mu_W = self.add_weight(
            name='mu_W',
            shape=(input_dim, self.units),
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2),
            trainable=True,
        )
        # Weight log-std
        self.rho_W = self.add_weight(
            name='rho_W',
            shape=(input_dim, self.units),
            initializer=tf.keras.initializers.RandomUniform(-3, -2),
            trainable=True,
        )
        if self.use_bias:
            # Bias mean - custom initializer for forget gate
            self.mu_b = self.add_weight(
                name='mu_b',
                shape=(self.units,),
                initializer=self.bias_initializer,
                trainable=True,
            )
            # Bias log-std
            self.rho_b = self.add_weight(
                name='rho_b',
                shape=(self.units,),
                initializer=tf.keras.initializers.RandomUniform(-3.5, -2.5),
                trainable=True,
            )
        super().build(input_shape)

    def call(self, inputs, training=None, sample_at_inference=True, use_cached=False):
        """
        Forward pass with optional weight caching.
        Args:
            inputs: Input tensor
            training: Training mode flag
            sample_at_inference: Whether to sample during inference (True for Bayesian)
            use_cached: If True, reuse cached weights from previous call
                        If False, sample fresh weights
        Returns:
            output: Dense layer output
        """
        should_sample = training or sample_at_inference
        if should_sample:
            # Check if we should use cached weights
            if use_cached and self.cached_W is not None:
                # Reuse cached weights (LSTM timesteps 1..T-1)
                W = self.cached_W
                b = self.cached_b
            else:
                # Sample fresh weights (LSTM timestep 0)
                sigma_W = tf.nn.softplus(self.rho_W)
                epsilon_W = tf.random.normal(shape=self.mu_W.shape, dtype=inputs.dtype)
                W = self.mu_W + sigma_W * epsilon_W
                if self.use_bias:
                    sigma_b = tf.nn.softplus(self.rho_b)
                    epsilon_b = tf.random.normal(shape=self.mu_b.shape, dtype=inputs.dtype)
                    b = self.mu_b + sigma_b * epsilon_b
                else:
                    b = None
                # Cache for subsequent calls
                self.cached_W = W
                self.cached_b = b
            # Store for KL computation
            self.sampled_W = W
            self.sampled_b = b
        else:
            # Inference with posterior mean (deterministic)
            W = self.mu_W
            b = self.mu_b if self.use_bias else None
            self.sampled_W = self.mu_W
            self.sampled_b = self.mu_b if self.use_bias else None
        # Dense operation
        output = tf.matmul(inputs, W)
        if b is not None:
            output = output + b
        if training and not use_cached:
            self.add_loss(self.kl_scale * self.kl_divergence())
        return output

    def clear_cache(self):
        """
        Clear cached weights. Call this at the start of each new batch.
        """
        self.cached_W = None
        self.cached_b = None

    def kl_divergence(self):
        """
        Compute KL(q(W) || p(W)) evaluated at sampled weights.
        """
        # Posterior distribution for weights
        sigma_W = tf.nn.softplus(self.rho_W)
        qw = tfd.Normal(loc=self.mu_W, scale=sigma_W)
        # Log probability of sampled weights under posterior
        log_q = tf.reduce_sum(qw.log_prob(self.sampled_W))
        # Log probability of sampled weights under mixture prior
        log_p = tf.reduce_sum(self._log_mixture_prior(self.sampled_W))
        # Add bias contribution if used
        if self.use_bias:
            sigma_b = tf.nn.softplus(self.rho_b)
            qb = tfd.Normal(loc=self.mu_b, scale=sigma_b)
            log_q += tf.reduce_sum(qb.log_prob(self.sampled_b))
            log_p += tf.reduce_sum(self._log_mixture_prior(self.sampled_b))
        # KL = E[log q(w) - log p(w)]
        return log_q - log_p

    def _log_mixture_prior(self, w):
        """
        Compute log p(w) under scale mixture prior:
        p(w) = pi * N(0, sigma1^2) + (1-pi) * N(0, sigma2^2)
        """
        pi = self.prior_params['pi']
        sigma1 = self.prior_params['sigma1']
        sigma2 = self.prior_params['sigma2']
        # Two components of the mixture
        p1 = tfd.Normal(loc=0.0, scale=sigma1)
        p2 = tfd.Normal(loc=0.0, scale=sigma2)
        # Log probabilities
        log_p1 = p1.log_prob(w)
        log_p2 = p2.log_prob(w)
        # Log mixture
        log_mix = tf.math.reduce_logsumexp(
            tf.stack([
                tf.math.log(pi) + log_p1,
                tf.math.log(1.0 - pi) + log_p2
            ], axis=0),
            axis=0
        )
        return log_mix


# ==============================================================================
# Low-Rank Dense Variational Layer
# ==============================================================================

class LowRankDenseVariational(tf.keras.layers.Layer):
    """
    Low-Rank Bayes by Backprop with weight caching for LSTM.
    Factorization: W = A @ B^T where A: (input_dim, rank), B: (units, rank)
    Both A and B have their own mean and variance parameters.
    Identical caching logic to DenseVariational (Cell 4).
    """
    def __init__(
        self,
        units,
        rank=10,
        use_bias=True,
        bias_initializer='zeros',
        activation=None,
        prior_params=None,
        name=None,
        **kwargs
    ):
        super().__init__(name=name, **kwargs)
        self.units = units
        self.rank = rank
        self.use_bias = use_bias
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.activation = activation
        # Prior: mixture of two Gaussians
        if prior_params is None:
            prior_params = {
                'pi': 0.5,
                'sigma1': 1.0,
                'sigma2': tf.exp(-6.0),
            }
        self.prior_params = prior_params
        # Cached sampled weights
        self.cached_W = None
        self.cached_b = None
        self.kl_scale = tf.Variable(1.0, trainable=False, dtype=tf.float32)
        # Store sampled factors and weights for KL computation
        self.sampled_A = None
        self.sampled_B = None
        self.sampled_W = None
        self.sampled_b = None

    def _is_relu_family(self, activation):
        """Check if activation belongs to ReLU family for He initialization."""
        if activation is None:
            return False

        # Handle both string and function object
        if isinstance(activation, str):
            name = activation
        else:
            name = getattr(activation, '__name__', '')

        name_normalized = name.lower().replace('_', '')

        # ReLU family activations
        relu_family = {'relu', 'elu', 'selu', 'leakyrelu', 'relu6', 'gelu'}

        return name_normalized in relu_family

    def build(self, input_shape):
        din = int(input_shape[-1])   # Input dimension
        dout = self.units             # Output dimension
        r = self.rank                 # Rank of factorization

        # ADAPTIVE INITIALIZATION - Variance-preserving & rank-aware

        # Determine target variance based on activation
        if self._is_relu_family(self.activation):
            sigma_w_sq = 2.0 / din              # He initialization for ReLU family
        else:
            sigma_w_sq = 2.0 / (din + dout)     # Glorot initialization

        #  Calculate uniform bound for mean parameters
        # Formula: a = sqrt(3) * (σ²_W / r)^(1/4)
        # Fourth root provides variance-preserving scaling with rank
        a = np.sqrt(3.0) * np.power(sigma_w_sq / r, 0.25)

        # Calculate rho initialization for variance parameters
        # Theory: sigma = eta * (σ²_W / r)^(1/4), where eta ∈ (0,1)
        # preserves signal-to-noise ratio O(1) for stable variational optimization
        #
        # Empirical tuning: Original rho ~ U[-3,-2] gave softplus: [0.048, 0.127]
        # Average ≈ 0.088 worked well for LSTM. We match this with eta ≈ 0.6
        eta = 0.7
        sigma_init = eta * np.power(sigma_w_sq / r, 0.25)

        # Invert softplus: rho such that softplus(rho) = sigma_init
        if sigma_init > 1e-4:
            rho_center = np.log(np.expm1(sigma_init))  # Central value
        #using the Taylor series approximation:exp(σ) ≈ 1 + σ   (for small σ) -> 
        # exp(σ) - 1 ≈ σlog(exp(σ) - 1) ≈ log(σ)

        else:
            rho_center = np.log(sigma_init)

        # Add diversity: small random perturbation ±0.3 around center
        # This preserves rank-awareness while adding needed heterogeneity
        rho_min = rho_center - 0.3
        rho_max = rho_center + 0.3

        # WEIGHT INITIALIZATION

        # Factor A mean: U[-a, a]
        self.A_mu = self.add_weight(
            name="A_mu",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor B mean: U[-a, a]
        self.B_mu = self.add_weight(
            name="B_mu",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(-a, a),
            trainable=True
        )

        # Factor A log-variance: Random for diversity
        self.A_rho = self.add_weight(
            name="A_rho",
            shape=[din, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        # Factor B log-variance: Random for diversity
        self.B_rho = self.add_weight(
            name="B_rho",
            shape=[dout, r],
            initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
            trainable=True
        )

        
        # BIAS INITIALIZATION

        if self.use_bias:
            # Bias mean: Use custom initializer () LSTM forget gate bias),
            # or use adaptive initialization (10% of weight bound)
            # Check if custom initializer was provided
            if isinstance(self.bias_initializer, tf.keras.initializers.Constant):
                # Use custom initializer (e.g., for LSTM forget gate bias)
                bias_init = self.bias_initializer
            else:
                # Use adaptive initialization
                bias_init = tf.keras.initializers.Zeros()
            
            self.b_mu = self.add_weight(
                name="b_mu",
                shape=[dout],
                initializer=bias_init,
                trainable=True
            )

            # Bias log-variance: random for diversity
            self.b_rho = self.add_weight(
                name="b_rho",
                shape=[dout],
                initializer=tf.keras.initializers.RandomUniform(rho_min, rho_max),
                trainable=True
            )

        super().build(input_shape)

    def init_from_full_matrix(self, W_full):
        """
        Initialize A_mu and B_mu from a full-rank matrix using truncated SVD.
        """
        W_full = np.asarray(W_full)
        U, s, Vt = np.linalg.svd(W_full, full_matrices=False)
        r_eff = min(self.rank, len(s))
        A = U[:, :r_eff] * np.sqrt(s[:r_eff])
        B = Vt[:r_eff, :].T * np.sqrt(s[:r_eff])
        if r_eff < self.rank:
            A_pad = np.zeros((A.shape[0], self.rank), dtype=A.dtype)
            B_pad = np.zeros((B.shape[0], self.rank), dtype=B.dtype)
            A_pad[:, :r_eff] = A
            B_pad[:, :r_eff] = B
            A, B = A_pad, B_pad
        self.A_mu.assign(A)
        self.B_mu.assign(B)

    def call(self, inputs, training=None, sample_at_inference=True, use_cached=False):
        """
        Forward pass with optional weight caching.
        Identical logic to DenseVariational but with low-rank factorization.
        """
        should_sample = training or sample_at_inference
        if should_sample:
            # Check if we should use cached weights
            if use_cached and self.cached_W is not None:
                # Reuse cached weights (LSTM timesteps 1..T-1)
                W = self.cached_W
                b = self.cached_b
            else:
                # Sample fresh weights (LSTM timestep 0)
                # Sample factor A
                A_sigma = tf.nn.softplus(self.A_rho)
                epsilon_A = tf.random.normal(shape=self.A_mu.shape, dtype=inputs.dtype)
                A = self.A_mu + A_sigma * epsilon_A
                # Sample factor B
                B_sigma = tf.nn.softplus(self.B_rho)
                epsilon_B = tf.random.normal(shape=self.B_mu.shape, dtype=inputs.dtype)
                B = self.B_mu + B_sigma * epsilon_B
                # Compute full weight matrix W = A @ B^T
                W = tf.matmul(A, B, transpose_b=True)  # (input_dim, units)
                # Store sampled factors for KL
                self.sampled_A = A
                self.sampled_B = B
                if self.use_bias:
                    b_sigma = tf.nn.softplus(self.b_rho)
                    epsilon_b = tf.random.normal(shape=self.b_mu.shape, dtype=inputs.dtype)
                    b = self.b_mu + b_sigma * epsilon_b
                else:
                    b = None
                # Cache for subsequent calls
                self.cached_W = W
                self.cached_b = b
            # Store for KL computation
            self.sampled_W = W
            self.sampled_b = b
        else:
            # Inference with posterior mean (deterministic)
            W = tf.matmul(self.A_mu, self.B_mu, transpose_b=True)
            b = self.b_mu if self.use_bias else None
            self.sampled_A = self.A_mu
            self.sampled_B = self.B_mu
            self.sampled_W = W
            self.sampled_b = self.b_mu if self.use_bias else None
        # Dense operation
        output = tf.matmul(inputs, W)
        if b is not None:
            output = output + b
        if training and not use_cached:
            self.add_loss(self.kl_scale * self.kl_divergence())
        return output

    def clear_cache(self):
        """Clear cached weights. Call this at the start of each new batch."""
        self.cached_W = None
        self.cached_b = None

    def kl_divergence(self):
        """
        Compute KL(q(A,B,b) || p(A,B,b)) evaluated at sampled factors.
        KL for A + KL for B + KL for bias.
        """
        # KL for factor A
        A_sigma = tf.nn.softplus(self.A_rho)
        qA = tfd.Normal(loc=self.A_mu, scale=A_sigma)
        log_q_A = tf.reduce_sum(qA.log_prob(self.sampled_A))
        log_p_A = tf.reduce_sum(self._log_mixture_prior(self.sampled_A))
        # KL for factor B
        B_sigma = tf.nn.softplus(self.B_rho)
        qB = tfd.Normal(loc=self.B_mu, scale=B_sigma)
        log_q_B = tf.reduce_sum(qB.log_prob(self.sampled_B))
        log_p_B = tf.reduce_sum(self._log_mixture_prior(self.sampled_B))
        # Total log_q and log_p
        log_q = log_q_A + log_q_B
        log_p = log_p_A + log_p_B
        # Add bias contribution if used
        if self.use_bias:
            b_sigma = tf.nn.softplus(self.b_rho)
            qb = tfd.Normal(loc=self.b_mu, scale=b_sigma)
            log_q += tf.reduce_sum(qb.log_prob(self.sampled_b))
            log_p += tf.reduce_sum(self._log_mixture_prior(self.sampled_b))
        # KL = E[log q(theta) - log p(theta)]
        return log_q - log_p

    def _log_mixture_prior(self, w):
        """
        Compute log p(w) under scale mixture prior.
        Identical to DenseVariational.
        """
        pi = self.prior_params['pi']
        sigma1 = self.prior_params['sigma1']
        sigma2 = self.prior_params['sigma2']
        # Two components of the mixture
        p1 = tfd.Normal(loc=0.0, scale=sigma1)
        p2 = tfd.Normal(loc=0.0, scale=sigma2)
        # Log probabilities
        log_p1 = p1.log_prob(w)
        log_p2 = p2.log_prob(w)
        # Log mixture
        log_mix = tf.math.reduce_logsumexp(
            tf.stack([
                tf.math.log(pi) + log_p1,
                tf.math.log(1.0 - pi) + log_p2
            ], axis=0),
            axis=0
        )
        return log_mix

# Rank-1 Dense Variational Layer

def log_mixture_prior(w, pi=0.5, sigma1=1.0, sigma2=None):
    """
    Scale-mixture Gaussian prior used in Bayes by Backprop.
    p(w) = pi * N(0, sigma1^2) + (1-pi) * N(0, sigma2^2)
    """
    if sigma2 is None:
        sigma2 = tf.exp(-6.0)
    p1 = tfd.Normal(loc=0.0, scale=sigma1)
    p2 = tfd.Normal(loc=0.0, scale=sigma2)
    log_p1 = p1.log_prob(w)
    log_p2 = p2.log_prob(w)
    return tf.math.reduce_logsumexp(
        tf.stack([
            tf.math.log(pi) + log_p1,
            tf.math.log(1.0 - pi) + log_p2
        ], axis=0),
        axis=0
    )


class Rank1DenseVariational(tf.keras.layers.Layer):
    """
    Rank-1 multiplicative reparameterization (Dusenberry et al., 2020) with
    LSTM-friendly weight caching. Samples rank-1 factors once per batch and
    reuses them across timesteps when use_cached=True.
    """
    def __init__(
        self,
        units,
        use_bias=True,
        bias_initializer='zeros',
        activation=None,
        name=None,
        **kwargs,
    ):
        super().__init__(name=name, **kwargs)
        self.units = int(units)
        self.use_bias = use_bias
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.activation = tf.keras.activations.get(activation)
        self.kl_scale = tf.Variable(1.0, trainable=False, dtype=tf.float32)
        self.cached_r = None
        self.cached_s = None
        self.sampled_r = None
        self.sampled_s = None

    def build(self, input_shape):
        input_dim = int(input_shape[-1])
        self.W0 = self.add_weight(
            name='W0',
            shape=(input_dim, self.units),
            initializer='glorot_uniform', #glorot initializer matches the determinisitc model initializer
            trainable=True,
        )
        if self.use_bias:
            self.b = self.add_weight(
                name='bias',
                shape=(self.units,),
                initializer=self.bias_initializer,
                trainable=True,
            )
        self.r_mu = self.add_weight(
            name='r_mu',
            shape=(self.units,),
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2),
            trainable=True,
        )
        self.r_rho = self.add_weight(
            name='r_rho',
            shape=(self.units,),
            initializer=tf.keras.initializers.RandomUniform(-3.0, -2.0),
            trainable=True,
        )
        self.s_mu = self.add_weight(
            name='s_mu',
            shape=(input_dim,),
            initializer=tf.keras.initializers.RandomUniform(-0.2, 0.2),
            trainable=True,
        )
        self.s_rho = self.add_weight(
            name='s_rho',
            shape=(input_dim,),
            initializer=tf.keras.initializers.RandomUniform(-3.0, -2.0),
            trainable=True,
        )
        super().build(input_shape)

    def _sample(self, mu, rho, dtype):
        sigma = tf.nn.softplus(rho)
        eps = tf.random.normal(shape=tf.shape(mu), dtype=dtype)
        return mu + sigma * eps, sigma

    def call(self, inputs, training=None, sample_at_inference=True, use_cached=False):
        should_sample = training or sample_at_inference

        if should_sample:
            if use_cached and self.cached_r is not None and self.cached_s is not None:
                r = self.cached_r
                s = self.cached_s
                #r_sigma = tf.nn.softplus(self.r_rho)
                #s_sigma = tf.nn.softplus(self.s_rho)
            else:
                r, r_sigma = self._sample(self.r_mu, self.r_rho, inputs.dtype)
                s, s_sigma = self._sample(self.s_mu, self.s_rho, inputs.dtype)
                self.cached_r = r
                self.cached_s = s
        else:
            r = self.r_mu
            s = self.s_mu

        self.sampled_r = r
        self.sampled_s = s

        x_scaled = inputs * (1.0 + s)
        y = tf.matmul(x_scaled, self.W0)
        y = y * (1.0 + r)
        if self.use_bias:
            y = y + self.b
        if self.activation is not None:
            y = self.activation(y)

        if training and not use_cached:
            self.add_loss(self.kl_scale * self.kl_divergence())

        return y

    def clear_cache(self):
        self.cached_r = None
        self.cached_s = None

    def kl_divergence(self):
        if self.sampled_r is None or self.sampled_s is None:
            self.sampled_r = self.r_mu
            self.sampled_s = self.s_mu
        r_sigma = tf.nn.softplus(self.r_rho)
        s_sigma = tf.nn.softplus(self.s_rho)
        qr = tfd.Normal(loc=self.r_mu, scale=r_sigma)
        qs = tfd.Normal(loc=self.s_mu, scale=s_sigma)
        log_q = tf.reduce_sum(qr.log_prob(self.sampled_r)) + tf.reduce_sum(qs.log_prob(self.sampled_s))
        log_p = tf.reduce_sum(log_mixture_prior(self.sampled_r)) + tf.reduce_sum(log_mixture_prior(self.sampled_s))
        return log_q - log_p
