import torch
import torch.nn as nn


class PolicyNetwork(nn.Module):
    """
    # & Policy network that outputs action distribution parameters
    """
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64, 
                continuous_actions: bool = False):
        super(PolicyNetwork, self).__init__()
        
        self.continuous_actions = continuous_actions
        
        # & Shared layers
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )
        
        if continuous_actions:
            # & For continuous action spaces: output mean and log_std
            self.mean = nn.Linear(hidden_dim, action_dim)
            self.log_std = nn.Parameter(torch.zeros(action_dim))
        else:
            # & For discrete action spaces: output logits
            self.logits = nn.Linear(hidden_dim, action_dim)
    
    def forward(self, state: torch.Tensor):
        """
        # & Forward pass through policy network
        """
        x = self.shared(state)
        
        if self.continuous_actions:
            mean = self.mean(x)
            return mean, self.log_std.expand_as(mean)
        else:
            return self.logits(x)
