from typing import Tuple

import torch
import torch.nn as nn

class TransitionNetwork(nn.Module):
    """
    # & Transition network for DVRL, models p(z_t|h_{t-1}, a_{t-1})
    """
    def __init__(self, h_dim: int, z_dim: int, action_dim: int, 
                hidden_layers: int = 1, action_factor: float = 0.5):
        super(TransitionNetwork, self).__init__()
        
        self.action_enc_dim = int(h_dim * action_factor)
        
        # & Action encoder
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, self.action_enc_dim),
            nn.ReLU()
        )
        
        # & Build combined network
        combined_layers = []
        combined_input_dim = h_dim + self.action_enc_dim
        
        # & First layer
        combined_layers.extend([
            nn.Linear(combined_input_dim, h_dim),
            nn.ReLU()
        ])
        
        # & Additional hidden layers if requested
        for _ in range(hidden_layers - 1):
            combined_layers.extend([
                nn.Linear(h_dim, h_dim),
                nn.ReLU()
            ])
            
        self.combined_network = nn.Sequential(*combined_layers)
        
        # & Output heads for distribution parameters
        self.mean = nn.Linear(h_dim, z_dim)
        self.logvar = nn.Linear(h_dim, z_dim)
    
    def forward(self, h_prev: torch.Tensor, a_prev: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        # & Forward pass through transition network
        """
        # & Process action
        action_encoded = self.action_encoder(a_prev)
        
        # & Combine inputs
        combined = torch.cat([h_prev, action_encoded], dim=-1)
        
        # & Process through network
        hidden = self.combined_network(combined)
        
        # & Produce distribution parameters
        return self.mean(hidden), self.logvar(hidden)

