'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

6.  Latent Derivative Bayesian Last Layer Networks (GBLL, LDGBLL)
    Joe Watson, et al, PMLR, 2021.
    https://proceedings.mlr.press/v130/watson21a/watson21a.pdf

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Callable, Union
from dataclasses import dataclass
import math
import time
from baseline.common import compute_metrics


class FeatureExtractor(nn.Module):
    '''
    Neural network feature extractor for Bayesian Last Layer Networks.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    dim_latent : int
        Number of output features (dimension of the feature space)
        
    n_hidden_layers : int
        Number of hidden layers
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_hidden: int,
                 dim_latent: int,
                 n_hidden_layers: int,
                 activation: nn.Module = nn.ReLU()):
        super(FeatureExtractor, self).__init__()
        
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_latent = dim_latent
        self.n_hidden_layers = n_hidden_layers
        self.activation = activation
        
        # Input layer
        layers = [nn.Linear(dim_input, dim_hidden), activation]
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            layers.append(nn.Linear(dim_hidden, dim_hidden))
            layers.append(activation)
        
        # Output layer
        layers.append(nn.Linear(dim_hidden, dim_latent))
        
        self.model = nn.Sequential(*layers)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the feature extractor.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_latent]
            Feature tensor
        '''
        return self.model(x)


class GaussianProcessLastLayer(nn.Module):
    '''
    Gaussian Process Last Layer for Bayesian Last Layer Networks.
    
    Parameters
    ----------
    dim_latent : int
        Dimension of the feature space
        
    kernel_type : str
        Type of kernel to use ('rbf', 'matern32', 'matern52')
        
    jitter : float
        Small value added to the diagonal of the kernel matrix for numerical stability
    '''
    def __init__(self,
                 dim_latent: int,
                 kernel_type: str = 'rbf',
                 jitter: float = 1e-6,
                 prior_log_obs_var: float = -3.0):
        super(GaussianProcessLastLayer, self).__init__()
        
        self.dim_latent = dim_latent
        self.kernel_type = kernel_type
        self.jitter = jitter
        
        # Kernel hyperparameters
        self.log_length_scale = nn.Parameter(torch.zeros(1))
        self.log_output_scale = nn.Parameter(torch.zeros(1))
        
        # Observation noise variance
        self.log_obs_var = nn.Parameter(torch.ones(1)*prior_log_obs_var)
        
        # Training data (will be set during training)
        self.register_buffer('X_train_features', None)
        self.register_buffer('y_train', None)
        self.register_buffer('K_inv', None)
        self.register_buffer('alpha', None)
    
    def kernel(self, x1: torch.Tensor, x2: torch.Tensor = None) -> torch.Tensor:
        '''
        Compute the kernel matrix between x1 and x2.
        
        Parameters
        ----------
        x1 : torch.Tensor [n1, dim_latent]
            First set of feature vectors
            
        x2 : torch.Tensor [n2, dim_latent]
            Second set of feature vectors (if None, use x1)
            
        Returns
        -------
        K : torch.Tensor [n1, n2]
            Kernel matrix
        '''
        length_scale = torch.exp(self.log_length_scale)
        output_scale = torch.exp(self.log_output_scale)
        
        if x2 is None:
            x2 = x1
        
        # Compute squared distances
        dist_squared = torch.cdist(x1 / length_scale, x2 / length_scale, p=2).pow(2)
        
        # Apply kernel function
        if self.kernel_type == 'rbf':
            # RBF kernel: k(x, x') = output_scale * exp(-0.5 * ||x - x'||^2 / length_scale^2)
            K = output_scale * torch.exp(-0.5 * dist_squared)
        
        elif self.kernel_type == 'matern32':
            # Matern 3/2 kernel
            dist = torch.sqrt(dist_squared)
            K = output_scale * (1 + math.sqrt(3) * dist) * torch.exp(-math.sqrt(3) * dist)
        
        elif self.kernel_type == 'matern52':
            # Matern 5/2 kernel
            dist = torch.sqrt(dist_squared)
            K = output_scale * (1 + math.sqrt(5) * dist + 5 * dist_squared / 3) * torch.exp(-math.sqrt(5) * dist)
        
        else:
            raise ValueError(f"Unknown kernel type: {self.kernel_type}")
        
        return K
    
    def fit(self, X_features: torch.Tensor, y: torch.Tensor):
        '''
        Fit the Gaussian Process to the training data.
        
        Parameters
        ----------
        X_features : torch.Tensor [n_samples, dim_latent]
            Feature vectors of the training inputs
            
        y : torch.Tensor [n_samples, 1]
            Training targets
        '''
        # Store training data
        self.X_train_features = X_features
        self.y_train = y
        
        # Compute kernel matrix
        K = self.kernel(X_features)
        
        # Add observation noise to diagonal
        obs_var = torch.exp(self.log_obs_var)
        K_noisy = K + obs_var * torch.eye(K.shape[0], device=K.device)
        
        # Add jitter for numerical stability
        K_noisy = K_noisy + self.jitter * torch.eye(K_noisy.shape[0], device=K_noisy.device)
        
        # Compute inverse of kernel matrix
        L = torch.linalg.cholesky(K_noisy)
        self.K_inv = torch.cholesky_inverse(L)
        
        # Compute alpha = K^-1 * y
        self.alpha = torch.matmul(self.K_inv, y)
    
    def predict(self, X_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        Make predictions with the Gaussian Process.
        
        Parameters
        ----------
        X_features : torch.Tensor [n_samples, dim_latent]
            Feature vectors of the test inputs
            
        Returns
        -------
        mean : torch.Tensor [n_samples, 1]
            Predictive mean
            
        var : torch.Tensor [n_samples, 1]
            Predictive variance
        '''
        if self.X_train_features is None:
            raise ValueError("GP has not been fitted to training data yet.")
        
        # Compute kernel between test and training points
        K_star = self.kernel(X_features, self.X_train_features)
        
        # Compute predictive mean: K_star * alpha
        mean = torch.matmul(K_star, self.alpha)
        
        # Compute predictive variance
        K_star_star = self.kernel(X_features)
        var = K_star_star - torch.matmul(K_star, torch.matmul(self.K_inv, K_star.t())).diag().unsqueeze(1)
        
        # Add observation noise
        obs_var = torch.exp(self.log_obs_var)
        var = var + obs_var
        
        return mean, var
    
    def sample(self, X_features: torch.Tensor, n_samples: int = 1) -> torch.Tensor:
        '''
        Sample from the predictive distribution.
        
        Parameters
        ----------
        X_features : torch.Tensor [n_points, dim_latent]
            Feature vectors of the test inputs
            
        n_samples : int
            Number of samples to draw
            
        Returns
        -------
        samples : torch.Tensor [n_samples, n_points, 1]
            Samples from the predictive distribution
        '''
        # Get predictive mean and variance
        mean, var = self.predict(X_features)
        
        # Sample from multivariate normal
        epsilon = torch.randn(n_samples, mean.shape[0], 1, device=mean.device)
        samples = mean.unsqueeze(0) + torch.sqrt(var).unsqueeze(0) * epsilon
        
        return samples


class LatentDerivativeLayer(nn.Module):
    '''
    Latent Derivative Layer for modeling derivatives in the feature space.
    
    Parameters
    ----------
    dim_latent : int
        Dimension of the feature space
        
    dim_input : int
        Dimension of the input space
        
    kernel_type : str
        Type of kernel to use ('rbf', 'matern32', 'matern52')
    '''
    def __init__(self,
                 dim_latent: int,
                 dim_input: int,
                 kernel_type: str = 'rbf'):
        super(LatentDerivativeLayer, self).__init__()
        
        self.dim_latent = dim_latent
        self.dim_input = dim_input
        self.kernel_type = kernel_type
        
        # Kernel hyperparameters
        self.log_length_scale = nn.Parameter(torch.zeros(1))
        self.log_output_scale = nn.Parameter(torch.zeros(1))
    
    def forward(self, x, feature_extractor):
        '''
        Compute derivatives of features with respect to inputs.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor
            
        feature_extractor : nn.Module
            Feature extractor module
            
        Returns
        -------
        torch.Tensor
            Derivatives of features with respect to inputs
        '''
        batch_size = x.shape[0]
        
        # Store the original requires_grad state of the feature extractor
        original_states = {}
        for name, param in feature_extractor.named_parameters():
            original_states[name] = param.requires_grad
            param.requires_grad_(False)  # Disable gradients for feature extractor parameters
        
        # Create a copy of the input that requires gradients
        x_with_grad = x.clone().detach().requires_grad_(True)
        
        # Compute features
        features = feature_extractor(x_with_grad)
        
        # Initialize storage for derivatives
        derivatives = torch.zeros(batch_size, self.dim_latent, self.dim_input, device=x.device)
        
        # Compute derivatives for each feature dimension
        for i in range(self.dim_latent):
            # Extract the i-th feature for all samples
            feature_i = features[:, i].sum()
            
            # Compute gradient with respect to inputs
            if feature_i.requires_grad:
                feature_i.backward(retain_graph=(i < self.dim_latent - 1))
                
                # Store the gradient if it exists
                if x_with_grad.grad is not None:
                    derivatives[:, i, :] = x_with_grad.grad
                    
                    # Zero the gradients for the next iteration
                    x_with_grad.grad.zero_()
            else:
                # If no gradient is available, use a numerical approximation
                print("Warning: Using numerical approximation for derivatives")
                eps = 1e-6
                for j in range(self.dim_input):
                    x_plus = x_with_grad.clone()
                    x_plus[:, j] += eps
                    x_minus = x_with_grad.clone()
                    x_minus[:, j] -= eps
                    
                    features_plus = feature_extractor(x_plus)
                    features_minus = feature_extractor(x_minus)
                    
                    derivatives[:, i, j] = (features_plus[:, i] - features_minus[:, i]) / (2 * eps)
        
        # Restore the original requires_grad state of the feature extractor
        for name, param in feature_extractor.named_parameters():
            param.requires_grad_(original_states[name])
        
        return derivatives


class LDGBLL(nn.Module):
    '''
    Latent Derivative Gaussian Bayesian Last Layer Networks.
    
    This model combines a neural network feature extractor with a Gaussian process
    last layer that models both the function values and their derivatives.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    dim_latent : int
        Dimension of the feature space
        
    n_hidden_layers : int
        Number of hidden layers
        
    kernel_type : str
        Type of kernel to use ('rbf', 'matern32', 'matern52')
        
    use_derivatives : bool
        Whether to use derivative information
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_hidden: int,
                 dim_latent: int,
                 n_hidden_layers: int,
                 kernel_type: str = 'rbf',
                 use_derivatives: bool = True,
                 activation: nn.Module = nn.ReLU(),
                 prior_log_obs_var: float = -2.0):
        super(LDGBLL, self).__init__()
        
        self.dim_input = dim_input
        self.dim_hidden = dim_hidden
        self.dim_latent = dim_latent
        self.n_hidden_layers = n_hidden_layers
        self.kernel_type = kernel_type
        self.use_derivatives = use_derivatives
        
        # Feature extractor
        self.feature_extractor = FeatureExtractor(
            dim_input=dim_input,
            dim_hidden=dim_hidden,
            dim_latent=dim_latent,
            n_hidden_layers=n_hidden_layers,
            activation=activation
        )
        
        # Gaussian process last layer
        self.gp_layer = GaussianProcessLastLayer(
            dim_latent=dim_latent,
            kernel_type=kernel_type,
            prior_log_obs_var=prior_log_obs_var
        )
        
        # Latent derivative layer (if using derivatives)
        if use_derivatives:
            self.derivative_layer = LatentDerivativeLayer(
                dim_latent=dim_latent,
                dim_input=dim_input,
                kernel_type=kernel_type
            )
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "LDGBLL" if self.use_derivatives else "GBLL"
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, 1]
            Output tensor (mean prediction)
        '''
        # Extract features
        features = self.feature_extractor(x)
        
        # If GP has been fitted, make predictions
        if hasattr(self.gp_layer, 'X_train_features') and self.gp_layer.X_train_features is not None:
            mean, _ = self.gp_layer.predict(features)
            return mean
        else:
            # During training before GP is fitted, just return zeros
            return torch.zeros(x.shape[0], 1, device=x.device)
    
    def fit_gp(self, X: torch.Tensor, y: torch.Tensor):
        '''
        Fit the Gaussian process last layer to the training data.
        
        Parameters
        ----------
        X : torch.Tensor [n_samples, dim_input]
            Training inputs
            
        y : torch.Tensor [n_samples, 1]
            Training targets
        '''
        # Extract features
        self.feature_extractor.eval()
        with torch.no_grad():
            features = self.feature_extractor(X)
        
        # If using derivatives, compute them
        if self.use_derivatives:
            # We need gradients for this part
            self.feature_extractor.train()
            derivatives = self.derivative_layer(X, self.feature_extractor)
            
            # Flatten derivatives to use as additional features
            # Shape: [n_samples, dim_latent * dim_input]
            flat_derivatives = derivatives.reshape(X.shape[0], -1)
            
            # Concatenate features and derivatives
            # Shape: [n_samples, dim_latent + dim_latent * dim_input]
            augmented_features = torch.cat([features, flat_derivatives], dim=1)
            
            # Fit GP with augmented features
            self.gp_layer.fit(augmented_features, y)
        else:
            # Fit GP with just features
            self.gp_layer.fit(features, y)
    
    def get_prediction(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        n_samples : int
            Number of Monte Carlo samples
            
        Returns
        -------
        mean, var : torch.Tensor [batch_size, 1]
            Mean and variance of the predictive distribution
        '''
        # Ensure model is in evaluation mode
        self.eval()
        
        # Extract features
        with torch.no_grad():
            features = self.feature_extractor(x)
        
        # If using derivatives, compute them
        if self.use_derivatives:
            # We need gradients for this part
            self.feature_extractor.train()
            derivatives = self.derivative_layer(x, self.feature_extractor)
            
            # Flatten derivatives
            flat_derivatives = derivatives.reshape(x.shape[0], -1)
            
            # Concatenate features and derivatives
            augmented_features = torch.cat([features, flat_derivatives], dim=1)
            
            # Get predictions from GP
            mean, var = self.gp_layer.predict(augmented_features)

        else:
            # Get predictions from GP
            mean, var = self.gp_layer.predict(features)
            
        return mean, var
    

def create_bll_model(dim_input: int,
                     dim_hidden: int,
                     dim_latent: int,
                     n_hidden_layers: int,
                     kernel_type: str = 'rbf',
                     use_derivatives: bool = True,
                     activation: nn.Module = nn.ReLU(),
                     prior_log_obs_var: float = -3.0) -> LDGBLL:
    '''
    Create a Bayesian Last Layer Networks model.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    dim_latent : int
        Dimension of the feature space
        
    n_hidden_layers : int
        Number of hidden layers
        
    kernel_type : str
        Type of kernel to use ('rbf', 'matern32', 'matern52')
        
    use_derivatives : bool
        Whether to use derivative information (LDGBLL if True, GBLL if False)
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
        
    Returns
    -------
    LDGBLL
        Bayesian Last Layer Networks model
    '''
    model = LDGBLL(
        dim_input=dim_input,
        dim_hidden=dim_hidden,
        dim_latent=dim_latent,
        n_hidden_layers=n_hidden_layers,
        kernel_type=kernel_type,
        use_derivatives=use_derivatives,
        activation=activation,
        prior_log_obs_var=prior_log_obs_var
    )
    
    return model


def train_bll_model(
        model: LDGBLL, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    # Initialize optimizer (only for feature extractor)
    optimizer = torch.optim.Adam(model.feature_extractor.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    
    start_time = time.perf_counter()
    
    # Feature extractor training
    print("Training feature extractor...")
    for epoch in range(num_epochs):
        model.feature_extractor.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            # Forward pass through feature extractor
            features = model.feature_extractor(batch_X)
            
            # Simple MSE loss for feature learning
            reconstructed = torch.matmul(features, features.t())
            target = torch.matmul(batch_X, batch_X.t())
            loss = F.mse_loss(reconstructed, target)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.feature_extractor.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()
            
            epoch_loss += loss.item()
        
        # Validation
        model.feature_extractor.eval()
        with torch.no_grad():
            val_features = model.feature_extractor(X_test_tensor)
            reconstructed = torch.matmul(val_features, val_features.t())
            target = torch.matmul(X_test_tensor, X_test_tensor.t())
            val_loss = F.mse_loss(reconstructed, target)
        
        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(val_loss.item())
        
        # # Print progress
        # if (epoch + 1) % 50 == 0 or epoch == 0:
        #     print(f"Epoch {epoch+1}/{num_epochs}, Train MSE: {avg_train_loss:.4f}, Val MSE: {val_loss.item():.4f}")
    
    # Fit the GP last layer
    print("Fitting GP last layer...")
    model.fit_gp(X_train_tensor, y_train_tensor)
    
    training_time = time.perf_counter() - start_time
    
    # Training and testing data evaluation
    model.eval()
    with torch.no_grad():
        # Training data evaluation
        train_mean, train_var = model.get_prediction(X_train_tensor)
        
        # Print diagnostic information
        # print(f"Train variance stats: min={train_var.min().item()}, max={train_var.max().item()}, "
        #       f"has_nan={torch.isnan(train_var).any().item()}, "
        #       f"nan_count={torch.isnan(train_var).sum().item()}")
        
        # Handle NaN and negative values in variance
        epsilon = 1e-6
        train_var = torch.where(torch.isnan(train_var), torch.ones_like(train_var) * epsilon, train_var)
        train_var = torch.clamp(train_var, min=epsilon)
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation
        test_mean, test_var = model.get_prediction(X_test_tensor)
        
        # print(f"Test variance stats: min={test_var.min().item()}, max={test_var.max().item()}, "
        #       f"has_nan={torch.isnan(test_var).any().item()}, "
        #       f"nan_count={torch.isnan(test_var).sum().item()}")
        
        test_var = torch.where(torch.isnan(test_var), torch.ones_like(test_var) * epsilon, test_var)
        test_var = torch.clamp(test_var, min=epsilon)
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': num_epochs,
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results

# When calculating NLL, ensure variance is valid (no NaNs and positive)
def safe_nll_calculation(mean, var, targets):
    # Add small epsilon to variance to ensure it's positive
    epsilon = 1e-6
    # Replace NaN values with epsilon
    valid_var = torch.where(torch.isnan(var), torch.ones_like(var) * epsilon, var)
    # Ensure variance is positive
    valid_var = torch.clamp(valid_var, min=epsilon)
    # Calculate standard deviation
    std_dev = torch.sqrt(valid_var)
    # Calculate negative log likelihood
    nll = -torch.mean(torch.distributions.Normal(mean, std_dev).log_prob(targets)).item()
    return nll



