import torch
import torch.nn as nn


class ValueNetwork(nn.Module):
    """
    # & Value network that outputs state value estimate
    """
    def __init__(self, state_dim: int, hidden_dim: int = 64):
        super(ValueNetwork, self).__init__()
        
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """
        # & Forward pass through value network
        """
        return self.network(state).squeeze(-1)
