import numpy as np

from lm_understanding.baselines.baseline import Baseline


class PredictAverage(Baseline):
    def train(self):
        self.answer = np.array(self.model_behavior.answers('train')).astype(float).mean()

    def predict(self, *args, **kwargs) -> float:
        return self.answer