import re
import json

from . import code_judge
from . import sanitize

from custom_verl.reward_utils import RewardType

oj = code_judge.OnlineJudge(timeout=3)


def extract_code(solution_str):
    # assume the code is wrapped in ```python```
    return sanitize.sanitize(solution_str, "main")


def extract_code_thinkans(solution_str):
    pattern = r"<think>.*?</think>\s*<answer>(.*?)</answer>"
    match = re.search(pattern, solution_str, re.DOTALL | re.MULTILINE)
    if match:
        answer_content = match.group(1)
    else:
        answer_content = ""
    return answer_content


def compute_score(
    solution_str,
    test_cases,
    continuous=False,
    format_score=0.1,
    extract_fn=extract_code,
    lang="python",
):
    # code_str = extract_code_thinkans(solution_str)
    # code_str = extract_code(solution_str)
    code_str = extract_fn(solution_str)
    if code_str == "":
        # Format also wrong
        return (0.0, RewardType.FormatError)
    else:
        # complete the checking
        if lang == "python":
            code_str = code_str + "\nmain()"

        if isinstance(test_cases, str):
            test_cases = json.loads(test_cases)

        test_score, reward_types = oj.run(
            code_str,
            test_cases,
        )
        meta_data_list = [{"test_res": x} for x in test_score]
        if continuous:
            return sum(test_score) / len(test_score), meta_data_list
        else:
            eval_score = float(sum(test_score) == len(test_score))
            return max(eval_score, format_score), meta_data_list
