import dataclasses

from llm_inference import eval_utils, output_parsers
from llm_inference.metrics import exact_match
from llm_inference.tasks.task import HFTask


@dataclasses.dataclass
class GSM8k(HFTask):
  dataset_path: str = "gsm8k"
  dataset_split: str = "test"

  @property
  def output_keys(self):
    return ["question", "answer"]

  @property
  def stop_tokens(self):
    return ["Q:", "</s>", "<|im_end|>"]

  def get_reference_solutions(self, example):
    return [example["answer"]]

  def get_evaluation_cfg(self) -> eval_utils.EvaluationConfig:
    return eval_utils.EvaluationConfig(
      metric=exact_match.ExactMatch(
        regexes_to_ignore=[r",", r"\\$", r"(?s).*####", r"\\.$"], ignore_case=True
      ),
      get_reference=_get_reference,
      output_parser=[output_parsers.gsm8k_first_number_from_last_line],
      execution_strategy="process",
    )


def _get_reference(x):
  return x["answer"].replace(",", "").split("####")[1].strip()
