from typing import List, Dict
from qwen_agent.agents import MultiAgent, NousAgent # , NousAgentChat, 
from qwen_agent.llm.schema import Message, SYSTEM, USER, ASSISTANT
from qwen_agent.utils.utils import build_text_completion_prompt
from copy import deepcopy
import re
import os
import torch
import numpy as np
import copy

from torch import nn
import torch.nn.functional as F
# from vllm import SamplingParams
from omegaconf import ListConfig
from pebble import ThreadPool
from tqdm import tqdm
from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.third_party.vllm import vllm_version
from tensordict import TensorDict
from ray.util import pdb

from qwen_agent.settings import DEFAULT_WORKSPACE

TOOL_RESPONSE_TAG = None if os.getenv('TOOL_RESPONSE_TAG', 'tool_response').lower() == "none" else os.getenv('TOOL_RESPONSE_TAG', 'tool_response')
VERBOSE = os.getenv("VERBOSE", "0").lower() in ("true", "1", "yes", "on")

class vLLMRolloutMultiAgent:
    """Multi agent that handles long-horizon vllm rollout generation by using qwen-agent as backend."""

    def __init__(self, multiturn_agent, function_list, tokenizer, inference_engine, **kwargs):
        self.multiturn_agent = multiturn_agent
        AgentClass = MultiAgent # if multiturn_agent else NousAgent
        self.tokenizer = tokenizer
        self.inference_engine = inference_engine
        self.kwargs = kwargs
        self.multi_agent_pattern = kwargs['multi_agent_pattern']
        self.rl_agent = AgentClass(llm={'model': 'qwen-max'}, function_list=function_list, tokenizer=tokenizer, multiturn_agent=multiturn_agent, system1_mode=self.multi_agent_pattern.system1_sampling_params.system1_mode)
        # for system1 llm
        if self.multi_agent_pattern.system1_sampling_params.system1_mode == "online":
            from openai import OpenAI
            self.system1_client = OpenAI(
                api_key=self.multi_agent_pattern.system1_sampling_params.system1_key,
                base_url=self.multi_agent_pattern.system1_sampling_params.system1_url,
                max_retries=10,
                timeout=None,
            )
        self.debug = False
        self.show_system1_params = True
        self.show_system2_params = True
        self.params_check(self.multi_agent_pattern)

    def params_check(self, multi_agent_pattern):
        if multi_agent_pattern.system1_sampling_params.system1_mode == "online":
            assert multi_agent_pattern.system1_sampling_params.system1_key is not None, \
                f"system1_sampling_params.system1_key must be provided when only_system2 is True"
            assert multi_agent_pattern.system1_sampling_params.system1_url is not None, \
                f"system1_sampling_params.system1_url must be provided when only_system2 is True"

        elif multi_agent_pattern.system1_sampling_params.system1_mode == "training":
            pass
        #     assert multi_agent_pattern.system1_sampling_params.max_prompt_length > multi_agent_pattern.system1_sampling_params.max_observation_length, \
        #         f"system1_sampling_params.max_prompt_length ({multi_agent_pattern.system1_sampling_params.max_prompt_length}) must be more than system1_sampling_params.max_observation_length ({multi_agent_pattern.system1_sampling_params.max_observation_length})"

        elif multi_agent_pattern.system1_sampling_params.system1_mode == "empty":
            print("system1 is not used")
        else:
            raise ValueError(f"system1_sampling_params.system1_mode must be one of ['online', 'training', 'empty']")
        
        if self.multiturn_agent:
            print("training with multi-turn dialogue")
        else:
            print("training with single-turn dialogue")

    @staticmethod
    def delete_cached_codes():
        print("[rl_agent cleanup] start removing cached files generated by code interpreter")
        target_dir = os.path.join(os.getcwd(), 'workspace', 'tools', 'code_interpreter')
        all_items = os.listdir(target_dir)
        files_to_delete = [item for item in all_items if os.path.isfile(os.path.join(target_dir, item))]
        delete_count = 0
        for file_name in files_to_delete:
            file_path = os.path.join(target_dir, file_name)
            try:
                os.remove(file_path)
                delete_count += 1
                print(f"deleted: {file_path}")
            except Exception as e:
                print(f"[rl_agent cleanup error] deleting {file_path} gives: {e}")
        print(f"[rl_agent cleanup] completed with {delete_count}/{len(files_to_delete)} deleted")
    
    @staticmethod
    def delete_cached_python_codes():
        print("[rl_agent cleanup] start removing cached files generated by python code interpreter")
        target_dir = os.path.join(DEFAULT_WORKSPACE, 'tools', 'PythonInterpreter', PYTHON_CACHE_UUID)
        all_items = os.listdir(target_dir)
        files_to_delete = [item for item in all_items if os.path.isfile(os.path.join(target_dir, item))]
        delete_count = 0
        for file_name in files_to_delete:
            file_path = os.path.join(target_dir, file_name)
            try:
                os.remove(file_path)
                delete_count += 1
                if VERBOSE:
                    print(f"deleted: {file_path}")
            except Exception as e:
                if VERBOSE:
                    print(f"[rl_agent cleanup error] deleting {file_path} gives: {e}")
        print(f"[rl_agent cleanup] completed with {delete_count}/{len(files_to_delete)} deleted")


    def system2_generate_oneturn_fn(self, messages: List[Dict], generate_cfg=None):
        """Single-Turn Version for multi-agent rl.

        messages: batch_size x n (list of chatml sessions), end with user
        output: tuple of (output_token_ids, output_logprobs), each of batch_size x n
        ret_messages: batch_size x n, end with assistant
        """
        completion_prompts = [
            self.tokenizer.apply_chat_template(message[:2], tokenize=False, add_generation_prompt=True) + message[-1]['content']
            if message[-1]['role'] == 'assistant'
            else self.tokenizer.apply_chat_template(message[:2], tokenize=False, add_generation_prompt=True)
            for message in messages
        ]
        if generate_cfg['add_prefix'] is not None:
            completion_prompts = [prompt + generate_cfg['add_prefix'] for prompt in completion_prompts]

        # check the length of each message
        # if exceed the max_tokens, skip the generation
        valid_completion_prompts = []
        prompt_map = {}

        for idx, completion_prompt in enumerate(completion_prompts):
            if len(self.tokenizer.encode(completion_prompt, truncation=False)) <= generate_cfg.max_prompt_length + generate_cfg.max_response_length:
                prompt_map[len(valid_completion_prompts)] = idx
                valid_completion_prompts.append(completion_prompt)

        for k in generate_cfg.keys():
            if hasattr(self.sampling_params, str(k)):
                v = generate_cfg.get(k)
                if isinstance(v, (ListConfig, list)):
                    v = list(v)
                setattr(self.sampling_params, str(k), v)
        
        if len(valid_completion_prompts) > 0:
            output = self.inference_engine.generate(
                prompts=valid_completion_prompts,
                sampling_params=self.sampling_params,
                use_tqdm=False) # this is a tuple of (output_token_ids, output_logprobs)
            if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
                try:
                    decoded_responses = self.tokenizer.batch_decode(output[0], skip_special_tokens=True) # self.tokenizer.batch_decode(output[0], skip_special_tokens=False)
                except Exception as e:
                    print(f'[rl_agent rollout] FATAL decode error: {e}')
                    print(f'[rl_agent rollout] FATAL output is: {output}')
                    token_ids = [[int(token_id) for token_id in sequence] for sequence in output[0]]
                    decoded_responses = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
            else:
                # output is List[vllm.outputs.RequestOutput] if vllm >= 0.7.0
                decoded_responses = [o.outputs[0].text for o in output] # output[0].outputs[0].text
        else:
            decoded_responses = []

        system2_outputs = [None] * len(messages)
        for new_idx, response in enumerate(decoded_responses):
            orig_idx = prompt_map[new_idx]
            system2_outputs[orig_idx] = generate_cfg['add_prefix'] + response if generate_cfg['add_prefix'] is not None else response

        return system2_outputs

    def generate_multiturn_fn(self, messages: List[Dict], generate_cfg=None):
        """Multi-Turn Version for multi-agent rl.
        system -> user -> assistant -> tool -> assistant ...

        messages: batch_size x n (list of chatml sessions), end with user
        output: tuple of (output_token_ids, output_logprobs), each of batch_size x n
        ret_messages: batch_size x n, end with assistant
        """
        sampling_params = self.s1_sampling_params if 'system1_url' in generate_cfg else self.sampling_params
        # max_prompt_length = generate_cfg.max_prompt_length if 'system1_url' in generate_cfg else generate_cfg.max_prompt_length + generate_cfg.max_response_length 
        completion_prompts = [
            self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=generate_cfg.enable_thinking)
            for message in messages
        ]
        if generate_cfg['add_prefix'] is not None:
            completion_prompts = [prompt + generate_cfg['add_prefix'] for prompt in completion_prompts]
            
        # check the length of each message
        # if exceed the max_tokens, skip the generation
        valid_completion_prompts = []
        prompt_map = {}
        
        max_context_length = generate_cfg.max_prompt_length if 'system1_url' in generate_cfg else generate_cfg.max_prompt_length + generate_cfg.max_response_length
        cur_max_tokens = []
        for idx, completion_prompt in enumerate(completion_prompts):
            cur_prompt_length = len(self.tokenizer.encode(completion_prompt, truncation=False))
            
            if 'system1_url' not in generate_cfg:
                # only system2
                cur_max_tokens.append(max(1, min(generate_cfg.max_response_length, max_context_length - cur_prompt_length)))

            if cur_prompt_length < max_context_length:
                prompt_map[len(valid_completion_prompts)] = idx
                valid_completion_prompts.append(completion_prompt)
            else:
                if 'system1_url' in generate_cfg:
                    assert False, f'[rl_agent rollout] system1 prompt is too long: {cur_prompt_length} > {generate_cfg.max_prompt_length}\nMessage:\n{messages[idx]}'

        for k in generate_cfg.keys():
            if hasattr(sampling_params, str(k)):
                v = generate_cfg.get(k)
                if isinstance(v, (ListConfig, list)):
                    v = list(v)
                setattr(sampling_params, str(k), v)

        if 'system1_url' not in generate_cfg:
            sampling_params.max_tokens = max(cur_max_tokens)

        # print sampling_params:
        if self.show_system1_params and 'system1_url' in generate_cfg:
            print(f"System1_params: {sampling_params}", flush=True)
            self.show_system1_params = False

        if self.show_system2_params and 'system1_url' not in generate_cfg:
            print(f"System2_params: {sampling_params}", flush=True)
            self.show_system2_params = False

        if len(valid_completion_prompts) > 0:
            output = self.inference_engine.generate(
                prompts=valid_completion_prompts,
                sampling_params=sampling_params,
                use_tqdm=False) # this is a tuple of (output_token_ids, output_logprobs)
            if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3'):
                try:
                    decoded_responses = self.tokenizer.batch_decode(output[0], skip_special_tokens=True) # self.tokenizer.batch_decode(output[0], skip_special_tokens=False)
                except Exception as e:
                    print(f'[rl_agent rollout] FATAL decode error: {e}')
                    print(f'[rl_agent rollout] FATAL output is: {output}')
                    token_ids = [[int(token_id) for token_id in sequence] for sequence in output[0]]
                    decoded_responses = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
            else:
                # output is List[vllm.outputs.RequestOutput] if vllm >= 0.7.0
                decoded_responses = [o.outputs[0].text for o in output] # output[0].outputs[0].text
        else:
            decoded_responses = []
        system2_outputs = [None] * len(messages)
        for new_idx, response in enumerate(decoded_responses):
            orig_idx = prompt_map[new_idx]
            system2_outputs[orig_idx] = generate_cfg['add_prefix'] + response if generate_cfg['add_prefix'] is not None else response
        # if generate_cfg['add_prefix'] == "<tool_response>":
        # if generate_cfg['add_prefix'] is None:
            # pdb.set_trace() # len(self.tokenizer.encode(completion_prompts[0], truncation=False))
        # if 'system1_url' in generate_cfg:
        #     pdb.set_trace()
        return system2_outputs # output[0].outputs[0].stop_reason


    def generate_openai(self, query_list):
        messages = query_list["messages"]
        sampling_params = query_list["sampling_params"]
        outputs = self.system1_client.chat.completions.create(
            model=sampling_params.model_name,
            messages=messages,
            n=sampling_params['n'],
            temperature=sampling_params['temperature'],
            top_p=sampling_params['top_p'],
            max_tokens=sampling_params['max_tokens'],
            # stop=sampling_params.get('stop', None), # not use stop in openai
        )
        return outputs
    def system1_generate_oneturn_fn(self, messages: List[Dict], generate_cfg=None):
        """Multi-Turn Version, Inner loop of qwen_agent fncall:
        duplicate messages beforehand, force n_samples to be 1
        so that we can use the same function for iterative qwen_agent calls

        messages: batch_size x n (list of chatml sessions), end with user
        output: tuple of (output_token_ids, output_logprobs), each of batch_size x n
        ret_messages: batch_size x n, end with assistant
        """
        if generate_cfg.system1_mode == 'openai':
            query_list = [{"messages": message, "sampling_params": generate_cfg} for message in messages]
            max_workers = min(len(query_list), generate_cfg.system1_concurrency)
            with ThreadPool(max_workers=max_workers) as pool:
                future = pool.map(self.generate_openai, query_list)
                # outputs = list(future.result())
                outputs = list(tqdm(future.result(), total=len(query_list), desc="Online System1"))
            
            system2_outputs = [output.choices[0].message.content for output in outputs]
            return system2_outputs

        raise NotImplementedError(f"Only support `only_system2: True` and `multiturn_agent: False`, but now `only_system2: {self.multi_agent_pattern.only_system2}` and `multiturn_agent: {self.multiturn_agent}`")

    def get_response(self, prompts_DataProto: DataProto, max_length, sampling_params):
        self.sampling_params = sampling_params
        if self.multi_agent_pattern.system1_sampling_params.system1_mode == 'training':
            self.s1_sampling_params = copy.deepcopy(sampling_params)

        items = prompts_DataProto.non_tensor_batch['rollout_item'] # items: List[Dict], 
        tasks = prompts_DataProto.non_tensor_batch['task']
        for item, task in zip(items, tasks, strict=True): # for dummy
            item['task'] = task

        system2_llm_fn = self.generate_multiturn_fn if self.multiturn_agent else self.system2_generate_oneturn_fn
        system1_llm_fn = self.generate_multiturn_fn if self.multiturn_agent else self.system1_generate_oneturn_fn

        print(f'[rl_agent rollout] start with {len(items)} messages')
        rl_data = self.rl_agent.run_batch(
            items, system2_llm_fn=system2_llm_fn, system1_llm_fn=system1_llm_fn, llm_batch_enabled=True,
            lang_batch=['en'] * len(items),
            multi_agent_pattern=self.multi_agent_pattern,
        ) # return List[dict]
        print(f'[rl_agent rollout] completed with {len(items)} messages')
        
        # # 后处理
        # if 'code_interpreter' in self.rl_agent.function_map:
        #     vLLMRolloutMultiAgent.delete_cached_codes()

        # if 'PythonCodeInterpreter' in self.rl_agent.function_map:
        #     vLLMRolloutMultiAgent.delete_cached_python_codes()
        
        rl_DataProto = self.convert_DataProto(rl_data, prompts_DataProto)
        return rl_DataProto

    def convert_DataProto(self, rl_data, prompts_DataProto) -> DataProto:
        tasks = prompts_DataProto.non_tensor_batch['task']
        if 'uid' in prompts_DataProto.non_tensor_batch.keys():
            uids = prompts_DataProto.non_tensor_batch['uid']
            task2uid = {}
            for u, t in zip(uids, tasks):
                if t not in task2uid:
                    task2uid[t] = u
                else:
                    assert task2uid[t] == u, f'[rl_agent rollout] uid {u} is not consistent with task {t}\n now task2uid is {task2uid}'
        else:
            task2uid = None
        # [item['task'] for item in rl_data]  [item['task'] for item in prompts_DataProto.non_tensor_batch['rollout_item']]
        jobs = prompts_DataProto.non_tensor_batch['job']
        # pdb.set_trace()
        task2job = {}
        for t, j in zip(tasks, jobs):
            task2job[t] = j

        cur_device = prompts_DataProto.batch['input_ids'].device
        train_system1 = self.multi_agent_pattern.system1_sampling_params.system1_mode == "training"

        if train_system1:
            max_prompt_length = max(self.multi_agent_pattern.system1_sampling_params.max_prompt_length, self.multi_agent_pattern.system2_sampling_params.max_prompt_length)
            max_response_length = max(self.multi_agent_pattern.system1_sampling_params.max_response_length, self.multi_agent_pattern.system2_sampling_params.max_response_length)
        else:
            max_prompt_length = self.multi_agent_pattern.system2_sampling_params.max_prompt_length
            max_response_length = self.multi_agent_pattern.system2_sampling_params.max_response_length

        system2_DataProto = self.convert_system2_data(rl_data, task2uid, task2job, cur_device, max_prompt_length, max_response_length)
        if train_system1:
            system1_DataProto = self.convert_system1_data(rl_data, task2uid, task2job, cur_device, max_prompt_length, max_response_length)
        else:
            system1_DataProto = None
        # TODO: different merge strategy
        if system1_DataProto is not None:
            system2_DataProto = DataProto.concat([system2_DataProto, system1_DataProto])

        return system2_DataProto

    def process_tokenize_and_pad(
        self, 
        texts: List[str], 
        truncate_length: int,  # 截断长度
        pad_target_length: int,  # 最终需要pad到的长度
        padding_side: str = 'left'  # padding方向
    ) -> Dict[str, torch.Tensor]:
        """
        先按truncate_length进行tokenize和truncate，然后pad到pad_target_length
        
        Args:
            texts: 输入文本列表
            truncate_length: tokenize时的截断长度
            pad_target_length: 最终需要pad到的长度
            padding_side: padding方向，'left'或'right'
            
        Returns:
            包含input_ids和attention_mask的字典
        """
        # 先用原始长度进行tokenize
        padded = self.tokenizer(
            texts,
            padding='max_length',
            truncation=True,
            max_length=truncate_length,
            padding_side=padding_side,
            return_tensors="pt"
        )
        
        # 如果需要，补充padding到目标长度
        if pad_target_length > truncate_length:
            padded['input_ids'] = torch.nn.functional.pad(
                padded['input_ids'],
                (pad_target_length - truncate_length, 0) if padding_side == 'left' else (0, pad_target_length - truncate_length),
                value=self.tokenizer.pad_token_id
            )
            padded['attention_mask'] = torch.nn.functional.pad(
                padded['attention_mask'],
                (pad_target_length - truncate_length, 0) if padding_side == 'left' else (0, pad_target_length - truncate_length),
                value=0
            )
        
        return padded
    
    def convert_system1_data(self, rl_data, task2uid, task2job, cur_device, max_prompt_length, max_response_length) -> DataProto:
        prompts_str, responses_str = [], []
        system2_response = []
        for item in rl_data:  # len(rl_data[0]['system1'])
            if item['system1'] is None:
                continue
            for s_i in item['system1']:
                prompts_str.append(s_i['system1_prompt'])
                responses_str.append(s_i['system1_response'])
            system2_response.extend([item['system2_response']] * len(item['system1']))
        if len(prompts_str) == 0:
            return None
        
        padded_prompts = self.process_tokenize_and_pad(
            prompts_str,
            truncate_length=self.multi_agent_pattern.system1_sampling_params.max_prompt_length,
            pad_target_length=max_prompt_length,
            padding_side='left'
        )
        
        padded_responses = self.process_tokenize_and_pad(
            responses_str,
            truncate_length=self.multi_agent_pattern.system1_sampling_params.max_response_length,
            pad_target_length=max_response_length,
            padding_side='right'
        )

        system2_responses = self.process_tokenize_and_pad(
            system2_response,
            truncate_length=self.multi_agent_pattern.system2_sampling_params.max_response_length,
            pad_target_length=max_response_length,
            padding_side='right'
        )['input_ids']

        prompts = padded_prompts['input_ids']
        responses = padded_responses['input_ids']
        input_ids = torch.cat([prompts, responses], dim=1)
        assert system2_responses.size(0) == input_ids.size(0), f"convert_system1_data ERROR: {system2_responses.size(0)} != {input_ids.size(0)}"
        
        # position ids of prompts
        prompt_attention_mask = padded_prompts['attention_mask']
        prompt_position_ids = prompt_attention_mask.long().cumsum(-1) - 1
        # 将padding位置的position_id设为0
        prompt_position_ids.masked_fill_(prompt_attention_mask == 0, 0)

        # position ids of response
        batch_size = responses.size(0)
        response_length = responses.size(1)
        delta_position_id = torch.arange(1, response_length + 1)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([prompt_position_ids, response_position_ids], dim=-1)

        # 为 response 生成 attention mask
        if self.multiturn_agent:
            # 生成pad_mask（非pad token为1，pad token为0）
            pad_mask = responses.ne(self.tokenizer.pad_token_id).to(padded_responses['attention_mask'].dtype)

            # 处理全padding行（若无有效token则将首个位置设为1）
            all_pad_rows = (pad_mask.sum(dim=1) == 0)
            pad_mask[all_pad_rows, 0] = 1
            response_attention_mask = pad_mask
        else:
            response_attention_mask = get_eos_mask(response_id=responses, eos_token=self.tokenizer.eos_token_id, dtype=padded_responses['attention_mask'].dtype)
        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) # TODO(zlw) why??
        
        # # 为 function call 生成 loss mask
        # print('[rl_agent mask] start')
        # response_attention_mask = self.build_response_mask(responses, response_attention_mask, responses_str)
        # print('[rl_agent mask] completed')

        batch_tensors = TensorDict(
            {
                'prompts': prompts,
                'responses': responses,
                'input_ids': input_ids,  # here input_ids become the whole sentences
                'response_for_reward': system2_responses,
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                'attention_mask': attention_mask,
                'response_mask': response_attention_mask,
                'position_ids': position_ids
            },
            batch_size=batch_size).to(cur_device)

        # organize corresponding system2 response for reward calculating

        non_tensor_batch = { # len(non_tensor_batch['uid'])
            'data_source': [],
            'task': [],
            'uid': [],
            'job': [],
            'messages': [],  # all sampled information for saving
            # 'decoded_responses': responses_str,
            'answer_tag_part': [],
            "boxed_answer": [],
            "answer": [],
            'stop_reason': [],
            'output_format': [], # True or False
            'tool_usage': [],
        }

        for item in rl_data:
            cur_task = item['task'] # 'multi_agent/hle_425/hle'
            if item['system1'] is None:
                continue

            for s_i in item['system1']:
                if task2uid is not None:
                    non_tensor_batch['uid'].append(task2uid[cur_task] + "_system1")
                else:
                    non_tensor_batch['uid'].append(None)
                non_tensor_batch['job'].append(task2job[cur_task])
                non_tensor_batch['messages'].append(item)
                non_tensor_batch['output_format'].append(s_i['system1_output_format'])
                # 数量上保持一致
                for k in non_tensor_batch:
                    if k in item:
                        if k in ["task", "data_source"]:
                            non_tensor_batch[k].append(item[k] + "_system1")
                        else:
                            non_tensor_batch[k].append(item[k])
        non_tensor_batch = {k: np.array(v, dtype=object) for k, v in non_tensor_batch.items()}
        assert len(batch_tensors) == len(non_tensor_batch['data_source']), f"system1 DataProto Error: batch ({len(batch_tensors)}) should be equal to non_tensor_batch ({len(non_tensor_batch['data_source'])})"

        return DataProto(
            batch=batch_tensors,
            non_tensor_batch=non_tensor_batch,
            # meta_info=self.meta_info,
        )

    def convert_system2_data(self, rl_data, task2uid, task2job, cur_device, max_prompt_length, max_response_length) -> DataProto:
        prompts_str = [item['system2_prompt'] for item in rl_data]
        responses_str = [item['system2_response'] for item in rl_data]

        padded_prompts = self.process_tokenize_and_pad(
            prompts_str,
            truncate_length=self.multi_agent_pattern.system2_sampling_params.max_prompt_length,
            pad_target_length=max_prompt_length,
            padding_side='left'
        )
        
        padded_responses = self.process_tokenize_and_pad(
            responses_str,
            truncate_length=self.multi_agent_pattern.system2_sampling_params.max_response_length,
            pad_target_length=max_response_length,
            padding_side='right'
        )

        prompts = padded_prompts['input_ids']
        responses = padded_responses['input_ids']
        input_ids = torch.cat([prompts, responses], dim=1)
        
        # position ids of prompts
        prompt_attention_mask = padded_prompts['attention_mask']
        prompt_position_ids = prompt_attention_mask.long().cumsum(-1) - 1
        # 将padding位置的position_id设为0
        prompt_position_ids.masked_fill_(prompt_attention_mask == 0, 0)

        # position ids of response
        batch_size = responses.size(0)
        response_length = responses.size(1)
        delta_position_id = torch.arange(1, response_length + 1)
        delta_position_id = delta_position_id.unsqueeze(0).repeat(batch_size, 1)

        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        response_position_ids = prompt_position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([prompt_position_ids, response_position_ids], dim=-1)

        # 为 response 生成 attention mask
        if self.multiturn_agent:
            # 生成pad_mask（非pad token为1，pad token为0）
            pad_mask = responses.ne(self.tokenizer.pad_token_id).to(padded_responses['attention_mask'].dtype)

            # 处理全padding行（若无有效token则将首个位置设为1）
            all_pad_rows = (pad_mask.sum(dim=1) == 0)
            pad_mask[all_pad_rows, 0] = 1
            response_attention_mask = pad_mask
        else:
            response_attention_mask = get_eos_mask(response_id=responses, eos_token=self.tokenizer.eos_token_id, dtype=padded_responses['attention_mask'].dtype)
        attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1) # TODO(zlw) why??
        
        # 为 function call 生成 loss mask
        if VERBOSE:
            print('[rl_agent mask] start')
        response_attention_mask = self.build_response_mask(responses, response_attention_mask, responses_str)
        if VERBOSE:
            print('[rl_agent mask] completed')

        batch_tensors = TensorDict(
            {
                'prompts': prompts,
                'responses': responses,
                'input_ids': input_ids,  # here input_ids become the whole sentences
                'response_for_reward': responses,
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                'attention_mask': attention_mask,
                'response_mask': response_attention_mask,
                'position_ids': position_ids
            },
            batch_size=batch_size).to(cur_device)

        non_tensor_batch = {
            'data_source': [],
            'task': [],
            'uid': [],
            'job': [],
            'messages': [],  # all sampled information for saving
            # 'decoded_responses': responses_str,
            'answer_tag_part': [],
            "boxed_answer": [],
            "answer": [],
            'stop_reason': [],
            'output_format': [], # True or False
            'tool_usage': [],
        }
        for item in rl_data:
            if task2uid is not None:
                non_tensor_batch['uid'].append(task2uid[item['task']] + "_system2")
            else:
                non_tensor_batch['uid'].append(None)
            non_tensor_batch['job'].append(task2job[item['task']])
            non_tensor_batch['messages'].append(item)
            non_tensor_batch['output_format'].append(None)

            for k in non_tensor_batch:
                if k in item:
                    non_tensor_batch[k].append(item[k])
        non_tensor_batch = {k: np.array(v, dtype=object) for k, v in non_tensor_batch.items()}
        return DataProto(
            batch=batch_tensors,
            non_tensor_batch=non_tensor_batch,
            # meta_info=self.meta_info,
        )

    def build_response_mask(self, response_ids, response_attention_mask, response_strings):
        if self.multiturn_agent:
            return self.build_response_mask_multiturn(response_ids, response_attention_mask, response_strings)
        else:
            return self.build_response_mask_oneturn(response_ids, response_attention_mask, response_strings)

    # def build_response_mask_multiturn(self, response_ids, response_attention_mask):
    #     # A: [assistant_content, <|im_end|>, \n]
    #     # B: [<|im_start|>, user_role, \n, <tool_response>, tool_response, </tool_response>, <|im_end|>, \n, <|im_start|>, assistant_role, \n]
    #     # response: [A] + [0 or more of BA]
    #     # we want to mask out all of B by finding each [<|im_start|>, user_role, \n, ..., <|im_start|>, assistant_role, \n]

    #     bs = response_ids.size(0)
    #     im_start_id, user_role_id, newline_id = self.tokenizer.encode("<|im_start|>user\n")
    #     assistant_role_id = self.tokenizer.convert_tokens_to_ids("assistant")

    #     for batch_idx in range(bs):
    #         tokens = response_ids[batch_idx].tolist()
    #         tool_responses_count = 0
    #         i = 0
    #         while i < len(tokens) - 2:
    #             # Reach <|endoftext|>
    #             if response_attention_mask[batch_idx, i] == 0:
    #                 break
    #             # Check if current position is the start of a B section
    #             if (tokens[i] == im_start_id and 
    #                 tokens[i+1] == user_role_id and 
    #                 tokens[i+2] == newline_id):
    #                 # Search for the end of the B section
    #                 j = i + 3
    #                 end_found = False
    #                 while j < len(tokens) - 2:
    #                     if (tokens[j] == im_start_id and 
    #                         tokens[j+1] == assistant_role_id and 
    #                         tokens[j+2] == newline_id):
    #                         # Found the end pattern; mask from i to j+2 inclusive
    #                         response_attention_mask[batch_idx, i:j+3] = 0
    #                         # Move index to the end of the current B section to avoid overlapping
    #                         i = j + 3
    #                         end_found = True
    #                         break
    #                     j += 1
    #                 if not end_found:
    #                     i += 1
    #                 else:
    #                     tool_responses_count += 1
    #             else:
    #                 i += 1
    #         print(f'[rl_agent mask] batch_idx = {batch_idx}: {tool_responses_count} tool_responses masked')
    #     return response_attention_mask

    def build_response_mask_oneturn(self, response_ids, response_attention_mask, response_strings):
        # we want to mask out all of [<tool_response>, ..., </tool_response>]

        encodings = self.tokenizer(
            response_strings,
            padding=False,
            return_offsets_mapping=True
        )
        assert TOOL_RESPONSE_TAG is None, f"TOOL_RESPONSE_TAG can not be None if use build_response_mask_oneturn"
        pattern = re.compile(r'<{tag}>.*?</{tag}>'.format(tag=TOOL_RESPONSE_TAG), re.DOTALL)
        
        for idx, (response, offset_mapping) in enumerate(zip(response_strings, encodings.offset_mapping)):
            # 使用正则表达式一次性找到所有matches
            matches = list(pattern.finditer(response))
            
            if matches:
                offset_mapping = np.array(offset_mapping)
                # 批量处理所有matches
                for match in matches:
                    start_char, end_char = match.span()
                    # 向量化操作找token位置
                    start_token = np.where(offset_mapping[:, 0] >= start_char)[0][0]
                    end_token = np.where(offset_mapping[:, 1] >= end_char)[0][0]
                    
                    response_attention_mask[idx, start_token:end_token+1] = 0
                    
                print(f'[rl_agent mask] idx = {idx}: {len(matches)} observations masked')
                
                if self.debug:
                    for start_char, end_char in [m.span() for m in matches]:
                        print(f"Original text: {response}")
                        print(f"Masked text: {response[start_char:end_char]}")
        
        return response_attention_mask

    def build_response_mask_multiturn(self, response_ids, response_attention_mask, response_strings):
        encodings = self.tokenizer(
            response_strings,
            padding=False,
            return_offsets_mapping=True
        )
        # 1. 匹配从user开始到最近的im_end
        user_pattern = re.compile(r'<\|im_start\|>user\n.*?<\|im_end\|>\n?', re.DOTALL)
        # 2. 匹配assistant开始标记
        assistant_pattern = re.compile(r'<\|im_start\|>assistant\n')
        for idx, (response, offset_mapping) in enumerate(zip(response_strings, encodings.offset_mapping)):
            offset_mapping = np.array(offset_mapping)
            total_matches = 0
            
            # 第一步：mask掉所有user部分
            user_matches = list(user_pattern.finditer(response))
            total_matches += len(user_matches)
            
            for match in user_matches:
                start_char, end_char = match.span()
                start_tokens = np.where(offset_mapping[:, 0] >= start_char)[0]
                end_tokens = np.where(offset_mapping[:, 1] >= end_char)[0]
                
                if len(start_tokens) > 0 and len(end_tokens) > 0:
                    start_token = start_tokens[0]
                    end_token = end_tokens[0]
                    response_attention_mask[idx, start_token:end_token+1] = 0
                    # mask_str = self.tokenizer.decode(response_ids[idx, start_token:end_token+1])

            # 第二步：mask掉所有assistant开始标记
            assistant_matches = list(assistant_pattern.finditer(response))
            total_matches += len(assistant_matches)

            for match in assistant_matches:
                start_char, end_char = match.span()
                start_tokens = np.where(offset_mapping[:, 0] >= start_char)[0]
                end_tokens = np.where(offset_mapping[:, 1] >= end_char)[0]
                
                if len(start_tokens) > 0 and len(end_tokens) > 0:
                    start_token = start_tokens[0]
                    end_token = end_tokens[0]
                    response_attention_mask[idx, start_token:end_token+1] = 0
                    # mask_str = self.tokenizer.decode(response_ids[idx, start_token:end_token+1])

            if VERBOSE and total_matches > 0:
                print(f'[rl_agent mask] idx = {idx}: {len(user_matches)} user sections and {len(assistant_matches)} assistant markers masked')
                
            if self.debug:
                print(f"Original response: {response}")
                print("Masked sections:")
                for m in user_matches:
                    print(f"User section: {response[m.start():m.end()]}")
                for m in assistant_matches:
                    print(f"Assistant marker: {response[m.start():m.end()]}")
                    
        return response_attention_mask
    