
from argparse import Namespace

import numpy as np

from lm_understanding.baselines.baseline import Baseline
from lm_understanding.question_template import TemplateModelBehavior


class NearestNeighbor(Baseline):
    def __init__(self, model_behavior: TemplateModelBehavior, baseline_config: Namespace, *args, **kwargs):
        super().__init__(model_behavior, baseline_config, *args, **kwargs)
        self.k = baseline_config.knn

    def train(self):
        pass

    def predict(self, split: str, question_idx: int) -> float:
        assert self.retriever is not None
        nearest_idx = self.retriever(split, question_idx, self.k)
        scores = [self.model_behavior.answers('train')[i] for i in nearest_idx]
        return np.array(scores).mean()
