'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

2.  Bayes-by-Backprop
    Charles Blundell, et al, PMLR, 2015.
    http://proceedings.mlr.press/v37/blundell15.pdf

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Callable
from dataclasses import dataclass
import math


@dataclass
class BayesByBackpropReturn:
    '''
    Dataclass of the Bayes-by-Backprop model output.
    
    Attributes
    ----------------
    mean: torch.Tensor
        The mean prediction of the model.
        
    variance: torch.Tensor
        The variance of the prediction.
        
    samples: torch.Tensor
        The samples from the predictive distribution.
        
    x: torch.Tensor | None
        The input data.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    '''
    mean: torch.Tensor
    variance: torch.Tensor
    samples: torch.Tensor
    x: torch.Tensor | None
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]


class BayesianLinear(nn.Module):
    '''
    Bayesian Linear layer that uses weight posteriors.
    
    This layer implements a Bayesian Linear layer with a variational posterior
    over the weights and biases. The layer uses the reparameterization trick
    to enable backpropagation through the stochastic weights.
    
    Parameters
    ----------
    in_features : int
        Number of input features
        
    out_features : int
        Number of output features
        
    prior_sigma1 : float
        Standard deviation for the first Gaussian in the scale mixture prior
        
    prior_sigma2 : float
        Standard deviation for the second Gaussian in the scale mixture prior
        
    prior_pi : float
        Mixing proportion for the scale mixture prior (between 0 and 1)
    '''
    def __init__(self, 
                 in_features: int, 
                 out_features: int, 
                 prior_sigma1: float = 0.1, 
                 prior_sigma2: float = 0.001, 
                 prior_pi: float = 0.5):
        super(BayesianLinear, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        
        # Prior distribution parameters
        self.prior_sigma1 = prior_sigma1
        self.prior_sigma2 = prior_sigma2
        self.prior_pi = prior_pi
        
        # Weight parameters (mean and rho)
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).normal_(0, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).normal_(-3, 0.1))
        
        # Bias parameters (mean and rho)
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(0, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).normal_(-3, 0.1))
        
        # Initialize log_prior and log_variational_posterior
        self.log_prior = 0
        self.log_variational_posterior = 0
        
        # Constants for numerical stability
        self.log_sqrt_2pi = math.log(math.sqrt(2 * math.pi))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the Bayesian Linear layer.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, in_features]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, out_features]
            Output tensor
        '''
        # Sample weights and calculate log probabilities
        weight_sigma = torch.log1p(torch.exp(self.weight_rho))
        weight_epsilon = torch.randn_like(self.weight_mu).to(self.weight_mu.device)
        weight = self.weight_mu + weight_epsilon * weight_sigma
        
        # Sample bias and calculate log probabilities
        bias_sigma = torch.log1p(torch.exp(self.bias_rho))
        bias_epsilon = torch.randn_like(self.bias_mu).to(self.bias_mu.device)
        bias = self.bias_mu + bias_epsilon * bias_sigma
        
        # Calculate log prior
        log_prior_weights = self._log_scale_mixture_prior(weight)
        log_prior_bias = self._log_scale_mixture_prior(bias)
        self.log_prior = log_prior_weights + log_prior_bias
        
        # Calculate log variational posterior
        log_var_posterior_weights = self._log_variational_posterior(weight, self.weight_mu, weight_sigma)
        log_var_posterior_bias = self._log_variational_posterior(bias, self.bias_mu, bias_sigma)
        self.log_variational_posterior = log_var_posterior_weights + log_var_posterior_bias
        
        # Linear transformation
        return F.linear(x, weight, bias)
    
    def _log_scale_mixture_prior(self, w: torch.Tensor) -> torch.Tensor:
        '''
        Calculate the log probability of the scale mixture prior.
        
        Parameters
        ----------
        w : torch.Tensor
            Weights or biases
            
        Returns
        -------
        log_prior : torch.Tensor
            Log probability of the prior
        '''
        # Compute log probabilities for each component of the mixture
        log_prior1 = -0.5 * ((w ** 2) / (self.prior_sigma1 ** 2)) - torch.log(torch.tensor(self.prior_sigma1).to(w.device)) - self.log_sqrt_2pi
        log_prior2 = -0.5 * ((w ** 2) / (self.prior_sigma2 ** 2)) - torch.log(torch.tensor(self.prior_sigma2).to(w.device)) - self.log_sqrt_2pi
        
        # Use log-sum-exp trick for numerical stability
        log_prior_max = torch.max(log_prior1, log_prior2)
        log_prior = log_prior_max + torch.log(
            self.prior_pi * torch.exp(log_prior1 - log_prior_max) + 
            (1 - self.prior_pi) * torch.exp(log_prior2 - log_prior_max)
        )
        
        return torch.sum(log_prior)
    
    def _log_variational_posterior(self, w: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
        '''
        Calculate the log probability of the variational posterior.
        
        Parameters
        ----------
        w : torch.Tensor
            Sampled weights or biases
            
        mu : torch.Tensor
            Mean of the posterior
            
        sigma : torch.Tensor
            Standard deviation of the posterior
            
        Returns
        -------
        log_posterior : torch.Tensor
            Log probability of the variational posterior
        '''
        log_posterior = torch.sum(-0.5 * ((w - mu) ** 2) / (sigma ** 2) - 
                                 torch.log(sigma) - self.log_sqrt_2pi)
        
        return log_posterior


class BayesByBackprop(nn.Module):
    '''
    Bayes-by-Backprop model for regression.
    
    This model implements Bayes-by-Backprop (Blundell et al., 2015), which uses
    variational inference to approximate the posterior distribution over the weights
    of a neural network. The model samples weights from the variational posterior
    during both training and inference to estimate predictive uncertainty.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    prior_sigma1 : float
        Standard deviation for the first Gaussian in the scale mixture prior
        
    prior_sigma2 : float
        Standard deviation for the second Gaussian in the scale mixture prior
        
    prior_pi : float
        Mixing proportion for the scale mixture prior (between 0 and 1)
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 prior_sigma1: float = 0.1,
                 prior_sigma2: float = 0.001,
                 prior_pi: float = 0.5,
                 activation: nn.Module = nn.ReLU()):
        
        super(BayesByBackprop, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.prior_sigma1 = prior_sigma1
        self.prior_sigma2 = prior_sigma2
        self.prior_pi = prior_pi
        self.activation = activation
        
        # Input layer
        self.layers = nn.ModuleList([BayesianLinear(dim_input, dim_hidden, 
                                                   prior_sigma1, prior_sigma2, prior_pi)])
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            self.layers.append(BayesianLinear(dim_hidden, dim_hidden, 
                                             prior_sigma1, prior_sigma2, prior_pi))
        
        # Output layer
        self.layers.append(BayesianLinear(dim_hidden, dim_output, 
                                         prior_sigma1, prior_sigma2, prior_pi))
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "BayesByBackprop"
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        # Reset log probabilities
        self.log_prior = 0
        self.log_variational_posterior = 0
        
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.activation(x)
            
            # Accumulate log probabilities
            self.log_prior += layer.log_prior
            self.log_variational_posterior += layer.log_variational_posterior
        
        # Final layer (no activation)
        x = self.layers[-1](x)
        
        # Accumulate log probabilities from the final layer
        self.log_prior += self.layers[-1].log_prior
        self.log_variational_posterior += self.layers[-1].log_variational_posterior
        
        return x
    
    def get_prediction(self, x: torch.Tensor, n_samples: int = 20) -> BayesByBackpropReturn:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        n_samples : int
            Number of Monte Carlo samples
            
        Returns
        -------
        BayesByBackpropReturn
            Object containing mean, variance, and samples
        '''
        # Ensure model is in evaluation mode
        self.eval()
        
        # Generate multiple predictions
        samples = torch.stack([self.forward(x) for _ in range(n_samples)], dim=0)
        
        # Calculate mean and variance
        mean = torch.mean(samples, dim=0)
        variance = torch.var(samples, dim=0)
        
        # Define loss functions
        def train_loss_fn(y_true: torch.Tensor) -> Dict[str, torch.Tensor]:
            # Negative ELBO loss (data likelihood term + KL divergence term)
            mse_loss = F.mse_loss(mean, y_true)
            
            # KL divergence term (scaled by the number of batches)
            kl_divergence = self.log_variational_posterior - self.log_prior
            
            # Scale KL divergence to prevent it from dominating the loss
            # This is a common practice in variational inference
            kl_weight = 1.0 / x.size(0)  # Scale by batch size
            
            # Apply gradient clipping to KL divergence to prevent instability
            kl_clipped = torch.clamp(kl_divergence * kl_weight, -1000, 1000)
            
            # Total loss is negative ELBO
            total_loss = mse_loss + kl_clipped
            
            return {
                "loss": total_loss, 
                "mse": mse_loss, 
                "kl": kl_clipped
            }
        
        def val_loss_fn(y_true: torch.Tensor) -> torch.Tensor:
            # Use MSE loss for validation
            return F.mse_loss(mean, y_true)
        
        return BayesByBackpropReturn(
            mean=mean,
            variance=variance,
            samples=samples,
            x=x,
            train_loss_fn=train_loss_fn,
            val_loss_fn=val_loss_fn
        )
    