'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

1.  Monte-Carlo Dropout (Bayesian Dropout)
    Yarin Gal, Zoubin Ghahramani, PMLR, 2016.
    https://proceedings.mlr.press/v48/gal16.pdf

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple

import time
from typing import Dict, Tuple
from .common import compute_metrics


class MCDropout(nn.Module):
    '''
    Monte Carlo Dropout model for regression.
    
    This model uses dropout at test time to perform approximate Bayesian inference.
    The model consists of a neural network with dropout layers that remain active
    during both training and inference. Multiple forward passes with the same input
    generate different outputs, which are used to estimate the predictive mean and variance.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    dropout_rate : float
        Dropout probability (typically between 0.1 and 0.5)
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 dropout_rate: float = 0.2,
                 activation: nn.Module = nn.ReLU()):
        
        super(MCDropout, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.dropout_rate = dropout_rate
        self.activation = activation
        
        # Input layer
        self.layers = nn.ModuleList([nn.Linear(dim_input, dim_hidden)])
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            self.layers.append(nn.Linear(dim_hidden, dim_hidden))
        
        # Output layer
        self.layers.append(nn.Linear(dim_hidden, dim_output))
        
        # Initialize weights
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.zeros_(layer.bias)
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "MCDropout"
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        for i, layer in enumerate(self.layers[:-1]):
            x = layer(x)
            x = self.activation(x)
            # Apply dropout even during inference
            x = F.dropout(x, p=self.dropout_rate, training=True)
        
        # Final layer (no activation or dropout)
        x = self.layers[-1](x)
        
        return x
    
    def get_prediction(self, x: torch.Tensor, n_samples: int = 100) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        n_samples : int
            Number of Monte Carlo samples
            
        Returns
        -------
        mean, variance, samples : Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        '''
        # Ensure model is in evaluation mode but dropout will still be applied
        self.eval()
        
        # Generate multiple predictions
        samples = torch.stack([self.forward(x) for _ in range(n_samples)], dim=0)
        
        # Calculate mean and variance
        mean = torch.mean(samples, dim=0)
        variance = torch.var(samples, dim=0)
        
        return mean, variance, samples
    
    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001, n_samples=100):
        '''
        Plot the mean function and uncertainty for 1D input.
        
        Parameters
        ----------
        x_min : float
            Minimum x value
            
        x_max : float
            Maximum x value
            
        num_points : int
            Number of points to evaluate
            
        n_samples : int
            Number of Monte Carlo samples
        '''
        import matplotlib.pyplot as plt
        
        # Generate input points
        x = torch.linspace(x_min, x_max, num_points).unsqueeze(1).to(self.device)
        
        # Get predictions
        prediction = self.get_prediction(x, n_samples)
        
        # Convert to numpy for plotting
        x_np = x.cpu().numpy()
        mean_np = prediction.mean.cpu().numpy()
        std_np = torch.sqrt(prediction.variance).cpu().numpy()
        
        # Plot
        plt.figure(figsize=(10, 6))
        
        # Plot mean
        plt.plot(x_np, mean_np, 'b-', label='Mean prediction')
        
        # Plot uncertainty (±2 standard deviations)
        plt.fill_between(
            x_np.reshape(-1), 
            (mean_np - 2 * std_np).reshape(-1), 
            (mean_np + 2 * std_np).reshape(-1),
            alpha=0.3, 
            color='b', 
            label='Uncertainty (±2σ)'
        )
        
        plt.xlabel('Input')
        plt.ylabel('Output')
        plt.title('Monte Carlo Dropout Prediction with Uncertainty')
        plt.legend()
        plt.grid(True)
        plt.show()


def train_mc_dropout_model(model: MCDropout, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    train_losses = []
    val_losses = []
    best_result = {}
    
    start_time = time.perf_counter()
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            # Forward pass
            outputs = model(batch_X)
            
            # Calculate loss
            loss = F.mse_loss(outputs, batch_y)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            lr_scheduler.step()            
            
            epoch_loss += loss.item()

        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)

        # Print progress
        if (epoch + 1) % 50 == 0 or epoch == 0:
            
            # Compute MSE and NLL for train and test sets
            model.eval()
            with torch.no_grad():
                
                # Training data evaluation
                train_mean, train_var, _ = model.get_prediction(X_train_tensor)
                
                # Compute all training metrics
                train_metrics = compute_metrics(
                    train_mean, train_var, y_train_tensor
                )
                
                # Testing data evaluation  
                test_mean, test_var, _ = model.get_prediction(X_test_tensor)
                
                # Compute all testing metrics
                test_metrics = compute_metrics(
                    test_mean, test_var, y_test_tensor
                )
            
            # Record losses
            train_losses.append(avg_train_loss)
            val_losses.append(test_metrics['nll'])
            
            # Update best
            # if len(best_result) == 0 or (train_metrics['nll'] + test_metrics['nll'] < best_result['train_nll'] + best_result['test_nll']):
            if True:
                best_result = {
                    'train_metrics': train_metrics,
                    'test_metrics': test_metrics,
                    'epoch': epoch + 1,
                    'train_var': train_var,
                }

            # print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train NLL: {train_metrics['nll']:.4f}, Test NLL: {test_metrics['nll']:.4f}")
    
            if epoch > 0.25*int(num_epochs) and 'epoch' in best_result:
                if epoch - best_result['epoch'] > 0.25*int(num_epochs):
                    print(f"Early stopping at epoch {epoch + 1}")
                    break
    
    training_time = time.perf_counter() - start_time
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': best_result['train_metrics']['mse'],
        'train_nll': best_result['train_metrics']['nll'],
        'test_mse': best_result['test_metrics']['mse'],
        'test_nll': best_result['test_metrics']['nll'],
        'epoch': best_result['epoch'],
        'train_var': best_result['train_var'],
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in best_result['train_metrics'].items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in best_result['test_metrics'].items():
        results[f'test_{key}'] = value
    
    return results
