import dataclasses
import re
import string
import typing as tp

import numpy as np
import pandas as pd

from llm_inference import eval_utils, output_parsers
from llm_inference.tasks.task import HFTask


# Adapted from the official eval script:
# https://github.com/mandarjoshi90/triviaqa/blob/ca43b5820b107f3970cf4b7d67f7db7a98117b79/evaluation/triviaqa_evaluation.py#L14
def normalize_answer(s):
  """Normalize answers from a model to match TriviaQA formatting."""

  def remove_articles(text):
    return re.sub(r"\b(a|an|the)\b", " ", text)

  def white_space_fix(text):
    return " ".join(text.split())

  def remove_punct(text):
    exclude = set(re.escape(p) for p in string.punctuation)
    return "".join(ch for ch in text if ch not in exclude)

  def lower(text):
    return text.lower()

  return white_space_fix(remove_articles(remove_punct(lower(s))))


@dataclasses.dataclass
class TriviaQA(HFTask):
  dataset_path: str = "mandarjoshi/trivia_qa"
  dataset_name: str = "rc.nocontext"
  dataset_split: str = "validation"

  @property
  def output_keys(self):
    return ["question", "answer"]

  @property
  def stop_tokens(self):
    return ["\n", ".", ","]

  def get_reference_solutions(self, example):
    return example["answer"]["normalized_aliases"]

  def get_reference_solution(self, example: dict) -> str:
    return example["answer"]["normalized_value"]

  def get_evaluation_cfg(self) -> eval_utils.EvaluationConfig:
    # dataset = self.load_dataset()
    # dataset_by_idx = {idx: example for idx, example in enumerate(dataset)}
    # predictions_df = pd.DataFrame({"dataset_idx": indices, "prediction": predictions})
    # predictions_df["sample_idx"] = predictions_df.groupby("dataset_idx").cumcount()

    # results = []
    # for row in predictions_df.itertuples():
    #   reference_example = dataset_by_idx[row.dataset_idx]
    #   reference_solutions = self.get_reference_solutions(reference_example)
    #   prediction = normalize_answer(row.prediction)

    #   correct = any(
    #     prediction == normalize_answer(solution) for solution in reference_solutions
    #   )

    #   results.append(
    #     {
    #       "dataset_idx": row.dataset_idx,
    #       "sample_idx": row.sample_idx,
    #       "correct": correct,
    #       "expected": reference_solutions,
    #       "actual": prediction,
    #     }
    #   )
    # df = pd.DataFrame(results)
    # overall_accuracy = df.groupby("dataset_idx")["correct"].mean().mean()
    # return {"accuracy": overall_accuracy}, {"results": results}
    # return eval_utils.EvaluationConfig(

    #   get_reference=self.get_reference_solution,
    #   execution_strategy="process",
    #   )

    return eval_utils.EvaluationConfig(
      metric=TriviaQAMetric(),
      get_reference=lambda x: tuple(
        normalize_answer(d) for d in x["answer"]["normalized_aliases"]
      ),
      execution_strategy="process",
      output_parser=[_extract_answer],
    )


def _extract_answer(example: str):
  try:
    if example.startswith("A: "):
      example = example[3:]
    return normalize_answer(example)
  except Exception as e:
    raise output_parsers.OutputParserException(f"Failed to extract answer: {e}")


class TriviaQAMetric(eval_utils.Metric):
  def compute(self, preds: np.ndarray, refs: np.ndarray) -> pd.DataFrame:
    results = []
    for pred, ref in zip(preds, refs):
      results.append(any(pred == r for r in ref))
    return pd.DataFrame(
      {
        "correct": results,
      }
    )

  def summarize(self, details: pd.DataFrame, index_key: str) -> dict:
    pass_rate = (details.groupby(index_key)["correct"].sum() > 0).mean()
    return {"accuracy": details["correct"].mean(), "pass_rate": pass_rate}
