# ============================================================================
# GSM8K Custom Evaluation Configuration
# ============================================================================
# REQUIRED: This file must be placed alongside utils.py (see below)
# ============================================================================

tag:
  - math_word_problems
  - custom_eval

task: gsm8k_custom_standalone

dataset_path: gsm8k
dataset_name: main
output_type: generate_until

training_split: train
fewshot_split: train
test_split: test

# Use custom preprocessing function to remove calculator notation
doc_to_text: !function utils.doc_to_text
doc_to_target: !function utils.doc_to_target

# Use exact_match metric with custom filter
metric_list:
  - metric: exact_match
    aggregation: mean
    higher_is_better: true

# Generation with sampling (not greedy) and fixed token count
generation_kwargs:
  max_gen_toks: 512
  until:
    - "Question:"
    - "</s>"
    - "<|im_end|>"
  # Uncomment to customize sampling (default: model's natural sampling)
  # do_sample: true
  # temperature: 1.0
  # top_p: 0.9

repeats: 1
num_fewshot: 5

# Custom filter that extracts all numbers and checks if answer is present
filter_list:
  - name: "extract-all-numbers"
    filter:
      - function: !function utils.ExtractAllNumbersFilter

metadata:
  version: 1.0
  description: "Custom GSM8K evaluation with flexible number extraction"
  paper: "Training Verifiers to Solve Math Word Problems (Cobbe et al., 2021)"
  url: "https://arxiv.org/abs/2110.14168"

# ============================================================================
# REQUIRED utils.py FILE - Save the following code as:
# lm_eval/tasks/gsm8k/utils.py
# In https://github.com/EleutherAI/lm-evaluation-harness
# ============================================================================
# """
# import re
# from lm_eval.api.filter import Filter
# from lm_eval.api.registry import register_filter
#
#
# def doc_to_text(doc):
#     \"\"\"Remove calculator notation and answer markers from question.\"\"\"
#     question = doc["question"]
#     question = re.sub(r"<<[^>]*>>", "", question)
#     question = re.sub(r"####.*$", "", question, flags=re.DOTALL)
#     question = re.sub(r"\s+", " ", question).strip()
#     return f"Question: {question}\nAnswer:"
#
#
# def doc_to_target(doc):
#     \"\"\"Extract final answer from answer field.\"\"\"
#     answer = doc["answer"]
#     match = re.search(r"####\s*(.+)", answer)
#     if match:
#         return match.group(1).strip()
#     return answer.strip()
#
#
# @register_filter("extract_all_numbers")
# class ExtractAllNumbersFilter(Filter):
#     \"\"\"Extract all numbers and check if answer is present.\"\"\"
#
#     def __init__(self, fallback: str = "[invalid]") -> None:
#         self.fallback = fallback
#
#     def extract_numbers(self, text: str) -> list:
#         \"\"\"Extract all numbers, distinguishing singles, multis, and floats.\"\"\"
#         numbers = []
#         pattern = r'-?\b\d+(?:,\d{3})*(?:\.\d+)?\b'
#         matches = re.finditer(pattern, text)
#         for match in matches:
#             num_str = match.group().replace(',', '')
#             numbers.append(num_str)
#         return numbers
#
#     def normalize_number(self, num_str: str) -> str:
#         \"\"\"Normalize number for comparison.\"\"\"
#         num_str = num_str.replace(',', '').replace('$', '').strip()
#         try:
#             num_float = float(num_str)
#             if num_float == int(num_float):
#                 return str(int(num_float))
#             return str(num_float)
#         except ValueError:
#             return num_str
#
#     def numbers_match(self, extracted: str, target: str) -> bool:
#         \"\"\"Check if two numbers match.\"\"\"
#         norm_extracted = self.normalize_number(extracted)
#         norm_target = self.normalize_number(target)
#         if norm_extracted == norm_target:
#             return True
#         try:
#             return float(norm_extracted) == float(norm_target)
#         except (ValueError, TypeError):
#             return False
#
#     def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
#         \"\"\"Apply filter to extract and match numbers.\"\"\"
#         filtered_resps = []
#         for resp_set, doc in zip(resps, docs):
#             filtered = []
#             target = doc_to_target(doc)
#             target_normalized = self.normalize_number(target)
#             for resp in resp_set:
#                 extracted_numbers = self.extract_numbers(resp)
#                 match_found = None
#                 for num in extracted_numbers:
#                     if self.numbers_match(num, target):
#                         match_found = target_normalized
#                         break
#                 if match_found:
#                     filtered.append(match_found)
#                 elif extracted_numbers:
#                     filtered.append(self.normalize_number(extracted_numbers[0]))
#                 else:
#                     filtered.append(self.fallback)
#             filtered_resps.append(filtered)
#         return filtered_resps
# """
# ============================================================================
