import re
import time
import multiprocessing as mp
from multiprocessing import Process, Queue
import math
import torch
from mindspeed_rl.utils.loggers import Loggers
import json
import ast
from .bfcl_ast_checker.eval_runner import *
logger = Loggers("Rule verify")
   
def code2json(input_str):
    extracted = decode_ast(input_str)
    extracted_tool_list = []
    for tool_call in extracted:
        # print(type(tool_call), len(tool_call.keys()))
        for name in tool_call.keys():
            extracted_tool_list.append({'name': name, 'arguments': tool_call[name]})
    return extracted_tool_list

def json2code(data_list):
    code_strs = []
    for data in data_list:
        # 获取工具名称和参数
        tool_name = data["name"]
        parameters = data.get("parameters",{})
        if not parameters:
            parameters = data.get("arguments",{})
        # 构建参数的字符串表示
        param_str_list = []
        for key, value in parameters.items():
            if isinstance(value, str):
                param_str_list.append(f"{key}={json.dumps(value)}")
            else:
                param_str_list.append(f'{key}={value}')
        param_str = ', '.join(param_str_list)

        # 构建单个工具调用的代码格式字符串
        single_code_str = f"{tool_name}({param_str})"
        code_strs.append(single_code_str)

    # 构建最终的代码格式字符串列表
    final_code_str = f"[{', '.join(code_strs)}]"
    return final_code_str

def try_parse_tool_calls(content: str):
    """Try parse the tool calls."""
    tool_calls = []
    #非贪婪，匹配 0 次或多次，尽可能少地匹配: "<tool_call>(.*?)</tool_call>"
    for i, m in enumerate(re.finditer(r'\[{"name":(.*?)}\]', content, re.DOTALL)):
        try:
            func = json.loads(m.group(0).strip())
            # print(func)
            if isinstance(func, list):
                for each_func in func:
                    if isinstance(each_func["arguments"], str):
                        each_func["arguments"] = json.loads(each_func["arguments"])
                    tool_calls.append(each_func)
            elif isinstance(func, dict):
                if isinstance(func["arguments"], str):
                    func["arguments"] = json.loads(func["arguments"])
                tool_calls.append(func)
        except Exception as e:
            # raise e
            pass
    if len(tool_calls) > 0:
        return json.dumps(tool_calls)
    return content

def _extract_tool_calls(input_string, tool_generalized_mode: int):
    if 'tool_call' not in input_string:
        return input_string
    if tool_generalized_mode == 1 or tool_generalized_mode ==3:
        pattern = r"[tool_call](.*?)[/tool_call]"
    else:
        pattern = r"<tool_call>(.*?)</tool_call>"
    matches = re.findall(pattern, input_string, re.DOTALL)

    # Process matches into a list of dictionaries
    result = []
    for match in matches:
        try:
            match = json.loads(match.strip() )
            result.extend(match)
        except Exception as e:
            pass
    if len(result) > 0:
        return json.dumps(result)
    else:
        return input_string
    
def extract_tool_answer(pred_str, data_name, use_last_number=True, tool_generalized_mode=1):
    try:
        if not pred_str:
            return ""
            
        pred_str = pred_str.replace("\u043a\u0438", "")
        
        # 阶段1: 检查 \boxed{}
        matches = list(re.finditer(r'\\boxed\{', pred_str))
        if matches:
            match = matches[0]  # 只取第一个匹配
            start_pos = match.end()
            stack = 1
            content = []
            for char in pred_str[start_pos:]:
                if char == '{':
                    stack += 1
                elif char == '}':
                    stack -= 1
                    if stack == 0:
                        break
                content.append(char)
            return ''.join(content)
        
        # 阶段2: <answer> 标签, todo
        if tool_generalized_mode == 1 or tool_generalized_mode ==3:
            pred_str = pred_str.split('</think>')[-1]
        else:
            pred_str = pred_str.split('[/think]')[-1]
        # answer_match = re.search(r'<answer>(.*)</answer>', pred_str, re.DOTALL)
        # if answer_match:
        #     content = answer_match.group(1)
        #     return content.strip().strip('`python')

        # 阶段3: _extract_tool_calls 中<tool_call> 标签 => 只适用 tool_call 数据集,
        #  try_parse_tool_calls 中无<tool_call> 标签 => 适用 原始的<think><answer>[tool call] 数据集
        # content = try_parse_tool_calls(pred_str)
        content = _extract_tool_calls(pred_str,tool_generalized_mode)
        return content.strip()
        
    except Exception as e:
        logger.error(f"Error extracting answer: {repr(e)}")
        return ""

# @mstx_timer_decorator
def compute_verifier_score(batch, megatron_config, rl_config, tokenizer, ignore_token=-100):
    start_time = time.time()
    str_responses = tokenizer.batch_decode(torch.where(batch["responses"] == ignore_token, tokenizer.eos_token_id, batch["responses"]), skip_special_tokens=True)
    # str_labels = tokenizer.batch_decode(torch.where(batch["labels"] == ignore_token, tokenizer.eos_token_id, batch["labels"]), skip_special_tokens=True)
    str_question = tokenizer.batch_decode(batch["prompts"], skip_special_tokens=True)
    reward_index = batch["response_length"]
    extra_data = {
        "prompts": str_question,
        "valid_response_length": reward_index,
        "ids": batch.get("id")
    }
    
    if hasattr(megatron_config, "dataset_additional_keys"):
        for k in megatron_config.dataset_additional_keys:
            extra_data[k] = tokenizer.batch_decode(batch[k], skip_special_tokens=True)
            # logger.info(f">>>>>>>>>> {k}")
            # logger.info(extra_data[k][0])
    
    # logger.info( f"len of 0th {batch.keys()}, { question.size(0) }" )
    logger.info("=" * 50)

    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    # 新增：根据 tool_generalized_mode 解析 tool_mode
    # tool_mode: 0 = no_think, 1 = think
    # 添加四种匹配样式，通过claude 检查一下    
    tool_mode_list = []
    tool_generalized_mode = getattr(rl_config, 'tool_generalized_mode', 0)  # 默认模式1

    for resp in str_responses:
        resp_clean = resp.strip()  # 清理前后空白，确保 ^ 能匹配开头
        if tool_generalized_mode == 1:
            think_pattern = r'^<think>.*?</think>'
            no_think_pattern = r'^<no_think>.*?</no_think>'
            if re.search(think_pattern, resp_clean, re.DOTALL):
                tool_mode_list.append(1)
            elif re.search(no_think_pattern, resp_clean, re.DOTALL):
                tool_mode_list.append(0)
            else:
                tool_mode_list.append(0)
        elif tool_generalized_mode == 2:
            think_pattern = r'^\[think\].*?\[/think\]'
            no_think_pattern = r'^\[no_think\]\s*\[/no_think\]'
            if re.search(think_pattern, resp_clean, re.DOTALL):
                tool_mode_list.append(1)
            elif re.search(no_think_pattern, resp_clean, re.DOTALL):
                tool_mode_list.append(0)
            else:
                tool_mode_list.append(0)
        elif tool_generalized_mode == 3:
            mode_think = re.search(r'^<mode>\s*think\s*</mode>\s*<think>.*?</think>', resp_clean, re.IGNORECASE | re.DOTALL)
            mode_no_think = re.search(r'^<mode>\s*no_think\s*</mode>\s*<no_think>.*?</no_think>', resp_clean, re.IGNORECASE | re.DOTALL)
            if mode_think:
                tool_mode_list.append(1)
            elif mode_no_think:
                tool_mode_list.append(0)
            else:
                tool_mode_list.append(0)
        elif tool_generalized_mode == 4:
            mode_think = re.search(r'^\[mode\]\s*think\s*\[/mode\]\s*\[think\].*?\[/think\]', resp_clean, re.IGNORECASE | re.DOTALL)
            mode_no_think = re.search(r'^\[mode\]\s*no_think\s*\[/mode\]\s*\[no_think\].*?\[/no_think\]', resp_clean, re.IGNORECASE | re.DOTALL)
            if mode_think:
                tool_mode_list.append(1)
            elif mode_no_think:
                tool_mode_list.append(0)
            else:
                tool_mode_list.append(0)
        else:
            tool_mode_list.append(0)

    batch["think_mode"] = torch.tensor(tool_mode_list, dtype=torch.long, device=batch["responses"].device)
    # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
    
    scores, metrics = verifier(str_responses, extra_data, rl_config)    
    metrics["think_rate"] = tool_mode_list
    
    if rl_config.overlong_buffer_enable: #less long penlty
        overlong_buffer_len = rl_config.overlong_buffer
        expected_len = overlong_buffer_len
        exceed_len = [expected_len - length.item()  for length in batch["response_length"]] #不足的长度
        overlong_penalty_factor = rl_config.overlong_buffer_penalty_factor
        overlong_reward = [min(-length / overlong_buffer_len * overlong_penalty_factor, 0) for length in exceed_len]
        scores = [score + reward for score, reward in zip(scores, overlong_reward)]
        
    # if rl_config.overlong_buffer_enable: #overlong penlty
    #     overlong_buffer_len = rl_config.overlong_buffer
    #     expected_len = rl_config.rollout_max_tokens - overlong_buffer_len
    #     exceed_len = [length.item() - expected_len for length in batch["response_length"]]
    #     overlong_penalty_factor = rl_config.overlong_buffer_penalty_factor
    #     overlong_reward = [min(-length / overlong_buffer_len * overlong_penalty_factor, 0) for length in exceed_len]
    #     scores = [score + reward for score, reward in zip(scores, overlong_reward)]
        
    scores = torch.tensor(
        scores,
        dtype=torch.float32,
        device=reward_index.device
    )
    scores = scores.reshape(reward_index.shape)
    logger.info("=" * 50)
    logger.info(">>>>>>>>>> User:\n")
    logger.info(str_question[0])
    logger.info(">>>>>>>>>> Assistant:\n")
    logger.info(str_responses[0])
    
    logger.info(">>>>>>>>>> Label:\n")
    logger.info(extra_data['labels'][0])
    logger.info(">>>>>>>>>> Scores:\n")
    logger.info(f"{scores[0]}")
    
    end_time = time.time()
    metrics["timing/rule_reward"] = [round(end_time, 4), round(start_time, 4)]
    metrics["start_time/rule_reward"] = [round(start_time, 4)]
    metrics["end_time/rule_reward"] = [round(end_time, 4)]

    return scores, metrics


def verifier(responses, data, config, **kwargs):
    """
    User-defined verifier scoring process.

    Parameters:
    ----------
    responses(List[`str`]):
        Actor rollout answers.
    labels(List[`str`]):
        Ground Truth.
    infos(List[`str`], *optional*):
         Additional usable information loaded from the dataset.

    Return:
        scores(List[`float`]): Final scores.
    """
    rule_verifier_function = {
        "format": format_reward,
        "tool_format": tool_format_reward,
        "step": reasoning_steps_reward,
        "tool_step": tool_reasoning_steps_reward,
        "tool_base_acc": base_tool_accuracy_reward,
    }

    labels = data["labels"]
    questions = data['prompts']
    valid_response_length = data['valid_response_length']
    cur_train_iter = kwargs.get('cur_train_iter', None)
    # logger.info( f"len of 2th {len(responses)}, {len(labels)}, {len(questions)}" )
    
    rewards = [0.0] * len(labels)
    metrics = {}

    verifier_function = config.verifier_function
    verifier_weight = config.verifier_weight
    # fuzzy_iter = config.tool_fuzzy_iter

    for idx, fun_verifier in enumerate(verifier_function):
        if fun_verifier not in rule_verifier_function:
            continue
        # scores = rule_verifier_function[fun_verifier](sequences=responses, answers=labels)
        scores = multiprocess_executor(
                rule_verifier_function[fun_verifier],
                sequences=responses,
                answers=labels,
                questions=questions,
                timeout_seconds=config.verifier_timeout,
                max_num_workers=config.verifier_parallel,
                rl_config = config,
                valid_response_length = valid_response_length
            )
        metrics[f'{fun_verifier}_rewards/mean'] = scores
        rewards = [all_score + tmp_score * verifier_weight[idx]
                  for all_score, tmp_score in zip(rewards, scores)]
    logger.info(f"compute_verifier_score: verifier end")
    return rewards, metrics


def extract_tool_list(system_info, tool_generalized_mode = 0):
    """使用正则表达式提取工具列表"""
    if tool_generalized_mode == 1 or tool_generalized_mode ==3:
        if '<tools>\n[{"type"' in system_info:
            pattern = r'<tools>\s*\[{"type":(.*?)}\]\s*</tools>'
        elif '<tools>\n[{"name"' in system_info:
            pattern = r'<tools>\s*\[{"name":(.*?)}\]\s*</tools>'
        else:
            pattern = r'\[{"name":(.*)}\]\n\nuser'
    else:
        if '[tools]\n[{"type"' in system_info:
            pattern = r'\[tools\]\s*\[{"type":(.*?)}\]\s*\[/tools\]'
        elif '[tools]\n[{"name"' in system_info:
            pattern = r'\[tools\]\s*\[{"name":(.*?)}\]\s*\[/tools\]'
        else:
            pattern = r'\[{"name":(.*)}\]\n\nuser'
    match = re.search(pattern, system_info, re.DOTALL)
    if match:
        tool_list_str = match.group(0).strip()
        if tool_list_str.endswith('</tools>') or tool_list_str.endswith('[/tools]'):
            tool_list_str = tool_list_str[:-8].strip()
        if tool_list_str.startswith('<tools>') or tool_list_str.startswith('[tools]'):
            tool_list_str = tool_list_str[7:].strip()
        if tool_list_str.endswith('user'):
            tool_list_str = tool_list_str[:-4].strip()
            
        try:
            tool_list = json.loads(tool_list_str)
            for idx, tool in enumerate(tool_list):
                if 'type' in tool and 'function' in tool:
                    tool_list[idx] = tool['function']
                    tool_list[idx]['parameters']['type'] = "dict"
            return tool_list
        except Exception as e:
            logger.error(f'Error decoding JSON: {e}, Invalid JSON string: {tool_list_str}, || {system_info}')
            return []
    else:
        return []
        
def tool_equal_subprocess(prediction='', reference='', question='', **kwargs):

    prediction = prediction.strip()
    reference = reference.strip()
    reward_value = 0
    rl_config = kwargs.get('rl_config', None)
    decode_mode = rl_config.tool_decode_mode
    tool_generalized_mode = rl_config.tool_generalized_mode
    tool_generalized_multi = int(rl_config.tool_generalized_multi)

    prediction_type, reference_type = None, None
    prediction_json, prediction_code = None, None
    reference_json, reference_code = None, None
    try:
        prediction_json = json.loads(prediction)
        if len(prediction_json) == 0:
            raise RuntimeError('tool is none')
        prediction_code = json2code(prediction_json)
        prediction_type = 'json'
    except:
        try:
            prediction_json = code2json(prediction)
            prediction_type = 'code'
        except:
            pass

    try:
        reference_json = json.loads(reference)
        if len(reference_json) == 0:
            raise RuntimeError('tool is none')
        reference_code = json2code(reference_json)
        reference_type = 'json'
    except:
        try:
            reference_json = code2json(reference)
            reference_type = 'code'
        except:
            pass
    
    # logger.info(f'type: {prediction_type} || {reference_type}')
    
    if prediction_type is None and reference_type is None:
        return reward_value + 1.0
    elif prediction_type is None or reference_type is None:
        return reward_value
    
    # 调用模式奖励
    reward_value += -tool_generalized_mode if reference_type != prediction_type else tool_generalized_mode
    assert prediction_type is not None and reference_type is not None, f"ErrorType in {prediction_type} or {reference_type}"
    
    if prediction == reference: #字串级别 完全 一样
        reward_value += 1.0
    # elif decode_mode == 'json':
    #     raise RuntimeError(f'{decode_mode} is not support')
    elif decode_mode =='code':
        prediction = prediction_code if prediction_type=='json' else prediction
        reference = reference_code if reference_type=='json' else reference
        #包含system, first user question
        system_info = question.split('\nassistant\n')[0]
        tool_list = extract_tool_list(system_info, tool_generalized_mode)
        if len(tool_list) == 0:
            logger.info(f'tool info: {tool_list}' + system_info)
        try:
            reference_ans = decode_ast(reference.strip())
            possible_answer = {'ground_truth': wrap_values_in_lists(reference_ans)}
            model_result = {'result': prediction.strip() }
            prompt = {'function': tool_list}
            res = ast_file_runner(model_result, prompt, possible_answer)
        except Exception as e:
            print('this debug info: ', tool_list, prediction, reference, question, sep='\n--\n')
            raise e
            
        if len(res)==0 or res[0]['valid'] == True:
            reward_value += 1.0
        elif res[0]['valid'] == False:
            if 'value_error' in res[0]['error_type'] and (tool_generalized_multi == 1 or tool_generalized_multi ==3):
                reward_value += 0.2
        else:
            logger.error(f'{res}')
            reward_value = 0
    else:
        raise RuntimeError(f'{decode_mode} is not support')
    
    return reward_value


def base_tool_accuracy_reward(queue, sequences, answers, *args, **kwargs):
    scores = []
    questions = kwargs.get('questions', [None]*len(sequences))
    rl_config = kwargs.get('rl_config', None)
    
    for i, (sequence, answer, question) in enumerate(zip(sequences, answers, questions)):
        ext_sequence = extract_tool_answer(sequence, data_name="math", tool_generalized_mode=rl_config.tool_generalized_mode)
        ext_answer = extract_tool_answer(answer, data_name="math", tool_generalized_mode=rl_config.tool_generalized_mode)

        box_match = 0.0
        try:
            box_match = tool_equal_subprocess(prediction=ext_sequence, reference=ext_answer, question=question, **kwargs)
        except Exception as e:
            logger.error(str(e) + '\n' + str(ext_answer) + '\n' + str(sequence) )
            raise e
        scores.append(box_match)

    if queue is not None:
        queue.put(scores)
    
    return scores



def format_reward(queue, sequences, *args, **kwargs):
    """
    Reward function that checks if the completion has a specific format.

    Args:
        queue: parallel queue
        sequences: A list of sequences, where each completion is a tuple containing a list of dictionaries.
                     Each dictionary should have a "content" key with the text to be checked.

    Returns:
        A list of floats, where each float is 1.0 if the corresponding completion matches the required format,
        and 0.0 otherwise.

    Raises:
        ValueError: If the input sequences are not in the expected format.
    """    
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    # pattern = r"^<think>(.*?)</think>(.*?)"
    if not isinstance(sequences, list):
        raise ValueError("Input sequences must be a list.")

    scores = []
    for completion in sequences:
        # if re.match(pattern, completion, re.DOTALL | re.MULTILINE)
       
        if re.fullmatch(pattern, completion.strip(), re.DOTALL):
            scores.append(1.0)
        else:
            scores.append(0.0)

    if queue is not None:
        queue.put(scores)

    return scores

def tool_format_reward(queue, sequences, *args, **kwargs):
    """
    Reward function that checks if the completion has a specific format.

    Args:
        queue: parallel queue
        sequences: A list of sequences, where each completion is a tuple containing a list of dictionaries.
                     Each dictionary should have a "content" key with the text to be checked.

    Returns:
        A list of floats, where each float is 1.0 if the corresponding completion matches the required format,
        and 0.0 otherwise.

    Raises:
        ValueError: If the input sequences are not in the expected format.
    """
    # pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    rl_config = kwargs.get('rl_config', None)
    tool_generalized_mode=rl_config.tool_generalized_mode
    
    if not isinstance(sequences, list):
        raise ValueError("Input sequences must be a list.")

    scores = []
    for completion in sequences:
        resp = completion.strip()
        if tool_generalized_mode == 1:
            think_pattern = r'^<think>.*?</think>'
            no_think_pattern = r'^<no_think>.*?</no_think>'
            
        elif tool_generalized_mode == 2:
            think_pattern = r'^\[think\].*?\[/think\]'
            no_think_pattern = r'^\[no_think\]\s*\[/no_think\]'

        elif tool_generalized_mode == 3:
            think_pattern = re.search(r'^<mode>\s*think\s*</mode>\s*<think>.*?</think>', resp_clean, re.IGNORECASE | re.DOTALL)
            no_think_pattern = re.search(r'^<mode>\s*no_think\s*</mode>\s*<no_think>.*?</no_think>', resp_clean, re.IGNORECASE | re.DOTALL)

        elif tool_generalized_mode == 4:
            think_pattern = re.search(r'^\[mode\]\s*think\s*\[/mode\]\s*\[think\].*?\[/think\]', resp_clean, re.IGNORECASE | re.DOTALL)
            no_think_pattern = re.search(r'^\[mode\]\s*no_think\s*\[/mode\]\s*\[no_think\].*?\[/no_think\]', resp_clean, re.IGNORECASE | re.DOTALL)
        else:
            think_pattern = ''
            no_think_pattern = ''
            # pass
            
        if "<think>" in resp or '[think]' in resp:
            pattern = think_pattern
        elif "<no_think>" in resp or '[no_think]' in resp:
            pattern = no_think_pattern
        else:
            scores.append(0)
            continue
        
        if re.match(pattern, resp, re.DOTALL | re.MULTILINE):
            scores.append(1.0)
        else:
            scores.append(0.0)

    if queue is not None:
        queue.put(scores)

    return scores


def reasoning_steps_reward(queue, sequences, *args, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
    """
    pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    matches = [len(re.findall(pattern, content)) for content in sequences]
    scores = [min(1.0, count / 3) for count in matches]

    if queue is not None:
        queue.put(scores)

    return scores

def tool_reasoning_steps_reward(queue, sequences, *args, **kwargs):
    r"""Reward function that checks for clear step-by-step reasoning.
    Regex pattern:
        Step \d+: - matches "Step 1:", "Step 2:", etc.
        ^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
        \n- - matches bullet points with hyphens
        \n\* - matches bullet points with asterisks
        First,|Second,|Next,|Finally, - matches transition words
        parameter type error | parameter value error | Based on the context | Considering the context | referring to the context | Given the context
    """
    math_pattern = r"(Step \d+:|^\d+\.|First,|Second,|Next,|Finally,)"
    # math_pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
    tool_pattern = r"(type error|value error|Based on the context,|Considering the context,|Given the context,)"
    math_matches = [len(re.findall(math_pattern, content)) for content in sequences]
    tool_matches = [len(re.findall(tool_pattern, content)) for content in sequences]
    matches = [ (a+2*b)/2 for a, b in zip(math_matches, tool_matches)]
    scores = [min(1.0, count / 3) for count in matches]

    if queue is not None:
        queue.put(scores)

    return scores

def multiprocess_executor(worker, sequences, answers, questions=None, timeout_seconds=10, max_num_workers=32, **worker_kwargs):
    if not sequences:
        return []

    num_workers = min(len(sequences), mp.cpu_count() - 1, max_num_workers)
    batch_size = len(sequences) // num_workers

    processes = []
    lengths = []
    queues = []

    # 确保 questions 长度一致
    if questions is None:
        questions = [None] * len(sequences)
    else:
        assert len(questions) == len(sequences), f"questions length mismatch,{len(questions)},{len(sequences)}"

    valid_response_length = worker_kwargs.get('valid_response_length', [0] * len(sequences))
    assert len(valid_response_length) == len(sequences), "valid_response_length length mismatch"

    for i in range(num_workers):
        start_index = i * batch_size
        end_index = (i + 1) * batch_size if i < num_workers - 1 else len(sequences)
        
        sequence_batch = sequences[start_index:end_index]
        answer_batch = answers[start_index:end_index]
        question_batch = questions[start_index:end_index]
        response_length_batch = valid_response_length[start_index:end_index]

        lengths.append(len(sequence_batch))
        q = Queue()
        queues.append(q)

        # ✅ 创建独立的 kwargs 副本
        local_kwargs = worker_kwargs.copy()
        local_kwargs['questions'] = question_batch
        local_kwargs['valid_response_length'] = response_length_batch

        p = Process(
            target=worker,
            args=(q, sequence_batch, answer_batch),
            kwargs=local_kwargs
        )
        processes.append(p)
        p.start()

    final_results = []
    for i, p in enumerate(processes):
        p.join(timeout=timeout_seconds)
        if p.is_alive():
            p.terminate()
            p.join()
            logger.info(f'Process {i} timed out, returning [0.0] * {lengths[i]}')
            final_results.extend([0.0] * lengths[i])
        else:
            try:
                res = queues[i].get_nowait()
                final_results.extend(res)
            except Exception:
                final_results.extend([0.0] * lengths[i])
    return final_results

def math_compute_score(predict_str, ground_truth, acc_ratio=0.9, format_ratio=0.1):
    return 0.0 # 或者 pass