import torch
import torch.nn as nn


class LightValueFunction(nn.Module):
    def __init__(self, input_length):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_length, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.fc(x).squeeze(-1)

class ValueFunction(nn.Module):
    def __init__(self, input_length, vocab_size):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_length * vocab_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.fc(x).squeeze(-1) 
