"""Reward functions for GxPO training."""

import asyncio
import json
import math
import re
from functools import partial, update_wrapper
from typing import Callable, Dict
import torch.distributed as dist

from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from .utils import is_e2b_available
from .utils.ioi import SubtaskResult, add_includes, get_piston_client_from_env, score_subtask


if is_e2b_available():
    from dotenv import load_dotenv
    from e2b_code_interpreter import AsyncSandbox

    load_dotenv()
else:
    AsyncSandbox = None

def pre_process(completions):
    """retrieve the completion content from input"""
    if  isinstance(completions[0],(list,)):
        completion_contents = [completion[0]["content"] for completion in completions]
    elif isinstance(completions[0],(dict)):
        completion_contents = [completion["content"] for completion in completions]
    else:
        completion_contents = [completion for completion in completions]
    return completion_contents

def accuracy_reward(completions, solution, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    # contents = [completion[0]["content"] for completion in completions]
    contents = pre_process(completions)
    rewards = []
    for content, sol in zip(contents, solution):
        gold_parsed = parse(
            sol,
            extraction_mode="first_match",
            extraction_config=[LatexExtractionConfig()],
        )
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        # Ensures that boxed is tried first
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            try:
                reward = float(verify(answer_parsed, gold_parsed))
            except Exception as e:
                print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
                reward = 0.0
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("accuracy_reward: Failed to parse gold solution: ", sol)
        rewards.append(reward)

    return rewards


def accuracy_reward_lv35(completions, answer, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""
    # if isinstance(completions[0],(dict)):
    #     contents = [completion["content"] for completion in completions]
    # else:
    #     contents = [completion for completion in completions]
    contents = pre_process(completions)
    rewards = []
    for content, sol in zip(contents, answer):
        box_sol = "$\\\\boxed{}$".format(sol)
        try:
            gold_parsed = parse(
                box_sol,
                extraction_mode="first_match",
            )
        except TimeoutError:
            rank = dist.get_rank() if dist.is_initialized() else 0
            print(f"[Rank  {rank}] gold parse timeout | content='{content}' | sol='{sol}' | box_sol='{box_sol}'")
            rewards.append(1.0)
            continue
        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            try:
                answer_parsed = parse(
                    content,
                    extraction_config=[
                        LatexExtractionConfig(
                            normalization_config=NormalizationConfig(
                                nits=False,
                                malformed_operators=False,
                                basic_latex=True,
                                equations=True,
                                boxed="all",
                                units=True,
                            ),
                            # Ensures that boxed is tried first
                            boxed_match_priority=0,
                            try_extract_without_anchor=False,
                        )
                    ],
                    extraction_mode="first_match",
                )
                # print(f'answer_parsed:{answer_parsed}')
                # if len(anxswer_parsed) == 0:
                #     print(f"answer_parsed is None | content='{content}' | sol='{sol}'")
            except TimeoutError:
                rank = dist.get_rank() if dist.is_initialized() else 0
                print(f"[Rank {rank}] answer parse timeout | content='{content}' | sol='{sol}'")
                rewards.append(0.0)
                continue
            # Reward 1 if the content is the same as the ground truth, 0 otherwise
            try:
                reward = float(verify(answer_parsed, gold_parsed))
            except Exception as e:
                print(f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}")
                reward = 0.0
        else:
            # If the gold solution is not parseable, we reward 1 to skip this example
            reward = 1.0
            print("accuracy_reward_lv35: Failed to parse gold solution: ", box_sol)
        rewards.append(reward)

    return rewards

def extra_box_len_reward_v1(completions, threhold=100.0, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""

    def after_substring(x, y):
        if x not in y:
            return ("", 1)
        index = y.find(x)
        if index == -1:
            return None
        return (y[index + len(x):], 0)

    # if isinstance(completions[0], (dict)):
    #     contents = [completion["content"] for completion in completions]
    # else:
    #     contents = [completion for completion in completions]
    contents = pre_process(completions)
    penaltys = []
    n_notfind = 0
    for content in contents:
        # We require the answer to be provided in correct latex (no malformed operators)
        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed="all",
                        units=True,
                    ),
                    # Ensures that boxed is tried first
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )

        if len(answer_parsed) == 0:
            penaltys.append(0.0)
        else:
            extra_box_part, notfind = after_substring(answer_parsed[-1], content)
            n_notfind += notfind
            #             print("extra_box_part",extra_box_part)
            extra_len = len(extra_box_part)
            if extra_len < threhold:
                penaltys.append(0.0)
            else:
                penalty = max(-(extra_len / threhold - 1.0), -1.0)
                #                 penalty = extra_len
                penaltys.append(penalty)
    return penaltys

def extra_box_len_reward_v2(completions, threhold=100.0, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth."""

    def after_substring(x, y):
        if x not in y:
            return ("", 1)
        index = y.rfind(x)
        if index == -1:
            return None
        return (y[index + len(x):], 0)

    # if isinstance(completions[0], (dict)):
    #     contents = [completion["content"] for completion in completions]
    # else:
    #     contents = [completion for completion in completions]
    contents = pre_process(completions)
    penaltys = []
    n_notfind = 0
    for content in contents:
        # We require the answer to be provided in correct latex (no malformed operators)
        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed="all",
                        units=True,
                    ),
                    # Ensures that boxed is tried first
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )

        if len(answer_parsed) == 0:
            penaltys.append(0.0)
        else:
            extra_box_part, notfind = after_substring(answer_parsed[-1], content)
            n_notfind += notfind
            #             print("extra_box_part",extra_box_part)
            extra_len = len(extra_box_part)
            if extra_len < threhold:
                penaltys.append(0.0)
            else:
                penalty = max(-(extra_len / threhold - 1.0), -1.0)
                #                 penalty = extra_len
                penaltys.append(penalty)
    return penaltys

def get_language_penalty_reward(completions, **kwargs):
    contents = pre_process(completions)
    allowed_symbols = " +-*/=()[]{}<>^_.,:;!?|&%~#@\"'"  #
    safe_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + allowed_symbols)
    rewards = []
    latex_pattern = r'\$.*?\$|\\[a-zA-Z]+|\\[^\w\s]'
    for content in contents:
        if not content.strip():
            rewards.append(0.0)
            continue
        cleaned = re.sub(latex_pattern, '', content)
        non_allowed_count = 0
        non_allowed_char_str = ""
        for char in cleaned:
            if char not in safe_chars and not char.isspace():
                non_allowed_count += 1
                non_allowed_char_str += char
        token_ratio = len(non_allowed_char_str) / len(content)
        reward = -1.0 * token_ratio if non_allowed_count > 0 else 0.0
        rewards.append(reward)
    return rewards

def format_reward(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>$"
    # if  isinstance(completions[0],(list,)):
    #     completion_contents = [completion[0][0]["content"] for completion in completions]
    # else:
    #     completion_contents = [completion for completion in completions]
    completion_contents = pre_process(completions)
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]


def format_reward_v2(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""

    def count_tags(text: str) -> float:
        count = 0.0
        # We only count </think> tag, because <think> tag is available in system prompt
        if text.count("\n</think>\n") == 1:
            count += 1.0
        return count

    # contents = [completion[0]["content"] for completion in completions]
    contents = pre_process(completions)
    return [count_tags(c) for c in contents]

def format_reward_v3(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags."""
    pattern = r"^<think>\n.*?\n</think>\n.*?$"
    # if  isinstance(completions[0],(list,)):
    #     completion_contents = [completion[0][0]["content"] for completion in completions]
    # else:
    #     completion_contents = [completion for completion in completions]
    completion_contents = pre_process(completions)
    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

def tag_count_reward(completions, **kwargs) -> list[float]:
    """Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.
    Adapted from: https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py-L90
    """
    def count_tags(text: str) -> float:
        count = 0.0
        if text.count("<think>\n") == 1:
            count += 0.25
        if text.count("\n</think>\n") == 1:
            count += 0.25
        if text.count("\n<answer>\n") == 1:
            count += 0.25
        if text.count("\n</answer>") == 1:
            count += 0.25
        return count

    contents = pre_process(completions)

    return [count_tags(c) for c in contents]


def tag_count_reward_v2(completions, **kwargs) -> list[float]:
    """Reward function that checks if we produce the desired number of think and answer tags associated with `format_reward()`.

    remove <answer>, </answer> from tag_count_reward
    """

    def count_tags(text: str) -> float:
        count = 0.0
        if text.count("<think>\n") == 1:
            count += 0.5
        else:
            count -= 0.5
        if text.count("\n</think>\n") == 1:
            count += 0.5
        else:
            count -= 0.5
        return count

    # contents = [completion[0]["content"] for completion in completions]
    contents = pre_process(completions)
    return [count_tags(c) for c in contents]

def fused_format_reward(completions, **kwargs):
    """Reward function that checks if the reasoning process is enclosed within <think> and </think> tags. Also cater the number of <think> and </think> tags"""
    pattern = r"^<think>\n\s*\S.*?\n</think>\n.*?$"
    rewards = []
    # print(f'completions[0]:{completions[0]}')
    # if  isinstance(completions[0],(list,)):
    #     completion_contents = [completion[0]["content"] for completion in completions]
    # else:
    #     completion_contents = [completion for completion in completions]
    completion_contents = pre_process(completions)
    for content in completion_contents:
        reward = 1.0 if re.match(pattern, content, re.DOTALL | re.MULTILINE) else -1.0
        if reward == 1.0:
            reward = reward if content.count("<think>\n") == 1 and content.count("\n</think>\n") == 1 else reward - 2.0
            # reward = reward if content.count("<think>") == 1 else reward - 1.0
            # reward = reward if content.count("</think>") == 1 else reward - 1.0
        rewards.append(reward)
    return rewards

def reasoning_steps_reward(completions, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"

    # # completion_contents = [completion[0]["content"] for completion in completions]
    # if isinstance(completions[0], (dict)):
    #     completion_contents = [completion["content"] for completion in completions]
    # else:
    #     completion_contents = [completion for completion in completions]
    completion_contents = pre_process(completions)

    matches = [len(re.findall(pattern, content)) for content in completion_contents]

    # Magic number 3 to encourage 3 steps and more, otherwise partial reward
    return [min(1.0, count / 3) for count in matches]


def len_reward(completions: list[Dict[str, str]], answer: list[str], **kwargs) -> float:
    """Compute length-based rewards to discourage overthinking and promote token efficiency.

    Taken from the Kimi 1.5 tech report: https://arxiv.org/abs/2501.12599

    Args:
        completions: List of model completions
        answer: List of ground truth solutions

    Returns:
        List of rewards where:
        - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
        - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
    """
    # if isinstance(completions[0], (dict)):
    #     contents = [completion["content"] for completion in completions]
    # else:
    #     contents = [completion for completion in completions]
    contents = pre_process(completions)
    # First check correctness of answers
    correctness = []
    for content, sol in zip(contents, answer):
        #############################
        # gold_parsed = parse(
        #     sol,
        #     extraction_mode="first_match",
        #     extraction_config=[LatexExtractionConfig()],
        # )
        # if len(gold_parsed) == 0:
        #############################
        box_sol = "$\\\\boxed{}$".format(sol)
        gold_parsed = parse(
            box_sol,
            extraction_mode="first_match",
        )
        if len(gold_parsed) == 0:
            # Skip unparseable examples
            correctness.append(True)  # Treat as correct to avoid penalizing
            print("len_reward: Failed to parse gold solution: ", sol)
            continue

        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed=True,
                        units=True,
                    ),
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )
        correctness.append(verify(answer_parsed, gold_parsed))

    # Calculate lengths
    lengths = [len(content) for content in contents]
    min_len = min(lengths)
    max_len = max(lengths)

    # If all responses have the same length, return zero rewards
    if max_len == min_len:
        return [0.0] * len(completions)

    rewards = []
    for length, is_correct in zip(lengths, correctness):
        lambda_val = 0.5 - (length - min_len) / (max_len - min_len)

        if is_correct:
            reward = lambda_val
        else:
            reward = min(0, lambda_val)

        rewards.append(float(reward))

    return rewards


def get_cosine_scaled_reward(
    min_value_wrong: float = -1.0,
    max_value_wrong: float = -0.5,
    min_value_correct: float = 0.5,
    max_value_correct: float = 1.0,
    max_len: int = 1000,
):
    def cosine_scaled_reward(completions, solution, **kwargs):
        """Reward function that scales based on completion length using a cosine schedule.

        Shorter correct solutions are rewarded more than longer ones.
        Longer incorrect solutions are penalized less than shorter ones.

        Args:
            completions: List of model completions
            solution: List of ground truth solutions

        This function is parameterized by the following arguments:
            min_value_wrong: Minimum reward for wrong answers
            max_value_wrong: Maximum reward for wrong answers
            min_value_correct: Minimum reward for correct answers
            max_value_correct: Maximum reward for correct answers
            max_len: Maximum length for scaling
        """
        # contents = [completion[0]["content"] for completion in completions]
        contents = pre_process(completions)
        rewards = []

        for content, sol in zip(contents, solution):
            gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
            if len(gold_parsed) == 0:
                rewards.append(1.0)  # Skip unparseable examples
                print("cosine_scaled_reward: Failed to parse gold solution: ", sol)
                continue

            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed=True,
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            is_correct = verify(answer_parsed, gold_parsed)
            gen_len = len(content)

            # Apply cosine scaling based on length
            progress = gen_len / max_len
            cosine = math.cos(progress * math.pi)

            if is_correct:
                min_value = min_value_correct
                max_value = max_value_correct
            else:
                # Swap min/max for incorrect answers
                min_value = max_value_wrong
                max_value = min_value_wrong

            reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
            rewards.append(float(reward))

        return rewards

    return cosine_scaled_reward


def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
    """
    Computes N-gram repetition penalty as described in Appendix C.2 of https://arxiv.org/abs/2502.03373.
    Reference implementation from: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

    Args:
    ngram_size: size of the n-grams
    max_penalty: Maximum (negative) penalty for wrong answers
    """
    if max_penalty > 0:
        raise ValueError(f"max_penalty {max_penalty} should not be positive")

    def zipngram(text: str, ngram_size: int):
        words = text.lower().split()
        return zip(*[words[i:] for i in range(ngram_size)])

    def repetition_penalty_reward(completions, **kwargs) -> float:
        """
        reward function the penalizes repetitions
        ref implementation: https://github.com/eddycmu/demystify-long-cot/blob/release/openrlhf/openrlhf/reward/repetition.py

        Args:
            completions: List of model completions
        """

        # if isinstance(completions[0], (dict)):
        #     contents = [completion["content"] for completion in completions]
        # else:
        #     contents = [completion for completion in completions]
        contents = pre_process(completions)
        rewards = []
        for completion in contents:
            if completion == "":
                rewards.append(0.0)
                continue
            if len(completion.split()) < ngram_size:
                rewards.append(0.0)
                continue

            ngrams = set()
            total = 0
            for ng in zipngram(completion, ngram_size):
                ngrams.add(ng)
                total += 1

            scaling = 1 - len(ngrams) / total
            reward = scaling * max_penalty
            rewards.append(reward)
        return rewards

    return repetition_penalty_reward


def _init_event_loop():
    try:
        loop = asyncio.get_event_loop()
    except RuntimeError:
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
    return loop


def ioi_code_reward(completions, test_batch_size: int = 1, **kwargs) -> list[float]:
    """Reward function that evaluates IOI problems using Piston+our IOI package.

    Assumes the dataset has the same format as hf.co/datasets/open-r1/ioi

    test_batch_size: evaluate these many test cases in parallel, then check if any of them failed (0 score): if so stop evaluating; otherwise continue with the next batch of test cases.
    """
    # for info on setting up piston workers, see slurm/piston/README.md
    piston_client = get_piston_client_from_env()

    code_snippets = [
        # note: grading is automatically skipped if no code is extracted
        add_includes(extract_code(completion[-1]["content"], "cpp"), problem_id)
        for completion, problem_id in zip(completions, kwargs["id"])
    ]

    async def run_catch_exceptions(task):
        try:
            return await task
        except Exception as e:
            print(f"Error from Piston worker: {e}")
            return SubtaskResult()  # score 0.0

    # load problem data. undo separating kwargs by column
    problems_data = [dict(zip(kwargs.keys(), values)) for values in zip(*kwargs.values())]

    loop = _init_event_loop()
    evals = [
        loop.create_task(
            run_catch_exceptions(score_subtask(piston_client, problem_data, code, test_batch_size=test_batch_size))
        )
        for problem_data, code in zip(problems_data, code_snippets)
    ]
    results = loop.run_until_complete(asyncio.gather(*evals))

    return [result.score for result in results]


def extract_code(completion: str, language: str = "python") -> str:
    pattern = re.compile(rf"```{language}\n(.*?)```", re.DOTALL)
    matches = pattern.findall(completion)
    extracted_answer = matches[-1] if len(matches) >= 1 else ""
    return extracted_answer


def binary_code_reward(completions, **kwargs) -> list[float]:
    rewards = code_reward(completions, **kwargs)
    BINARY_THRESHOLD = 0.99
    return [1.0 if reward > BINARY_THRESHOLD else 0.0 for reward in rewards]


def code_reward(completions, **kwargs) -> list[float]:
    """Reward function that evaluates code snippets using the E2B code interpreter.

    Assumes the dataset contains a `verification_info` column with test cases.
    """
    if not is_e2b_available():
        raise ImportError(
            "E2B is not available and required for this reward function. Please install E2B with "
            "`pip install e2b-code-interpreter` and add an API key to a `.env` file."
        )

    # TODO: add support for other languages in E2B: https://e2b.dev/docs/code-interpreting/supported-languages
    """Returns a reward function that evaluates code snippets in a sandbox."""
    evaluation_script_template = """
    import subprocess
    import json

    def evaluate_code(code, test_cases):
        passed = 0
        total = len(test_cases)
        exec_timeout = 5

        for case in test_cases:
            process = subprocess.run(
                ["python3", "-c", code],
                input=case["input"],
                text=True,
                capture_output=True,
                timeout=exec_timeout
            )

            if process.returncode != 0:  # Error in execution
                continue

            output = process.stdout.strip()

            # TODO: implement a proper validator to compare against ground truth. For now we just check for exact string match on each line of stdout.
            all_correct = True
            for line1, line2 in zip(output.split('\\n'), case['output'].split('\\n')):
                all_correct = all_correct and line1.strip() == line2.strip()

            if all_correct:
                passed += 1

        success_rate = (passed / total)
        return success_rate

    code_snippet = {code}
    test_cases = json.loads({test_cases})

    evaluate_code(code_snippet, test_cases)
    """
    code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
    verification_info = kwargs["verification_info"]
    scripts = [
        evaluation_script_template.format(code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"])))
        for code, info in zip(code_snippets, verification_info)
    ]

    language = verification_info[0]["language"]

    if not all(v["language"] == language for v in verification_info):
        raise ValueError("All verification_info must have the same language", verification_info)
    try:
        rewards = run_async_from_sync(scripts, language)

    except Exception as e:
        print(f"Error from E2B executor: {e}")
        rewards = [0.0] * len(completions)

    return rewards


def get_code_format_reward(language: str = "python"):
    """Format reward function specifically for code responses.

    Args:
        language: Programming language supported by E2B https://e2b.dev/docs/code-interpreting/supported-languages
    """
    pattern = rf"^<think>\n.*?\n</think>\n<answer>\n.*?```{language}.*?```.*?\n</answer>$"

    def code_format_reward(completions, **kwargs):
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]

    return code_format_reward


def run_async_from_sync(scripts: list[str], language: str) -> list[float]:
    """Function wrapping the `run_async` function."""
    # Create a new event loop and set it
    try:
        # Run the async function and get the result
        rewards = asyncio.run(run_async(scripts, language))
    except Exception as e:
        print(f"Error from E2B executor async: {e}")
        raise e

    return rewards


async def run_async(scripts: list[str], language: str) -> list[float]:
    # Create the sandbox by hand, currently there's no context manager for this version
    sbx = await AsyncSandbox.create(timeout=30, request_timeout=3)

    # Create a list of tasks for running scripts concurrently
    tasks = [run_script(sbx, script, language) for script in scripts]

    # Wait for all tasks to complete and gather their results as they finish
    results = await asyncio.gather(*tasks)
    rewards = list(results)  # collect results

    # Kill the sandbox after all the tasks are complete
    await sbx.kill()

    return rewards


async def run_script(sbx: AsyncSandbox, script: str, language: str) -> float:
    execution = await sbx.run_code(script, language=language)
    try:
        return float(execution.text)
    except (TypeError, ValueError):
        return 0.0
    except Exception as e:
        print(f"Error from E2B executor run_script: {e}")
        return 0.0


def get_reward_funcs(script_args) -> list[Callable]:
    REWARD_FUNCS_REGISTRY = {
        "accuracy": accuracy_reward,
        "accuracy_lv35":accuracy_reward_lv35,
        "format": format_reward,
        "format_v2": format_reward_v2,
        "reasoning_steps": reasoning_steps_reward,
        "cosine": get_cosine_scaled_reward(
            min_value_wrong=script_args.cosine_min_value_wrong,
            max_value_wrong=script_args.cosine_max_value_wrong,
            min_value_correct=script_args.cosine_min_value_correct,
            max_value_correct=script_args.cosine_max_value_correct,
            max_len=script_args.cosine_max_len,
        ),
        "repetition_penalty": get_repetition_penalty_reward(
            ngram_size=script_args.repetition_n_grams,
            max_penalty=script_args.repetition_max_penalty,
        ),
        "length": len_reward,
        "code": code_reward,
        "binary_code": binary_code_reward,
        "ioi_code": update_wrapper(
            partial(ioi_code_reward, test_batch_size=script_args.code_eval_test_batch_size), ioi_code_reward
        ),
        "code_format": get_code_format_reward(language=script_args.code_language),
        "tag_count": tag_count_reward,
        ##############################
        "extra_box_v1": extra_box_len_reward_v1,
        "extra_box_v2": extra_box_len_reward_v2,
        "fused_format": fused_format_reward,
        "language_penalty":get_language_penalty_reward,
    }
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

    return reward_funcs
