'''
This file implements a baseline model for the regression of multivariate stochastic functions with one output.

5.  Stochastic Weight Averaging Gaussian (SWAG)
    Wesley J Maddox, et al, NeurIPS, 2019
    https://proceedings.neurips.cc/paper_files/paper/2019/file/118921efba23fc329e6560b27861f0c2-Paper.pdf

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict, Tuple, Optional, Callable, Union
from dataclasses import dataclass
import math
import copy
import time
from .common import compute_metrics


@dataclass
class SWAGReturn:
    '''
    Dataclass of the Stochastic Weight Averaging Gaussian model output.
    
    Attributes
    ----------------
    mean: torch.Tensor
        The mean prediction of the model.
        
    variance: torch.Tensor
        The variance of the prediction.
        
    samples: torch.Tensor
        The samples from the predictive distribution.
        
    x: torch.Tensor | None
        The input data.
        
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
        The training loss function.
        
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]
        The validation loss function.
    '''
    mean: torch.Tensor
    variance: torch.Tensor
    samples: torch.Tensor
    x: torch.Tensor | None
    train_loss_fn: Callable[[torch.Tensor], Dict[str, torch.Tensor]]
    val_loss_fn: Callable[[torch.Tensor], torch.Tensor]


class BaseModel(nn.Module):
    '''
    Base neural network model for SWAG.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
    '''
    def __init__(self,
                 dim_input: int,
                 dim_output: int,
                 dim_hidden: int,
                 n_hidden_layers: int,
                 activation: nn.Module = nn.ReLU()):
        super(BaseModel, self).__init__()
        
        self.dim_input = dim_input
        self.dim_output = dim_output
        self.dim_hidden = dim_hidden
        self.n_hidden_layers = n_hidden_layers
        self.activation = activation
        
        # Input layer
        layers = [nn.Linear(dim_input, dim_hidden)]
        
        # Hidden layers
        for _ in range(n_hidden_layers - 1):
            layers.append(activation)
            layers.append(nn.Linear(dim_hidden, dim_hidden))
        
        # Output layer
        layers.append(activation)
        layers.append(nn.Linear(dim_hidden, dim_output))
        
        self.model = nn.Sequential(*layers)
        
        # Observation noise variance (learnable)
        self.log_obs_var = None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass through the network.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        return self.model(x)
    
    def get_weight_vector(self) -> torch.Tensor:
        '''
        Get all model parameters as a single flattened vector.
        
        Returns
        -------
        torch.Tensor
            Flattened parameter vector
        '''
        return torch.cat([p.data.view(-1) for p in self.parameters()])
    
    def load_flattened_weights(self, weights: torch.Tensor) -> None:
        '''
        Load model parameters from a flattened vector.
        
        Parameters
        ----------
        weights : torch.Tensor
            Flattened parameter vector
        '''
        offset = 0
        for param in self.parameters():
            param_size = param.numel()
            param.data.copy_(weights[offset:offset + param_size].view(param.size()))
            offset += param_size


class SWAG(nn.Module):
    '''
    Stochastic Weight Averaging Gaussian model for regression.
    
    This model implements SWAG (Maddox et al., 2019), which approximates the posterior
    distribution over neural network weights using a Gaussian distribution. The mean
    is estimated using Stochastic Weight Averaging (SWA), and the covariance is
    estimated using a low-rank plus diagonal approximation.
    
    Parameters
    ----------
    base_model : nn.Module
        Base neural network model
        
    max_models : int
        Maximum number of models to maintain for covariance estimation
        
    swa_start : int
        Epoch to start collecting models for SWA
        
    swa_lr : float
        Learning rate for SWA phase
        
    var_clamp : float
        Minimum value for diagonal variance
        
    full_cov : bool
        Whether to use the full covariance matrix (True) or diagonal (False)
    '''
    def __init__(self,
                 base_model: nn.Module,
                 max_models: int = 20,
                 swa_start: int = 100,
                 swa_lr: float = 0.001,
                 var_clamp: float = 1e-6,
                 full_cov: bool = False,
                 prior_log_obs_var: float = -2.0):
        super(SWAG, self).__init__()
        
        self.base_model = base_model
        self.max_models = max_models
        self.swa_start = swa_start
        self.swa_lr = swa_lr
        self.var_clamp = var_clamp
        self.full_cov = full_cov
        
        # Register all parameters in the base model
        self.params = []
        self.param_names = []
        
        for name, param in self.base_model.named_parameters():
            self.param_names.append(name)
            self.params.append(param)
        
        # Initialize SWA model (copy of base model)
        self.swa_model = copy.deepcopy(base_model)
        
        # Number of parameters in the model
        self.n_params = sum(p.numel() for p in self.params)
        
        # Initialize mean and squared mean for each parameter
        self.register_buffer('mean', torch.zeros(self.n_params))
        self.register_buffer('sq_mean', torch.zeros(self.n_params))
        
        # For low-rank covariance approximation
        self.register_buffer('deviations', torch.zeros((self.max_models, self.n_params)))
        
        # Counter for number of models collected
        self.n_models = 0
        self.n_averaged = 0
        
        # Initialize mean with current parameters
        self._update_mean()
        
        # Observation noise variance (learnable)
        self.log_obs_var = None
    
    @property
    def device(self) -> torch.device:
        '''
        Get the device of the model.
        '''
        return next(self.parameters()).device
    
    @property
    def name(self) -> str:
        '''
        Model name.
        '''
        return "SWAG"
    
    def _update_mean(self):
        '''
        Update the mean parameter values using the current model.
        '''
        # Get flattened parameters from the model
        param_vector = self.base_model.get_weight_vector()
        
        # Update running mean
        if self.n_averaged == 0:
            self.mean.copy_(param_vector)
        else:
            # Incremental update
            self.mean.mul_(self.n_averaged / (self.n_averaged + 1.0))
            self.mean.add_(param_vector / (self.n_averaged + 1.0))
        
        # Update squared mean for diagonal covariance
        if self.n_averaged == 0:
            self.sq_mean.copy_(param_vector ** 2)
        else:
            self.sq_mean.mul_(self.n_averaged / (self.n_averaged + 1.0))
            self.sq_mean.add_((param_vector ** 2) / (self.n_averaged + 1.0))
        
        # Update low-rank covariance approximation
        if self.n_models < self.max_models:
            self.deviations[self.n_models].copy_(param_vector - self.mean)
            self.n_models += 1
        else:
            # Cycle through the stored models
            self.deviations[self.n_models % self.max_models].copy_(param_vector - self.mean)
            self.n_models += 1
        
        self.n_averaged += 1
    
    def update_swa(self):
        '''
        Update the SWA model with the current model parameters.
        '''
        # Update mean and covariance statistics
        self._update_mean()
        
        # Update SWA model parameters with the mean
        self.swa_model.load_flattened_weights(self.mean)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass using the SWA model.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        Returns
        -------
        out : torch.Tensor [batch_size, dim_output]
            Output tensor
        '''
        return self.swa_model(x)
    
    def sample_model(self, scale: float = 1.0, cov_type: str = 'diagonal') -> None:
        '''
        Sample a new model from the SWAG posterior.
        
        Parameters
        ----------
        scale : float
            Scaling factor for the covariance matrix
            
        cov_type : str
            Type of covariance approximation to use: 'diagonal', 'low_rank', or 'full'
        '''
        # Sample from Gaussian distribution
        if cov_type == 'diagonal':
            # Diagonal covariance approximation
            var = torch.clamp(self.sq_mean - self.mean ** 2, min=self.var_clamp)
            z = torch.randn_like(var)
            sample = self.mean + z * torch.sqrt(var * scale)
        
        elif cov_type == 'low_rank':
            # Low-rank approximation
            var = torch.clamp(self.sq_mean - self.mean ** 2, min=self.var_clamp)
            
            # Sample from diagonal component
            z1 = torch.randn_like(var)
            sample = self.mean + z1 * torch.sqrt(var * scale)
            
            # Sample from low-rank component
            if self.n_models > 0:
                # Use only the actual number of models collected
                n_actual = min(self.n_models, self.max_models)
                
                # Only sample from low-rank component if we have at least 2 models
                if n_actual > 1:
                    z2 = torch.randn(self.n_models)
                    z2 = z2.to(self.device)
                    
                    sample += scale / math.sqrt(2 * (n_actual - 1)) * torch.matmul(
                        z2[:n_actual], self.deviations[:n_actual]
                    )
        
        elif cov_type == 'full':
            # Full covariance approximation (diagonal + low-rank)
            var = torch.clamp(self.sq_mean - self.mean ** 2, min=self.var_clamp)
            
            # Sample from diagonal component
            z1 = torch.randn_like(var)
            sample = self.mean + z1 * torch.sqrt(var * scale)
            
            # Sample from low-rank component
            if self.n_models > 0:
                # Use only the actual number of models collected
                n_actual = min(self.n_models, self.max_models)
                
                # Only sample from low-rank component if we have at least 2 models
                if n_actual > 1:
                    z2 = torch.randn(self.n_models)
                    z2 = z2.to(self.device)
                    
                    sample += scale / math.sqrt(2 * (n_actual - 1)) * torch.matmul(
                        z2[:n_actual], self.deviations[:n_actual]
                    )
        
        else:
            raise ValueError(f"Unknown covariance type: {cov_type}")
        
        # Update model parameters with the sample
        self.base_model.load_flattened_weights(sample)
    
    def get_prediction(self, x: torch.Tensor, n_samples: int = 100) -> SWAGReturn:
        '''
        Get prediction with uncertainty estimation.
        
        Parameters
        ----------
        x : torch.Tensor [batch_size, dim_input]
            Input tensor
            
        n_samples : int
            Number of Monte Carlo samples
            
        Returns
        -------
        SWAGReturn
            Object containing mean, variance, and samples
        '''
        # Ensure model is in evaluation mode
        self.eval()
        
        # Use SWA model for mean prediction
        mean = self.forward(x)
        
        # Initialize storage for samples
        samples = torch.zeros((n_samples,) + mean.shape, device=x.device)
        
        # Generate samples from model posterior
        for i in range(n_samples):
            # Sample a new model from the posterior
            self.sample_model(scale=1.0, cov_type='full' if self.full_cov else 'diagonal')
            
            # Get prediction from sampled model
            samples[i] = self.base_model(x)
        
        # Compute predictive mean and model uncertainty from samples
        model_mean = torch.mean(samples, dim=0)
        model_var = torch.var(samples, dim=0)
        
        # Define loss functions
        def train_loss_fn(y_true: torch.Tensor) -> Dict[str, torch.Tensor]:
            # For SWAG, we use standard MSE loss during training
            mse_loss = F.mse_loss(model_mean, y_true)
            
            return {
                "loss": mse_loss,
                "mse": mse_loss
            }
        
        def val_loss_fn(y_true: torch.Tensor) -> torch.Tensor:
            # Use MSE loss for validation
            return F.mse_loss(model_mean, y_true)
        
        return SWAGReturn(
            mean=model_mean,
            variance=model_var,
            samples=samples,
            x=x,
            train_loss_fn=train_loss_fn,
            val_loss_fn=val_loss_fn
        )


def create_swag_model(dim_input: int,
                      dim_output: int,
                      dim_hidden: int,
                      n_hidden_layers: int,
                      max_models: int = 20,
                      swa_start: int = 100,
                      swa_lr: float = 0.001,
                      var_clamp: float = 1e-6,
                      full_cov: bool = False,
                      activation: nn.Module = nn.ReLU(),
                      prior_log_obs_var: float = -2.0) -> SWAG:
    '''
    Create a SWAG model with the specified parameters.
    
    Parameters
    ----------
    dim_input : int
        Number of input features
        
    dim_output : int
        Number of output features (typically 1 for regression)
        
    dim_hidden : int
        Number of neurons in each hidden layer
        
    n_hidden_layers : int
        Number of hidden layers
        
    max_models : int
        Maximum number of models to maintain for covariance estimation
        
    swa_start : int
        Epoch to start collecting models for SWA
        
    swa_lr : float
        Learning rate for SWA phase
        
    var_clamp : float
        Minimum value for diagonal variance
        
    full_cov : bool
        Whether to use the full covariance matrix (True) or diagonal (False)
        
    activation : nn.Module
        Activation function to use (default: nn.ReLU)
        
    Returns
    -------
    SWAG
        SWAG model
    '''
    # Create base model
    base_model = BaseModel(
        dim_input=dim_input,
        dim_output=dim_output,
        dim_hidden=dim_hidden,
        n_hidden_layers=n_hidden_layers,
        activation=activation
    )
    
    # Create SWAG model
    swag_model = SWAG(
        base_model=base_model,
        max_models=max_models,
        swa_start=swa_start,
        swa_lr=swa_lr,
        var_clamp=var_clamp,
        full_cov=full_cov,
        prior_log_obs_var=prior_log_obs_var
    )
    
    return swag_model


def train_swag_model(model: SWAG, train_loader: torch.utils.data.DataLoader, 
        X_train_tensor: torch.Tensor, y_train_tensor: torch.Tensor, 
        X_test_tensor: torch.Tensor, y_test_tensor: torch.Tensor, 
        num_epochs: int = 5000, learning_rate: float = 0.01, lr_step_size: int = 1000) -> Dict[str, any]:
    
    # Include both base_model parameters and log_obs_var in optimization
    optimizer = torch.optim.SGD(model.base_model.parameters(), lr=learning_rate, momentum=0.9)

    train_losses = []
    val_losses = []
    
    start_time = time.perf_counter()
    
    # SGD phase
    swa_start = model.swa_start
    for epoch in range(swa_start):
        model.base_model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            # Forward pass
            outputs = model.base_model(batch_X)
            
            # Calculate loss
            loss = F.mse_loss(outputs, batch_y)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.base_model.parameters(), max_norm=1.0)
            
            optimizer.step()

            epoch_loss += loss.item()
        
        # Validation
        model.base_model.eval()
        with torch.no_grad():
            val_outputs = model.base_model(X_test_tensor)
            val_loss = F.mse_loss(val_outputs, y_test_tensor)
        
        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(val_loss.item())
        
        # Print progress
        # if (epoch + 1) % 50 == 0 or epoch == 0:
        #     print(f"Epoch {epoch+1}/{swa_start}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss.item():.4f}")
    
    # SWA phase
    optimizer = torch.optim.SGD(model.base_model.parameters(), lr=model.swa_lr, momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.8, step_size=lr_step_size)
    
    for epoch in range(num_epochs - swa_start):
        model.base_model.train()
        epoch_loss = 0
        
        for batch_X, batch_y in train_loader:
            # Forward pass
            outputs = model.base_model(batch_X)
            
            # Calculate loss
            loss = F.mse_loss(outputs, batch_y)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        # Step the learning rate scheduler once per epoch
        lr_scheduler.step()
        
        # Update SWA model
        model.update_swa()
        
        # Validation
        model.eval()  # Use SWA model for validation
        with torch.no_grad():
            val_outputs = model(X_test_tensor)
            val_loss = F.mse_loss(val_outputs, y_test_tensor)
        
        # Record losses
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        val_losses.append(val_loss.item())
        
        # # Print progress
        # if (epoch + 1) % 50 == 0 or epoch == 0:
        #     print(f"SWA Epoch {epoch+1}/{num_epochs-swa_start}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss.item():.4f}")
    
    training_time = time.perf_counter() - start_time
    
    # Compute final MSE and NLL for train and test sets
    model.eval()
    
    # Training data evaluation
    with torch.no_grad():
        
        train_out = model.get_prediction(X_train_tensor)
        train_mean = train_out.mean
        train_var = train_out.variance
        
        # Compute all training metrics
        train_metrics = compute_metrics(
            train_mean, train_var, y_train_tensor
        )
        
        # Testing data evaluation
        test_out = model.get_prediction(X_test_tensor)
        test_mean = test_out.mean
        test_var = test_out.variance
        
        # Compute all testing metrics
        test_metrics = compute_metrics(
            test_mean, test_var, y_test_tensor
        )
    
    # Prepare results dictionary with all metrics
    results = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'training_time': training_time,
        'train_mse': train_metrics['mse'],
        'train_nll': train_metrics['nll'],
        'test_mse': test_metrics['mse'],
        'test_nll': test_metrics['nll'],
        'epoch': num_epochs,
    }
    
    # Add training metrics with 'train_' prefix
    for key, value in train_metrics.items():
        results[f'train_{key}'] = value
    
    # Add testing metrics with 'test_' prefix  
    for key, value in test_metrics.items():
        results[f'test_{key}'] = value
    
    return results
