'''
Regression model for Variational Bayesian Last Layers (VBLL)

This is modified from the original version in the vbll package.

Reference:

    https://arxiv.org/abs/2404.11599
    
    https://github.com/VectorInstitute/vbll

'''

from hvbll.distributions import DiagonalNormal, DenseNormal, gaussian_kl

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from collections.abc import Callable

import time
from typing import Dict, Tuple
from baseline.common import compute_metrics


@dataclass
class VBLLReturn():
    '''
    Dataclass of the VBLL model output.
    
    Attributes
    -------------
    predictive: DiagonalNormal | DenseNormal
        The predictive distribution of the model.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    
    ood_scores: None | Callable[[torch.Tensor], torch.Tensor] = None
        Out-of-distribution scores.
    
    '''
    predictive: DiagonalNormal | DenseNormal
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
    ood_scores: None | Callable[[torch.Tensor], torch.Tensor] = None


class BayesianLinReg(nn.Module):
    '''
    Variational Bayesian Linear Regression.
    
    A linear model: y = W @ x + noise.
    
    - The weight matrix W and noise covariance matrix are treated as random variables.
    - The weight matrix W is assumed to have a Gaussian prior.
    - The noise covariance matrix is assumed to have a Wishart prior.
    - The weight matrix W and noise covariance matrix are assumed to be independent.
    - The model is trained using the evidence lower bound (ELBO) objective.

    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features
        
    reg_weight_latent : float
        Weight on regularization term of weights in the evidence lower bound (ELBO)
        
    reg_weight_noise : float
        Weight on regularization term of noise in the evidence lower bound (ELBO)
        
    covariance_type : str
        Parameterization of covariance matrix. Currently supports {'dense', 'diagonal'}
        
    prior_scale : float
        Scale of the root square of the prior covariance matrix,
        the same order of magnitude as the standard deviation of the data being modeled.
        
    wishart_scale : float
        Scale of the root square of the Wishart prior on noise covariance, 
        the same order of magnitude as the standard deviation of the data being modeled.
        
        The Wishart distribution is a generalization of the gamma distribution to multiple dimensions.
        The Wishart distribution is a probability distribution over positive definite matrices. 
        It is often used in statistics and machine learning as the distribution of the precision matrix 
        (inverse covariance matrix) in a multivariate normal model.
        
    dof : float
        Degrees of freedom of Wishart prior on noise covariance
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 reg_weight_latent: float,
                 reg_weight_noise: float,
                 covariance_type='dense',
                 prior_scale=1.0,
                 wishart_scale=1e-2,
                 dof=1.):
        
        super(BayesianLinReg, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.reg_weight_latent = reg_weight_latent
        self.reg_weight_noise = reg_weight_noise
        self.covariance_type = covariance_type
        self.wishart_scale = wishart_scale
        self.dof = (dof + dim_output + 1.)/2.

        # Define weight's Gaussian prior, currently fixing zero mean and arbitrarily scaled cov
        self.prior_scale = prior_scale * (1. / dim_input) 

        # Noise distribution
        self.noise_mean = nn.Parameter(torch.zeros(dim_output), requires_grad = False)
        self.noise_logCov_diag = nn.Parameter(
            torch.randn(dim_output) * (np.log(wishart_scale)))

        # Stochastic weight distribution
        self.weight_mean = nn.Parameter(torch.randn(dim_output, dim_input))
        self.weight_logCov_diag = nn.Parameter(
            torch.randn(dim_output, dim_input) - 0.5 * np.log(dim_input))
        
        if covariance_type == 'diagonal':
            pass
            
        elif covariance_type == 'dense':
            
            self.weight_logCov_offdiag = nn.Parameter(
                torch.randn(dim_output, dim_input, dim_input)/dim_input)
            
        else:
            raise Exception(f'Covariance parameterization method {covariance_type} is not supported.')

    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device

    def get_weight_distribution(self) -> DiagonalNormal | DenseNormal:
        '''
        Get the weight distribution.
        '''
        cov_diag = torch.exp(self.weight_logCov_diag)
        
        if self.covariance_type == 'diagonal':
            weight_dist = DiagonalNormal(self.weight_mean, cov_diag)
            
        elif self.covariance_type == 'dense':
            
            # Construct the lower triangular matrix
            tril = torch.tril(self.weight_logCov_offdiag, diagonal=-1) + torch.diag_embed(cov_diag)
            
            weight_dist = DenseNormal(self.weight_mean, tril)
            
        else:
            raise Exception(f'Covariance parameterization method {self.covariance_type} is not supported.')
            
        return weight_dist

    def get_noise_distribution(self) -> DiagonalNormal:
        '''
        Get the noise distribution.
        '''
        return DiagonalNormal(self.noise_mean, torch.exp(self.noise_logCov_diag))

    def predictive_distribution(self, x: torch.Tensor) -> DiagonalNormal | DenseNormal:
        '''
        Compute the predictive distribution.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
        
        Returns
        -------
        dist = DiagonalNormal | DenseNormal
            Predictive distribution
        '''
        weight_dist = self.get_weight_distribution()
        noise_dist = self.get_noise_distribution()
        
        dist = (weight_dist @ x[..., None]).squeeze(-1) + noise_dist
        
        return dist

    def _get_train_loss_fn(self, x: torch.Tensor) -> Callable[[torch.Tensor], Dict[str, torch.Tensor]]:
        '''
        Internal method to get the training loss function.
        '''
        def loss_fn(y: torch.Tensor) -> Dict[str, torch.Tensor]:
            
            # construct predictive density N(W @ phi, Sigma)
            
            weight_dist = self.get_weight_distribution()
            noise_dist = self.get_noise_distribution()
            
            #* Reconstruction term
            pred_density = DiagonalNormal((weight_dist.mean @ x[...,None]).squeeze(-1), noise_dist.scale)
            pred_likelihood = pred_density.log_prob(y)

            trace_term = 0.5*((weight_dist.covariance_weighted_inner_prod(x.unsqueeze(-2)[..., None])) 
                                * noise_dist.trace_precision)
            
            reconstruction_term = torch.mean(pred_likelihood - trace_term)

            #* KL divergence between weight distribution and prior
            kl_term = gaussian_kl(weight_dist, q_variance=self.prior_scale)
            
            #* KL divergence between noise distribution and prior
            wishart_term = (self.dof * noise_dist.logdet_precision 
                            - 0.5 * self.wishart_scale * noise_dist.trace_precision)
            
            #* Evidence lower bound loss
            total_elbo = reconstruction_term - self.reg_weight_latent * kl_term + self.reg_weight_noise * wishart_term
            
            #* Negative log likelihood
            nll = - torch.mean(pred_likelihood)
            
            result = {  'neg_total_elbo': -total_elbo, 
                        'nll': nll,
                        'kl_term': kl_term,
                        'wishart_term': wishart_term,}
            
            return result

        return loss_fn

    def _get_val_loss_fn(self, x: torch.Tensor) -> Callable[[torch.Tensor], torch.Tensor]:
        '''
        Internal method to get the validation loss function.
        '''
        
        def loss_fn_old(y):
            # compute log likelihood under variational posterior via marginalization
            logprob = self.predictive_distribution(x).log_prob(y).sum(-1) # sum over output dims            
            return -logprob.mean(0) # mean over batch dim
        
        def loss_fn(y):
            '''
            Negative log likelihood loss function.
            '''
            weight_dist = self.get_weight_distribution()
            noise_dist = self.get_noise_distribution()
            
            pred_density = DiagonalNormal((weight_dist.mean @ x[...,None]).squeeze(-1), noise_dist.scale)
            pred_likelihood = pred_density.log_prob(y)

            return - torch.mean(pred_likelihood)

        return loss_fn

    def forward(self, x: torch.Tensor) -> VBLLReturn:
        '''
        Forward pass of the model.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
        
        Returns
        -------
        out: VBLLReturn
            Dataclass of the model output.
        '''
        out = VBLLReturn(self.predictive_distribution(x),
                         self._get_train_loss_fn(x),
                         self._get_val_loss_fn(x))
        return out

    @property
    def noise_std_numpy(self) -> np.ndarray:
        '''
        Get the noise standard deviation as numpy array.
        
        ndarray [dim_output]
        '''
        return torch.exp(self.noise_logCov_diag).detach().cpu().numpy()

    @property
    def weight_mean_numpy(self) -> np.ndarray:
        '''
        Get the weight mean as numpy array.
        '''
        return self.weight_mean.detach().cpu().numpy()

    @property
    def weight_std_numpy(self) -> np.ndarray:
        '''
        Get the weight standard deviation as numpy array.
        (This is the square root of the diagonal of the covariance matrix.)
        '''
        return torch.exp(self.weight_logCov_diag).detach().cpu().numpy()

    def get_aleatoric_uncertainty(self) -> torch.Tensor:
        '''
        Get the aleatoric uncertainty, i.e., the noise variance.
        
        Returns
        -------
        noise_var : torch.Tensor [dim_output]
            Noise variance for each output dimension.
        '''
        return torch.exp(self.noise_logCov_diag) ** 2

    def get_epistemic_uncertainty(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Get the epistemic uncertainty, i.e., x^T @ Cov(W) @ x.
        
        Parameters
        ----------
        x : torch.Tensor [n, dim_input]
            Input data
        
        Returns
        -------
        epistemic_var : torch.Tensor [n, dim_output, dim_output]
            Epistemic uncertainty for each input data point.
        '''
        weight_dist = self.get_weight_distribution()
        
        new_cov = weight_dist.covariance_weighted_inner_prod(x[..., None])
        
        return new_cov

    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001):
        '''
        Plot the mean function and noise level.
        '''
        if self.dim_input != 1 or self.dim_output != 1:
            raise Exception('The function is only supported for 1-d input and output.')
        
        xx = np.linspace(x_min, x_max, num_points, endpoint=True)
        yy = self.weight_mean_numpy[0] * xx
        
        tX = torch.from_numpy(xx).float().unsqueeze(-1).to(self.device)
        
        a_uncertainty = self.noise_std_numpy[0]
        
        e_uncertainty = self.get_epistemic_uncertainty(tX).detach().cpu().numpy()
        
        t_uncertainty = np.sqrt(a_uncertainty**2 + e_uncertainty)

        plt.plot(xx, yy, 'r--', label='Mean prediction')

        plt.fill_between(xx, yy - t_uncertainty, yy + t_uncertainty, 
                            alpha=0.2, color='g', label='Total uncertainty (1 std)')
        
        plt.fill_between(xx, yy - a_uncertainty, yy + a_uncertainty, 
                            alpha=0.2, color='r', label='Aleatoric uncertainty (1 std)')


class VBLL(nn.Module):
    '''
    Variational Bayesian Last Layer.
    
    A neural network model with a variational Bayesian last layer.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features
        
    dim_latent : int
        Number of latent features
        
    dim_hidden : int
        Number of hidden units
        
    n_hidden_layers : int
        Number of hidden layers
        
    reg_weight_latent : float
        Weight on regularization term of latent variables in the evidence lower bound (ELBO)
        
    reg_weight_noise : float
        Weight on regularization term of noise in the evidence lower bound (ELBO)
        
    covariance_type : str
        Parameterization of covariance matrix. Currently supports {'dense', 'diagonal'}
        
    prior_scale : float
        Scale of the root square of the prior covariance matrix,
        the same order of magnitude as the standard deviation of the data being modeled.
        
    wishart_scale : float
        Scale of the root square of the Wishart prior on noise covariance, 
        the same order of magnitude as the standard deviation of the data being modeled.
        
        The Wishart distribution is a generalization of the gamma distribution to multiple dimensions.
        
    dof : float
        Degrees of freedom of Wishart prior on noise covariance
    '''
 
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_latent: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 reg_weight_latent: float,
                 reg_weight_noise: float,
                 covariance_type='dense',
                 prior_scale=1.0,
                 wishart_scale=1e-2,
                 dof=1.):
        
        super(VBLL, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_latent = dim_latent
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.covariance_type = covariance_type
        self.prior_scale = prior_scale
        self.wishart_scale = wishart_scale
        self.dof = dof

        self.layers = nn.ModuleDict({
            
            'in_layer': nn.Linear(dim_input, dim_hidden),
            
            'core': nn.ModuleList(
                [nn.Linear(dim_hidden, dim_hidden) for _ in range(n_hidden_layers-1)] +
                [nn.Linear(dim_hidden, dim_latent)] ),
            
            'out_layer': BayesianLinReg(dim_latent, dim_output, 
                            reg_weight_latent, reg_weight_noise, covariance_type, prior_scale, wishart_scale, dof)
            })

        self.activations = nn.ModuleList([nn.ELU() for _ in range(n_hidden_layers)])

    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return 'VBLL'

    def forward(self, x: torch.Tensor) -> VBLLReturn:
        
        x = self.layers['in_layer'](x)

        for layer, ac in zip(self.layers['core'], self.activations):
            x = ac(layer(x))

        out = self.layers['out_layer'](x)

        return out
    
    def get_prediction(self, xs: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        
        self.eval()
        
        out = self.forward(xs)
        
        dist_y = out.predictive
        y_mean = dist_y.mean.cpu().detach().numpy().squeeze()
        y_var  = dist_y.covariance.squeeze().cpu().detach().numpy()
        
        vblr = self.layers['out_layer']
        a_uncertainty = np.ones_like(y_mean) * (vblr.noise_std_numpy)**2

        self.train()
        
        return y_mean, y_var, a_uncertainty
    
    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001):
        '''
        Plot the mean function and noise level.
        '''
        if self.dim_input != 1 or self.dim_output != 1:
            raise Exception('The function is only supported for 1-d input and output.')
        
        self.eval()
        
        tX = torch.linspace(x_min, x_max, num_points).float().unsqueeze(-1).to(self.device)
        xx = tX.cpu().detach().numpy().squeeze()

        dist_y = self.forward(tX).predictive
        y_mean = dist_y.mean.cpu().detach().numpy().squeeze()
        y_std  = torch.sqrt(dist_y.covariance.squeeze()).cpu().detach().numpy()
        
        vblr = self.layers['out_layer']
        a_uncertainty = vblr.noise_std_numpy[0]
        
        plt.plot(xx, y_mean, 'r--', label='Mean prediction')

        plt.fill_between(xx, y_mean - y_std, y_mean + y_std, 
                            alpha=0.2, color='g', label='Total uncertainty (1 std)')
        
        plt.fill_between(xx, y_mean - a_uncertainty, y_mean + a_uncertainty, 
                            alpha=0.2, color='r', label='Aleatoric uncertainty (1 std)')
        
        self.train()


def train_vbll_model(model: VBLL, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    best_result = {}
    
    start_time = time.perf_counter()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        epoch_nll = 0
        epoch_kl = 0
        epoch_wishart = 0
        
        for batch_X, batch_y in train_loader:
            
            # Forward pass
            out = model.forward(batch_X)
            
            # Calculate loss (negative ELBO)
            result = out.train_loss_fn(batch_y)
            loss = result['neg_total_elbo']
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()
            
            epoch_loss += loss.item()
            epoch_nll += result['nll'].item()
            epoch_kl += result['kl_term'].item()
            epoch_wishart += result['wishart_term'].item()

        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)

        # Print progress
        if (epoch + 1) % 50 == 0 or epoch == 0:
            
            # Compute MSE and NLL for train and test sets
            model.eval()
            with torch.no_grad():
                
                # Training data evaluation
                train_out = model.forward(X_train_tensor)
                train_mean = train_out.predictive.mean
                train_var  = train_out.predictive.covariance
                
                # Compute all training metrics
                train_metrics = compute_metrics(
                    train_mean, train_var, y_train_tensor
                )
                
                # Testing data evaluation
                test_out = model.forward(X_test_tensor)
                test_mean = test_out.predictive.mean
                test_var = test_out.predictive.covariance
                
                # Compute all testing metrics
                test_metrics = compute_metrics(
                    test_mean, test_var, y_test_tensor
                )
                
                train_mse = train_metrics['mse']
                train_nll = train_metrics['nll']
                test_mse = test_metrics['mse']
                test_nll = test_metrics['nll']
            
            # Record losses
            train_losses.append(avg_train_loss)
            val_losses.append(test_nll)
            
            # Update best
            if len(best_result) == 0 or (train_nll + 0.2*test_nll < best_result['train_nll'] + 0.2*best_result['test_nll']):
                best_result = {
                    'train_metrics': train_metrics,
                    'test_metrics': test_metrics,
                    'train_mse': train_mse,
                    'train_nll': train_nll,
                    'test_mse': test_mse,
                    'test_nll': test_nll,
                    'epoch': epoch + 1,
                }

            # print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train NLL: {train_nll:.4f}, Test NLL: {test_nll:.4f}")
    
            if epoch > 0.25*int(num_epochs) and 'epoch' in best_result:
                if epoch - best_result['epoch'] > 0.25*int(num_epochs):
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
    
    training_time = time.perf_counter() - start_time
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': best_result['train_mse'],
        'train_nll': best_result['train_nll'],
        'test_mse': best_result['test_mse'],
        'test_nll': best_result['test_nll'],
        'epoch': best_result['epoch'],
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in best_result['train_metrics'].items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in best_result['test_metrics'].items():
        results[f'test_{key}'] = value
    
    return results



