from torch import Tensor
from torch_geometric.data.batch import Batch

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

from building_blocks.neural_network import NeuralNetworkEstimator


class ZeroBaseline(NeuralNetworkEstimator):
    def __init__(self, args):
        super(ZeroBaseline, self).__init__(args)

    def forward(self, batch: Batch):
        return Tensor(len(batch.covariates) * [0.0])
