import torch
import torch.nn as nn


class StateUpdateRNN(nn.Module):
    """
    # & RNN for deterministic state updates: h_t = ψ_RNN(h_{t-1}, z_t, a_{t-1}, o_t)
    """
    def __init__(self, h_dim: int, z_dim: int, action_dim: int, obs_dim: int, 
                rnn_type: str = 'gru', action_factor: float = 0.5):
        super(StateUpdateRNN, self).__init__()
        
        self.action_enc_dim = int(h_dim * action_factor)
        
        # & Encoders for inputs
        self.z_encoder = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU()
        )
        
        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, self.action_enc_dim),
            nn.ReLU()
        )
        
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, h_dim),
            nn.ReLU()
        )
        
        # & RNN Cell
        combined_input_size = h_dim + h_dim + self.action_enc_dim
        if rnn_type.lower() == 'gru':
            self.rnn_cell = nn.GRUCell(combined_input_size, h_dim)
        elif rnn_type.lower() == 'lstm':
            self.rnn_cell = nn.LSTMCell(combined_input_size, h_dim)
        else:
            raise ValueError(f"Unsupported RNN type: {rnn_type}")
            
        self.rnn_type = rnn_type.lower()
    
    def forward(self, h_prev: torch.Tensor, z_curr: torch.Tensor, 
                a_prev: torch.Tensor, o_curr: torch.Tensor) -> torch.Tensor:
        """
        # & Forward pass through state update RNN
        """
        # & Process inputs
        z_encoded = self.z_encoder(z_curr)
        action_encoded = self.action_encoder(a_prev)
        obs_encoded = self.obs_encoder(o_curr)
        
        # & Combine inputs
        combined = torch.cat([z_encoded, action_encoded, obs_encoded], dim=-1)
        
        # & Update state with RNN cell
        if self.rnn_type == 'gru':
            h_next = self.rnn_cell(combined, h_prev)
        elif self.rnn_type == 'lstm':
            h_next, c_next = self.rnn_cell(combined, h_prev)
            h_next = (h_next, c_next)
            
        return h_next

