'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

3.  Deep Gaussian Process Regression (GPR, deep GPR)
    Thang Bui, et al, PMLR, 2016.
    http://proceedings.mlr.press/v48/bui16.pdf

'''

import torch
import torch.nn as nn
import numpy as np
import gpytorch
from typing import List, Dict, Tuple, Optional, Callable
import torch.nn.functional as F
from baseline.common import compute_metrics
import time


class SingleLayerGP(gpytorch.models.ApproximateGP):
    """
    Single-layer Gaussian Process model with variational inference.
    
    This is a simpler implementation that avoids the complexity of deep GPs.
    """
    def __init__(self, num_features, num_inducing=64):
        # Initialize inducing points with a simpler approach
        # Use a grid-based initialization for better coverage of the input space
        if num_features == 1:
            # For 1D, just use a uniform grid
            inducing_points = torch.linspace(-2, 2, num_inducing).unsqueeze(-1)
        else:
            # For higher dimensions, use random initialization with controlled variance
            inducing_points = torch.randn(num_inducing, num_features) * 0.5
        
        # Initialize variational distribution with a better scale
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=num_inducing
        )
        
        # Initialize variational strategy with better settings
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, 
            learn_inducing_locations=True
        )
        
        super(SingleLayerGP, self).__init__(variational_strategy)
        
        # Mean and covariance modules with better initialization
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(ard_num_dims=num_features, lengthscale_prior=gpytorch.priors.GammaPrior(3.0, 6.0))
        )
        # Initialize lengthscale to a reasonable value
        self.covar_module.base_kernel.lengthscale = torch.ones(num_features) * 0.5
        # Initialize outputscale to a reasonable value
        self.covar_module.outputscale = 1.0
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class DeepGPRegression(nn.Module):
    '''
    Deep Gaussian Process Regression model.
    
    This implementation uses a neural network feature extractor followed by a GP layer,
    which is a common simplification of deep GPs that avoids many of the numerical issues.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    num_hidden_layers : int
        Number of hidden layers in the feature extractor
        
    dim_hidden : int
        Dimensionality of the hidden layers
        
    num_inducing : int
        Number of inducing points for sparse GP approximation
        
    num_samples : int
        Number of samples to draw when making predictions
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int = 1,
                 num_hidden_layers: int = 2,
                 dim_hidden: int = 20,
                 num_inducing: int = 128,
                 num_samples: int = 100):
        
        super(DeepGPRegression, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.num_hidden_layers = num_hidden_layers
        self.dim_hidden = dim_hidden
        self.num_inducing = num_inducing
        self.num_samples = num_samples
        self.name = "DeepGPRegression"
        
        # Input layer with wider initial representation
        self.input_layer = nn.Sequential(
            nn.Linear(dim_input, dim_hidden * 2),
            nn.SiLU()  # SiLU (Swish) activation often works better for regression
        )
        
        # Hidden layers with residual connections for better gradient flow
        prev_dim = dim_hidden * 2
        for i in range(num_hidden_layers):
            # Add residual block
            res_block = []
            res_block.append(nn.Linear(prev_dim, dim_hidden))
            res_block.append(nn.SiLU())
            res_block.append(nn.LayerNorm(dim_hidden))
            res_block.append(nn.Linear(dim_hidden, dim_hidden))
            res_block.append(nn.SiLU())
            res_block.append(nn.LayerNorm(dim_hidden))
            
            # Add the residual block as a sequential module
            self.add_module(f"res_block_{i}", nn.Sequential(*res_block))
            
            prev_dim = dim_hidden
        
        # Output layer (features for GP) - project to lower dimension for GP efficiency
        self.output_layer = nn.Linear(dim_hidden, dim_hidden // 2)
        
        # GP layer with smaller feature dimension for better efficiency
        self.gp_layer = SingleLayerGP(dim_hidden // 2, num_inducing)
        
        # Likelihood with better noise initialization
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood(
            noise_prior=gpytorch.priors.GammaPrior(1.1, 0.05)
        )
        # Initialize noise to a reasonable value
        self.likelihood.noise = 0.1
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return self._name
    
    @name.setter
    def name(self, value: str) -> None:
        self._name = value
    
    def train_model(self, train_x, train_y, num_epochs=1000, learning_rate=0.01, verbose=True):
        '''
        Train the Deep GP model.
        
        Parameters
        ----------
        train_x : torch.Tensor
            Training inputs
            
        train_y : torch.Tensor
            Training targets
            
        num_epochs : int
            Number of training epochs
            
        learning_rate : float
            Learning rate for optimization
            
        verbose : bool
            Whether to print training progress
        '''
        # Set model to training mode
        self.train()
        self.gp_layer.train()
        self.likelihood.train()
        
        # Define optimizer with weight decay for regularization
        optimizer = torch.optim.Adam([
            {'params': self.input_layer.parameters(), 'weight_decay': 1e-4},
            {'params': self.output_layer.parameters(), 'weight_decay': 1e-4},
            # Add parameters for all residual blocks
            *[{'params': getattr(self, f"res_block_{i}").parameters(), 'weight_decay': 1e-4} 
              for i in range(self.num_hidden_layers)],
            {'params': self.gp_layer.parameters()},
            {'params': self.likelihood.parameters()}
        ], lr=learning_rate)
        
        # Learning rate scheduler for better convergence
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=20
        )
        
        # Define loss function
        mll = gpytorch.mlls.VariationalELBO(
            self.likelihood, 
            self.gp_layer, 
            num_data=train_x.size(0)
        )
        
        # Training loop
        losses = []
        best_loss = float('inf')
        patience_counter = 0
        patience = 30  # Early stopping patience
        
        for i in range(num_epochs):
            optimizer.zero_grad()
            
            # Extract features through the input layer
            features = self.input_layer(train_x)
            
            # Pass through residual blocks
            for j in range(self.num_hidden_layers):
                res_block = getattr(self, f"res_block_{j}")
                res_features = res_block(features)
                
                # Apply residual connection if dimensions match
                if features.shape == res_features.shape:
                    features = features + res_features
                else:
                    features = res_features
            
            # Final projection to GP feature space
            features = self.output_layer(features)
            
            # Forward pass through GP
            output = self.gp_layer(features)
            
            # Calculate loss
            loss = -mll(output, train_y.squeeze())
            losses.append(loss.item())
            
            # Backward pass and optimization
            loss.backward()
            
            # Clip gradients for stability
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
            
            optimizer.step()
            
            # Update learning rate
            scheduler.step(loss)
            
            # Early stopping
            if loss.item() < best_loss:
                best_loss = loss.item()
                patience_counter = 0
            else:
                patience_counter += 1
                
            if patience_counter >= patience and i > 100:  # Ensure we train for at least 100 epochs
                if verbose:
                    print(f'Early stopping at epoch {i+1}')
                break
            
            # # Print progress
            # if verbose and (i + 1) % 50 == 0:
            #     print(f'Epoch {i+1}/{num_epochs} - Loss: {loss.item():.4f}')
        
        return losses
    
    def forward(self, x):
        '''
        Forward pass through the model.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor
            
        Returns
        -------
        gpytorch.distributions.MultivariateNormal
            Output distribution
        '''
        # Extract features through the input layer and initial activation
        features = self.input_layer(x)
        
        # Pass through residual blocks
        for i in range(self.num_hidden_layers):
            res_block = getattr(self, f"res_block_{i}")
            res_features = res_block(features)
            
            # Apply residual connection if dimensions match
            if features.shape == res_features.shape:
                features = features + res_features
            else:
                features = res_features
        
        # Final projection to GP feature space
        features = self.output_layer(features)
        
        # Forward pass through GP
        return self.gp_layer(features)
    
    def get_prediction(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor

        Returns
        -------
        DeepGPRReturn
            Object containing mean, variance, and samples
        '''

        # Set model to evaluation mode
        self.eval()
        self.gp_layer.eval()
        self.likelihood.eval()
        
        # Get predictions
        with torch.no_grad():
            # Extract features through the input layer
            features = self.input_layer(x)
            
            # Pass through residual blocks
            for j in range(self.num_hidden_layers):
                res_block = getattr(self, f"res_block_{j}")
                res_features = res_block(features)
                
                # Apply residual connection if dimensions match
                if features.shape == res_features.shape:
                    features = features + res_features
                else:
                    features = res_features
            
            # Final projection to GP feature space
            features = self.output_layer(features)
            
            # Get output distribution from GP
            output = self.gp_layer(features)
            
            # Get mean and variance
            mean = output.mean.unsqueeze(-1) # [batch_size, dim_output], dim_output = 1
            variance = output.variance.unsqueeze(-1) # [batch_size, dim_output], dim_output = 1
            
        kl_divergence = self.gp_layer.variational_strategy.kl_divergence().sum() / x.size(0)
        
        return {
            'mean': mean,
            'variance': variance,
            'kl_divergence': kl_divergence,
        }
    
    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001, n_samples=None):
        '''
        Plot the mean function and uncertainty for 1D input.
        
        Parameters
        ----------
        x_min : float
            Minimum x value
            
        x_max : float
            Maximum x value
            
        num_points : int
            Number of points to evaluate
            
        n_samples : int
            Number of samples to draw
            
        Returns
        -------
        dict
            Dictionary containing x, mean, std, and samples
        '''
        if self.dim_input != 1:
            raise ValueError("This method is only for 1D input")
            
        # Generate input points
        x = torch.linspace(x_min, x_max, num_points, device=self.device).reshape(-1, 1)
        
        # Get predictions
        prediction = self.get_prediction(x, n_samples)
        
        # Convert to numpy
        x_np = x.cpu().numpy()
        mean_np = prediction.mean.cpu().numpy()
        std_np = torch.sqrt(prediction.variance).cpu().numpy()
        samples_np = prediction.samples.cpu().numpy()
        
        return {
            'x': x_np,
            'mean': mean_np,
            'std': std_np,
            'samples': samples_np
        }


def train_deep_gp_model(model: DeepGPRegression, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    # Deep GP has its own training loop
    start_time = time.perf_counter()
    
    # Train the model
    losses = model.train_model(X_train_tensor, y_train_tensor, num_epochs=num_epochs, learning_rate=learning_rate)
    
    training_time = time.perf_counter() - start_time
    
    # Compute final MSE and NLL for train and test sets
    model.eval()
    
    # Training data evaluation
    with torch.no_grad():
        
        train_out = model.get_prediction(X_train_tensor)
        train_mean = train_out['mean']
        train_var = train_out['variance']
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation
        test_out = model.get_prediction(X_test_tensor)
        test_mean = test_out['mean']
        test_var = test_out['variance']
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': losses,
        'val_losses': [],  # Deep GP doesn't compute validation loss during training
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': num_epochs,
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results


