from src.evaluator.metric.base_metric import BaseMetric
from src.evaluator.metric.em_metric import EMMetric
from src.evaluator.metric.f1_metric import F1Metric
from src.evaluator.metric.gps_metric import GPSMetric
from src.evaluator.metric.ips_metric import IPSMetric
from src.evaluator.metric.recall_at_k_metric import RecallAtKMetric


__all__ = [
    "metric_factory",
    "BaseMetric",
]


AVAILABLE_RETREIVAL_METRICS = [
    "recall_at_k",
]

AVAILABLE_GENERATION_METRICS = [
    "em",
    "f1",
    "ips",
    "gps",
]

AVAILABLE_METRICS = AVAILABLE_RETREIVAL_METRICS + AVAILABLE_GENERATION_METRICS


def metric_factory(metric_name: str) -> BaseMetric:
    if metric_name == "recall_at_k":
        return RecallAtKMetric()
    elif metric_name == "em":
        return EMMetric()
    elif metric_name == "f1":
        return F1Metric()
    elif metric_name == "ips":
        return IPSMetric()
    elif metric_name == "gps":
        return GPSMetric()
    else:
        raise ValueError(f"Unknown metric_name: {metric_name}")
