import gc
import copy
import json
import os
import os.path as osp
import torch.distributed as dist
from typing import Dict, Iterator, List, Literal, Optional, Union

from datetime import datetime
cur_date = datetime.now().strftime("%Y-%m-%d")

from qwen_agent.agents.rl.nous_agent_continue import NousAgent
from qwen_agent.llm.schema import ASSISTANT, USER, Message
from qwen_agent.utils.parallel_executor import parallel_exec, serial_exec
from qwen_agent.utils.utils import has_chinese_messages
from ray.util import pdb

from qwen_agent.llm import BaseChatModel
from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, FUNCTION, Message
from qwen_agent.tools import BaseTool
from qwen_agent.agents.rl.multi_agent.context.context import Context
from qwen_agent.agents.rl.multi_agent.context.node import Node
from qwen_agent.agents.rl.multi_agent.system import System1, System2
from qwen_agent.utils.utils import hash_sha256
from qwen_agent.utils.utils import extract_files_from_messages

TOOL_RESPONSE_TAG = None if os.getenv('TOOL_RESPONSE_TAG', 'tool_response').lower() == "none" else os.getenv('TOOL_RESPONSE_TAG', 'tool_response')

# OBS_START = '<tool_response>'
# OBS_END = '\n</tool_response>'

# MAX_LLM_CALL_PER_RUN = int(os.getenv('MAX_LLM_CALL_PER_RUN', 8))
# MAX_TOOL_CALL_WORKERS = int(os.getenv('MAX_TOOL_CALL_WORKERS', 20))
MAX_TOOL_CALL_INTERVAL = float(os.getenv('MAX_TOOL_CALL_INTERVAL', 0.01))
# print(f'Running with MAX_LLM_CALL_PER_RUN = {MAX_LLM_CALL_PER_RUN}')
# print(f'Running with MAX_TOOL_CALL_WORKERS = {MAX_TOOL_CALL_WORKERS}')
print(f'Running with MAX_TOOL_CALL_INTERVAL = {MAX_TOOL_CALL_INTERVAL}')


global_rank = dist.get_rank()
world_size = dist.get_world_size()
mp_size = int(os.getenv("TP_SIZE"))        # model parallel size
dp_size = world_size // mp_size            # data parallel size
dp_rank = global_rank // mp_size           # which machine/node
mp_rank = global_rank % mp_size            # which GPU inside that node

def create_model_parallel_group(world_size, mp_size):
    groups = []
    for start_rank in range(0, world_size, mp_size):
        ranks_in_node = list(range(start_rank, start_rank + mp_size))
        group = dist.new_group(ranks=ranks_in_node)
        groups.append(group)
    return groups

# figure out which group this current rank belongs to, dp_rank = node_id
model_parallel_groups = create_model_parallel_group(world_size, mp_size)
model_parallel_group = model_parallel_groups[dp_rank]



class MultiAgent(NousAgent):
    """Multiple Agent (system 1 and system 2)"""

    def __init__(self,
                 function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
                 llm: Optional[Union[Dict, BaseChatModel]] = None,
                 system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
                 name: Optional[str] = None,
                 description: Optional[str] = None,
                 files: Optional[List[str]] = None,
                 **kwargs):
        """Initialization the agent.

        Args:
            function_list: One list of tool name, tool configuration or Tool object,
              such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
            llm: The LLM model configuration or LLM model object.
              Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
            system_message: The specified system message for LLM chat.
            name: The name of this agent.
            description: The description of this agent, which will be used for multi_agent.
            files: A file url list. The initialized files for the agent.
        """
        super().__init__(function_list=function_list,
                         llm=llm,
                         system_message=system_message,
                         name=name,
                         description=description)
        
        from .tools import MultiAgentFnCallPrompt
        self.function_desc = [tool_func.function for tool_func in self.function_map.values()]
        self.tool_desc = MultiAgentFnCallPrompt.get_tool_desc(self.function_desc, lang='en')
        # pdb.set_trace()
        self.tokenizer = kwargs.get('tokenizer', None)
        self.multiturn_agent = kwargs.get('multiturn_agent', False)
        self.system1_mode = kwargs.get('system1_mode', "training")
        if global_rank == 0:
            print(f"tool desc: {self.tool_desc}", flush=True)

    def run_batch(self, data: List[Dict], **kwargs) -> List[List[Dict]]:
        lang_batch = kwargs.get('lang_batch', [])

        if not lang_batch:
            for i, item in enumerate(data):
                if has_chinese_messages(item):
                    lang_batch.append('zh')
                else:
                    lang_batch.append('en')
            kwargs['lang_batch'] = lang_batch
        assert len(lang_batch) == len(data)

        rl_data = self._run_batch(data=data, **kwargs)
        return rl_data

    @staticmethod
    def postprocess_system2_outputs(system2_messages: List[Dict], system2_outputs: List[str], parse_output_func):
        """parse the output of system2"""
        nodes, tool_nodes = [], []
        for idx, (system2_message, system2_output) in enumerate(zip(system2_messages, system2_outputs, strict=True)):
            tool_part, answer_part = parse_output_func(system2_output)
            if isinstance(tool_part, dict) and answer_part is None:
                tool_part['index'] = idx
                tool_part['t_index'] = idx
                new_node = Node(
                    system2_message=system2_message,
                    system2_output=system2_output,
                    tool=tool_part
                )
                tool_nodes.append({
                    'idx': idx,
                    'node': new_node
                })
                
            elif tool_part is None and isinstance(answer_part, dict):
                new_node = Node(
                    system2_message=system2_message,
                    system2_output=system2_output,
                    answer_tag_part=answer_part['answer_tag_part'],
                    boxed_answer=answer_part['boxed_answer'],
                    valid=True,
                    end=True,
                    stop_reason="Answer",
                )
                nodes.append({
                    'idx': idx,
                    'node': new_node
                })
            else:
                new_node = Node(
                    system2_message=system2_message,
                    system2_output=system2_output,
                    tool=tool_part,
                    answer_tag_part=answer_part,
                    valid=False,
                    end=True,
                    stop_reason="Invalid Output" if system2_output is not None else "Exceeded Maximum Context Length"
                )
                nodes.append({
                    'idx': idx,
                    'node': new_node
                })

        return nodes, tool_nodes

    @staticmethod
    def postprocess_system1_outputs(system1_messages: List[Dict], system1_outputs: str, contexts, parse_output_func, system1_mode: str, readpage: bool):
        """parse the output of system1"""
        if system1_mode == "empty":
            for tool_obs, context in zip(system1_outputs, contexts, strict=True): 
                node = context.chain[-1]
                node.system1_output = tool_obs
        else:
            if not readpage:
                for system1_message, system1_output, context in zip(system1_messages, system1_outputs, contexts, strict=True):
                    node = context.chain[-1]
                    try:
                        system1_output, output_status = parse_output_func(system1_output, mode=system1_mode) # if training system1, reward = 0 if format is invalid
                    except Exception as e:
                        print(e)
                        print(f"parse_output_func system1 ERROR:\nsystem1_message: {system1_message}\nsystem1_output: {system1_output}")
                        raise
                    node.system1_message = system1_message
                    node.system1_output = system1_output
                    node.system1_output_format = output_status
            else:
                # system1_messages is a List[dict]
                for system1_dict, context in zip(system1_messages, contexts, strict=True):
                    node = context.chain[-1]
                    for mess, o in zip(system1_dict['message_list'], system1_dict['response_list'], strict=True):
                        if mess == 'non_readpage_tools':
                            node.system1_message = None
                            node.system1_output = o
                        else:
                            try:
                                assert o is not None, f"the response of system1 is None"
                                o, output_status = parse_output_func(o, mode=system1_mode) # if training system1, reward = 0 if format is invalid
                            except Exception as e:
                                print(e)
                                print(f"parse_output_func system1 ERROR:\nsystem1_dict: {system1_dict}\nsystem1_output: {o}")
                                raise

                            if node.system1_message is None:
                                node.system1_message = []
                            node.system1_message.append(mess)

                            if node.system1_output is None:
                                node.system1_output = []
                            node.system1_output.append(o)

                            if node.system1_output_format is None:
                                node.system1_output_format = []
                            node.system1_output_format.append(output_status)

    def _run_batch(self, data: List[Dict], lang_batch: List[str],
                   **kwargs) -> List[List[Message]]:
        multi_agent_pattern = kwargs.get('multi_agent_pattern', None)
        llm_fn = kwargs.get('system2_llm_fn', None)
        system1_llm_fn = kwargs.get('system1_llm_fn', None)
        llm_batch_enabled = kwargs.get('llm_batch_enabled', False)
        if not llm_batch_enabled:
            return super()._run_batch(messages_batch=messages_batch, lang_batch=lang_batch, **kwargs)
        assert llm_fn, 'Please pass in llm_fn that supports batch inference externally'
        assert system1_llm_fn, 'Please pass in system1_llm_fn that supports batch inference externally'

        # dynamically change the tool desc in system prompts
        for item in data:
            item['prompt']['system1']['system'] = item['prompt']['system1']['system'].replace("{cur_date}", cur_date)

            assert "{tool_desc}" in item['prompt']['system2']['system'] or self.tool_desc in item['prompt']['system2']['system'], "system prompt should contain {tool_desc}, but now\n" + item['prompt']['system2']['system']
            # assert "{tool_desc}" in item['prompt']['system2']['system'], "system prompt should contain {tool_desc}, but now\n" + item['prompt']['system2']['system']
            if item.get('is_multiChoice', False) and 'mc_user' in item['prompt']['system2']:
                item['prompt']['system2']['user'] = item['prompt']['system2']['mc_user']
                
            item['prompt']['system2']['system'] = item['prompt']['system2']['system'].replace("{tool_desc}", self.tool_desc) # if {tool_desc} is not exist, it will return the same string

        contexts = [Context(idx, item, multiturn_agent=self.multiturn_agent, tool_names=list(self.function_map.keys())) for idx, item in enumerate(data)]
        terminal_contexts = []

        for step in range(multi_agent_pattern.max_depth):
            # system 2
            system2_input_messages = [context.prepare_system2_input(mask_pre_think=False if multi_agent_pattern.tool_response_role == "tool" else True) for context in contexts]
            system2_outputs = llm_fn(messages=system2_input_messages, generate_cfg=multi_agent_pattern['system2_sampling_params'])

            # system2_outputs[0] = '<research> To solve this step-by-step, we need to go through each operation one by one on the given sequence of musical notes: A4 C5 F4 E4 G4 C4 B4 D4.\n\n### Step 1: Transpose up Major Second\n- Major Second above A4 is B4.\n- Major Second above C5 is D5.\n- Major Second above F4 is G4.\n- Major Second above E4 is F4.\n- Major Second above G4 is A4.\n- Major Second above C4 is D4.\n- Major Second above B4 is C5.\n- Major Second above D4 is E4.\nThus, the sequence becomes:\n\\[ B4 \\, D5 \\, G4 \\, F4 \\, A4 \\, D4 \\, C5 \\, E4 \\]\n\n### Step 2: Invert Around E4\n- The inversion of a note around E4 is computed as E4 - (note - E4).\n- For example, the inversion of B4 around E4 is E4 - (B4 - E4) = C4.\n- Continuing this for the whole sequence:\n\\[ C4 \\, A4 \\, B4 \\, E4 \\, C4 \\, F4 \\, D4 \\, B4 \\]\n\n### Step 3: Retrograde the Sequence\n- Retrograding means reversing the order.\n\\[ B4 \\, D4 \\, F4 \\, C4 \\, C4 \\, E4 \\, B4 \\, A4 \\]\n\n### Step 4: Augment the Intervals Between Consecutive Notes by Adding 3 Semitones\n- Augmenting means increasing the interval by 3 semitones.\n- For example, the interval between B4 and D4 is 3 semitones, so the new note after augmentation would be D4#.\n- Applying this for the whole sequence:\n\\[ E4 \\, G4 \\, A4 \\, F4 \\, F4 \\, A4 \\, E4 \\, D4 \\]\n\n### Step 5: Change to Dorian Mode Starting from D4\n- In Dorian Mode, the notes are D4, E4, F4, G4, A4, B4, C4, D4.\n- Our sequence already matches Dorian Mode starting from D4, so no change is needed.\n\\[ E4 \\, G4 \\, A4 \\, F4 \\, F4 \\, A4 \\, E4 \\, D4 \\]\n\n### Step 6: Transpose Down Minor Third\n- Minor Third below E4 is C4.\n- Applying this to the whole sequence:\n\\[ C4 \\, E4 \\, F4 \\, D4 \\, D4 \\, G4 \\, C4 \\, C4 \\]\n\n### Step 7: Invert Around F4\n- Inversion of a note around F4 is F4 - (note - F4).\n- For example, the inversion of C4 around F4 is F4 - (C4 - F4) = A4.\n- Applying this to the whole sequence:\n\\[ A4 \\, C4 \\, D4 \\, B4 \\, B4 \\, E4 \\, A4 \\, A4 \\]\n\n### Step 8: Transpose Up One Octave\n- Transposing each note up one octave.\n\\[ B4 \\, D5 \\, E5 \\, C5 \\, C5 \\, F5 \\, B5 \\, B5 \\]\n\n</research>\n<tool_call>\n{"purpose": "Confirm the transposed and inverted notes are correctly applied according to the given operations.", "name": "PythonInterpreter", "arguments": {"code": ""}}\n<code>\nimport mido\nfrom mido import Message\nfrom typing import List\n\ndef transpose_up(note: str, semitones: int) -> str:\n    # Define notes and their semitone values\n    notes = \'CDEFGAB\'\n    semitone_values = [0, 2, 4, 5, 7, 9, 11]\n    target_note = note[0]\n    target_value = semitone_values[note[1:]].index(note[1]) + 12\n    result_value = (target_value + semitones) % 12\n    return f"{notes[result_value]}{int(note[1:]) + 1 if result_value >= target_value else int(note[1:])}"\n\ndef invert_around(note: str, center_note: str) -> str:\n    # Invert around a center note\n    center_value = "CDEFGAB".index(center_note[0]) + int(center_note[1:])\n    note_value = "CDEFGAB".index(note[0]) + int(note[1:])\n    return f"{notes[(center_value + 12 - note_value) % 12]}{12 - (note_value - center_value) + int(center_note[1:])}"\n\ndef apply_sequence_operations(sequence: List[str], operations: List[str]) -> List[str]:\n    for operation in operations:\n        if "Transpose up" in operation:\n            semitones = int(operation.split(" ")[-1])\n            sequence = [transpose_up(note, semitones) for note in sequence]\n        elif "Invert around" in operation:\n            center_note = operation.split(" ")[-1]\n            sequence = [invert_around(note, center_note) for note in sequence]\n        elif "Retrograde" in operation:\n            sequence.reverse()\n        elif "Augment the intervals by adding 3 semitones" in operation:\n            augmented_sequence = []\n            for i in range(len(sequence) - 1):\n                note1 = sequence[i]\n                note2 = sequence[i + 1]\n                augmented_sequence.append(note1)\n                new_note_value = (note2[1:] + 3) if note2[1:] + 3 <= 12 else str((int(note2[1:]) + 3) - 12) + note2[0]\n                augmented_sequence.append(f"{note1[0]}{new_note_value}")\n            augmented_sequence[-1] = sequence[-1]\n            sequence = augmented_sequence\n        elif "Change to Dorian Mode" in operation:\n            # Dorian mode is defined as CDEFGAB, starting from a given note\n            mode_notes = [\'C\', \'D\', \'E\', \'F\', \'G\', \'A\', \'B\']\n            sequence = [f"{note[0]}{mode_notes.index(note[0])}" for note in sequence]\n        elif "Transpose down" in operation:\n            semitones = -int(operation.split(" ")[-1])\n            sequence = [transpose_up(note, semitones) for note in sequence]\n    return sequence\n\n# Define the sequence and operations\nsequence = [\'A4\', \'C5\', \'F4\', \'E4\', \'G4\', \'C4\', \'B4\', \'D4\']\noperations = [\n    "Transpose up Major Second",\n    "Invert Around E4",\n    "Retrograde the sequence",\n    "Augment the intervals between consecutive notes by adding 3 semitones",\n    "Change to Dorian Mode Starting from D4",\n    "Transpose down Minor Third",\n    "Invert around F4",\n    "Transposed up one Octave"\n]\n\n# Apply the operations to the sequence\nfinal_sequence = apply_sequence_operations(sequence, operations)\nprint(" ".join(final_sequence))\n</code>\n</tool_call>'
            # system2_outputs = ["<research> Since the direct search did not provide a solution, let's proceed by solving the equation step-by-step. We start by expressing the equation:\n\\[\n\\frac{2V}{v^2 - V^2} + \\frac{1}{2V+1} = 1\n\\]\nWe can rewrite the equation by combining the fractions on the left-hand side:\n\\[\n\\frac{2V(2V+1) + v^2 - V^2}{(v^2 - V^2)(2V+1)} = 1\n\\]\nSimplifying the numerator:\n\\[\n\\frac{4V^2 + 2V + v^2 - V^2}{(v^2 - V^2)(2V+1)} = 1 \\implies \\frac{3V^2 + 2V + v^2}{(v^2 - V^2)(2V+1)} = 1\n\\]\nThis gives us:\n\\[\n3V^2 + 2V + v^2 = (v^2 - V^2)(2V+1)\n\\]\nExpanding the right-hand side:\n\\[\n3V^2 + 2V + v^2 = 2Vv^2 + v^2 - 2V^3 - V^2\n\\]\nSimplifying and collecting like terms:\n\\[\n3V^2 + 2V + v^2 = 2Vv^2 + v^2 - 2V^3 - V^2 \\implies 2V^2 + 2V = 2Vv^2 - 2V^3\n\\]\nRearranging terms to one side:\n\\[\n2V^3 + 2V^2 + 2V - 2Vv^2 = 0 \\implies V^2 + V + V - Vv^2 = 0 \\implies V^2 + V(1 - v^2) + 2V = 0\n\\]\n\nGiven the complexity, let's simplify by assuming a specific value for \\(V\\) and solving for \\(v\\). Let's use the PythonInterpreter to try solving it directly.\n</research>\n\n<tool_call>\n{\"purpose\": \"To solve the simplified equation \\(2V^3 + 2V^2 + 2V - 2Vv^2 = 0\\) for v in terms of V\", \"name\": \"PythonInterpreter\", \"arguments\": {\"code\": \"\"}}\n<code>\nfrom sympy import symbols, solve\n\nV, v = symbols('V v')\nequation = 2*V**3 + 2*V**2 + 2*V - 2*V*v**2\nsolutions = solve(equation, v)\nsolutions\n</code>\n</tool_call>"] * len(system2_outputs)
            # parse the output and prepare for tool calling
            nodes, tool_nodes = MultiAgent.postprocess_system2_outputs(system2_input_messages, system2_outputs, parse_output_func=System2.parse_output)
            
            tool_results = None
            if len(tool_nodes) > 0:
                # invoke tool 
                # tool_nodes[0]['node'].tool
                tool_parts = [t_node['node'].tool for t_node in tool_nodes]

                # Exec code and get the messages_batch for next turn
                # Only first gpu of each machine should actually call the external tool(s)
                if mp_rank == 0:
                    print(f'[exec_start] (rank-{dp_rank*mp_size}) on {len(system2_outputs)} messages, {len(tool_parts)} tool tasks')
                    # the success return for tools of multi-agent is a dict, instead of None or string (f'Tool {tool_name} does not exists.')
                    
                    tool_results = parallel_exec(self._ask_code_exec, tool_parts, max_workers=multi_agent_pattern.tool_external_concurrency, jitter=MAX_TOOL_CALL_INTERVAL)
                    # tool_results = serial_exec(self._ask_code_exec, tool_parts)
                    print(f'[exec_done] (rank-{dp_rank*mp_size}) on {len(system2_outputs)} messages, {len(tool_parts)} tool tasks')
                    gc.collect()

            if mp_size > 1:  
                dist.barrier(group=model_parallel_group)

            # must be a list, even if single item
            obj_list = [tool_results]
            dist.broadcast_object_list(obj_list, src=dp_rank*mp_size, group=model_parallel_group)
            tool_results = obj_list[0] # results is a list, each result is (index, t_index, results)
            
            if tool_results is not None and len(tool_nodes) > 0:
                tool_results = sorted(tool_results, key=lambda x: x[0])
                for tool_result, tool_node in zip(tool_results, tool_nodes, strict=True):
                    # results can be None / string / string(dict). 
                    # in multi-agent, success tool calling will return a dict
                    idx, _, results = tool_result 
                    success = False
                    error_message = "tool results is None"

                    if results is not None and results != 'null':
                        try:
                            # json string
                            results = json.loads(results) # json.loads(results)['results']

                            success = results['success']
                            if success:  # dict_keys(['success', 'results', 'params'])
                                tool_node['node'].tool_json = results['params']
                                tool_node['node'].tool_output = results['results']
                                if "[Python Interpreter Error]" in results['results']:
                                    success = False
                                    error_message = results['results']
                            else:
                                error_message = results['error_message']
                        except json.JSONDecodeError:
                            # vanilla error string
                            error_message = results
                        except Exception as e:
                            import traceback
                            import sys
                            exc_type, exc_value, exc_traceback = sys.exc_info()
                            tb_info = traceback.extract_tb(exc_traceback)
                            filename, line, func, text = tb_info[-1]
                            detailed_message = f"Multi-agent Error {str(e)} | Error in file '{filename}', line {line}, in {func}: {str(e)}\nCode: {text}"
                            # pdb.set_trace()
                            # {'success': False, 'error_message': 'Task (index=17) timed out after 10 seconds'}
                            # {'code': 'import numpy as np\nA = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [-2, -4, -3, -5]])\nB = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])\ncontrollability_matrix = np.hstack((B, np.dot(A, B), np.dot(A**2, B)))\nnp.linalg.matrix_rank(controllability_matrix)'}
                            # 'Task (index=17) timed out after 10 seconds'
                            error_message = detailed_message
                    
                    tool_node['node'].valid = success
                    tool_node['node'].end = not success
                    tool_node['node'].stop_reason = None if success else "Invalid Tool Call: " + error_message # if failed, only success, error message

            node_list = sorted(nodes + tool_nodes, key=lambda x: x['idx'])
            # update local contexts
            for context, node_dict in zip(contexts, node_list, strict=True):
                context.chain.append(node_dict['node'])
            
            finished_contexts = [context for context in contexts if context.is_terminal] # tolerance for tool error

            for context in finished_contexts:
                terminal_contexts.append(context)
                contexts.remove(context)

            if len(contexts) == 0:
                break

            # system1
            if self.system1_mode == "empty":
                tool_observation = [context.prepare_system1_input(self.function_map, self.tokenizer, multi_agent_pattern.system1_sampling_params.max_prompt_length, empty_mode=True) for context in contexts]
                MultiAgent.postprocess_system1_outputs(None, tool_observation, contexts, parse_output_func=None, system1_mode=self.system1_mode, readpage=False)
            else:
                if not multi_agent_pattern.system1_sampling_params.readpage:
                    assert 'PythonInterpreter' in self.function_map, 'not support PythonInterpreter in non-readpage mode'
                    system1_input_messages = [context.prepare_system1_input(self.function_map, self.tokenizer, multi_agent_pattern.system1_sampling_params.max_prompt_length, readpage=multi_agent_pattern.system1_sampling_params.readpage) for context in contexts]
                    system1_outputs = system1_llm_fn(messages=system1_input_messages, generate_cfg=multi_agent_pattern.system1_sampling_params)
                    MultiAgent.postprocess_system1_outputs(system1_input_messages, system1_outputs, contexts, parse_output_func=System1.parse_output, system1_mode=self.system1_mode)
                else:
                    # system1_input_messages is a List[dict]
                    # one item in system1_input_messages is a context
                    candidate_input_messages = [context.prepare_system1_input(self.function_map, self.tokenizer, multi_agent_pattern.system1_sampling_params.max_prompt_length, readpage=multi_agent_pattern.system1_sampling_params.readpage) for context in contexts]

                    system1_readpage_dict, non_readpage_list = [], []
                    for idx, item in enumerate(candidate_input_messages):
                        if isinstance(item, dict):
                            item['idx'] = idx
                            for idx_2, cur_mess in enumerate(item['message_list']):
                                for each_role in cur_mess:
                                    each_role["mess_unique_id"] = f"{idx}_{idx_2}"

                            system1_readpage_dict.append(item)
                        else:
                            non_readpage_list.append({
                                'idx': idx,
                                'message_list': ['non_readpage_tools'],
                                'tool_observation': item,
                                'response_list': [item]
                            })

                    if len(system1_readpage_dict) > 0:
                        # organize batch
                        system1_input_messages = []
                        sha256_idx = {}
                        for idx, item in enumerate(system1_readpage_dict):
                            system1_input_messages.extend(item['message_list'])
                            for mess in item['message_list']:
                                sha256_idx[hash_sha256(json.dumps(mess))] = idx

                        system1_outputs = system1_llm_fn(messages=system1_input_messages, generate_cfg=multi_agent_pattern.system1_sampling_params)
                        # "Invalid Tool Call: Multi-agent Error the JSON object must be str, bytes or bytearray, not dict | Error in file '/root/miniconda3/envs/deep_research/lib/python3.10/json/__init__.py', line 339, in loads: the JSON object must be str, bytes or bytearray, not dict\nCode: raise TypeError(f'the JSON object must be str, bytes or bytearray, '"

                        # re-group
                        for mess, o in zip(system1_input_messages, system1_outputs, strict=True):
                            idx = sha256_idx[hash_sha256(json.dumps(mess))]
                            # idx是指第几个context产生的数据
                            position = system1_readpage_dict[idx]['message_list'].index(mess)
                            system1_readpage_dict[idx]['response_list'][position] = o
                    all_system1_readpage_dict = sorted(system1_readpage_dict + non_readpage_list, key=lambda x: x['idx'])

                    MultiAgent.postprocess_system1_outputs(all_system1_readpage_dict, system1_outputs=None, contexts=contexts, parse_output_func=System1.parse_output, system1_mode=self.system1_mode, readpage=multi_agent_pattern.system1_sampling_params.readpage)

            if self.multiturn_agent: # contexts[1].chain[-1].system1_output
                for context in contexts:
                    if self.system1_mode == "empty" or not multi_agent_pattern.system1_sampling_params.readpage:
                        # is safty if TOOL_RESPONSE_TAG is None
                        if TOOL_RESPONSE_TAG is not None:
                            tool_response = context.chain[-1].system1_output.lstrip(f'<{TOOL_RESPONSE_TAG}>').rstrip(f'</{TOOL_RESPONSE_TAG}>').strip()
                        else:
                            tool_response = context.chain[-1].system1_output

                        if multi_agent_pattern.tool_response_role == "tool":
                            context.chain[-1].tool_role_message = {"role": "tool", "content": tool_response}
                        else:
                            context.chain[-1].tool_role_message = {"role": "user", "content": "<tool_response>\n" + tool_response + "\n</tool_response>"}

                    else:
                        if context.chain[-1].system1_message is None and isinstance(context.chain[-1].system1_output, str):
                            # non_readpage_tools
                            merge_tool_content = [context.chain[-1].system1_output] # for "\n\n".join()
                        else:
                            if TOOL_RESPONSE_TAG is not None:
                                merge_tool_content = [
                                    f"[Source {idx + 1} begin]\n" +
                                    o.lstrip(f'<{TOOL_RESPONSE_TAG}>').rstrip(f'</{TOOL_RESPONSE_TAG}>')  +
                                    f"\n[Source {idx + 1} end]"
                                for idx, o in enumerate(context.chain[-1].system1_output)]
                            else:
                                merge_tool_content = [
                                    f"[Source {idx + 1} begin]\n"
                                    + o
                                    + f"\n[Source {idx + 1} end]"
                                for idx, o in enumerate(context.chain[-1].system1_output)]
                        
                        if multi_agent_pattern.tool_response_role == "tool":
                            context.chain[-1].tool_role_message = {"role": "tool", "content": "\n\n".join(merge_tool_content)}
                        else: # user
                            # user just to adjust
                            context.chain[-1].tool_role_message = {"role": "user", "content": "<tool_response>\n" + "\n\n".join(merge_tool_content) + "\n</tool_response>"}
        
        if len(contexts) > 0:
            for context in contexts:
                context.exceed_max_depth()
            terminal_contexts.extend(contexts)

        # re-sort the terminal contexts and organize the response and output
        terminal_contexts = sorted(terminal_contexts, key=lambda x: x.idx)

        rl_data = [context.get_multi_turn_rl_data(self.tokenizer, readpage=multi_agent_pattern.system1_sampling_params.readpage, system2_think=multi_agent_pattern.system2_sampling_params.enable_thinking, system1_think=multi_agent_pattern.system1_sampling_params.enable_thinking, mask_pre_think=False if multi_agent_pattern.tool_response_role == "tool" else True) if self.multiturn_agent else context.get_rl_data(self.tokenizer) for context in terminal_contexts]

        return rl_data

    def save_file(self, data, file_dir, file_name: str):
        data_path = osp.join(file_dir, "data")
        os.makedirs(data_path, exist_ok=True)
        with open(osp.join(data_path, f"{file_name}.jsonl"), "a") as f:
            for item in data:
                f.write(json.dumps(item) + "\n")

