import torch
import torch.nn as nn


class DecoderNetwork(nn.Module):
    """
    # & Decoder network for DVRL, maps (h_{t-1}, z_t, a_{t-1}) to o_t parameters
    """
    def __init__(self, h_dim: int, z_dim: int, action_dim: int, obs_dim: int,
                hidden_layers: int = 1, action_factor: float = 0.5):
        """
        # & Initialize decoder network with configurable structure
        
        Args:
            h_dim: Dimension of RNN hidden state
            z_dim: Dimension of latent state
            action_dim: Dimension of action space
            obs_dim: Dimension of observation space
            hidden_layers: Number of hidden layers in combined decoder
            action_factor: Factor determining action encoding dimension as h_dim * action_factor
        """
        super(DecoderNetwork, self).__init__()
        
        self.action_enc_dim = int(h_dim * action_factor)
        
        # & Simple encoders for inputs
        self.z_encoder = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU()
        )
        
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, self.action_enc_dim),
            nn.ReLU()
        )
        
        # & Build combined decoder with configurable depth
        combined_layers = []
        combined_input_dim = h_dim + self.action_enc_dim + h_dim  # h_prev + action_enc + z_enc
        
        # First layer
        combined_layers.extend([
            nn.Linear(combined_input_dim, h_dim),
            nn.ReLU()
        ])
        
        # Additional hidden layers if requested
        for _ in range(hidden_layers - 1):
            combined_layers.extend([
                nn.Linear(h_dim, h_dim),
                nn.ReLU()
            ])
            
        # Output layer to observation space
        combined_layers.append(nn.Linear(h_dim, obs_dim))
        
        self.combined_decoder = nn.Sequential(*combined_layers)
        
    def forward(self, h_prev: torch.Tensor, z_curr: torch.Tensor, 
                a_prev: torch.Tensor) -> torch.Tensor:
        """
        # & Forward pass through decoder network
        
        Args:
            h_prev: Previous RNN hidden state (h_{t-1})
            z_curr: Current latent state (z_t)
            a_prev: Previous action (a_{t-1})
            
        Returns:
            Parameters for observation distribution
        """
        # & Process and combine inputs
        z_encoded = self.z_encoder(z_curr)
        action_encoded = self.action_encoder(a_prev)
        combined = torch.cat([h_prev, action_encoded, z_encoded], dim=-1)
        
        # & Generate observation parameters
        return self.combined_decoder(combined)
