from src.evaluator.metric.base_metric import BaseMetric
from src.evaluator.metric.recall_at_k_metric import RecallAtKMetric
from src.evaluator.metric.em_metric import EMMetric
from src.evaluator.metric.ips_metric import IPSMetric
from src.evaluator.metric.dps_metric import DPSMetric


__all__ = [
    "metric_factory",
    "BaseMetric",
]


AVAILABLE_RETREIVAL_METRICS = [
    "recall_at_k",
]

AVAILABLE_GENERATION_METRICS = [
    "em",
    "ips",
    "dps",
]

AVAILABLE_METRICS = AVAILABLE_RETREIVAL_METRICS + AVAILABLE_GENERATION_METRICS


def metric_factory(metric_name: str) -> BaseMetric:
    if metric_name == "recall_at_k":
        return RecallAtKMetric()
    elif metric_name == "em":
        return EMMetric()
    elif metric_name == "ips":
        return IPSMetric()
    elif metric_name == "dps":
        return DPSMetric()
    else:
        raise ValueError(f"Unknown metric_name: {metric_name}")
