'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

7.  Heteroscedastic Uncertainty Estimation with Probabilistic Neural Networks (beta-NLL)
    Maximilian Seitzer, et al, arXiv, 2022
    https://arxiv.org/abs/2203.09168

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Callable, Union
from dataclasses import dataclass
import time
from .common import compute_metrics


@dataclass
class PNNReturn:
    '''
    Dataclass of the Probabilistic Neural Networks model output.
    
    Attributes
    ----------------
    mean: torch.Tensor
        The mean prediction of the model.
        
    variance: torch.Tensor
        The variance of the prediction.
        
    samples: torch.Tensor
        The samples from the predictive distribution.
        
    x: torch.Tensor | None
        The input data.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    '''
    mean: torch.Tensor
    variance: torch.Tensor
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]


class BetaNLLLoss(nn.Module):
    '''
    Beta Negative Log-Likelihood Loss for heteroscedastic uncertainty estimation.
    
    This loss is based on the beta distribution and is designed to be robust to outliers
    and provide well-calibrated uncertainty estimates.
    
    Parameters
    ----------
    beta : float
        Beta parameter controlling the robustness of the loss.
        Higher values make the loss more robust to outliers.
        
    reduction : str
        Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'.
    '''
    def __init__(self, beta: float = 0.5, reduction: str = 'mean'):
        super(BetaNLLLoss, self).__init__()
        self.beta = beta
        self.reduction = reduction
    
    def forward(self, mean: torch.Tensor, variance: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        '''
        Compute the beta negative log-likelihood loss.
        
        Parameters
        ----------
        mean : torch.Tensor
            Predicted mean
            
        variance : torch.Tensor
            Predicted variance (must be positive)
            
        target : torch.Tensor
            Ground truth values
            
        Returns
        -------
        loss : torch.Tensor
            Beta NLL loss
        '''
        # Ensure variance is positive
        variance = torch.clamp(variance, min=1e-6)
        
        # Compute standardized residuals
        residuals = target - mean
        standardized_residuals = residuals / torch.sqrt(variance)
        
        # Compute beta NLL loss
        loss = 0.5 * torch.log(variance)
        loss = loss + 0.5 * torch.pow(torch.abs(standardized_residuals), 2.0 / (1.0 + self.beta))
        
        # Apply reduction
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'mean':
            return torch.mean(loss)
        elif self.reduction == 'sum':
            return torch.sum(loss)
        else:
            raise ValueError(f"Unknown reduction: {self.reduction}")


class ProbabilisticNeuralNetwork(nn.Module):
    '''
    Probabilistic Neural Network for heteroscedastic uncertainty estimation.
    
    This model outputs both mean and variance predictions, allowing it to model
    input-dependent (heteroscedastic) uncertainty.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    beta : float
        Beta parameter for the beta NLL loss
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 beta: float = 0.5,
                 activation: nn.Module = nn.ReLU()):
        super(ProbabilisticNeuralNetwork, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.beta = beta
        self.activation = activation
        
        # Shared layers
        self.shared_layers = nn.ModuleList()
        
        # Input layer
        self.shared_layers.append(nn.Linear(dim_input, dim_hidden))
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            self.shared_layers.append(nn.Linear(dim_hidden, dim_hidden))
        
        # Mean prediction head
        self.mean_head = nn.Linear(dim_hidden, dim_output)
        
        # Variance prediction head (outputs log variance for numerical stability)
        self.log_var_head = nn.Linear(dim_hidden, dim_output)
        
        # Loss function
        self.loss_fn = BetaNLLLoss(beta=beta)
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "PNN"
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        mean : torch.Tensor [batch_size, dim_output]
            Mean prediction
            
        variance : torch.Tensor [batch_size, dim_output]
            Variance prediction
        '''
        # Shared layers
        h = x
        for i, layer in enumerate(self.shared_layers):
            h = layer(h)
            h = self.activation(h)
        
        # Mean prediction
        mean = self.mean_head(h)
        
        # Variance prediction (ensure positive)
        log_var = self.log_var_head(h)
        variance = torch.exp(log_var)
        
        return mean, variance
    
    def get_prediction(self, x: torch.Tensor) -> PNNReturn:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor

        Returns
        -------
        PNNReturn
            Object containing mean, variance, and samples
        '''
        # Ensure model is in evaluation mode
        self.eval()
        
        # Forward pass
        mean, variance = self.forward(x)
        
        # Define loss functions
        def train_loss_fn(y_true: torch.Tensor) -> Dict[str, torch.Tensor]:
            # Beta NLL loss
            beta_nll_loss = self.loss_fn(mean, variance, y_true)
            
            # Also compute MSE for comparison
            mse_loss = F.mse_loss(mean, y_true)
            
            return {
                "loss": beta_nll_loss,
                "beta_nll": beta_nll_loss,
                "mse": mse_loss
            }
        
        def val_loss_fn(y_true: torch.Tensor) -> torch.Tensor:
            # Use MSE loss for validation
            return F.mse_loss(mean, y_true)
        
        return PNNReturn(
            mean=mean,
            variance=variance,
            train_loss_fn=train_loss_fn,
            val_loss_fn=val_loss_fn
        )
    

def create_pnn_model(dim_input: int,
                     dim_output: int,
                     dim_hidden: int,
                     n_hidden_layers: int,
                     beta: float = 0.5,
                     activation: nn.Module = nn.ReLU()) -> ProbabilisticNeuralNetwork:
    '''
    Create a Probabilistic Neural Network model.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    beta : float
        Beta parameter for the beta NLL loss
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
        
    Returns
    -------
    ProbabilisticNeuralNetwork
        Probabilistic Neural Network model
    '''
    model = ProbabilisticNeuralNetwork(
        dim_input=dim_input,
        dim_output=dim_output,
        dim_hidden=dim_hidden,
        n_hidden_layers=n_hidden_layers,
        beta=beta,
        activation=activation
    )
    
    return model


def train_pnn_model(model: ProbabilisticNeuralNetwork, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    best_result = {}
    
    start_time = time.perf_counter()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            # Forward pass
            prediction = model.get_prediction(batch_X)
            
            results = prediction.train_loss_fn(batch_y)
            loss = results['loss']
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()
            
            epoch_loss += loss.item()
        
        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)
        
        # Print progress
        if (epoch + 1) % 50 == 0 or epoch == 0:

            # Compute final MSE and NLL for train and test sets
            model.eval()
            with torch.no_grad():
                # Training data evaluation
                train_out = model.get_prediction(X_train_tensor)
                train_mean = train_out.mean
                train_var = train_out.variance
                
                # Compute all training metrics
                train_metrics = compute_metrics(
                    train_mean, train_var, y_train_tensor
                )
                
                # Testing data evaluation  
                test_out = model.get_prediction(X_test_tensor)
                test_mean = test_out.mean
                test_var = test_out.variance
                
                # Compute all testing metrics
                test_metrics = compute_metrics(
                    test_mean, test_var, y_test_tensor
                )
            
            # Record losses
            train_losses.append(avg_train_loss)
            val_losses.append(test_metrics['nll'])
            
            # Update best
            if len(best_result) == 0 or (train_metrics['nll'] + test_metrics['nll'] < best_result['train_nll'] + best_result['test_nll']):
                best_result = {
                    'train_mse': train_metrics['mse'],
                    'train_nll': train_metrics['nll'],
                    'test_mse': test_metrics['mse'],
                    'test_nll': test_metrics['nll'],
                    'epoch': epoch + 1,
                }

            # print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train NLL: {train_metrics['nll']:.4f}, Test NLL: {test_metrics['nll']:.4f}")
    
            if epoch > 0.25*int(num_epochs) and 'epoch' in best_result:
                if epoch - best_result['epoch'] > 0.25*int(num_epochs):
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
    
    training_time = time.perf_counter() - start_time
    
    # Compute final metrics using the best model state
    model.eval()
    with torch.no_grad():
        # Training data evaluation
        train_out = model.get_prediction(X_train_tensor)
        train_mean = train_out.mean
        train_var = train_out.variance
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation  
        test_out = model.get_prediction(X_test_tensor)
        test_mean = test_out.mean
        test_var = test_out.variance
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': best_result['epoch'],
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results


