


from tool.envs.retool import AgentMathEnv
from tool.tools.python_tool import PythonTool

def batch_call_python_code_function(batch_slice_dict):
    p_c = PythonTool()
    env = AgentMathEnv(tools=[p_c], max_tool_response_length=2048)
    use_batch_tool_calls = batch_slice_dict['use_batch_tool_calls']
    raw_responses = batch_slice_dict['silce_raw_responses']
    ip_address = batch_slice_dict['ip_address']
    print(f'sandbox python call ip_address ======={ip_address}')

    if use_batch_tool_calls:
        tool_responses, tool_successes, new_active_masks = env.batch_step(raw_responses, ip_address=ip_address)
    else:
        tool_responses = []
        new_active_masks = []
        tool_successes = []
        for raw_response in raw_responses:
            tool_response, tool_successe, active = env.step(raw_response, ip_address=ip_address)
            tool_responses.append(tool_response)
            tool_successes.append(tool_successe)
            new_active_masks.append(active)
    tool_images = [[]] * len(raw_responses)

    batch_slice_dict['tool_responses'] = tool_responses
    batch_slice_dict['tool_successes'] = tool_successes
    batch_slice_dict['new_active_masks'] = new_active_masks
    batch_slice_dict['tool_images'] = tool_images
    return batch_slice_dict

def multi_process_batch_call_python_code(raw_response_lists, ips_string, env=None, use_batch_tool_calls=True):
    import multiprocessing
    batch_length = len(raw_response_lists)
    ip_lists = ips_string.split(',')
    num_ips = len(ip_lists)
    batch_slice_num = batch_length // num_ips + 1
    multi_process_batch_lists = []
    for idx, ip_addr in enumerate(ip_lists):
        start = idx * batch_slice_num
        end = (idx + 1) * batch_slice_num
        slice_raw_responses = raw_response_lists[start:end]
        slice_dict = {
            "idx": idx,
            "silce_raw_responses": slice_raw_responses,
            "ip_address": ip_addr,
            "env": "ReToolEnv",
            "use_batch_tool_calls": use_batch_tool_calls
        }
        multi_process_batch_lists.append(slice_dict)

    with multiprocessing.Pool(processes=min(len(multi_process_batch_lists), 1000)) as pool:
        batch_python_results = pool.map(batch_call_python_code_function, multi_process_batch_lists)

    sorted_batch_python_results = sorted(batch_python_results, key=lambda x: x["idx"])
    all_tool_responses = []
    all_new_active_masks = []
    all_tool_successes = []
    all_tool_images = []
    for batch_dict in sorted_batch_python_results:
        all_tool_responses.extend(batch_dict['tool_responses'])
        all_new_active_masks.extend(batch_dict['new_active_masks'])
        all_tool_successes.extend(batch_dict['tool_successes'])
        all_tool_images.extend(batch_dict['tool_images'])

    return all_tool_responses, all_new_active_masks, all_tool_successes, all_tool_images


import re

def extract_code_blocks(response_text: str):

    pattern = r"<code>(.*?)</code>(.*?)<interpreter>(.*?)</interpreter>"

    extracted_data = []
    for block_idx, match in enumerate(re.finditer(pattern, response_text, flags=re.DOTALL)):
        block_info = {
            'block_idx': block_idx,
            'code': f"{match.group(1)}",
            'code_format': f"<code>{match.group(1)}</code>",
            'between_text': match.group(2),
            'interpreter_output': match.group(3),
            'start': match.start(),
            'end': match.end()
        }
        extracted_data.append(block_info)

    return extracted_data


def identify_error_score(solution_str):
    identify_fail_lists = ["/root/lib/python3.11", "/tmp", "Error", "Unknown error"]
    is_fail = False
    for fail_text in identify_fail_lists:
        if fail_text in solution_str:
            is_fail = True
            break
    return is_fail

def extract_code_blocks_exclude_fail_code(response_text: str):

    pattern = r"<code>(.*?)</code>(.*?)<interpreter>(.*?)</interpreter>"

    extracted_data = []
    identify_fail_lists = ["/root/lib/python3.11", "/tmp", "Error", "Unknown error"]

    for block_idx, match in enumerate(re.finditer(pattern, response_text, flags=re.DOTALL)):
        is_fail_code = identify_error_score(match.group(3))
        if is_fail_code == True:
            continue

        block_info = {
            'block_idx': block_idx,
            'code': f"{match.group(1)}",
            'code_format': f"<code>{match.group(1)}</code>",
            'between_text': match.group(2),
            'interpreter_output': match.group(3),
            'start': match.start(),
            'end': match.end()
        }
        extracted_data.append(block_info)

    return extracted_data

def reassemble_response(original_text: str, processed_blocks) -> str:

    if not processed_blocks:
        return original_text

    result_parts = []
    last_end = 0

    for block in processed_blocks:
        result_parts.append(original_text[last_end:block['start']])
        new_block_str = (
            f"<code>{block['code']}</code>"
            f"{block['between_text']}"
            f"<interpreter>{block['interpreter_output']}</interpreter>"
        )
        result_parts.append(new_block_str)

        last_end = block['end']

    result_parts.append(original_text[last_end:])

    return "".join(result_parts)


def reassemble_response_real_interpreter_output(original_text: str, processed_blocks) -> str:

    if not processed_blocks:
        return original_text

    result_parts = []
    last_end = 0

    for block in processed_blocks:
        result_parts.append(original_text[last_end:block['start']])
        new_block_str = (
            f"<code>{block['code']}</code>"
            f"{block['between_text']}"
            f"<interpreter>{block['interpreter_output_real_result']}</interpreter>"
        )
        result_parts.append(new_block_str)

        last_end = block['end']

    result_parts.append(original_text[last_end:])

    return "".join(result_parts)

