import random,re
import string
import unicodedata
from collections import Counter
from rouge import Rouge
import numpy as np


def extract_math_number(text):
    text = text.strip()

    match = re.search(r"\\boxed\{([^}]*)\}", text)
    if match:
        answer = match.group(1)
    else:
        try:
            answer = text.split("\n")[-1]
            answer = [token for token in answer.split() if any(c.isdigit() for c in token)][-1]
            answer = answer.split(".")[0]
            answer = "".join([c for c in answer if c.isdigit()])
            answer = int(answer)
        except (ValueError, IndexError):
            answer = 0
    return answer

def normalize_fp_text(s):
    s = unicodedata.normalize("NFD", s)

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def _normalize_choice(s: str):
    if s is None:
        return None
    t = s.strip().upper()
    m = re.search(r"\(([A-F])\)", t)  # (A)
    if m:
        return m.group(1)
    m = re.search(r"\b([A-F])\b", t)  # A
    if m:
        return m.group(1)
    return None

def _extract_pred_choice(text: str):
    if text is None:
        return None
    t = text.strip()

    last = t.split("\n")[-1].upper()
    cand = []
    cand += re.findall(r"\(([A-F])\)", last)          # (A)
    cand += re.findall(r"\b([A-F])\)", last)          # A)
    cand += re.findall(r"\(([A-F])\b", last)          # (A
    cand += re.findall(r"\b([A-F])\b", last)          # A
    if cand:
        return cand[-1]

    any_line = re.findall(r"\b([A-F])\b", t.upper())
    return any_line[-1] if any_line else None

def multichoice_metric(example, pred):
    gt_letter = _normalize_choice(example['target'])
    pred_letter = _extract_pred_choice(pred)
    return int(gt_letter is not None and pred_letter == gt_letter)


def math_metric(example, pred):
    out = extract_math_number(pred)

    # answer = example["answer"].strip().split()
    # assert answer[-2] == "####"
    # answer = int(answer[-1].replace(",", ""))
    # answer = extract_math_number(example["answer"])
    answer = extract_math_number(example["target"])

    return int(answer == out)

def classification_metric(exmaple, pred):
    # return normalize_fp_text(exmaple["label"]) in normalize_fp_text(pred)
    return normalize_fp_text(exmaple["target"]) in normalize_fp_text(pred)


def summary_metric(example, pred):
    rouge = Rouge()
    # scores = rouge.get_scores(pred, example["summary"])
    scores = rouge.get_scores(pred, example["target"])
    return scores[0]['rouge-l']['f']


def evaluate_batched_metric(taskname,examples, preds):
    if not examples:
        raise ValueError("Empty examples.")
    if len(examples) != len(preds):
        raise ValueError(f"Length mismatch: examples={len(examples)} preds={len(preds)}")


    if taskname == "xsum":
        metric=summary_metric
    elif taskname == "fp":
        metric = classification_metric
    elif taskname in ["date","salient","gpqa"]:
        metric = multichoice_metric
    elif taskname == "gsm8k":
        metric = math_metric
    else:
        raise ValueError(f"unknown task: {taskname}")

    scores = []
    for example, pred in zip(examples, preds):

        try:
            s = metric(example, pred)
            scores.append(float(0 if s is None else s))
        except Exception:
            scores.append(0.0)
    return float(np.mean(scores)) if scores else 0.0