'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

2.  Mixture Density Network (MDN)
    Christopher M. Bishop, 1994

'''

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Optional
import torch.nn.functional as F
from baseline.common import compute_metrics
import time


class MixtureDensityNetwork(nn.Module):
    '''
    Mixture Density Network (MDN) for probabilistic regression.
    
    The network outputs parameters for a mixture of Gaussians:
    - Mixing coefficients (π): weights for each component
    - Means (μ): mean of each Gaussian component  
    - Variances (σ²): variance of each Gaussian component
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    num_hidden_layers : int
        Number of hidden layers
        
    dim_hidden : int
        Dimensionality of the hidden layers
        
    num_components : int
        Number of mixture components
        
    num_samples : int
        Number of samples to draw when making predictions (for compatibility)
    '''
    
    def __init__(self,
                 dim_input: int,
                 dim_output: int = 1,
                 num_hidden_layers: int = 2,
                 dim_hidden: int = 20,
                 num_components: int = 5,
                 num_samples: int = 100):
        
        super(MixtureDensityNetwork, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.num_hidden_layers = num_hidden_layers
        self.dim_hidden = dim_hidden
        self.num_components = num_components
        self.num_samples = num_samples
        self.name = "MixtureDensityNetwork"
        
        # Build the neural network backbone
        layers = []
        
        # Input layer
        layers.extend([
            nn.Linear(dim_input, dim_hidden),
            nn.ReLU(),
            nn.BatchNorm1d(dim_hidden)
        ])
        
        # Hidden layers
        for i in range(num_hidden_layers):
            layers.extend([
                nn.Linear(dim_hidden, dim_hidden),
                nn.ReLU(),
                nn.BatchNorm1d(dim_hidden),
                nn.Dropout(0.1)
            ])
        
        self.backbone = nn.Sequential(*layers)
        
        # Output layers for mixture parameters
        # For each output dimension, we need num_components * 3 parameters:
        # - num_components mixing coefficients (π)
        # - num_components means (μ) 
        # - num_components log variances (log σ²)
        total_params = num_components * 3 * dim_output
        
        self.output_layer = nn.Linear(dim_hidden, total_params)
        
        # Initialize output layer with small weights for stability
        nn.init.normal_(self.output_layer.weight, 0, 0.01)
        nn.init.constant_(self.output_layer.bias, 0)
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return self._name
    
    @name.setter
    def name(self, value: str) -> None:
        self._name = value
    
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor
            Input tensor [batch_size, dim_input]
            
        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing mixing coefficients, means, and log_vars
        '''
        batch_size = x.size(0)
        
        # Forward through backbone
        features = self.backbone(x)
        
        # Get mixture parameters
        output = self.output_layer(features)
        
        # Reshape and split parameters
        # output shape: [batch_size, num_components * 3 * dim_output]
        output = output.view(batch_size, self.dim_output, self.num_components, 3)
        
        # Split into mixing coefficients, means, and log variances
        pi_logits = output[:, :, :, 0]  # [batch_size, dim_output, num_components]
        means = output[:, :, :, 1]     # [batch_size, dim_output, num_components]
        log_vars = output[:, :, :, 2]  # [batch_size, dim_output, num_components]
        
        # Apply softmax to mixing coefficients to ensure they sum to 1
        mixing_coeffs = F.softmax(pi_logits, dim=-1)
        
        # Clamp log variances for numerical stability
        log_vars = torch.clamp(log_vars, min=-10, max=10)
        
        return {
            'mixing_coeffs': mixing_coeffs,
            'means': means,
            'log_vars': log_vars
        }
    
    def mixture_loss(self, params: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
        '''
        Compute the negative log-likelihood loss for the mixture density network.
        
        Parameters
        ----------
        params : Dict[str, torch.Tensor]
            Output from forward pass
        targets : torch.Tensor
            Target values [batch_size, dim_output]
            
        Returns
        -------
        torch.Tensor
            Negative log-likelihood loss
        '''
        mixing_coeffs = params['mixing_coeffs']  # [batch_size, dim_output, num_components]
        means = params['means']                   # [batch_size, dim_output, num_components]
        log_vars = params['log_vars']            # [batch_size, dim_output, num_components]
        
        batch_size = targets.size(0)
        targets = targets.view(batch_size, self.dim_output, 1)  # [batch_size, dim_output, 1]
        
        # Compute Gaussian probabilities for each component
        variances = torch.exp(log_vars)
        
        # Gaussian PDF: (1/√(2πσ²)) * exp(-(x-μ)²/(2σ²))
        # Log PDF: -0.5 * log(2π) - 0.5 * log(σ²) - (x-μ)²/(2σ²)
        log_prob_components = -0.5 * torch.log(2 * torch.pi * variances) - \
                             0.5 * ((targets - means) ** 2) / variances
        
        # Weighted probabilities: π * p(x|component)
        log_weighted_probs = torch.log(mixing_coeffs + 1e-8) + log_prob_components
        
        # Log sum exp for numerical stability
        log_likelihood = torch.logsumexp(log_weighted_probs, dim=-1)  # [batch_size, dim_output]
        
        # Return negative log-likelihood
        return -torch.mean(log_likelihood)
    
    def train_model(self, train_x: torch.Tensor, train_y: torch.Tensor, 
                    num_epochs: int = 1000, learning_rate: float = 0.001, verbose: bool = True):
        '''
        Train the MDN model.
        
        Parameters
        ----------
        train_x : torch.Tensor
            Training inputs
            
        train_y : torch.Tensor
            Training targets
            
        num_epochs : int
            Number of training epochs
            
        learning_rate : float
            Learning rate for optimization
            
        verbose : bool
            Whether to print training progress
        '''
        # Set model to training mode
        self.train()
        
        # Define optimizer
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=1e-4)
        
        # Learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-6
        )
        
        # Training loop
        losses = []
        best_loss = float('inf')
        patience_counter = 0
        patience = 1000
        
        for epoch in range(num_epochs):
            optimizer.zero_grad()
            
            # Forward pass
            params = self.forward(train_x)
            
            # Compute loss
            loss = self.mixture_loss(params, train_y)
            losses.append(loss.item())
            
            # Backward pass
            loss.backward()
            
            # Clip gradients for stability
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
            
            optimizer.step()
            
            # Update learning rate
            # scheduler.step(loss)
            
            # Early stopping
            # if loss.item() < best_loss:
            #     best_loss = loss.item()
            #     patience_counter = 0
            # else:
            #     patience_counter += 1
                
            # if patience_counter >= patience and epoch > 100:
            #     if verbose:
            #         print(f'Early stopping at epoch {epoch+1}')
            #     break
            
            # # Print progress
            # if verbose and (epoch + 1) % 100 == 0:
            #     print(f'Epoch {epoch+1}/{num_epochs} - Loss: {loss.item():.4f}')
        
        return losses
    
    def get_prediction(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor

        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary containing mean and variance predictions
        '''
        # Set model to evaluation mode
        self.eval()
        
        with torch.no_grad():
            # Get mixture parameters
            params = self.forward(x)
            mixing_coeffs = params['mixing_coeffs']  # [batch_size, dim_output, num_components]
            means = params['means']                   # [batch_size, dim_output, num_components]
            log_vars = params['log_vars']            # [batch_size, dim_output, num_components]
            
            variances = torch.exp(log_vars)
            
            # Compute mixture mean: E[X] = Σ π_k * μ_k
            mixture_mean = torch.sum(mixing_coeffs * means, dim=-1)  # [batch_size, dim_output]
            
            # Compute mixture variance: Var[X] = Σ π_k * (σ_k² + μ_k²) - (E[X])²
            second_moment = torch.sum(mixing_coeffs * (variances + means**2), dim=-1)
            mixture_var = second_moment - mixture_mean**2
            
            # Ensure variance is positive
            mixture_var = torch.clamp(mixture_var, min=1e-6)
        
        return {
            'mean': mixture_mean,
            'variance': mixture_var,
        }
    
    def plot_prediction_1d(self, x_min=0, x_max=1, num_points=1001, n_samples=None):
        '''
        Plot the mean function and uncertainty for 1D input.
        
        Parameters
        ----------
        x_min : float
            Minimum x value
            
        x_max : float
            Maximum x value
            
        num_points : int
            Number of points to evaluate
            
        n_samples : int
            Number of samples to draw (unused for compatibility)
            
        Returns
        -------
        dict
            Dictionary containing x, mean, std, and samples
        '''
        if self.dim_input != 1:
            raise ValueError("This method is only for 1D input")
            
        # Generate input points
        x = torch.linspace(x_min, x_max, num_points, device=self.device).reshape(-1, 1)
        
        # Get predictions
        prediction = self.get_prediction(x)
        
        # Generate samples from the mixture
        self.eval()
        with torch.no_grad():
            params = self.forward(x)
            mixing_coeffs = params['mixing_coeffs'].squeeze(1)  # [num_points, num_components]
            means = params['means'].squeeze(1)                  # [num_points, num_components]
            log_vars = params['log_vars'].squeeze(1)           # [num_points, num_components]
            
            # Sample from mixture
            if n_samples is None:
                n_samples = self.num_samples
                
            samples = torch.zeros(num_points, n_samples, device=self.device)
            
            for i in range(num_points):
                # Sample component indices according to mixing coefficients
                component_indices = torch.multinomial(mixing_coeffs[i], n_samples, replacement=True)
                
                # Sample from selected components
                selected_means = means[i][component_indices]
                selected_vars = torch.exp(log_vars[i][component_indices])
                
                samples[i] = torch.normal(selected_means, torch.sqrt(selected_vars))
        
        # Convert to numpy
        x_np = x.cpu().numpy()
        mean_np = prediction['mean'].cpu().numpy()
        std_np = torch.sqrt(prediction['variance']).cpu().numpy()
        samples_np = samples.cpu().numpy()
        
        return {
            'x': x_np,
            'mean': mean_np,
            'std': std_np,
            'samples': samples_np
        }


def train_mdn_model(model: MixtureDensityNetwork, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 2000, learning_rate: float = 0.001, lr_step_size: int = 1000) -> Dict[str, any]:
    '''
    Train MDN model and return results in the same format as other baseline models.
    
    Parameters
    ----------
    model : MixtureDensityNetwork
        The MDN model to train
    train_loader : torch.utils.data.DataLoader
        Training data loader (not used, for API compatibility)
    X_train_tensor : torch.Tensor
        Training input data
    y_train_tensor : torch.Tensor
        Training target data
    X_test_tensor : torch.Tensor
        Test input data
    y_test_tensor : torch.Tensor
        Test target data
    num_epochs : int
        Number of training epochs
    learning_rate : float
        Learning rate
    lr_step_size : int
        Learning rate step size (not used, for API compatibility)
        
    Returns
    -------
    Dict[str, any]
        Results dictionary with metrics
    '''
    start_time = time.perf_counter()
    
    # Train the model
    losses = model.train_model(X_train_tensor, y_train_tensor, num_epochs=num_epochs, learning_rate=learning_rate)
    
    training_time = time.perf_counter() - start_time
    
    # Compute final metrics for train and test sets
    model.eval()
    
    # Training data evaluation
    with torch.no_grad():
        train_out = model.get_prediction(X_train_tensor)
        train_mean = train_out['mean']
        train_var = train_out['variance']
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation
        test_out = model.get_prediction(X_test_tensor)
        test_mean = test_out['mean']
        test_var = test_out['variance']
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': losses,
        'val_losses': [],  # MDN doesn't compute validation loss during training
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': len(losses),  # Actual number of epochs trained
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results