import torch.nn as nn

from .commons import ResidualConv


class Value(nn.Module):
    def __init__(
        self,
        d_hidden,
    ):
        super().__init__()
        self.layers = nn.Sequential(
            ResidualConv(d_hidden),
            nn.AdaptiveMaxPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(d_hidden, 1),
        )

    def forward(
        self,
        states,
    ):
        return self.layers(states)
