import dataclasses
import pydantic
import re
import string
from llm_inference import eval_utils
import pandas as pd
import numpy as np


def exact_match(
  predictions: list[str],
  references: list[str],
  regexes_to_ignore=None,
  ignore_case=False,
  ignore_punctuation=False,
  ignore_numbers=False,
):
  if regexes_to_ignore is not None:
    for s in regexes_to_ignore:
      predictions = np.array([re.sub(s, "", x) for x in predictions])
      references = np.array([re.sub(s, "", x) for x in references])
  else:
    predictions = np.asarray(predictions)
    references = np.asarray(references)

  if ignore_case:
    predictions = np.char.lower(predictions)
    references = np.char.lower(references)

  if ignore_punctuation:
    repl_table = string.punctuation.maketrans("", "", string.punctuation)
    predictions = np.char.translate(predictions, table=repl_table)
    references = np.char.translate(references, table=repl_table)

  if ignore_numbers:
    repl_table = string.digits.maketrans("", "", string.digits)
    predictions = np.char.translate(predictions, table=repl_table)
    references = np.char.translate(references, table=repl_table)

  return predictions == references


@dataclasses.dataclass
class ExactMatch(eval_utils.Metric):
  regexes_to_ignore: list[str] | None = None
  ignore_case: bool = False
  ignore_punctuation: bool = False
  ignore_numbers: bool = False

  def compute(
    self,
    preds: np.ndarray,
    refs: np.ndarray,
  ):
    correct = exact_match(
      predictions=preds,
      references=refs,
      regexes_to_ignore=self.regexes_to_ignore,
      ignore_case=self.ignore_case,
      ignore_punctuation=self.ignore_punctuation,
      ignore_numbers=self.ignore_numbers,
    )
    return pd.DataFrame({"correct": correct})

  def summarize(self, details, index_key: str):
    return {
      "accuracy": details["correct"].mean(),
      "pass_rate": (details.groupby(index_key)["correct"].sum() > 0).mean(),
    }
