import torch
import torch.nn as nn
from utils import one_hot_encode

class ValueFunction(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(3 * h, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x_onehot):
        return self.fc(x_onehot).squeeze(-1)

class EstimatedValue:
    def __init__(self, value_functions, reward_fn, H, device):
        self.value_functions = value_functions
        self.reward = reward_fn
        self.H = H
        self.device = device

    def __call__(self, seq):
        h = len(seq)
        if h == self.H:
            return self.reward(seq)
        x = one_hot_encode(seq, h).unsqueeze(0).to(self.device)
        with torch.no_grad():
            v_hat = self.value_functions[h](x).item()
        return v_hat 
