import torch
import torch.nn as nn


class SimpleRewardModel(nn.Module):
    def __init__(self, vocabSize: int):
        super(SimpleRewardModel, self).__init__()

        self.linear = nn.Linear(vocabSize, 1)

    def forward(self, logits):
        """
        logits is a tensor of shape (batch_size, N, vocabSize)
        """
        x = logits[:, -1, :] # last token
        return self.linear(x)




