"""
Encoder components for GLEAM-AI.

This module contains the EncoderRNN and LatentEncoder classes used for
temporal and latent variable encoding in the STNP framework.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple


class EncoderRNN(nn.Module):
    """
    Recurrent Neural Network encoder for temporal feature processing.
    
    This encoder uses a GRU to process temporal sequences and extract
    temporal representations for each timestep.
    
    Args:
        enc_in_dim: Input dimension for the encoder
        r_dim: Hidden dimension for the RNN
        num_rnn: Number of RNN layers
    """
    
    def __init__(self, enc_in_dim: int, r_dim: int, num_rnn: int):
        super().__init__()
        
        if enc_in_dim <= 0:
            raise ValueError("enc_in_dim must be positive")
        if r_dim <= 0:
            raise ValueError("r_dim must be positive")
        if num_rnn <= 0:
            raise ValueError("num_rnn must be positive")
        
        self.num_rnn = num_rnn
        self.r_dim = r_dim
        
        # GRU layer for temporal processing
        self.rnn = nn.GRU(
            input_size=enc_in_dim, 
            hidden_size=r_dim, 
            num_layers=num_rnn, 
            batch_first=True
        )
    
    def forward(self, enc_in: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the RNN encoder.
        
        Args:
            enc_in: Input sequence [batch_size, seq_len, enc_in_dim]
            
        Returns:
            Encoded sequence [batch_size, seq_len, r_dim]
        """
        batch_size = enc_in.size(0)
        device = enc_in.device
        
        # Initialize hidden state
        h0 = torch.zeros(
            (self.num_rnn, batch_size, self.r_dim), 
            device=device
        )
        
        # Process sequence through GRU
        output, _ = self.rnn(enc_in, h0)
        
        return output


class LatentEncoder(nn.Module):
    """
    Latent encoder for generating latent variable distributions.
    
    This encoder takes temporal representations and produces mean and variance
    parameters for latent variables using a Gaussian distribution.
    
    Args:
        r_dim: Input dimension (temporal representation dimension)
        z_dim: Output dimension (latent variable dimension)
    """
    
    def __init__(self, r_dim: int, z_dim: int) -> None:
        super().__init__()
        
        if r_dim <= 0:
            raise ValueError("r_dim must be positive")
        if z_dim <= 0:
            raise ValueError("z_dim must be positive")
        
        self.r_dim = r_dim
        self.z_dim = z_dim
        
        # Linear layers for mean and variance
        self.mu_layer = nn.Linear(r_dim, z_dim)
        self.var_layer = nn.Linear(r_dim, z_dim)
    
    def forward(self, r: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the latent encoder.
        
        Args:
            r: Temporal representations [batch_size, r_dim] or [seq_len, r_dim]
            
        Returns:
            Tuple of:
            - mu_z: Mean parameters [batch_size, z_dim] or [seq_len, z_dim]
            - var_z: Variance parameters [batch_size, z_dim] or [seq_len, z_dim]
        """
        # Compute mean parameters
        mu_z = self.mu_layer(r)
        
        # Compute variance parameters (with softplus activation and clipping)
        var_z = 0.01 + 0.99 * F.softplus(self.var_layer(r))
        
        return mu_z, var_z
    
    def sample(self, r: torch.Tensor) -> torch.Tensor:
        """
        Sample latent variables from the encoded distribution.
        
        Args:
            r: Temporal representations [batch_size, r_dim] or [seq_len, r_dim]
            
        Returns:
            Sampled latent variables [batch_size, z_dim] or [seq_len, z_dim]
        """
        mu_z, var_z = self.forward(r)
        
        # Sample from Gaussian distribution
        eps = torch.randn_like(mu_z)
        z = mu_z + torch.sqrt(var_z) * eps
        
        return z
    
    def get_latent_dim(self) -> int:
        """Get the latent variable dimension."""
        return self.z_dim
