import torch
import torch.nn as nn

class ValueFunction(nn.Module):
    def __init__(self, input_length, hidden_dim=64, num_layers=1):
        super().__init__()
        layers = [nn.Linear(input_length, hidden_dim), nn.ReLU()]
        for i in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, 1))
        layers.append(nn.Sigmoid())
        self.fc = nn.Sequential(*layers)
    def forward(self, x):
        return self.fc(x).squeeze(-1)

class NormalizedValueFunction(nn.Module):
    def __init__(self, input_length, hidden_dim=64, num_layers=1):
        super().__init__()
        layers = [nn.LayerNorm(input_length), nn.Linear(input_length, hidden_dim), nn.ReLU()]
        for i in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_dim, 1))
        layers.append(nn.Sigmoid())
        self.fc = nn.Sequential(*layers)
    def forward(self, x):
        return self.fc(x).squeeze(-1)

class __DeprecatedValueFunction(nn.Module):
    def __init__(self, input_length, vocab_size):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_length * vocab_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.fc(x).squeeze(-1) 
