from .base_learner import BaseOCGLearner

class OCGLearner(BaseOCGLearner):
    def __init__(self, model, neighborhood_processor, args):
        super().__init__(model, neighborhood_processor, args)
        self.center_features = False

    def observe(self, g, train_ids, train_labels):
        labels = self.before_passes(train_labels)
        input_data = self.neighborhood_processor.extract(g, train_ids)
        self.net.fit_batch(input_data, labels)
