import dataclasses
import multiprocessing as mp
import operator
import re
from typing import Sequence

import datasets as ds
from numpy import ndarray
import pandas as pd
from loguru import logger
from tqdm.auto import tqdm

from llm_inference import eval_utils, output_parsers
from llm_inference.eval_utils import Metric
from llm_inference.tasks import math_utils
from llm_inference.tasks.task import Task

SUBSETS = [
  "algebra",
  "counting_and_probability",
  "geometry",
  "intermediate_algebra",
  "number_theory",
  "precalculus",
  "prealgebra",
]


def extract_answer(llm_response):
  # First, check for LaTeX content between $ symbols
  latex_pattern = r"\$(.*?)\$"
  latex_matches = re.findall(latex_pattern, llm_response)

  if latex_matches:
    # If LaTeX content is found, return the last LaTeX expression
    return latex_matches[-1].strip()

  # If no LaTeX, look for the last number
  number_pattern = r"-?\d+(?:\.\d+)?"
  number_matches = re.findall(number_pattern, llm_response)

  if number_matches:
    # If numbers are found, return the last one
    return number_matches[-1]

  # If neither LaTeX nor numbers are found
  return None


@dataclasses.dataclass
class MATH(Task):
  dataset_split: str = "test"

  def load_dataset(
    self,
  ):
    datasets = []
    for subset in SUBSETS:
      dataset = ds.load_dataset(
        "EleutherAI/hendrycks_math",
        subset,
        split=self.dataset_split,
        trust_remote_code=True,
      )
      dataset = dataset.map(lambda x: {**x, "subset": subset})
      datasets.append(dataset)

    merged_dataset = ds.concatenate_datasets(datasets)
    merged_dataset = math_utils.process_docs(merged_dataset)

    return merged_dataset

  @property
  def stop_tokens(self):
    return ["Problem:"]

  @property
  def output_keys(self):
    return ["problem", "solution", "answer", "type", "subset", "level"]

  def get_reference_solutions(self, example: dict) -> list[str]:
    return [example["answer"]]

  def get_evaluation_cfg(self):
    return eval_utils.EvaluationConfig(
      MathMetric(),
      get_reference=operator.itemgetter("answer"),
      output_parser=[math_utils.extract_final_answer],
      execution_strategy="process",
    )


class MathMetric(eval_utils.Metric):
  def compute(self, preds: ndarray, refs: ndarray) -> pd.DataFrame:
    return pd.DataFrame(
      {"correct": [math_utils.is_equiv(pred, ref) for pred, ref in zip(preds, refs)]}
    )

  def summarize(self, details: pd.DataFrame, dataset_index_key: str) -> dict:
    return {
      "accuracy": details["correct"].mean(),
      "pass_rate": (details.groupby(dataset_index_key)["correct"].sum() > 0).mean(),
    }
