from typing import Optional

from lm_understanding.explanations.explanations import LocalExplanationSet
from lm_understanding.question_template import TemplateModelBehavior

from .baseline import Baseline
from .distillation import Distillation
from .linear_regression import LinearRegression
from .llm_prompting import LLMPrompting
from .nearest_neighbor import NearestNeighbor
from .predict_average import PredictAverage

BASELINES = dict(
    predict_average=PredictAverage,
    nearest_neighbor=NearestNeighbor,
    linear_regression=LinearRegression,
    distillation=Distillation,
    llm_prompting=LLMPrompting
)


def make_baseline(baseline_config, model_behavior: TemplateModelBehavior, explainer: Optional[LocalExplanationSet] = None) -> Baseline:
    BaselineClass = BASELINES[baseline_config.class_name]
    return BaselineClass(model_behavior, baseline_config, explainer=explainer)