"""
Decoder components for GLEAM-AI.

This module contains the DecoderRNN_5 class used for generating predictions
from latent variables and embeddings in the STNP framework.
"""

import torch
import torch.nn as nn
from typing import List, Tuple


class DecoderRNN_5(nn.Module):
    """
    Recurrent Neural Network decoder for generating predictions.
    
    This decoder uses a GRU to process embeddings and latent variables,
    then generates mean and dispersion parameters for a Negative Binomial distribution.
    
    Args:
        embed_out_dim: Input embedding dimension
        z_dim: Latent variable dimension
        hidden_dims: List of hidden layer dimensions
        y_dim: Output dimension (number of target variables)
        num_rnn: Number of RNN layers
    """
    
    def __init__(
        self,
        embed_out_dim: int,
        z_dim: int, 
        hidden_dims: List[int],
        y_dim: int,
        num_rnn: int = 1,
    ):
        super().__init__()
        
        if embed_out_dim <= 0:
            raise ValueError("embed_out_dim must be positive")
        if z_dim <= 0:
            raise ValueError("z_dim must be positive")
        if not hidden_dims or any(dim <= 0 for dim in hidden_dims):
            raise ValueError("hidden_dims must contain positive values")
        if y_dim <= 0:
            raise ValueError("y_dim must be positive")
        if num_rnn <= 0:
            raise ValueError("num_rnn must be positive")
        
        self.hidden_dims = hidden_dims
        self.y_dim = y_dim
        self.num_rnn = num_rnn
        
        # Input dimension for RNN (embedding + latent variable)
        input_dim = embed_out_dim + z_dim
        
        # GRU layer for temporal processing
        self.rnn = nn.GRU(
            input_size=input_dim, 
            hidden_size=hidden_dims[0], 
            num_layers=num_rnn, 
            batch_first=True
        )
        
        # Learnable initial hidden state
        self.h0 = nn.Parameter(torch.zeros(num_rnn, 1, hidden_dims[0]))
        
        # Common feature processing layer
        self.fc_common = nn.Sequential(
            nn.Linear(hidden_dims[0], 1024), 
            nn.ReLU()
        )
        
        # Mean parameter prediction layers
        self.mu_layer = nn.Sequential(
            nn.Linear(1024, 512), 
            nn.ReLU(), 
            nn.Linear(512, y_dim)
        )
        
        # Dispersion parameter prediction layers
        self.phi_layer = nn.Sequential(
            nn.Linear(1024, 1024), 
            nn.ReLU(), 
            nn.Linear(1024, 512), 
            nn.ReLU(), 
            nn.Linear(512, y_dim)
        )
        
        # Activation functions
        self.softplus = nn.Softplus()
        self.relu = nn.ReLU()
    
    def forward(
        self, 
        y0: torch.Tensor, 
        embed_out: torch.Tensor, 
        zs: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the decoder.
        
        Args:
            y0: Initial conditions [batch_size, y_dim]
            embed_out: Embedding outputs [batch_size, seq_len, embed_out_dim]
            zs: Latent variables [batch_size, seq_len, z_dim]
            
        Returns:
            Tuple of:
            - mu_t: Mean parameters [batch_size, seq_len, y_dim]
            - phi_t: Dispersion parameters [batch_size, seq_len, y_dim]
        """
        batch_size, seq_len, _ = embed_out.size()
        
        # Validate input shapes
        if y0.size(0) != batch_size:
            raise ValueError("Batch size mismatch between y0 and embed_out")
        if y0.size(1) != self.y_dim:
            raise ValueError(f"Expected y0 to have {self.y_dim} dimensions")
        if zs.size(0) != batch_size or zs.size(1) != seq_len:
            raise ValueError("Shape mismatch between embed_out and zs")
        
        # Expand initial hidden state to batch size
        h0 = self.h0.expand(-1, batch_size, -1).contiguous().to(embed_out.device)
        
        # Concatenate embeddings and latent variables
        inp = torch.cat([embed_out, zs], dim=-1)
        
        # Process through GRU
        output, _ = self.rnn(inp, h0)
        output = self.relu(output)
        
        # Process through common feature layer
        h_t = self.relu(self.fc_common(output))
        
        # Generate mean parameters (with softplus and clipping)
        mu_t = 0.0001 + 0.9999 * self.softplus(self.mu_layer(h_t))
        
        # Generate dispersion parameters (with softplus and clipping)
        phi_t = 0.0001 + 0.9999 * self.softplus(self.phi_layer(h_t))
        
        return mu_t, phi_t
    
    def get_output_dim(self) -> int:
        """Get the output dimension."""
        return self.y_dim
    
    def get_hidden_dims(self) -> List[int]:
        """Get the hidden dimensions."""
        return self.hidden_dims.copy()
