import re
import json
import pandas as pd
from core.infra.llm import LLMApi
from core.model.prompt import Prompt
from sklearn.metrics import f1_score


def parse_acc_response(response):
    """
    抽取第一个数字字符串（整数或小数），否则返回空字符串
    """
    if not response:
        return ""
    m = re.findall(r"-?\d+(?:\.\d+)?", response)
    return m[0] if m else ""


def parse_f1_response(response):
    if response is None:
        return "NO"
    try:
        json_data = json.loads(response.strip())
        if isinstance(json_data, list) and len(json_data) > 0:
            label = json_data[0].get("label", "").strip().upper()
            if label in ["YES", "NO"]:
                return label
    except (json.JSONDecodeError, KeyError, IndexError):
        response_upper = response.strip().upper()
        if "YES" in response_upper:
            return "YES"
        if "NO" in response_upper:
            return "NO"
        print(f"无效响应格式: {response[:200]}...")
    return "YES"


class Evaluator:
    def __init__(self, is_f1: bool = False):
        self.is_f1: bool = is_f1
        self.parse_response = parse_f1_response if is_f1 else parse_acc_response

    def evaluate(self, prompt: Prompt, dataset: pd.DataFrame):
        def generate_prediction(input_series: pd.Series) -> pd.Series:
            llm_api = LLMApi()
            results = []
            for input_text in input_series:
                messages = [
                    {"role": "system", "content": prompt.text},
                    {"role": "user", "content": input_text}
                ]
                response = llm_api.generate(messages)
                pred = self.parse_response(response)
                results.append(pred)
            return pd.Series(results)

        targets = dataset["target"].tolist()
        predictions = generate_prediction(dataset["input"]).tolist()  # FIXME 把 col("input") -> 换成 dataset["input"]，是否等价？

        if self.is_f1:
            bad_cases = [(input_text, target) for input_text, target, pred in zip(dataset["input"], targets, predictions) if pred != target or pred == ""]  # bad_cases 只记录不正确的input和target
            score = f1_score(targets, predictions, average='macro')
            prompt.f1 = score
            prompt.bad_cases = bad_cases
        else:
            correct = [p == t and p != "" for p, t in zip(predictions, targets)]
            score = sum(correct) / len(correct) if correct else 0.0
            prompt.bad_cases = [(input_text, target) for input_text, target, pred in zip(dataset["input"], targets, predictions) if pred != target or pred == ""]  # bad_cases 只记录不正确的input和target
            prompt.acc = score

        return score
