'''
Regression model for Heteroscedastic Variational Bayesian Last Layers (H-VBLL)
'''

from hvbll.distributions import DiagonalNormal, DenseNormal, gaussian_kl

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections.abc import Callable

import time
from typing import Dict, Tuple
from baseline.common import compute_metrics


@dataclass
class HVBLLReturn():
    '''
    Dataclass of the H-VBLL model output.
    
    Attributes
    ----------------
    predictive: DiagonalNormal | DenseNormal
        The predictive distribution of the model.
        
    latent: torch.Tensor | None
        The latent variables.

    x: torch.Tensor | None
        The input data.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    
    ood_scores: None | Callable[[torch.Tensor], torch.Tensor] = None
        Out-of-distribution scores.
    
    '''
    predictive: DiagonalNormal | DenseNormal
    latent: torch.Tensor | None
    x: torch.Tensor | None
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
    ood_scores: None | Callable[[torch.Tensor], torch.Tensor] = None


class HeteroBayesianLinReg(nn.Module):
    '''
    Heteroscedastic Variational Bayesian Linear Regression.
    
    A linear model: y = W @ phi(x) + noise(x).
    
    - The weight matrix W and noise covariance matrix are treated as random variables.
    - The weight matrix W is assumed to have a Gaussian prior.
    - The noise covariance matrix is assumed to have a Wishart prior.
    - The weight matrix W and noise covariance matrix are assumed to be independent.
    - The model is trained using the evidence lower bound (ELBO) objective.

    Parameters
    ----------
    dim_input : int
        Dimension of input, x

    dim_feature : int
        Dimension of feature, phi(x)
        
    dim_output : int
        Dimension of output, y
        
    dim_hidden : int
        Number of hidden neurons in the hidden layer
        
    n_hidden_layers : int
        Number of hidden layers for the noise variance
        
    reg_weight_latent : float
        Weight on regularization term of weights in the evidence lower bound (ELBO)
        
    reg_weight_noise : float
        Weight on regularization term of noise in the evidence lower bound (ELBO)
        
    covariance_type : str
        Parameterization of covariance matrix. Currently supports {'dense', 'diagonal'}
        
    prior_scale : float
        Scale of the root square of the prior covariance matrix,
        the same order of magnitude as the standard deviation of the data being modeled.
        
    wishart_scale : float
        Scale of the root square of the Wishart prior on noise covariance, 
        the same order of magnitude as the standard deviation of the data being modeled.
        
        The Wishart distribution is a generalization of the gamma distribution to multiple dimensions.
        The Wishart distribution is a probability distribution over positive definite matrices. 
        It is often used in statistics and machine learning as the distribution of the precision matrix 
        (inverse covariance matrix) in a multivariate normal model.
        
    dof : float
        Degrees of freedom of Wishart prior on noise covariance
    '''
    def __init__(self,
                 dim_input: int,
                 dim_feature: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 reg_weight_latent: float,
                 reg_weight_noise: float,
                 covariance_type='dense',
                 prior_scale=1.0,
                 wishart_scale=1e-2,
                 dof=1.):
        
        super(HeteroBayesianLinReg, self).__init__()
        
        self.dim_input = dim_input
        self.dim_feature = dim_feature
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.reg_weight_latent = reg_weight_latent
        self.reg_weight_noise = reg_weight_noise
        self.covariance_type = covariance_type
        self.wishart_scale = wishart_scale
        self.dof = (dof + dim_output + 1.)/2.

        # Define weight's Gaussian prior, currently fixing zero mean and arbitrarily scaled cov
        self.prior_scale = prior_scale * (1. / dim_feature) 

        # Noise distribution
        self.noise_mean = nn.Parameter(torch.zeros(dim_output), requires_grad = False)
        self.init_noise_network()

        # Stochastic weight distribution
        self.weight_mean = nn.Parameter(torch.randn(dim_output, dim_feature))
        self.weight_logCov_diag = nn.Parameter(
            torch.randn(dim_output, dim_feature) - 0.5 * np.log(dim_feature))
        
        if covariance_type == 'diagonal':
            pass
            
        elif covariance_type == 'dense':
            
            self.weight_logCov_offdiag = nn.Parameter(
                torch.randn(dim_output, dim_feature, dim_feature)/dim_feature)
            
        else:
            raise Exception(f'Covariance parameterization method {covariance_type} is not supported.')
        
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    def init_noise_network(self):
        '''
        Initialize the noise network.
        '''
        self._noise_logCov_diag = nn.ModuleList([nn.Linear(self.dim_input, self.dim_hidden), nn.ELU()])
        
        for _ in range(self.n_hidden_layers):
            self._noise_logCov_diag.append(nn.Linear(self.dim_hidden, self.dim_hidden))
            self._noise_logCov_diag.append(nn.ELU())
        
        self._noise_logCov_diag.append(nn.Linear(self.dim_hidden, self.dim_output))
        
    def noise_logCov_diag(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Logarithm of the diagonal of the noise covariance matrix.
        
        Parameters
        ----------
        x : torch.Tensor [n, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [n, dim_output]
            Logarithm of the diagonal of the noise covariance matrix
        '''
        for layer in self._noise_logCov_diag:
            x = layer(x)
            
        return x
        
    def get_weight_distribution(self) -> DiagonalNormal | DenseNormal:
        '''
        Get the weight distribution.
        
        Returns
        -------
        dist = DiagonalNormal | DenseNormal
            Weight distribution, [1, dim_output]
        '''
        cov_diag = torch.exp(self.weight_logCov_diag)

        if self.covariance_type == 'diagonal':
            weight_dist = DiagonalNormal(self.weight_mean, cov_diag)
            
        elif self.covariance_type == 'dense':
            
            # Construct the lower triangular matrix
            tril = torch.tril(self.weight_logCov_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            
            weight_dist = DenseNormal(self.weight_mean, tril)
            
        else:
            raise Exception(f'Covariance parameterization method {self.covariance_type} is not supported.')
            
        return weight_dist

    def get_noise_distribution(self, x: torch.Tensor) -> DiagonalNormal:
        '''
        Get the noise distribution.
        
        Parameters
        ----------
        x : torch.Tensor [n, dim_input]
            Input tensor
            
        Returns
        -------
        dist = DiagonalNormal
            Noise distribution, [n, dim_output]
        '''
        scale = torch.exp(self.noise_logCov_diag(x))
        return DiagonalNormal(self.noise_mean, scale)

    def predictive_distribution(self, phi: torch.Tensor, x: torch.Tensor) -> DiagonalNormal | DenseNormal:
        '''
        Compute the predictive distribution.
        
        Parameters
        ----------
        phi : torch.Tensor [n, dim_feature]
            Feature tensor

        x : torch.Tensor [n, dim_input]
            Input tensor
        
        Returns
        -------
        dist = DiagonalNormal | DenseNormal
            Predictive distribution
        '''
        weight_dist = self.get_weight_distribution()    # [dim_output, dim_feature]
        noise_dist = self.get_noise_distribution(x)     # [n, dim_output]
        
        '''
        weight_dist (DenseNormal): 
            loc: [dim_output, dim_feature]
            scale_tril: [dim_output, dim_feature, dim_feature]
            
        noise_dist (DiagonalNormal):
            loc: [n, dim_output]
            scale: [n, dim_output]
        '''

        dist = (weight_dist @ phi[..., None]).squeeze(-1) + noise_dist
        
        return dist

    def _get_train_loss_fn(self, phi: torch.Tensor, x: torch.Tensor) -> Callable[[torch.Tensor], Dict[str, torch.Tensor]]:
        '''
        Internal method to get the training loss function.

        Parameters
        ----------
        phi : torch.Tensor [n, dim_feature]
            Feature tensor

        x : torch.Tensor [n, dim_input]
            Input tensor
        
        Returns
        -------
        loss_fn : Callable[[torch.Tensor], Dict[str, torch.Tensor]]
            Training loss function
        '''
        def loss_fn(y: torch.Tensor) -> Dict[str, torch.Tensor]:
            
            # construct predictive density N(W @ phi, Sigma)
            
            weight_dist = self.get_weight_distribution()
            noise_dist = self.get_noise_distribution(x)
                        
            '''
            weight_dist (DenseNormal): 
                loc: [dim_output, dim_feature]
                scale_tril: [dim_output, dim_feature, dim_feature]
                
            noise_dist (DiagonalNormal):
                loc: [n, dim_output]
                scale: [n, dim_output]
                precision_diagonal: [n, dim_output]
                
            pred_density (DiagonalNormal):
                loc: [n, dim_output]
                scale: [n, dim_output]
                
            phi: [n, dim_feature]
            pred_likelihood: [n, dim_output]
            
            phi[...,None]: [n, dim_feature, 1]
            phi.unsqueeze(-2)[..., None]: [n, 1, dim_feature, 1]
            weight_dist.mean @ phi[...,None]: [n, dim_output]
            '''
            
            #* Reconstruction term
            pred_mean = (weight_dist.mean @ phi[...,None]).squeeze(-1)
            pred_density = DiagonalNormal(pred_mean, noise_dist.scale)
            pred_likelihood = pred_density.log_prob(y)

            trace_term = 0.5*((weight_dist.covariance_weighted_inner_prod(phi.unsqueeze(-2)[..., None])) 
                                * noise_dist.precision_diagonal)

            reconstruction_term = torch.mean(pred_likelihood - trace_term)

            #* KL divergence between weight distribution and prior
            kl_term = gaussian_kl(weight_dist, q_variance=self.prior_scale)
            
            #* KL divergence between noise distribution and prior
            wishart_term = torch.mean(self.dof * noise_dist.logdet_precision 
                            - 0.5 * self.wishart_scale * noise_dist.trace_precision)

            #* Evidence lower bound loss
            total_elbo = reconstruction_term - self.reg_weight_latent * kl_term + self.reg_weight_noise * wishart_term
            
            #* Negative log likelihood
            nll = - torch.mean(pred_likelihood)
            mse = F.mse_loss(pred_mean, y)
            
            result = {  'neg_total_elbo': -total_elbo, 
                        'nll': nll,
                        'mse': mse,
                        'kl_term': kl_term,
                        'wishart_term': wishart_term,}
            
            return result

        return loss_fn

    def _get_val_loss_fn(self, phi: torch.Tensor, x: torch.Tensor) -> Callable[[torch.Tensor], torch.Tensor]:
        '''
        Internal method to get the validation loss function.

        Parameters
        ----------
        phi : torch.Tensor [n, dim_feature]
            Feature tensor

        x : torch.Tensor [n, dim_input]
            Input tensor
        
        Returns
        -------
        loss_fn : Callable[[torch.Tensor], torch.Tensor]
            Validation loss function
        '''
        
        def loss_fn_old(y):
            # compute log likelihood under variational posterior via marginalization
            logprob = self.predictive_distribution(phi, x).log_prob(y).sum(-1) # sum over output dims            
            return -logprob.mean(0) # mean over batch dim
        
        def loss_fn(y):
            '''
            Negative log likelihood loss function.
            '''
            weight_dist = self.get_weight_distribution()
            noise_dist = self.get_noise_distribution(x)
            
            pred_density = DiagonalNormal((weight_dist.mean @ phi[...,None]).squeeze(-1), noise_dist.scale)
            pred_likelihood = pred_density.log_prob(y)

            return - torch.mean(pred_likelihood)

        return loss_fn

    def forward(self, phi: torch.Tensor, x: torch.Tensor) -> HVBLLReturn:
        '''
        Forward pass of the model.
        
        Parameters
        ----------
        phi : torch.Tensor [batch_size, dim_feature]
            Feature tensor

        x : torch.Tensor [batch_size, dim_input]
            Input tensor
        
        Returns
        -------
        out: HVBLLReturn
            Dataclass of the model output.
        '''
        out=HVBLLReturn(self.predictive_distribution(phi, x),
                        phi,
                        x,
                        self._get_train_loss_fn(phi, x),
                        self._get_val_loss_fn(phi, x),
                        )
        return out

    @property
    def weight_mean_numpy(self) -> np.ndarray:
        '''
        Get the weight mean as numpy array.
        '''
        return self.weight_mean.detach().cpu().numpy()

    @property
    def weight_std_numpy(self) -> np.ndarray:
        '''
        Get the weight standard deviation as numpy array.
        (This is the square root of the diagonal of the covariance matrix.)
        '''
        return torch.exp(self.weight_logCov_diag).detach().cpu().numpy()

    def get_aleatoric_uncertainty(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Get the aleatoric uncertainty, i.e., the noise variance.
        
        Parameters
        ----------
        x : torch.Tensor [n, dim_input]
            Input data
            
        Returns
        -------
        noise_var : torch.Tensor [n, dim_output]
            Noise variance for each output dimension.
        '''
        return torch.exp(self.noise_logCov_diag(x)) ** 2

    def get_epistemic_uncertainty(self, phi: torch.Tensor) -> torch.Tensor:
        '''
        Get the epistemic uncertainty, i.e., phi^T @ Cov(W) @ phi.
        
        Parameters
        ----------
        phi : torch.Tensor [n, dim_feature]
            Feature tensor
            
        Returns
        -------
        epistemic_var : torch.Tensor [n, dim_output]
            Epistemic uncertainty for each input data point.
        '''
        weight_dist = self.get_weight_distribution()

        phi = phi.unsqueeze(-1)
        
        assert phi.shape[-2] == weight_dist.loc.shape[-1]
        assert phi.shape[-1] == 1

        new_cov = weight_dist.covariance_weighted_inner_prod(phi.unsqueeze(-3))
        new_cov = torch.clip(new_cov, min = 1e-12)
        
        return new_cov

    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001):
        '''
        Plot the mean function and noise level.
        '''
        if self.dim_input != 1 or self.dim_output != 1:
            raise Exception('The function is only supported for 1-d input and output.')
        
        xx = np.linspace(x_min, x_max, num_points, endpoint=True)
        yy = self.weight_mean_numpy[0] * xx

        tX = torch.from_numpy(xx).float().unsqueeze(-1).to(self.device)
        
        a_uncertainty = self.get_aleatoric_uncertainty(tX).detach().cpu().numpy()
        e_uncertainty = self.get_epistemic_uncertainty(tX).detach().cpu().numpy()
        t_uncertainty = a_uncertainty + e_uncertainty

        a_uncertainty = np.sqrt(a_uncertainty)[:,0]
        t_uncertainty = np.sqrt(t_uncertainty)[:,0]
        
        plt.plot(xx, yy, 'r--', label='Mean prediction')

        plt.fill_between(xx, yy - t_uncertainty, yy + t_uncertainty, 
                            alpha=0.2, color='g', label='Total uncertainty (1 std)')
        
        plt.fill_between(xx, yy - a_uncertainty, yy + a_uncertainty, 
                            alpha=0.2, color='r', label='Aleatoric uncertainty (1 std)')


class HVBLL(nn.Module):
    '''
    Heteroscedastic Variational Bayesian Last Layer.
    
    A neural network model with a heteroscedastic variational Bayesian last layer.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features
        
    dim_latent : int
        Number of latent features
        
    dim_hidden : int
        Number of neurons in each hidden layer for the latent features
        
    dim_hidden_noise : int
        Number of neurons in each hidden layer for the noise variance
        
    n_hidden_layers : int
        Number of hidden layers for the latent features
        
    n_noise_layers : int
        Number of hidden layers for the noise variance
        
    reg_weight_latent : float
        Weight on regularization term of latent variables in the evidence lower bound (ELBO)
        
    reg_weight_noise : float
        Weight on regularization term of noise in the evidence lower bound (ELBO)
        
    covariance_type : str
        Parameterization of covariance matrix. Currently supports {'dense', 'diagonal'}
        
    prior_scale : float
        Scale of the root square of the prior covariance matrix,
        the same order of magnitude as the standard deviation of the data being modeled.
        
    wishart_scale : float
        Scale of the root square of the Wishart prior on noise covariance, 
        the same order of magnitude as the standard deviation of the data being modeled.
        
        The Wishart distribution is a generalization of the gamma distribution to multiple dimensions.
        
    dof : float
        Degrees of freedom of Wishart prior on noise covariance
    '''
 
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_latent: int,
                 dim_hidden: int,
                 dim_hidden_noise: int,
                 n_hidden_layers: int,
                 n_noise_layers: int,
                 reg_weight_latent: float,
                 reg_weight_noise: float,
                 covariance_type='dense',
                 prior_scale=1.0,
                 wishart_scale=1e-2,
                 dof=1.):
        
        super(HVBLL, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_latent = dim_latent
        self.dim_hidden = dim_hidden
        self.dim_hidden_noise = dim_hidden_noise
        self.n_hidden_layers = n_hidden_layers
        self.n_noise_layers = n_noise_layers
        self.covariance_type = covariance_type
        self.prior_scale = prior_scale
        self.wishart_scale = wishart_scale
        self.dof = dof

        self.layers = nn.ModuleDict({
            
            'in_layer': nn.Linear(dim_input, dim_hidden),
            
            'core': nn.ModuleList(
                [nn.Linear(dim_hidden, dim_hidden) for _ in range(n_hidden_layers-1)] +
                [nn.Linear(dim_hidden, dim_latent)] ),
            
            'out_layer': HeteroBayesianLinReg(dim_input, dim_latent, dim_output, dim_hidden_noise, n_noise_layers,
                            reg_weight_latent, reg_weight_noise, 
                            covariance_type, prior_scale, wishart_scale, dof)
            })

        self.activations = nn.ModuleList([nn.ELU() for _ in range(n_hidden_layers)])

    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return 'HVBLL'

    def forward(self, x: torch.Tensor) -> HVBLLReturn:
        
        v = self.layers['in_layer'](x)

        for layer, ac in zip(self.layers['core'], self.activations):
            v = ac(layer(v))

        out = self.layers['out_layer'].forward(v, x)

        return out
    
    def get_prediction(self, xs: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        
        self.eval()
        
        out = self.forward(xs)
        
        dist_y = out.predictive
        y_mean = dist_y.mean.cpu().detach().numpy().squeeze()
        y_var  = dist_y.covariance.squeeze().cpu().detach().numpy()
        
        hvblr = self.layers['out_layer']

        a_uncertainty = hvblr.get_aleatoric_uncertainty(xs).cpu().detach().cpu().numpy()
        a_uncertainty = a_uncertainty[:,0]

        self.train()
        
        return y_mean, y_var, a_uncertainty
    
    def get_validation(self, xs: torch.Tensor, ys: torch.Tensor) -> Tuple[float, float, HVBLLReturn]:
        '''
        
        '''
        self.eval()
        out = self.forward(xs)
        mean = out.predictive.mean
        var  = out.predictive.covariance
        
        mse = F.mse_loss(mean, ys).item()
        nll = -torch.mean(out.predictive.log_prob(ys)).item()
        
        self.train()
        
        return mse, nll, out
    
    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001):
        '''
        Plot the mean function and noise level.
        '''
        if self.dim_input != 1 or self.dim_output != 1:
            raise Exception('The function is only supported for 1-d input and output.')
        
        self.eval()
        
        tX = torch.linspace(x_min, x_max, num_points).float().unsqueeze(-1).to(self.device)
        xx = tX.cpu().detach().numpy().squeeze()
        
        out = self.forward(tX)
        
        dist_y = out.predictive
        y_mean = dist_y.mean.cpu().detach().numpy().squeeze()
        y_std  = torch.sqrt(dist_y.covariance.squeeze()).cpu().detach().numpy()
        
        hvblr = self.layers['out_layer']

        a_uncertainty = hvblr.get_aleatoric_uncertainty(tX).cpu().detach().cpu().numpy()
        a_uncertainty = np.sqrt(a_uncertainty)[:,0]

        plt.plot(xx, y_mean, 'r--', label='Mean prediction')

        plt.fill_between(xx, y_mean - y_std, y_mean + y_std, 
                            alpha=0.2, color='g', label='Total uncertainty (1 std)')
        
        plt.fill_between(xx, y_mean - a_uncertainty, y_mean + a_uncertainty, 
                            alpha=0.2, color='r', label='Aleatoric uncertainty (1 std)')
        
        self.train()
        
        

def train_hvbll_model(
        model: HVBLL, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    best_result = {}
    
    start_time = time.perf_counter()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_nll = 0
        epoch_kl = 0
        epoch_wishart = 0
        
        for batch_X, batch_y in train_loader:
            
            # Forward pass
            out = model.forward(batch_X)
            
            # Calculate loss (negative ELBO)
            result = out.train_loss_fn(batch_y)
            loss = result['neg_total_elbo']
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()
            
            epoch_loss += loss.item()
            epoch_nll += result['nll'].item()
            epoch_kl += result['kl_term'].item()
            epoch_wishart += result['wishart_term'].item()

        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)

        # Print progress
        if (epoch + 1) % 50 == 0 or epoch == 0:
            
            # Compute MSE and NLL for train and test sets
            model.eval()
            with torch.no_grad():
                
                # Training data evaluation
                train_out = model.forward(X_train_tensor)
                train_mean = train_out.predictive.mean
                train_var  = train_out.predictive.covariance
                
                # Compute all training metrics
                train_metrics = compute_metrics(
                    train_mean, train_var, y_train_tensor
                )
                
                # Testing data evaluation
                test_out = model.forward(X_test_tensor)
                test_mean = test_out.predictive.mean
                test_var = test_out.predictive.covariance
                
                # Compute all testing metrics
                test_metrics = compute_metrics(
                    test_mean, test_var, y_test_tensor
                )
                
                train_mse = train_metrics['mse']
                train_nll = train_metrics['nll']
                test_mse = test_metrics['mse']
                test_nll = test_metrics['nll']
            
            # Record losses
            train_losses.append(avg_train_loss)
            val_losses.append(test_nll)
            
            # Update best
            if len(best_result) == 0 or (train_nll + test_nll < best_result['train_nll'] + best_result['test_nll']):
                best_result = {
                    'train_metrics': train_metrics,
                    'test_metrics': test_metrics,
                    'train_mse': train_mse,
                    'train_nll': train_nll,
                    'test_mse': test_mse,
                    'test_nll': test_nll,
                    'epoch': epoch + 1,
                }

            # print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train NLL: {train_nll:.4f}, Test NLL: {test_nll:.4f}")
    
            if epoch > 0.25*int(num_epochs) and 'epoch' in best_result:
                if epoch - best_result['epoch'] > 0.25*int(num_epochs):
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
    
    training_time = time.perf_counter() - start_time
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': best_result['train_mse'],
        'train_nll': best_result['train_nll'],
        'test_mse': best_result['test_mse'],
        'test_nll': best_result['test_nll'],
        'epoch': best_result['epoch'],
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in best_result['train_metrics'].items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in best_result['test_metrics'].items():
        results[f'test_{key}'] = value
    
    return results


