# policy_models/constrained_policy.py
# A stubbed constrained policy model that maps context vectors -> logits over pathways or actions.
# In RLHF you would implement reward modeling, PPO/PG updates, etc. This is a simple supervised policy example.

from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConstrainedPolicy(nn.Module):
    """
    Policy that outputs action probabilities conditioned on context embedding.
    Actions could be: generate_constrained, escalate, safe_template, etc.
    """

    def __init__(self, input_dim: int = 256, hidden: int = 128, n_actions: int = 3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_actions),
        )

    def forward(self, context: torch.Tensor) -> torch.Tensor:
        """
        context: (B, input_dim) or (input_dim,)
        returns: action probabilities (B, n_actions)
        """
        if context.dim() == 1:
            context = context.unsqueeze(0)
        logits = self.net(context)
        return F.softmax(logits, dim=-1)
