import time
from abc import ABC
from copy import deepcopy
from dataclasses import dataclass, field
from datetime import timedelta
from typing import List, Optional, Tuple, Union, Dict, Any

import ray
import torch

from openrlhf.datasets.utils import zero_pad_sequences
from openrlhf.models.utils import compute_approx_kl, compute_reward, masked_mean, process_sequences
from openrlhf.trainer.ray.launcher import PPORayActorGroup
from openrlhf.utils.logging_utils import init_logger
import random



from openrlhf.trainer.ppo_utils.appworld_prompt import prompt_template,prompt_react,use_react_prompt  # 假设这个导入可用
import requests
import re
import os
import json
from jinja2 import Template

logger = init_logger(__name__)

think_mode=True  # 是否启用think模式



def to(tensor: Union[torch.Tensor, list[torch.Tensor]], device):
    if isinstance(tensor, list):
        return [to(t, device) for t in tensor]
    return tensor.to(device) if isinstance(tensor, torch.Tensor) else tensor


def pin_memory(tensor: Union[torch.Tensor, list[torch.Tensor]]):
    if isinstance(tensor, list):
        return [pin_memory(t) for t in tensor]
    return tensor.pin_memory() if isinstance(tensor, torch.Tensor) else tensor

def find_best_batch_size(total_samples: int, original_batch_size: int, num_gpus: int = 8) -> int:
    for delta in range(0, total_samples):
        #print(f"Trying delta: {delta}")
        for direction in [1, -1]:
            candidate = original_batch_size + direction * delta
            #print(f"Trying candidate batch size: {candidate}")
            if candidate <= 0:
                continue
            actual_batches = (total_samples + candidate - 1) // candidate
            #print(f"Actual batches: {actual_batches}")
            if actual_batches % num_gpus == 0:
                return candidate
    return total_samples  # Fallback

# [新增] 用于移除think标记的函数
def remove_think_blocks(text):
    """
    移除字符串中所有从 <think>\n 到 </think>\n 之间（包括两端标记本身）的内容。
    如果没有标记，则返回原字符串。
    """
    # 使用正则表达式，非贪婪匹配，跨多行
    pattern = r'<think>.*?</think>'
    return re.sub(pattern, '', text, flags=re.DOTALL)

def extract_code_and_fix_content(text: str) -> tuple[str, str]:
    original_text = text
    output_code = ""
    match_end = 0
    full_code_regex = r"```python\n(.*?)```"
    partial_code_regex = r".*```python\n(.*)"
    ignore_multiple_calls = True  # 是否忽略多次调用
    # Handle multiple calls
    for re_match in re.finditer(full_code_regex, original_text, flags=re.DOTALL):
        code = re_match.group(1).strip()
        if ignore_multiple_calls:
            text = original_text[: re_match.end()]
            return code, text
        output_code += code + "\n"
        match_end = re_match.end()
    # check for partial code match at end (no terminating ```)  following the last match
    partial_match = re.match(
        partial_code_regex, original_text[match_end:], flags=re.DOTALL
    )
    if partial_match:
        output_code += partial_match.group(1).strip()
        # terminated due to stop condition. Add stop condition to output.
        if not text.endswith("\n"):
            text = text + "\n"
        text = text + "```"
    if len(output_code) == 0:
        return "", text
    else:
        return output_code, text


# [新增] 用于处理对话消息格式的函数
def prompt_messages(prompt):
    """Builds prompt messages for the agent to solve self.task.instruction"""
    if not prompt or not isinstance(prompt, str):
        return []

    messages = []
    # 使用更灵活的正则表达式，允许标记后有多种空白格式
    pattern = re.compile(r"(USER|ASSISTANT|SYSTEM):\s*\n")

    matches = list(pattern.finditer(prompt))
    if not matches:
        # 如果没有匹配到任何标记，将整个内容作为用户消息
        return [{"role": "user", "content": prompt.strip()}]

    # 检查开头是否有未分配角色的文本
    if matches[0].start() > 0:
        prefix_text = prompt[:matches[0].start()].strip()
        if prefix_text:
            # 可以选择将其作为系统消息或报错
            messages.append({"role": "system", "content": prefix_text})

    # 处理每个匹配的角色和内容
    for i, match in enumerate(matches):
        role = match.group(1).lower()
        start_pos = match.end()

        # 确定内容结束位置
        if i < len(matches) - 1:
            end_pos = matches[i + 1].start()
        else:
            end_pos = len(prompt)

        # 提取并清理内容
        content = prompt[start_pos:end_pos].strip()
        if content:  # 只添加非空内容
            messages.append({"role": role, "content": content})

    # 检查是否有有效消息
    if not messages:
        return [{"role": "user", "content": "No valid content found."}]

    return messages


# [新增] 辅助函数用于保存内容到文件
def save_w(folder_path: str, json_name: str, data):
    """将数据以 JSON 格式保存到指定文件夹下的指定文件中。"""
    try:
        # 确保文件夹存在
        os.makedirs(folder_path, exist_ok=True)

        # 构建完整文件路径
        file_path = os.path.join(folder_path, json_name)

        # 写入 JSON 文件
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=4)

    except Exception as e:
        logger.error(f"保存 JSON 文件时出错: {str(e)}")


# [新增] AppWorld API交互类
class AppWorldAPI:
    """管理与AppWorld服务的交互"""

    def __init__(self, base_url="http://localhost:8000"):
        self.base_url = base_url
        self.session = requests.Session()  # 使用会话保持连接
        self.timeout = 100  # 设置请求超时时间

    def initialize_task(self, task_id: str):
        """初始化一个对话任务"""
        response = self.session.post(
            f"{self.base_url}/initialize",
            json={"task_id": task_id},
            timeout=self.timeout
        )
        if response.status_code == 200:
            # print(f"Task {task_id} initialized successfully.")
            # print(f"Response: {response.json()}")
            return response.json()
        raise Exception(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!Failed to initialize task: {response.text}")

    def execute(self, task_id: str, message: str):
        """发送消息并获取响应"""
        response = self.session.post(
            f"{self.base_url}/execute",
            json={"task_id": task_id, "code": message},
            timeout=self.timeout
        )
        if response.status_code == 200:
            # print(f"Task {task_id} executed successfully.")
            # print(f"Response: {response.json()}")
            return response.json()
        raise Exception(f"!!!!!!!!!!!!!!!!!!!!!! Failed to execute: {response.text}")

    def check_task_completed(self, task_id: str):
        """检查对话任务是否完成"""
        response = self.session.post(
            f"{self.base_url}/task_completed",
            json={"task_id": task_id},
            timeout=self.timeout
        )
        if response.status_code == 200:
            # print(f"Task {task_id} completion status checked successfully.")
            # print(f"Response: {response.json()}")
            return response.json()
        raise Exception(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!Failed to check task status: {response.text}")
    
    def check_task_success(self, task_id: str):
        """
        检查任务是否成功完成。返回:
        - success: True(成功)/False(失败)/None(未完成)
        - status: 任务状态字符串
        - answer: 如有返回answer字段则一并返回
        """
        response = self.session.post(
            f"{self.base_url}/evaluate",
            json={
        "task_id": task_id,
        "suppress_errors": True,
        "report": False
    },
            timeout=self.timeout
        )
        if response.status_code == 200:
            # print(f"!!!!!!!!!!!!!!!!!!!!!!!!!Checking task {task_id} success status...")
            # print(f"!!!!!!!!!!!!!!!!!!!!!!!!!Response: {response.json()}")
            data = response.json()
            # data is like: {"output": {...}}
            output = data.get("output", {})
            # print(f"!!!!!!!!!!!!!!!!!!!!!!!!!Output: {output}")
            status_code = output.get("success", False)
            # print(f"!!!!!!!!!!!!!!!!!!!!!!!!!Status code: {status_code}")
            # print(type(status_code))
            return status_code, output

        # if response.status_code == 200:
        #     data = response.json().get("output", {})
        #     status = data.get("status", None)
        #     answer = data.get("answer", None)
        #     if status == "success":
        #         #print(f"Task {task_id} completed successfully.")
        #         return True, status, answer
        #     elif status == "fail":
        #         #print(f"Task {task_id} failed.")
        #         return False, status, None
        #     else:
        #         return False, status, None  # 未完成
        raise Exception(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!Failed to check task success: {response.text}")

    def close_task(self, task_id: str):
        """关闭对话任务"""
        response = self.session.post(
            f"{self.base_url}/close",
            json={"task_id": task_id},
            timeout=self.timeout
        )
        if response.status_code == 200:
            # print(f"Task {task_id} closed successfully.")
            # print(f"Response: {response.json()}")
            return response.json()
        raise Exception(f"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Failed to close task: {str(response)}")


# [新增] 对话轮次数据类
@dataclass
class Turn:
    """表示对话中的一个轮次，包含用户输入和AI响应"""
    turn_idx: int  # 轮次索引
    rollout_idx: int
    prompt: str  # 用户/系统输入
    response: str  # AI响应(含think)
    clean_response: str  # 清除think后的响应
    feedback: str  # 环境反馈
    prompt_tokens: torch.Tensor  # 提示的token IDs
    response_tokens: torch.Tensor  # 响应的token IDs
    full_prompt_tokens: torch.Tensor = None  # 完整格式化提示的token IDs，包括历史记录
    env_feedback: Optional[Dict] = None  # 环境反馈的附加信息



@dataclass
class Experience:
    """Experience is a batch of data.
    These data should have the the sequence length and number of actions.
    Left padding for sequences is applied.

    Shapes of each tensor:
    sequences: (B, S)
    action_log_probs: (B, A)
    base_action_log_probs: (B, A)
    values: (B, A)
    returns: (B, A)
    advantages: (B, A)
    attention_mask: (B, S)
    action_mask: (B, A)
    kl: (B, A)

    "A" is the number of actions.
    """

    sequences: torch.Tensor
    action_log_probs: torch.Tensor
    base_action_log_probs: torch.Tensor
    values: torch.Tensor
    returns: Optional[torch.Tensor]
    advantages: Optional[torch.Tensor]
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    info: Optional[dict]
    kl: Optional[torch.Tensor] = None

    # [新增] 轨迹元数据，帮助奖励计算
    batch_info: Optional[Dict] = None

    @torch.no_grad()
    def to_device(self, device: torch.device):
        self.sequences = to(self.sequences, device)
        self.action_log_probs = to(self.action_log_probs, device)
        self.base_action_log_probs = to(self.base_action_log_probs, device)
        self.returns = to(self.returns, device)
        self.advantages = to(self.advantages, device)
        self.values = to(self.values, device)
        self.attention_mask = to(self.attention_mask, device)
        self.action_mask = to(self.action_mask, device)
        self.kl = to(self.kl, device)
        self.info = {key: to(value, device) for key, value in self.info.items()}
        return self

    def pin_memory(self):
        self.sequences = pin_memory(self.sequences)
        self.action_log_probs = pin_memory(self.action_log_probs)
        self.base_action_log_probs = pin_memory(self.base_action_log_probs)
        self.returns = pin_memory(self.returns)
        self.advantages = pin_memory(self.advantages)
        self.values = pin_memory(self.values)
        self.attention_mask = pin_memory(self.attention_mask)
        self.action_mask = pin_memory(self.action_mask)
        self.kl = pin_memory(self.kl)
        self.info = {key: pin_memory(value) for key, value in self.info.items()}

        return self


@dataclass
class Samples:
    """Samples is a batch of data.
    There can be 2 formats to store the samples, batched or packed.
    The batched format means padding is applied to the sequences, while the packed format
    will concatenate the prompt and response without padding.

    Shapes of each tensor, when 2 shapes are shown, the first one is for batched format
        and the second one is for packed format:
    sequences: (B, S) or (1, total_length), the tokens of both prompt and response.
    attention_mask: (B, S) or (1, total_length), the attention mask for sequences.
    action_mask: (B, A) or None, the action (response) mask to show which part of the
        sequence is the response. When the samples are packed, this is None.
        When the samples are not packed, we will use action_mask, so this is an int to
        show the size of action_mask. Otherwise, this is a tensor to show the number of
        actions for each sample.
    packed_seq_lens: None or (B,), the length of each sample in the packed samples.
    response_length: (B,), the number of tokens in the response.
    total_length: (B,), the total number of tokens in the sequences.
    prompts: the prompts used to generate responses
    """

    sequences: torch.Tensor
    attention_mask: Optional[torch.LongTensor]
    action_mask: Optional[torch.BoolTensor]
    response_length: torch.Tensor
    total_length: torch.Tensor
    # prompts: list[str]
    # labels: list[str]

    # [新增] 轨迹和回合信息
    task_id: Optional[str] = None
    turn_idx: Optional[int] = None
    trajectory_info: Optional[Dict] = None
    prompt_len: Optional[int] = None
    env_feedback: Optional[Dict] = None  # 环境反馈的附加信息
    str_info: Optional[Dict] = None  # 用于存储字符串信息，如提示文本


    def __init__(
            self,
            sequences=None,
            attention_mask=None,
            action_mask=None,
            response_length=None,
            total_length=None,
            # prompts=None,
            # labels=None,
            packed_seq_lens=None,
            task_id=None,
            turn_idx=None,
            trajectory_info=None,
            prompt_len=None,
            env_feedback = None,  # 环境反馈的附加信息
            str_info = None
    ):
        self.sequences = sequences
        self.attention_mask = attention_mask
        self.action_mask = action_mask
        self.response_length = response_length
        self.total_length = total_length
        # self.prompts = prompts or []
        # self.labels = labels or []
        self.packed_seq_lens = packed_seq_lens
        self.task_id = task_id
        self.turn_idx = turn_idx
        self.trajectory_info = trajectory_info
        self.prompt_len = prompt_len
        self.env_feedback = env_feedback  # 环境反馈的附加信息
        self.str_info = str_info  # 用于存储字符串信息，如提示文本

    def split(self, split_size: int):
        sequences_list = self.sequences.split(split_size, dim=0)
        attention_mask_list = self.attention_mask.split(split_size, dim=0)
        action_mask_list = self.action_mask.split(split_size, dim=0)
        sample_list = []
        for i, (seq, mask, action_mask) in enumerate(zip(sequences_list, attention_mask_list, action_mask_list)):
            sample = Samples()
            sample.sequences = seq
            sample.attention_mask = mask
            sample.action_mask = action_mask
            sample.response_length = sample.action_mask.float().sum(dim=-1)
            sample.total_length = sample.attention_mask.float().sum(dim=-1)
            # sample.prompts = self.prompts[i * split_size: (i + 1) * split_size]
            # sample.labels = self.labels[i * split_size: (i + 1) * split_size]
            sample.task_id = self.task_id  # 保留轨迹信息
            sample.turn_idx = self.turn_idx
            sample.trajectory_info = self.trajectory_info
            sample.prompt_len = self.prompt_len
            sample_list.append(sample)

        return sample_list


class RemoteExperienceMaker(ABC):
    def __init__(
            self,
            actor_model_group: PPORayActorGroup,
            critic_model_group: PPORayActorGroup,
            reward_model_group: PPORayActorGroup,
            initial_model_group: PPORayActorGroup,
            tokenizer,
            prompt_max_len: int,
            kl_controller,
            strategy=None,
            remote_rm_url: Union[list[str], str] = None,
            vllm_engines: List = None,
            packing_samples=False,
            # [新增] 多轮对话配置参数
            api_base_urls: List[str] = ["http://localhost:8000"],
            max_turns: int = 20,
            max_gen_length: int = 2000,
            log_dir: str = "/home/OpenRLHF_2/new_log",
            **kwargs,
    ):
        super().__init__()

        self.vllm_engines = vllm_engines
        self.packing_samples = packing_samples
        self.actor_model_group = actor_model_group
        self.critic_model_group = critic_model_group
        self.reward_model_group = reward_model_group
        self.initial_model_group = initial_model_group
        self.tokenizer = tokenizer
        self.prompt_max_len = prompt_max_len
        self.kl_ctl = kl_controller
        self.strategy = strategy
        self.advantage_estimator = strategy.args.advantage_estimator
        self.args = strategy.args


        # [新增] 多轮对话相关配置
        self.api_clients = [AppWorldAPI(url) for url in api_base_urls]
        self.num_api_clients = len(self.api_clients)
        if self.num_api_clients == 0:
            raise ValueError("Must provide at least one API base URL.")
        self.max_turns = max_turns
        self.max_gen_length = max_gen_length
        self.log_dir = log_dir
        self.log_sub_dir = None
        os.makedirs(self.log_dir, exist_ok=True)
        
        # max_tokens=kwargs.get("max_new_tokens", self.max_gen_length)
        # print(f"**********************max_tokens is set to******************: {max_tokens}")
        # custom reward func for reinforced finetuning
        self.custom_reward_func = None
        self.remote_rm_url = [remote_rm_url] if isinstance(remote_rm_url, str) else remote_rm_url
        if remote_rm_url and remote_rm_url[0].endswith(".py"):
            print(f"Loading custom `reward_func(queries, prompts, labels)` from {remote_rm_url[0]}")
            import importlib.util

            spec = importlib.util.spec_from_file_location("reward_func", remote_rm_url[0])
            reward_module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(reward_module)
            self.custom_reward_func = ray.remote(reward_module.reward_func)

    def create_log_subdir(self):
        # 生成当前时间戳，格式为 YYYYMMDD_HHMMSS
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        # 构造子目录路径
        self.log_sub_dir = os.path.join(self.log_dir, timestamp)
        print(f"Creating log subdirectory: {self.log_sub_dir}")
        # 创建子目录（如果不存在）
        os.makedirs(self.log_sub_dir, exist_ok=True)
    # tokenizer
    def tokenize_fn(self, texts, max_length, padding=True, device=None):
        if not padding:
            # when padding is False, return tokenized texts as list
            return self.tokenizer(
                texts,
                add_special_tokens=False,
                max_length=max_length,
                truncation=True,
            )
        batch = self.tokenizer(
            texts,
            return_tensors="pt",
            add_special_tokens=False,
            max_length=max_length,
            padding=True,
            truncation=True,
        )
        return {k: v.to(device) for k, v in batch.items()}

    # [修改] 接收任务ID列表而非提示文本列表
    @torch.no_grad()
    def make_experience_list(
            self, task_ids: List[str], **generate_kwargs
    ) -> List[Experience]:
        """
        为多个任务ID生成经验列表，并将多轮对话拆分成单轮对话样本。
        """
        args = self.strategy.args
        self.create_log_subdir()


        # vLLM wakeup when vllm_enable_sleep
        if args.vllm_enable_sleep:
            from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
            batch_vllm_engine_call(self.vllm_engines, "wake_up")

        # [修改] 设置API和vLLM引擎参数
        n_samples_per_task = generate_kwargs.pop("n_samples_per_prompt", args.n_samples_per_prompt)
        max_api_failures = generate_kwargs.pop("max_api_failures", 30)
        api_cooldown = generate_kwargs.pop("api_cooldown", 60)

        engine_count = len(self.vllm_engines)
        print(f"engine_count: {engine_count}")

        print(f"Generating {n_samples_per_task} samples per task for {len(task_ids)} tasks")

        # vLLM引擎负载计数
        vllm_load_counter = {i: 0 for i in range(len(self.vllm_engines))}

        # 选择负载最小的vLLM引擎
        # def get_least_loaded_engine():
        #     return min(vllm_load_counter.items(), key=lambda x: x[1])[0]

        def get_free_engine():
            for idx, load in vllm_load_counter.items():
                if load == 0:
                    return idx
            return None

        # 创建任务队列
        pending_tasks = []
        for task_id in task_ids:
            for i in range(n_samples_per_task):
                pending_tasks.append((task_id, i))
        print(f"all_task_configs num is: {len(pending_tasks)}")

        # API客户端状态
        api_status = {}
        for i in range(len(self.api_clients)):
            api_status[i] = {
                'available': True,
                'available_time': time.time(),
                'failure_count': 0,
                'disabled': False
            }
        num_api_clients = len(self.api_clients)
        print(f"num_api_clients: {num_api_clients}")

        # [修改] 设置Ray远程函数处理单个轨迹
        @ray.remote
        def process_single_trajectory(task_id, api_client_idx, vllm_engine_idx, rollout_idx):
            try:
                # 获取特定的API客户端和vLLM引擎
                api_client = self.api_clients[api_client_idx]
                vllm_engine = self.vllm_engines[vllm_engine_idx]

                # 生成完整轨迹的样本
                trajectory_samples = self.generate_trajectory_samples(
                    task_id,
                    api_client=api_client,
                    vllm_engine=vllm_engine,
                    rollout_idx=rollout_idx,
                    **generate_kwargs
                )

                return trajectory_samples, api_client_idx, task_id, vllm_engine_idx
            except Exception as e:
                logger.error(f"Error processing task {task_id} with API client {api_client_idx}: {e}")
                return None, api_client_idx, task_id, vllm_engine_idx

        # 建立追踪结构
        active_refs = {}  # {ref: (api_client_idx, task_id, vllm_engine_idx)}
        all_samples = []

        # 主循环：等待任何任务完成，然后启动新任务
        while pending_tasks or active_refs:
            current_time = time.time()

            # 1. 更新API状态
            for api_idx, status in api_status.items():
                if not status['disabled'] and not status['available'] and current_time >= status['available_time']:
                    api_status[api_idx]['available'] = True
                    #print(f"API client {api_idx} is now available after cooldown")

            # 2. 获取所有可用API索引
            available_api_indices = [
                idx for idx, status in api_status.items()
                if status['available'] and not status['disabled'] and idx not in [v[0] for v in active_refs.values()]
            ]
            #print(f"Available API indices length is: {len(available_api_indices)}")

            # 3. 启动尽可能多的新任务
            while available_api_indices and pending_tasks:
                # 选择无负载的vLLM引擎
                vllm_idx = get_free_engine()
                if vllm_idx is None:
                    break

                api_idx = available_api_indices.pop(0)
                (task_id, rollout_idx)= pending_tasks.pop(0)

                # # 选择负载最小的vLLM引擎
                # vllm_idx = get_least_loaded_engine()
                vllm_load_counter[vllm_idx] = 1

                ref = process_single_trajectory.remote(task_id, api_idx, vllm_idx, rollout_idx )
                active_refs[ref] = (api_idx, task_id, vllm_idx)

                # 标记API为正在使用
                api_status[api_idx]['available'] = False

                #print(f"vLLM engine loads: {vllm_load_counter}")

            # 4. 检查是否所有API都被禁用了
            all_apis_disabled = all(status['disabled'] for _, status in api_status.items())
            if all_apis_disabled and pending_tasks:
                logger.error("All API clients have been disabled due to failures. Cannot complete remaining tasks.")
                break

            # 5. 如果没有活动任务，但有待处理任务，等待API冷却
            if not active_refs and pending_tasks:
                # 找出下一个将变为可用的API
                next_available_times = [status['available_time'] for api_idx, status in api_status.items()
                                        if not status['disabled'] and not status['available']]

                if next_available_times:
                    next_available_time = min(next_available_times)
                    wait_time = max(0, next_available_time - current_time)
                    if wait_time > 0:
                        #print(f"Waiting {wait_time:.2f}s for API cooldown...")
                        time.sleep(wait_time)
                else:
                    logger.error("No available API clients but there are pending tasks. Possible logic error.")
                    break

                continue

            # 如果没有活动引用和待处理任务，跳出循环
            print(f"Active tasks: {len(active_refs)}, Pending tasks: {len(pending_tasks)}")
            if not active_refs and not pending_tasks:
                break

            # 6. 等待任何活动任务完成
            done_refs, pending_refs = ray.wait(list(active_refs.keys()), num_returns=1, timeout=5)

            # 如果没有任务完成，继续循环
            if not done_refs:
                continue

            # 7. 处理完成的任务
            for ref in done_refs:
                samples, api_idx, task_id, vllm_idx = ray.get(ref)
                api_client_idx, _, engine_idx = active_refs.pop(ref)

                # 更新vLLM引擎负载计数
                vllm_load_counter[engine_idx] = 0

                # 处理任务结果
                if samples is not None:
                    # 任务成功，重置API失败计数
                    api_status[api_idx]['failure_count'] = 0
                    api_status[api_idx]['available'] = True
                    all_samples.extend(samples)
                    #print(f"Successfully generated samples for task {task_id} with API client {api_idx}")
                else:
                    # 任务失败，增加API失败计数
                    api_status[api_idx]['failure_count'] += 1

                    # 检查是否达到最大失败次数
                    if api_status[api_idx]['failure_count'] >= max_api_failures:
                        # 标记API为永久不可用
                        api_status[api_idx]['disabled'] = True
                        api_status[api_idx]['available'] = False
                        logger.warning(f"API client {api_idx} has been disabled after {max_api_failures} failures")
                    else:
                        # 设置API冷却期
                        api_status[api_idx]['available'] = False
                        api_status[api_idx]['available_time'] = time.time() + api_cooldown

                    # 将任务重新加入队列
                    pending_tasks.append((task_id, rollout_idx))
                    print(f"Task {task_id} failed with API client {api_idx}. Adding back to task queue.")

            #print(f"vLLM engine loads: {vllm_load_counter}")

        # vLLM offload when vllm_enable_sleep
        if args.vllm_enable_sleep:
            from openrlhf.trainer.ray.vllm_engine import batch_vllm_engine_call
            batch_vllm_engine_call(self.vllm_engines, "sleep")

        print(f"Completed {len(all_samples)} samples from {len(task_ids) * n_samples_per_task} requested samples")

        # 如果没有样本，返回空列表
        if not all_samples:
            return []
        
        # [新增] 确保样本数能被8整除
        # opt_batch_size = find_best_batch_size(total_samples=len(all_samples), original_batch_size=args.micro_rollout_batch_size, num_gpus=actor_gpus_num)
        # print(f"Optimal batch size is: {opt_batch_size}, orginal batch size is: {args.micro_rollout_batch_size}")

        remainder = len(all_samples) % 8
        if remainder != 0:
            print(f"Sample count {len(all_samples)} is not divisible by 8, removing {remainder} samples")
            
            # 收集错误样本和think too long样本的索引
            error_indices = []
            normal_indices = []
            
            for i, sample in enumerate(all_samples):
                if sample.turn_idx == 0:
                    # 只处理第一轮样本
                    if sample.env_feedback and sample.env_feedback.get('turn_reward', 0) < 0:
                        # 负奖励的样本（错误或think too long）
                        error_indices.append(i)
                    else:
                        normal_indices.append(i)
            
            print(f"Error samples count: {len(error_indices)}, Normal samples count: {len(normal_indices)}")

            # 优先移除错误样本
            indices_to_remove = []
            if len(error_indices) >= remainder:
                # 如果错误样本足够多，随机选择要移除的错误样本

                indices_to_remove = random.sample(error_indices, remainder)
            else:
                # 如果错误样本不够，先移除所有错误样本，再随机选择正常样本
                indices_to_remove = error_indices.copy()
                remaining_to_remove = remainder - len(error_indices)
                if remaining_to_remove > 0 and normal_indices:
                    indices_to_remove.extend(random.sample(normal_indices, remaining_to_remove))
            
            # 移除选中的样本
            indices_to_remove.sort(reverse=True)  # 从后往前删除，避免索引变化
            for idx in indices_to_remove:
                all_samples.pop(idx)
            
            print(f"After removal: {len(all_samples)} samples")
            print(f"Removed indices: {indices_to_remove}, error count: {len(error_indices)}, normal count: {len(normal_indices)}")

        # Make experiences (models forward: logprobs, values, rewards, and kl divergence)
        experiences = self.make_experience(all_samples)

        # Process experiences (reward shaping, etc.)
        experiences = self.compute_advantages_and_returns(experiences)
        return experiences

    @torch.no_grad()
    def make_experience(self, samples_list: List[Samples]) -> List[Experience]:
        """
        Turn samples into experience by calculating logprobs, values, rewards, and kl divergence.
        Compatible with the original experience_maker.
        """
        start_time = time.time()
        logger.info(f"🚀 Starting experience making with {len(samples_list)} samples")

        args = self.strategy.args
        device = "cpu"
        experiences = []

        over_all_prompt_parts = []
        over_all_response_parts = []
        for sample in samples_list:
            prompt_len = sample.prompt_len
            full_seq = sample.sequences[0]
            
            # 分离prompt和response
            prompt_part = full_seq[:prompt_len]
            response_part = full_seq[prompt_len:]
            
            over_all_prompt_parts.append(prompt_part)
            over_all_response_parts.append(response_part)
        
        # 计算批次内最大长度
        over_max_prompt_len = max(len(p) for p in over_all_prompt_parts)
        over_max_response_len = max(len(r) for r in over_all_response_parts)

        actor_gpus_num = self.args.actor_num_nodes*self.args.actor_num_gpus_per_node 
        #print(f"actor_gpus_num: {actor_gpus_num}")
        opt_batch_size = find_best_batch_size(total_samples=len(samples_list), original_batch_size=args.micro_rollout_batch_size, num_gpus=actor_gpus_num)
        #print(f"Optimal batch size is: {opt_batch_size}, orginal batch size is: {args.micro_rollout_batch_size}")


        # 按turn_idx分组样本，相同轮次的样本放在一起处理
        # turn_groups = {}
        # for sample in samples_list:
        #     turn_idx = sample.turn_idx if sample.turn_idx is not None else 0
        #     if turn_idx not in turn_groups:
        #         turn_groups[turn_idx] = []
        #     turn_groups[turn_idx].append(sample)
        
        # 处理每个turn组，生成批处理样本
        pad_token_id = self.tokenizer.pad_token_id
        eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
       
        # # 按turn_idx分组样本
        turn_groups = {}
        each_turn_idx_num = {}
        for sample in samples_list:
            turn_idx = sample.turn_idx if sample.turn_idx is not None else 0
            if turn_idx not in turn_groups:
                turn_groups[turn_idx] = []
            turn_groups[turn_idx].append(sample)
    
        for turn_idx, turn_samples in turn_groups.items():
            each_turn_idx_num[turn_idx] = len(turn_samples)
            print(f'Processing turn number is {turn_idx} with {len(turn_samples)} samples')
        #each_turn_idx_num['all'] = len(samples_list)
        
        # 处理每个turn组
        sequences_list = []
        attention_mask_list = []
        action_mask_list = []
        all_task_ids = []
        all_turn_indices = []
        all_trajectory_infos = []
        all_env_feedbacks = []
        # 计算响应长度和总长度
        all_response_lengths = []
        all_total_lengths = []
        batch_info = []
        str_info = []


        samples_list.sort(key=lambda x: (x.turn_idx if x.turn_idx is not None else 0))

        for i in range(0, len(samples_list), opt_batch_size):
            batch_samples = samples_list[i:i + opt_batch_size]
            
            # 分离每个样本的prompt和response部分
            batch_prompt_parts = []
            batch_response_parts = []
            
            for sample in batch_samples:
                prompt_len = sample.prompt_len
                full_seq = sample.sequences[0]
                
                # 分离prompt和response
                prompt_part = full_seq[:prompt_len]
                response_part = full_seq[prompt_len:]
                
                batch_prompt_parts.append(prompt_part)
                batch_response_parts.append(response_part)
            
            # 计算批次内最大长度
            # batch_max_prompt_len = max(len(p) for p in batch_prompt_parts)
            # batch_max_response_len = max(len(r) for r in batch_response_parts)
            
            batch_max_prompt_len = over_max_prompt_len
            batch_max_response_len = over_max_response_len
            
            # 进行padding和拼接
            padded_sequences = []
            for prompt_part, response_part in zip(batch_prompt_parts, batch_response_parts):
                # 左padding: prompt部分
                prompt_len = len(prompt_part)
                padded_prompt = torch.cat([
                    torch.full((batch_max_prompt_len - prompt_len,), pad_token_id, dtype=torch.long),
                    prompt_part
                ])
                
                # 右padding: response部分
                response_len = len(response_part)
                padded_response = torch.cat([
                    response_part,
                    torch.full((batch_max_response_len - response_len,), pad_token_id, dtype=torch.long)
                ])
                
                # 拼接
                padded_sequences.append(torch.cat([padded_prompt, padded_response]))
            
            # 转换为批次张量
            sequences = torch.stack(padded_sequences)
            
            # 使用process_sequences处理序列
            sequences, attention_mask, action_mask = process_sequences(
                sequences, batch_max_prompt_len, eos_token_id, pad_token_id
            )

            # 添加到全局列表
            #print(f"sequences shape: {sequences.shape}, attention_mask shape: {attention_mask.shape}, action_mask shape: {action_mask.shape}")
            sequences_list.append(sequences)
            attention_mask_list.append(attention_mask)
            action_mask_list.append(action_mask)

            response_length = action_mask.float().sum(dim=-1)
            total_length = attention_mask.float().sum(dim=-1)

            all_response_lengths.append(response_length)
            all_total_lengths.append(total_length)

            batch_info_item={}
            batch_info_item['task_id'] = [s.task_id for s in batch_samples]
            batch_info_item['turn_idx'] = [s.turn_idx for s in batch_samples]
            batch_info_item['trajectory_info'] = [s.trajectory_info for s in batch_samples]
            batch_info_item['env_feedback'] = [s.env_feedback for s in batch_samples]

            batch_info.append(batch_info_item)

            
            # 收集元数据
            all_task_ids.extend([s.task_id for s in batch_samples])
            all_turn_indices.extend([s.turn_idx for s in batch_samples])
            all_trajectory_infos.extend([s.trajectory_info for s in batch_samples])
            all_env_feedbacks.extend([s.env_feedback for s in batch_samples])
            str_info.extend([s.str_info for s in batch_samples])  # 如果需要收集字符串信息
            # 处理每个turn组，生成批处理数据
        # for turn_idx, turn_samples in turn_groups.items():
        #     # 按micro_rollout_batch_size分批
        #     print(f'Processing turn number is {turn_idx} with {len(turn_samples)} samples')
        #     # 处理每个turn组，生成批处理数据
        #     for i in range(0, len(turn_samples), args.micro_rollout_batch_size):
        #         batch_samples = turn_samples[i:i + args.micro_rollout_batch_size]
                
        #         # 分离每个样本的prompt和response部分
        #         batch_prompt_parts = []
        #         batch_response_parts = []
                
        #         for sample in batch_samples:
        #             prompt_len = sample.prompt_len
        #             full_seq = sample.sequences[0]
                    
        #             # 分离prompt和response
        #             prompt_part = full_seq[:prompt_len]
        #             response_part = full_seq[prompt_len:]
                    
        #             batch_prompt_parts.append(prompt_part)
        #             batch_response_parts.append(response_part)
                
        #         # 计算批次内最大长度
        #         batch_max_prompt_len = max(len(p) for p in batch_prompt_parts)
        #         batch_max_response_len = max(len(r) for r in batch_response_parts)
        #         # batch_max_prompt_len =over_max_prompt_len
        #         # batch_max_response_len = over_max_response_len
                
        #         # 进行padding和拼接
        #         padded_sequences = []
        #         for prompt_part, response_part in zip(batch_prompt_parts, batch_response_parts):
        #             # 左padding: prompt部分
        #             prompt_len = len(prompt_part)
        #             padded_prompt = torch.cat([
        #                 torch.full((batch_max_prompt_len - prompt_len,), pad_token_id, dtype=torch.long),
        #                 prompt_part
        #             ])
                    
        #             # 右padding: response部分
        #             response_len = len(response_part)
        #             padded_response = torch.cat([
        #                 response_part,
        #                 torch.full((batch_max_response_len - response_len,), pad_token_id, dtype=torch.long)
        #             ])
                    
        #             # 拼接
        #             padded_sequences.append(torch.cat([padded_prompt, padded_response]))
                
        #         # 转换为批次张量
        #         sequences = torch.stack(padded_sequences)
                
        #         # 使用process_sequences处理序列
        #         sequences, attention_mask, action_mask = process_sequences(
        #             sequences, batch_max_prompt_len, eos_token_id, pad_token_id
        #         )

        #         # 添加到全局列表
        #         sequences_list.append(sequences)
        #         attention_mask_list.append(attention_mask)
        #         action_mask_list.append(action_mask)

        #         response_length = action_mask.float().sum(dim=-1)
        #         total_length = attention_mask.float().sum(dim=-1)

        #         all_response_lengths.append(response_length)
        #         all_total_lengths.append(total_length)

        #         batch_info_item={}
        #         batch_info_item['task_id'] = [s.task_id for s in batch_samples]
        #         batch_info_item['turn_idx'] = [s.turn_idx for s in batch_samples]
        #         batch_info_item['trajectory_info'] = [s.trajectory_info for s in batch_samples]
        #         batch_info_item['env_feedback'] = [s.env_feedback for s in batch_samples]

        #         batch_info.append(batch_info_item)

                
        #         # 收集元数据
        #         all_task_ids.extend([s.task_id for s in batch_samples])
        #         all_turn_indices.extend([s.turn_idx for s in batch_samples])
        #         all_trajectory_infos.extend([s.trajectory_info for s in batch_samples])
        #         all_env_feedbacks.extend([s.env_feedback for s in batch_samples])


        print(f'len of sequences_list: {len(sequences_list)}')

        assert (
            len(samples_list)
            == len(all_task_ids)
            == len(all_turn_indices)
            == len(all_trajectory_infos)
            == len(all_env_feedbacks)
        ), f"len(samples_list): {len(samples_list)}, len(all_task_ids): {len(all_task_ids)}, len(all_turn_indices): {len(all_turn_indices)}, len(all_trajectory_infos): {len(all_trajectory_infos)}, len(all_env_feedbacks): {len(all_env_feedbacks)}"

        assert (
            len(sequences_list)
            == len(attention_mask_list)
            == len(action_mask_list)
            == len(all_response_lengths)
            == len(all_total_lengths)
            == len(batch_info)
        ), f"len(sequences_list): {len(sequences_list)}, len(attention_mask_list): {len(attention_mask_list)}, len(action_mask_list): {len(action_mask_list)}, len(all_response_lengths): {len(all_response_lengths)}, len(all_total_lengths): {len(all_total_lengths)},len(batch_info): {len(batch_info)}"



        # Batch call reward model
        r_refs = None
        if not self.remote_rm_url:
            print('**************************** error we do not have a reward modle !!!!!!!!!!!!!!! ***********************')
        else:
            if self.custom_reward_func:
                # 准备奖励计算的元数据
                reward_metadata = {
                    'task_ids': all_task_ids,
                    'turn_indices': all_turn_indices,
                    'trajectory_infos': all_trajectory_infos,
                    'env_feedbacks': all_env_feedbacks,
                    'str_info': str_info,
                    # 'sequences': sequences_list,
                    # 'attention_masks': attention_mask_list,
                    # 'action_masks': action_mask_list,
                    # 'response_lengths': all_response_lengths,
                    # 'total_lengths': all_total_lengths,
                    #'batch_size': args.micro_rollout_batch_size,
                    #'each_turn_idx_num': each_turn_idx_num
                }
    
                r_refs = []
                r = self.custom_reward_func.remote(
                    reward_metadata
                )
                r_refs.append(r)

            else:
                print('**************************** error we do not have a reward modle !!!!!!!!!!!!!!! ***********************')


        # # Sync to avoid GPU OOM when colocate models
        # if args.colocate_all_models and not self.remote_rm_url:
        #     ray.get(r_refs)
        #     ray.get(self.reward_model_group.async_run_method(method_name="empty_cache"))

        # 处理奖励
        rewards_list = ray.get(r_refs)
        if self.remote_rm_url is None:
            rewards_list = sum(rewards_list[::duplicate_factor], [])
        else:
            rewards_list = torch.cat(rewards_list, dim=0).chunk(len(sequences_list))

        # Batch call actor model
        print("actor model group is not None")
        action_log_probs_ref = self.actor_model_group.async_run_method_batch(
            method_name="forward",
            sequences=sequences_list,
            action_mask=action_mask_list,
            attention_mask=attention_mask_list,
        )


        # Sync to avoid GPU OOM when colocate models
        if args.colocate_all_models or args.colocate_actor_ref:
            ray.get(action_log_probs_ref)
            ray.get(self.actor_model_group.async_run_method(method_name="empty_cache"))

        # Batch call critic model
        if self.critic_model_group is not None:
            print("critic model group is not None")
            # if args.colocate_critic_reward and not self.remote_rm_url:
            #     ray.get(r_refs)
            #     ray.get(self.reward_model_group.async_run_method(method_name="empty_cache"))

            value_ref = self.critic_model_group.async_run_method_batch(
                method_name="forward",
                sequences=sequences_list,
                action_mask=action_mask_list,
                attention_mask=attention_mask_list,
            )
            if args.colocate_all_models or args.colocate_critic_reward:
                ray.get(value_ref)
                ray.get(self.critic_model_group.async_run_method(method_name="empty_cache"))
        else:
            value_ref = ray.put([[None]] * (len(sequences_list) * args.ring_attn_size * args.ds_tensor_parallel_size))

        # Batch call initial model
        if self.initial_model_group is not None:
            print("initial model group is not None")
            base_action_log_probs_ref = self.initial_model_group.async_run_method_batch(
                method_name="forward",
                sequences=sequences_list,
                action_mask=action_mask_list,
                attention_mask=attention_mask_list,
            )

            if args.colocate_all_models or args.colocate_actor_ref:
                ray.get(base_action_log_probs_ref)
                ray.get(self.initial_model_group.async_run_method(method_name="empty_cache"))
        else:
            base_action_log_probs_ref = ray.put(
                [[None]] * (len(sequences_list)* args.ring_attn_size * args.ds_tensor_parallel_size)
            )

        # Wait for all remote calls to complete and flatten the results
        duplicate_factor = args.ring_attn_size * args.ds_tensor_parallel_size
        action_log_probs_list = sum(ray.get(action_log_probs_ref)[::duplicate_factor], [])
        base_action_log_probs_list = sum(ray.get(base_action_log_probs_ref)[::duplicate_factor], [])
        value_list = sum(ray.get(value_ref)[::duplicate_factor], [])



        assert (
            len(sequences_list)
            == len(action_mask_list)
            == len(attention_mask_list)
            == len(all_response_lengths)
            == len(all_total_lengths)
            == len(action_log_probs_list)
            == len(base_action_log_probs_list)
            == len(value_list)
            == len(rewards_list)
        ), f"len(samples_list): {len(samples_list)},len(attention_mask_list): {len(attention_mask_list)},,len(sequences_list): {len(sequences_list)},len(action_mask_list): {len(action_mask_list)},len(action_log_probs_list): {len(action_log_probs_list)}, len(base_action_log_probs_list): {len(base_action_log_probs_list)}, len(value_list): {len(value_list)}, len(rewards_list): {len(rewards_list)}, len(all_response_lengths): {len(all_response_lengths)}, len(all_total_lengths): {len(all_total_lengths)}"


        # Process results for each sample
        for i, (batch_info_it,response_lengths,total_lengths,sequences,action_mask,attention_mask,action_log_probs, base_action_log_probs, value, rewards) in enumerate(
                zip(batch_info,all_response_lengths,all_total_lengths,sequences_list,action_mask_list,attention_mask_list,action_log_probs_list, base_action_log_probs_list, value_list, rewards_list)
        ):
            if (self.initial_model_group is not None) and (not args.use_kl_loss):
                kl = compute_approx_kl(
                    action_log_probs,
                    base_action_log_probs,
                    kl_estimator=self.strategy.args.kl_estimator,
                )
            else:
                kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=device)
            kl_mean = masked_mean(kl, action_mask, dim=-1)


            if not args.use_kl_loss:
                base_action_log_probs = None

            info = {
                "kl": kl_mean.float(),
                "reward": rewards.float(),
                "response_length": response_lengths.float(),
                "total_length": total_lengths.float()
            }

            experience = Experience(
                sequences,
                action_log_probs,
                base_action_log_probs,
                value,
                None,
                None,
                attention_mask,
                action_mask,
                info,
                kl,
                batch_info_it
            )

            experiences.append(experience)

        end_time = time.time()
        duration = end_time - start_time
        time_str = str(timedelta(seconds=duration)).split(".")[0]
        logger.info(f"✨ Experience making completed in {time_str}")
        return experiences

    @torch.no_grad()
    def compute_advantages_and_returns(
            self, experiences: List[Experience], **kwargs
    ) -> List[Experience]:
        """
        Process experiences, this can be used to filter out some experiences or do some processing on the rewards.
        """
        args = self.strategy.args

        # get rewards from experiences
        rewards = [experience.info["reward"] for experience in experiences]

        # # reward shaping
        # if args.advantage_estimator == "rloo":
        #     rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
        #     baseline = (rewards.sum(-1, keepdim=True) - rewards) / (args.n_samples_per_prompt - 1)
        #     rewards = rewards - baseline
        #     rewards = rewards.reshape(-1).chunk(len(experiences))
        # elif args.advantage_estimator in ["reinforce_baseline", "dr_grpo"]:
        #     rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
        #     rewards = rewards - rewards.mean(-1, keepdim=True)
        #     rewards = rewards.reshape(-1).chunk(len(experiences))
        # elif args.advantage_estimator == "group_norm":
        #     rewards = torch.cat(rewards).reshape(-1, args.n_samples_per_prompt)
        #     rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9)
        #     rewards = rewards.reshape(-1).chunk(len(experiences))

        # calculate return and advantages

        for experience, reward in zip(experiences, rewards):
            reward = compute_reward(
                reward,
                self.kl_ctl.value,
                experience.kl,
                action_mask=experience.action_mask,
                reward_clip_range=args.reward_clip_range,
            )

            if self.advantage_estimator == "gae":
                experience.advantages, experience.returns = self.get_advantages_and_returns(
                    experience.values,
                    reward,
                    experience.action_mask,
                    args.gamma,
                    args.lambd,
                )
            elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm", "dr_grpo"]:
                if args.gamma != 1.0 and self.advantage_estimator in [
                    "rloo",
                    "reinforce_baseline",
                    "group_norm",
                    "dr_grpo",
                ]:
                    logger.warning("gamma is set to 1.0 for rloo, reinforce_baseline, and group_norm")
                    args.gamma = 1.0

                experience.returns = self.get_cumulative_returns(
                    reward,
                    experience.action_mask,
                    args.gamma,
                )
                experience.advantages = deepcopy(experience.returns)
            else:
                raise Exception(f"Unkown advantage_estimator {self.advantage_estimator}")

            # calculate the return info.
            return_sums = reward.sum(dim=-1)
            experience.info["return"] = return_sums
            # remove unnecessary info
            experience.kl = None

        # Normalize advantages across all experiences
        if self.args.advantage_estimator not in ["group_norm", "dr_grpo"]:
            all_advantages = []
            all_action_masks = []
            for exp in experiences:
                all_advantages.append(exp.advantages)
                all_action_masks.append(exp.action_mask)

            advantages_vector = zero_pad_sequences(all_advantages).float().flatten()
            action_masks_vector = zero_pad_sequences(all_action_masks).flatten()
            num_actions = action_masks_vector.sum()

            # mean
            mean = (advantages_vector * action_masks_vector).sum() / num_actions
            # std
            if not self.args.no_advantage_std_norm:
                var = ((advantages_vector - mean).pow(2) * action_masks_vector).sum() / num_actions
                rstd = var.clamp(min=1e-8).rsqrt()
            else:
                rstd = 1

            # Apply normalization to each experience
            for exp in experiences:
                exp.advantages = (exp.advantages - mean) * rstd

        return experiences

    @torch.no_grad()
    def get_advantages_and_returns(
            self,
            values: torch.Tensor,
            rewards: torch.Tensor,
            action_mask: torch.Tensor,
            gamma: float,
            lambd: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Function that computes advantages and returns from rewards and values."""
        lastgaelam = 0
        advantages_reversed = []
        response_length = rewards.size(1)

        # Mask invalid responses
        if action_mask is not None:
            values = action_mask * values
            rewards = action_mask * rewards

        for t in reversed(range(response_length)):
            nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
            delta = rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lambd * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values
        return advantages.detach(), returns

    @torch.no_grad()
    def get_cumulative_returns(
            self,
            rewards: torch.Tensor,
            action_mask: torch.Tensor,
            gamma: float,
    ) -> torch.Tensor:
        """
        Function that computes cumulative returns from rewards using REINFORCE.
        """
        response_length = rewards.size(1)
        returns = torch.zeros_like(rewards)
        cumulative_return = torch.zeros(rewards.size(0), device=rewards.device)

        # Mask invalid responses if action_mask is provided
        if action_mask is not None:
            rewards = action_mask * rewards

        # Calculate returns by accumulating discounted rewards
        for t in reversed(range(response_length)):
            cumulative_return = rewards[:, t] + gamma * cumulative_return
            returns[:, t] = cumulative_return

        return returns

    # [新增] 处理完整轨迹生成单轮样本
    def generate_trajectory_samples(self, task_id: str, rollout_idx , api_client=None, vllm_engine=None, **kwargs) -> List[Samples]:
        """生成一个完整的对话轨迹，并将其拆分为多个单轮对话样本"""
        if self.vllm_engines is None:
            raise NotImplementedError("HF generation is not implemented for multi-turn dialogue")

        # 确保有API客户端和vLLM引擎
        if api_client is None or vllm_engine is None:
            logger.error("API client and vLLM engine must be provided")
            return None

        # 生成完整轨迹
        result = self._generate_vllm_trajectory(task_id, api_client, vllm_engine, rollout_idx, **kwargs)
        if result is None:
            return None

        # 解析轨迹结果
        turns, task_completed = result

        # 创建轨迹元数据
        trajectory_info = {
            'task_id': task_id,
            'task_completed': task_completed,
            'total_turns': len(turns),
            'rollout_idx': rollout_idx
        }

        # 将轨迹拆分为单轮样本
        samples = []
        for turn_idx, turn in enumerate(turns):
            # 创建单轮样本
            sample = self._create_turn_sample(task_id, turn_idx, turn, trajectory_info)
            samples.append(sample)
        
        print(f"**********************we have samples: {len(samples)}")

        return samples

    def compute_env_feedback(self, ai_response,code,env_feedback):
        """计算环境反馈分数

        Args:
            ai_response: AI模型的响应（包含think内容）
            env_feedback: 环境返回的执行结果字符串

        Returns:
            tuple: (reward_score, think_too_long)
        """
        # 1. 检查think标记完整性
        if '<think>' in ai_response and '</think>' not in ai_response:
            # think过长，返回负分和终止标记
            return -5.0, True
        # if 'apis.supervisor.complete_task(status="fail")' in ai_response:
        #     return -5.0, True  # 禁止的操作，返回负分和终止标记

        # 2. 解析环境反馈
        # 失败情况的判断（这些关键词只在失败时出现）
        failure_keywords = [
        "Execution failed",
        "Traceback:",
        "SyntaxError",
        "Exception",
        "Error:",
        "Maximum number of executions",
        "timed out after"
        ]
                # 检查是否包含失败关键词
        for keyword in failure_keywords:
            if keyword in env_feedback:
                return -1.0, False  # 其他执行错误

        # 特殊情况
        if "No code available to execute" in env_feedback:
            return -1.0, False  # 没有代码执行，较重的惩罚

        if code.strip() == "":
            return -1.0, False  # 没有代码执行，较重的惩罚
        
        # 3. 检查是否有 appworld 交互
        # 3.1 env_feedback 明确反馈与 appworld 相关的API操作（可以扩展关键字）
        # appworld_feedback_keywords = [
        #     "apis.",     # 直接展示API调用结果
        #     "Task completed", "Event created", "Message sent", "User profile", 'app_name','account_name','access_token'
        #     # 可以根据实际交互扩展更多
        # ]
        # 3.2 ai_response 中是否有 apis./world.
        appworld_keywords = ["apis."]
        # unrelated_outputs = [
        #     "1\n", "2\n", "3\n", "hello", "test", "[]", "None"
        #     # 还可扩展一些无关内容
        # ]
        has_appworld_in_ai = any(kw in code for kw in appworld_keywords)
        # has_appworld_in_feedback = any(kw in env_feedback for kw in appworld_feedback_keywords)
        # has_unrelated_output = any(kw in env_feedback for kw in unrelated_outputs)

        # 3.3 response提供了明确的 appworld 交互证据
        if has_appworld_in_ai:
            return 1.0, False
        else:
            return -1.0, False  # 没有 appworld 交互，返回负分
    #     # else:
    #     #     print(f"ai_response: \n {ai_response} \n and env_feedback: \n {env_feedback}")
    #     #     return -1.0, False
    # # 4. 明确无关 appworld 的情况（ai_response 无 apis.、无 world.，env_feedback 只是普通输出，如数字、字符串）
    #     if not has_appworld_in_ai and has_unrelated_output:
    #         print(f"ai_response: \n {ai_response} \n and env_feedback: \n {env_feedback}")
    #         return -1.0, False
        
    #     if not has_appworld_in_ai and env_feedback.strip() == "Execution successful.":
    #         print(f"ai_response: \n {ai_response} \n and env_feedback: \n {env_feedback}")
    #         return -1.0, False  # 执行成功但无实际交互
        
    #     print(f"ai_response: \n {ai_response} \n and env_feedback: \n {env_feedback}")

    #     # 5. 其他未能判定的情况，返回0分
    #     return 0.0, False

    def _generate_vllm_trajectory(self, task_id: str, api_client, vllm_engine, rollout_idx, **kwargs) -> Tuple[List[Turn], bool]:
        """使用vLLM生成完整的对话轨迹"""
        max_turns = kwargs.get("max_turns", self.max_turns)
        task_completed = False

        try:
            # 1. 初始化任务并获取任务描述
            task_info = api_client.initialize_task(task_id)
            task_info = task_info['output']
            #print('task_info: ',task_info)

            # 从任务信息中提取数据
            supervisor = task_info.get("supervisor", "")
            instruction = task_info.get("instruction", "")
            dict_task_init = {"supervisor": supervisor, "instruction": instruction}

            if use_react_prompt:
                # 使用REACT模板
                prompt_template_use = prompt_react
            else:
                # 使用默认模板
                prompt_template_use = prompt_template 

            # 使用模板生成初始提示
            try:
                #react模板成功渲染
                initial_prompt = Template(prompt_template_use).render(dict_task_init)
            except Exception as e:
                logger.warning(f"Error rendering prompt template: {e}")
                initial_prompt = f"Supervisor: {supervisor}\nInstruction: {instruction}"

            # 2. 准备对话
            turns = []
            conversation_history = []
            initial_history = prompt_messages(initial_prompt)
            #print(f"*************************Initial prompt: {initial_history[-1]}")

            # 添加初始提示作为用户输入
            conversation_history.extend(initial_history)

            # 添加系统消息
            conversation_history.append({
                "role": "system",
                "content": 'Your think length should **not be longger than 1024 tokens.**'
            })


            # 格式化初始提示
            formatted_prompt = self._format_conversation(conversation_history)

            # 记录用户输入
            prompt_tokens = torch.tensor(
                self.tokenize_fn([formatted_prompt[:-22]], self.prompt_max_len, padding=False)["input_ids"][0]
            )

            # 当前用户输入
            current_input = formatted_prompt[:-22]

            # 3. 开始多轮对话
            #print(f'*****************task: {task_id} in rollout idx: {rollout_idx} begin generate response ******************')
            for turn_idx in range(max_turns):
                #print(f"Task {task_id}: Turn {turn_idx + 1}/{max_turns}, history length is:{len(conversation_history)}")

                # 检查提示长度
                prompt_length_check = self._format_conversation(conversation_history)
                prompt_token_ids = self.tokenize_fn([prompt_length_check], self.prompt_max_len, padding=False)["input_ids"][0]
                prompt_token_ids_length = len(torch.tensor(prompt_token_ids))
                #print(f'in turn idx {turn_idx} length of prompt token ids: {prompt_token_ids_length}')
                if prompt_token_ids_length > self.prompt_max_len - 20:
                    logger.warning(
                        f"Prompt length {prompt_token_ids_length} exceeds max allowed length. Stopping dialogue.")
                    break

                # 使用vLLM生成响应
                vllm_response, _, current_prompt_token_ids = self._generate_response_with_vllm(
                    conversation_history,
                    vllm_engine=vllm_engine,
                    task_id=task_id,
                    turn_idx=turn_idx,
                    **kwargs
                )

                vllm_response = vllm_response[0]
                ai_response = vllm_response.outputs[0].text
                #ai_response是llm的完整回答，包含think部分和最后的回答
                #clean_response是去除think部分的"干净"回复
                #code是用于执行的代码，是干净的代码
                #用于执行的是干净的代码：code，记录于history的是think之后的回答:clean_response，普通prompt要求llm直接输出code，而react要求llm输出的是python代码块


                # 清除think部分获取"干净"的回复
                # 如果不是think模型，这个函数也会返回完整的原来内容
                clean_response = remove_think_blocks(ai_response)
                
                # 添加AI响应到对话历史（存储不含think的干净回复，便于后续对话）
                conversation_history.append({"role": "assistant", "content": clean_response})

                # 发送响应到AppWorld并获取下一个输入
                if use_react_prompt:
                    # 如果使用REACT模板，提取代码块
                    # 代码提取功能测试成功
                    code_block, _ = extract_code_and_fix_content(clean_response)
                    if code_block.strip() == "":
                        print(f"Warning: No code block found in response for task {task_id} at turn {turn_idx}. Using clean response.")
                    code = code_block
                else:
                    code = clean_response  # 用于执行的是干净的代码
                
                next_task_output = api_client.execute(task_id, code)
                next_user_input = next_task_output["output"]
                #获取回复，计算turn级别的奖励
                # 如果不是think模型，不包含任何think内容，think_too_long会是False
                env_reward, think_too_long = self.compute_env_feedback(ai_response,code,next_user_input)

                # 记录本轮对话
                if vllm_response.outputs[0].token_ids is None:
                    raise ValueError(f"Response output token IDs are None in turn {turn_idx} for task {task_id}")

                response_tokens = torch.tensor(vllm_response.outputs[0].token_ids)

                env_feedback={'turn_reward':env_reward,'think_too_long':think_too_long}
                turn = Turn(
                    turn_idx=turn_idx,
                    rollout_idx=rollout_idx,
                    prompt=current_input,
                    response=ai_response,  # 原始回复，包含think
                    clean_response=code,  # 纯code部分
                    feedback=next_user_input,
                    prompt_tokens=prompt_tokens,
                    response_tokens=response_tokens,
                    full_prompt_tokens=torch.tensor(current_prompt_token_ids),  # 保存完整提示tokens
                    env_feedback=env_feedback,
                )
                turns.append(turn)

                # if think_too_long:
                #     break

                # 检查任务是否完成
                task_completed_response = api_client.check_task_completed(task_id)
                task_completed = task_completed_response['output']
                if task_completed:
                    logger.info(f"Task {task_id}: Task completed.")
                    break

                # 准备下一轮对话
                current_input = next_user_input
                if current_input is None:
                    raise ValueError(f"Next user input is None in turn {turn_idx} for task {task_id}")

                prompt_tokens = torch.tensor(
                    self.tokenize_fn([current_input], self.prompt_max_len, padding=False)["input_ids"][0]
                )
                conversation_history.append({"role": "user", "content": current_input})

            #新增：检测任务的各个状态
            success, status = api_client.check_task_success(task_id)
            #print(type(success))
            print(f"Task {task_id} in rollout idx {rollout_idx} ,success status is {success}, and tsak_completed status is {task_completed}")
            

            # 4. 关闭任务
            api_client.close_task(task_id)

            # 5. 记录对话历史
            #current_time = time.time()
            save_name = f'{task_id}_{rollout_idx}.json'
            save_w(self.log_sub_dir, save_name, {
                "task_id": task_id,
                "conversation": conversation_history,
                "success": success,
                'evaluation': status
            })

            # 只有在有对话轮次的情况下才返回
            if not turns:
                logger.warning(f"Task {task_id}: No dialogue turns collected.")
                return None

            return turns, success

        except Exception as e:
            logger.error(f"Error generating trajectory for task {task_id}: {str(e)}")
            try:
                api_client.close_task(task_id)
            except:
                pass
            return None

    def _generate_response_with_vllm(self, conversation_history, vllm_engine=None, task_id=None, turn_idx=None,
                                     **kwargs):
        """使用vLLM生成响应"""
        from vllm import SamplingParams

        llm = vllm_engine

        # 将对话历史转换为模型输入格式
        prompt = self._format_conversation(conversation_history)

        # 将prompt转换为token IDs
        prompt_token_ids = self.tokenize_fn([prompt], self.prompt_max_len, padding=False)["input_ids"][0]

        # 为每个对话+轮次生成唯一的actor_rank
        unique_actor_rank = hash(f"{task_id}_{turn_idx}_{time.time_ns()}")

        sampling_params = SamplingParams(
            temperature=kwargs.get("temperature", 1.0),
            top_p=kwargs.get("top_p", 0.9),
            top_k=kwargs.get("top_k", 50),
            max_tokens=kwargs.get("max_new_tokens", self.max_gen_length),
            min_tokens=kwargs.get("min_new_tokens", 1),
            skip_special_tokens=True,
            include_stop_str_in_output=False,
        )

        # 使用Ray远程调用vLLM
        ray.get(llm.add_requests.remote(unique_actor_rank, sampling_params=sampling_params,
                                        prompt_token_ids=[prompt_token_ids]))
        response = ray.get(llm.get_responses.remote(unique_actor_rank))

        return response, prompt, prompt_token_ids

    def _format_conversation(self, conversation_history):
        """将对话历史转换为可接受的格式"""
        formatted = ""

        for msg in conversation_history:
            role = msg["role"]
            content = msg["content"]

            # 根据模型的标记格式
            if role == "system":
                formatted += f"<|im_start|>system\n{content}\n<|im_end|>\n"
            elif role == "user":
                formatted += f"<|im_start|>user\n{content}\n<|im_end|>\n"
            elif role == "assistant":
                formatted += f"<|im_start|>assistant\n{content}\n<|im_end|>\n"

        # 如果最后一条是用户消息，添加assistant标记提示模型回复
        if conversation_history and (conversation_history[-1]["role"] == "user" or conversation_history[-1]["role"] == "system"):
            # 添加assistant开始标记，提示模型回复
            formatted += "<|im_start|>assistant\n"

        return formatted

    def _create_turn_sample(self, task_id: str, turn_idx: int, current_turn: Turn, trajectory_info: Dict) -> Samples:
        """创建单轮对话样本，特殊处理第一轮对话"""
        # 获取特殊标记的token ids
        # current_turn.full_prompt_tokens（包含开始信号）和current_turn.response_tokens（不包含eos信号）是完整的一对儿，不需要额外处理
        im_start_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
        im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
        user_id = self.tokenizer.convert_tokens_to_ids("user")
        assistant_id = self.tokenizer.convert_tokens_to_ids("assistant")

        for name, tid in [("im_start_id", im_start_id), ("im_end_id", im_end_id), ("user_id", user_id), ("assistant_id", assistant_id)]:
            if tid is None:
                raise ValueError(f"Special token {name} not found in tokenizer vocabulary!")

        # 构建序列
        prompt_tokens = current_turn.full_prompt_tokens.tolist()
        prompt_len = len(prompt_tokens)#prompt_len是当前prompt的长度（包含开始信号）

        # 添加当前回复(包含think)
        response_tokens = current_turn.response_tokens.tolist()
        
        # 创建完整序列
        total_tokens = prompt_tokens + response_tokens
        
        # 添加结束标记(响应没有自带)
        if len(total_tokens) == 0:
            logger.error(f"Empty token sequence for task {task_id} turn {turn_idx}")
            return None

        if not (len(total_tokens) > 0 and total_tokens[-1] == im_end_id):
            total_tokens.append(im_end_id)#现在是完整seq包含所有信号
        
        # 创建序列张量 - 不进行padding，留给make_experience处理
        sequences = torch.tensor([total_tokens])
        
        # 计算基本长度信息
        response_length = torch.tensor([len(response_tokens)])
        total_length = torch.tensor([len(total_tokens)])

         
        env_feedback = current_turn.env_feedback

        turn_str_info = {
            # 'prompt': current_turn.prompt,
            'response': current_turn.clean_response,
            'feedback': current_turn.feedback,
        }
#     prompt: str  # 用户/系统输入
#     response: str  # AI响应(含think)
#     clean_response: str  # 清除think后的响应
#     feedback: str  # 环境反馈

        # 返回未padding的样本
        return Samples(
            sequences=sequences,
            attention_mask=None,  # 留空，在make_experience中创建
            action_mask=None,     # 留空，在make_experience中创建
            response_length=response_length,
            total_length=total_length,
            task_id=task_id,
            turn_idx=turn_idx,
            trajectory_info=trajectory_info,
            prompt_len=prompt_len,  # 用于后续创建action_mask
            env_feedback=env_feedback,
            str_info=turn_str_info
        )
