'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

4.  Deterministic Variational Inference (DVI, dDVI, DVI-MC)
    Anqi Wu, et al, arXiv, 2019.
    https://arxiv.org/pdf/1810.03958

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Tuple, Optional, Callable
from dataclasses import dataclass
import math
import time
from baseline.common import compute_metrics


@dataclass
class DVIReturn:
    '''
    Dataclass of the Deterministic Variational Inference model output.
    
    Attributes
    ----------------
    mean: torch.Tensor
        The mean prediction of the model.
        
    variance: torch.Tensor
        The variance of the prediction.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    '''
    mean: torch.Tensor
    variance: torch.Tensor
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]


class DVILayer(nn.Module):
    '''
    Deterministic Variational Inference layer.
    
    This layer implements a Bayesian neural network layer using DVI,
    which maintains a deterministic approximation to the posterior
    distribution over weights.
    
    Parameters
    ----------
    in_features : int
        Number of input features
        
    out_features : int
        Number of output features
        
    prior_mean : float
        Mean of the prior distribution
        
    prior_var : float
        Variance of the prior distribution
        
    n_components : int
        Number of mixture components in the variational posterior
    '''
    def __init__(self, 
                 in_features: int, 
                 out_features: int, 
                 prior_mean: float = 0.0, 
                 prior_var: float = 1.0,
                 n_components: int = 2):
        super(DVILayer, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.prior_mean = prior_mean
        self.prior_var = prior_var
        self.n_components = n_components
        
        # Initialize mixture weights (alpha)
        self.log_alpha = nn.Parameter(torch.zeros(n_components))
        
        # Initialize mixture means (mu)
        self.weight_mu = nn.ParameterList([
            nn.Parameter(torch.Tensor(out_features, in_features).normal_(0, 0.1))
            for _ in range(n_components)
        ])
        
        # Initialize mixture variances (sigma^2)
        self.weight_log_var = nn.ParameterList([
            nn.Parameter(torch.Tensor(out_features, in_features).fill_(-3.0))
            for _ in range(n_components)
        ])
        
        # Initialize bias parameters
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).normal_(0, 0.1))
        self.bias_log_var = nn.Parameter(torch.Tensor(out_features).fill_(-3.0))
        
        # For KL divergence calculation
        self.kl_divergence = 0
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the DVI layer.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, in_features]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, out_features]
            Output tensor
        '''
        # Reset KL divergence
        self.kl_divergence = 0
        
        # Normalize mixture weights (alpha) using softmax
        alpha = F.softmax(self.log_alpha, dim=0)
        
        # Initialize output
        output = 0
        
        # Compute weighted sum of component outputs
        for k in range(self.n_components):
            # Get weight mean and variance for this component
            weight_mu_k = self.weight_mu[k]
            weight_var_k = torch.exp(self.weight_log_var[k])
            
            # Compute output for this component
            component_output = F.linear(x, weight_mu_k)
            
            # Add to weighted sum
            output += alpha[k] * component_output
            
            # Calculate KL divergence for this component's weights
            kl_weights = self._kl_gaussian_mixture_component(
                weight_mu_k, weight_var_k, 
                self.prior_mean, self.prior_var
            )
            self.kl_divergence += alpha[k] * kl_weights
        
        # Add bias
        output = output + self.bias_mu
        
        # Calculate KL divergence for bias
        bias_var = torch.exp(self.bias_log_var)
        kl_bias = self._kl_gaussian(
            self.bias_mu, bias_var,
            self.prior_mean, self.prior_var
        )
        self.kl_divergence += kl_bias
        
        # Add KL divergence for mixture weights (alpha)
        # Using uniform prior for alpha
        uniform_prior = torch.ones_like(alpha) / self.n_components
        kl_alpha = torch.sum(alpha * torch.log(alpha / uniform_prior + 1e-10))
        self.kl_divergence += kl_alpha
        
        return output
    
    def _kl_gaussian(self, mu_q, var_q, mu_p, var_p):
        '''
        Calculate KL divergence between two Gaussian distributions.
        
        KL(q||p) = 0.5 * (log(var_p/var_q) + (var_q + (mu_q - mu_p)^2)/var_p - 1)
        '''
        kl = 0.5 * torch.sum(
            torch.log(torch.clamp(var_p / (var_q + 1e-10), min=1e-10)) + 
            (var_q + (mu_q - mu_p)**2) / (var_p + 1e-10) - 1
        )
        return kl
    
    def _kl_gaussian_mixture_component(self, mu_q, var_q, mu_p, var_p):
        '''
        Calculate KL divergence between a Gaussian component and the prior.
        '''
        return self._kl_gaussian(mu_q, var_q, mu_p, var_p)
    
    def sample_weights(self):
        '''
        Sample weights from the variational posterior.
        
        Returns
        -------
        weight : torch.Tensor
            Sampled weight matrix
        bias : torch.Tensor
            Sampled bias vector
        '''
        # Sample component index based on mixture weights
        alpha = F.softmax(self.log_alpha, dim=0)
        component_idx = torch.multinomial(alpha, 1).item()
        
        # Get selected component parameters
        weight_mu = self.weight_mu[component_idx]
        weight_var = torch.exp(torch.clamp(self.weight_log_var[component_idx], max=10.0))
        
        # Sample weights from Gaussian
        weight_epsilon = torch.randn_like(weight_mu).to(weight_mu.device)
        weight = weight_mu + weight_epsilon * torch.sqrt(weight_var)
        
        # Sample bias
        bias_var = torch.exp(torch.clamp(self.bias_log_var, max=10.0))
        bias_epsilon = torch.randn_like(self.bias_mu).to(self.bias_mu.device)
        bias = self.bias_mu + bias_epsilon * torch.sqrt(bias_var)
        
        return weight, bias
    
    def get_predictive_mean_var(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        Get predictive mean and variance for the layer output.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, in_features]
            Input tensor
            
        Returns
        -------
        mean : torch.Tensor [batch_size, out_features]
            Predictive mean
        var : torch.Tensor [batch_size, out_features]
            Predictive variance
        '''
        # Normalize mixture weights (alpha)
        alpha = F.softmax(self.log_alpha, dim=0)
        
        # Initialize mean and second moment
        mean = 0
        second_moment = 0
        
        # Compute mean and variance using mixture components
        for k in range(self.n_components):
            # Get weight mean and variance for this component
            weight_mu_k = self.weight_mu[k]
            weight_var_k = torch.exp(torch.clamp(self.weight_log_var[k], max=10.0))
            
            # Compute mean output for this component
            component_mean = F.linear(x, weight_mu_k)
            
            # Compute variance due to weight uncertainty
            # For each output unit j: var_j = sum_i x_i^2 * var_w_ji
            x_squared = x.pow(2)
            component_var = torch.matmul(x_squared, weight_var_k.t())
            
            # Add to weighted sum for mean
            mean += alpha[k] * component_mean
            
            # Add to weighted sum for second moment (mean^2 + var)
            second_moment += alpha[k] * (component_mean.pow(2) + component_var)
        
        # Add bias mean
        mean = mean + self.bias_mu
        
        # Add bias variance
        bias_var = torch.exp(torch.clamp(self.bias_log_var, max=10.0))
        
        # Compute total variance: E[y^2] - E[y]^2 + bias_var
        var = second_moment - mean.pow(2) + bias_var
        # Ensure variance is always positive to prevent NaN in sqrt
        var = torch.clamp(var, min=1e-8)
        
        return mean, var


class DVI(nn.Module):
    '''
    Deterministic Variational Inference model for regression.
    
    This model implements Deterministic Variational Inference (Wu et al., 2019),
    which uses a deterministic approximation to the posterior distribution
    over the weights of a neural network.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    prior_mean : float
        Mean of the prior distribution
        
    prior_var : float
        Variance of the prior distribution
        
    n_components : int
        Number of mixture components in the variational posterior
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 prior_mean: float = 0.0,
                 prior_var: float = 1.0,
                 n_components: int = 2,
                 activation: nn.Module = nn.ReLU(),
                 prior_log_obs_var: float = -2.0,
                 ratio_kl: float = 0.01):
        
        super(DVI, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.prior_mean = prior_mean
        self.prior_var = prior_var
        self.n_components = n_components
        self.activation = activation
        self.ratio_kl = ratio_kl
        
        # Input layer
        self.layers = nn.ModuleList([
            DVILayer(dim_input, dim_hidden, prior_mean, prior_var, n_components)
        ])
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            self.layers.append(
                DVILayer(dim_hidden, dim_hidden, prior_mean, prior_var, n_components)
            )
        
        # Output layer
        self.layers.append(
            DVILayer(dim_hidden, dim_output, prior_mean, prior_var, n_components)
        )
        
        # Observation noise variance (learnable)
        self.log_obs_var = nn.Parameter(torch.ones(1) * prior_log_obs_var)
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "DVI"
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        # Reset KL divergence
        self.kl_divergence = 0
        
        # Forward pass through layers
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.activation(x)
            
            # Accumulate KL divergence
            self.kl_divergence += layer.kl_divergence
        
        # Final layer (no activation)
        x = self.layers[-1](x)
        
        # Add KL divergence from the final layer
        self.kl_divergence += self.layers[-1].kl_divergence
        
        return x
    
    def sample_predictive(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Sample from the predictive distribution.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        # Forward pass with sampled weights
        current_input = x
        
        for i, layer in enumerate(self.layers[:-1]):
            # Sample weights and biases
            weight, bias = layer.sample_weights()
            
            # Linear transformation with sampled weights
            current_input = F.linear(current_input, weight, bias)
            
            # Apply activation
            current_input = self.activation(current_input)
        
        # Final layer (no activation)
        weight, bias = self.layers[-1].sample_weights()
        output = F.linear(current_input, weight, bias)
        
        # Add observation noise
        obs_var = torch.exp(torch.clamp(self.log_obs_var, max=10.0))
        noise = torch.randn_like(output) * torch.sqrt(obs_var)
        output = output + noise
        
        return output
    
    def get_prediction(self, x: torch.Tensor) -> DVIReturn:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        DVIReturn
            Object containing mean, variance
        '''
        # Ensure model is in evaluation mode
        self.eval()
        
        # Forward pass to get mean prediction
        mean = self.forward(x)
        
        # Initialize variance
        total_var = torch.zeros_like(mean)
        
        # Compute predictive variance by propagating through the network
        current_input = x
        current_var = torch.zeros_like(current_input)
        
        for i, layer in enumerate(self.layers[:-1]):
            # Get mean and variance from this layer
            layer_mean, layer_var = layer.get_predictive_mean_var(current_input)
            
            # Update current input and variance
            current_input = self.activation(layer_mean)
            current_var = layer_var
        
        # Final layer
        _, output_var = self.layers[-1].get_predictive_mean_var(current_input)
        
        # Add observation noise variance
        obs_var = torch.exp(torch.clamp(self.log_obs_var, max=10.0))
        total_var = torch.clamp(output_var + obs_var, min=1e-8)
        
        # Define loss functions
        def train_loss_fn(y_true: torch.Tensor) -> Dict[str, torch.Tensor]:
            # Negative ELBO loss (data likelihood term + KL divergence term)
            # Data likelihood term (assuming Gaussian likelihood)
            obs_var = torch.exp(torch.clamp(self.log_obs_var, max=10.0))
            
            mse_loss = F.mse_loss(mean, y_true, reduction='none')
            log_likelihood = -0.5 * (torch.log(2 * math.pi * total_var) + mse_loss / total_var)
            log_likelihood = log_likelihood.mean()  # Reduce to scalar
            
            # KL divergence term (scaled by the number of batches)
            kl_weight = 1.0 / x.size(0)  # Scale by batch size
            
            # Total loss is negative ELBO
            total_loss = -log_likelihood + self.ratio_kl * kl_weight * self.kl_divergence
            
            return {
                "loss": total_loss, 
                "nll": -log_likelihood, 
                "kl": kl_weight * self.kl_divergence
            }
        
        def val_loss_fn(y_true: torch.Tensor) -> torch.Tensor:
            # Use MSE loss for validation
            return F.mse_loss(mean, y_true)
        
        return DVIReturn(
            mean=mean,
            variance=total_var,
            train_loss_fn=train_loss_fn,
            val_loss_fn=val_loss_fn
        )


def train_dvi_model(model: DVI, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    
    start_time = time.perf_counter()
    
    for epoch in range(num_epochs):

        model.train()
        epoch_nll = 0.0
        epoch_kl = 0.0
        
        for batch_X, batch_y in train_loader:

            # Forward pass
            prediction = model.get_prediction(batch_X)
            
            results = prediction.train_loss_fn(batch_y)
            loss = results['loss']
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Additional clipping for log variance parameters to prevent NaN
            for name, param in model.named_parameters():
                if 'log_var' in name and param.grad is not None:
                    torch.nn.utils.clip_grad_norm_(param, max_norm=0.5)
            
            optimizer.step()
            lr_scheduler.step()
            
            epoch_nll += results['nll'].item()
            epoch_kl += results['kl'].item()
        
        # Validation
        model.eval()
        with torch.no_grad():
            prediction = model.get_prediction(X_test_tensor)
            val_loss = prediction.val_loss_fn(y_test_tensor)
        
        # Record losses
        avg_epoch_nll = epoch_nll / len(train_loader)
        avg_epoch_kl = epoch_kl / len(train_loader)
        avg_train_loss = avg_epoch_nll + avg_epoch_kl
        train_losses.append(avg_train_loss)
        val_losses.append(val_loss.item())
        
        # # Print progress
        # if (epoch + 1) % 50 == 0 or epoch == 0:
        #     print(f"Epoch {epoch+1}/{num_epochs}, NLL: {avg_epoch_nll:.4f}, KL: {avg_epoch_kl:.4f}, Val Loss: {val_loss.item():.4f}")
    
    training_time = time.perf_counter() - start_time
    
    # Compute final MSE and NLL for train and test sets
    model.eval()
    with torch.no_grad():
        # Training data evaluation
        train_prediction = model.get_prediction(X_train_tensor)
        train_mean = train_prediction.mean
        train_var = train_prediction.variance
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation
        test_prediction = model.get_prediction(X_test_tensor)
        test_mean = test_prediction.mean
        test_var = test_prediction.variance
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': num_epochs,
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results


