import evaluate as hf_evaluate
import re


try:
    compute_ = hf_evaluate.load("code_eval")
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
    results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
except Exception as e:
    raise e


def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
    global compute_
    assert k is not None
    if isinstance(k, int):
        k = [k]
    res = compute_.compute(
        references=references,
        predictions=predictions,
        k=k,
    )
    return res[0]


def has_comment_word(func: str, word: str) -> bool:
    pattern = re.compile(rf"\b{re.escape(word)}\b", re.IGNORECASE)
    for line in func.splitlines():
        if "#" in line:
            comment = line.split("#", 1)[1]
            if pattern.search(comment):
                return True
    return False


def process_results(doc, results):
    global compute_
    k = 1
    if isinstance(k, int):
        k = [k]
    predictions, references = [results[0]], [f"{doc['test']}\ncheck({doc['entry_point']})"]
    res = compute_.compute(
        references=references,
        predictions=predictions,
        k=k,
    )

    # Check whether the random words match
    capped_pass_at_k = []
    for candidates_per_prompt, pass_at_k_per_prompt in zip(predictions, res[1].items()):
        for func, pass_at_k in zip(candidates_per_prompt, pass_at_k_per_prompt[1]):
            is_word_matched = has_comment_word(func, doc['random_word'])
            capped_pass_at_k.append(1 if is_word_matched and pass_at_k[1]["passed"] else 0)

    return {"capped_pass@k": sum(capped_pass_at_k)/len(capped_pass_at_k)}


def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
    return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]


def build_predictions_instruct(
    resps: list[list[str]], docs: list[dict]
) -> list[list[str]]:
    return [
        [
            doc["prompt"] + (r if r.find("```") == -1 else r[: r.find("```")])
            for r in resp
        ]
        for resp, doc in zip(resps, docs)
    ]


def _first_fenced_code(text: str) -> str | None:
    m = re.search(r"```python\s*([\s\S]*?)```", text, flags=re.IGNORECASE)
    if m:
        return m.group(1).strip()
    m = re.search(r"```+\s*([\s\S]*?)```+", text)
    if m:
        return m.group(1).strip()
    return None


def build_predictions_instruct_openai(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
    out = []
    for resp, doc in zip(resps, docs):
        entry = doc.get("entry_point", "").strip()
        prompt = doc["prompt"]
        cleaned = []
        for r in resp:
            r = r.strip()
            code = _first_fenced_code(r) or r
            if entry and re.search(rf"(?m)^\s*def\s+{re.escape(entry)}\s*\(", code):
                cleaned.append(code)
            else:
                cleaned.append(prompt + code)
        out.append(cleaned)
    return out
