from typing import Tuple

import torch
import torch.nn as nn


class EncoderNetwork(nn.Module):
    """
        Encoder network for DVRL, maps (h_{t-1}, a_{t-1}, o_t) to z_t parameters
    """
    def __init__(self, h_dim: int, z_dim: int, action_dim: int, obs_dim: int, 
                hidden_layers: int = 1, action_factor: float = 0.5):
        """
            Initialize encoder network
            
            Args:
                h_dim: Dimension of RNN hidden state
                z_dim: Dimension of latent state
                action_dim: Dimension of action space
                obs_dim: Dimension of observation space
                hidden_layers: Number of hidden layers in combined encoder
                action_factor: Factor determining action encoding dimension as h_dim * action_factor
        """
        super(EncoderNetwork, self).__init__()
        
        self.action_enc_dim = int(h_dim * action_factor)
        
        # & Simple encoders for inputs
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, h_dim),
            nn.ReLU()
        )
        
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, self.action_enc_dim),
            nn.ReLU()
        )
        
        # & Build combined encoder with configurable depth
        combined_layers = []
        combined_input_dim = h_dim + self.action_enc_dim + h_dim  # & h_prev + action_enc + obs_enc
        
        # & First layer
        combined_layers.extend([
            nn.Linear(combined_input_dim, h_dim),
            nn.ReLU()
        ])
        
        # & Additional hidden layers if requested
        for _ in range(hidden_layers - 1):
            combined_layers.extend([
                nn.Linear(h_dim, h_dim),
                nn.ReLU()
            ])
            
        self.combined_encoder = nn.Sequential(*combined_layers)
        
        # & Output heads
        self.mean = nn.Linear(h_dim, z_dim)
        self.logvar = nn.Linear(h_dim, z_dim)

    
    def forward(self, h_prev: torch.Tensor, a_prev: torch.Tensor, 
                o_curr: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
            Forward pass through encoder network
        """
        # & Process and combine inputs
        obs_encoded = self.obs_encoder(o_curr)
        action_encoded = self.action_encoder(a_prev)
        combined = torch.cat([h_prev, action_encoded, obs_encoded], dim=-1)
        
        # & Process through combined encoder
        hidden = self.combined_encoder(combined)
        
        # & Produce distribution parameters
        return self.mean(hidden), self.logvar(hidden)
    