import evaluate as hf_evaluate
import re

try:
    compute_ = hf_evaluate.load("code_eval")
    test_cases = ["assert add(2, 3)==5"]
    candidates = [["def add(a,b): return a*b"]]
    results = compute_.compute(references=test_cases, predictions=candidates, k=[1])
except Exception as e:
    raise e


def pass_at_k(references: list[str], predictions: list[list[str]], k: list[int] = None):
    global compute_
    assert k is not None
    if isinstance(k, int):
        k = [k]
    res = compute_.compute(
        references=references,
        predictions=predictions,
        k=k,
    )
    return res[0]


def build_predictions(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
    return [[doc["prompt"] + r for r in resp] for resp, doc in zip(resps, docs)]


def build_predictions_instruct(
    resps: list[list[str]], docs: list[dict]
) -> list[list[str]]:
    return [
        [
            doc["prompt"] + (r if r.find("```") == -1 else r[: r.find("```")])
            for r in resp
        ]
        for resp, doc in zip(resps, docs)
    ]

def build_predictions_cot(resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
    all_preds = []

    for resp_list, doc in zip(resps, docs):
        doc_preds = []

        for r in resp_list:
            r_no_think = re.sub(r"<think>.*?</think>", "", r, flags=re.DOTALL).strip()

            matches = re.findall(r"```(?:python)?\n(.*?)```", r_no_think, flags=re.DOTALL)
            if matches:
                code = matches[-1].strip()
            else:
                code = r_no_think

            final_code = doc["prompt"] + code
            doc_preds.append(final_code)

        all_preds.append(doc_preds)

    return all_preds