import torch
from torch import nn

class PropertyPredictionNet(nn.Module):
    def __init__(self, base_model, prediction_head):
        super(PropertyPredictionNet, self).__init__()

        self._base_model = base_model
        self._prediction_head = prediction_head

    def forward(self, x):
        return self._prediction_head(self._base_model(x))
