# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import heapq
import logging
import os
import queue
import random
import threading
from abc import ABC, abstractmethod
from collections import defaultdict
from concurrent.futures import Future

from typing import Any, Optional, List

import hydra
import numpy as np
import ray
import torch
from cachetools import LRUCache
from omegaconf import DictConfig, OmegaConf
from pydantic import BaseModel, ConfigDict, Field
from tensordict import TensorDict
from transformers import AutoProcessor, AutoTokenizer

from verl.single_controller.ray.base import RayWorkerGroup
from verl.trainer.ppo.reward import load_reward_manager
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import compute_position_id_with_mask
from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op


from verl.protocol import DataProtoConfig
from agentmath.agent_protocol import Agent_DataProto as DataProto

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

from agentmath.rollout.replica import Agent_TokenOutput, Agent_get_rollout_replica_class

class AsyncLLMServerManager:
    """
    A class to manage multiple OpenAI compatible LLM servers. This class provides
    - Load balance: least requests load balancing
    - Sticky session: send multi-turn chat completions to same server for automatic prefix caching
    """

    def __init__(self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], max_cache_size: int = 10000):
        """Initialize the AsyncLLMServerManager.

        Args:
            config (DictConfig): YAML config.
            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
            max_cache_size (int, optional): max cache size for request_id to server mapping. Defaults to 10000.
        """
        self.config = config
        self.server_handles = server_handles
        random.shuffle(self.server_handles)

        # Least requests load balancing
        self.weighted_serveres = [[0, (hash(server), server)] for server in server_handles]
        heapq.heapify(self.weighted_serveres)

        # LRU cache to map request_id to server
        self.request_id_to_server = LRUCache(maxsize=max_cache_size)
        self.server_to_request_ids = defaultdict(list)

        rollout_cfg = self.config.actor_rollout_ref.rollout
        self.weight_prompt_enable = rollout_cfg.LRUCache_server_weight_prompt_length_enable
        self.long_prompt_thr = rollout_cfg.LRUCache_server_weight_prompt_length
        self.LRUCache_server_weight_score = rollout_cfg.LRUCache_server_weight_score

    def _choose_server(self, request_id: str, prompt_ids: list[int]) -> ray.actor.ActorHandle:
        # TODO: implement server pressure awareness load balancing
        if request_id in self.request_id_to_server:
            return self.request_id_to_server[request_id]

        weight, (server_hash, server) = self.weighted_serveres[0]
        if self.weight_prompt_enable:
            weight_score = len(prompt_ids) // self.long_prompt_thr + self.LRUCache_server_weight_score
            weight += weight_score
        else:
            weight += 1

        self.weighted_serveres[0][0] += weight
        heapq.heapreplace(self.weighted_serveres, self.weighted_serveres[0])
        self.request_id_to_server[request_id] = server
        self.server_to_request_ids[server_hash].append((request_id, len(prompt_ids)))
        return server

    @rollout_trace_op
    async def generate(
        self,
        request_id,
        *,
        prompt_ids: list[int],
        sampling_params: dict[str, Any],
        image_data: Optional[list[Any]] = None,
    ) -> Agent_TokenOutput:
        """Generate tokens from prompt ids.

        Args:
            request_id (str): request id for sticky session.
            prompt_ids (List[int]): List of prompt token ids.
            sampling_params (Dict[str, Any]): Sampling parameters for the chat completion.

        Returns:
            Agent_TokenOutput: token output
        """
        server = self._choose_server(request_id, prompt_ids)
        output = await server.generate.remote(
            request_id=request_id,
            prompt_ids=prompt_ids,
            sampling_params=sampling_params,
            image_data=image_data,
        )
        mm = {}
        for k,v in self.server_to_request_ids.items():
            nums = len(v)
            prompt_lists = [kt[1] for kt in v]
            mm[k] = (nums, prompt_lists)
        return output

    def clear_cache(self):
        self.request_id_to_server.clear()
        self.server_to_request_ids.clear()
        random.shuffle(self.server_handles)
        self.weighted_servers = [[0, (hash(s), s)] for s in self.server_handles]
        heapq.heapify(self.weighted_servers)
        return True

class AgentLoopMetrics(BaseModel):
    """Agent loop performance metrics."""

    generate_sequences: float = 0.0
    tool_calls: float = 0.0


class AgentLoopOutput(BaseModel):
    """Agent loop output."""

    prompt_ids: list[int]
    """Prompt token ids."""
    response_ids: list[int]
    """Response token ids including LLM generated token, tool response token."""
    response_mask: list[int]
    """Response mask, 1 for LLM generated token, 0 for tool response token."""
    response_logprobs: Optional[list[float]] = None
    """Log probabilities for the response tokens."""
    multi_modal_data: Optional[dict[str, Any]] = None
    """Multi-modal data for multi-modal tools."""
    reward_score: Optional[float] = None
    """Reward score for the trajectory."""
    num_turns: int = 0
    """Number of chat turns, including user, assistant, tool."""
    metrics: AgentLoopMetrics
    """Auxiliary performance metrics"""
    extra_fields: dict[str, Any] = {}
    """Extra fields for dynamic addition."""
    is_finish: bool = True
    agent_name: str = ""
    messages: List[Any] = Field(default_factory=list)


class _InternalAgentLoopOutput(AgentLoopOutput):
    """Internal agent loop output with padded sequences."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    prompt_ids: torch.Tensor
    """Padded prompt token ids."""
    response_ids: torch.Tensor
    """Padded response token ids."""
    input_ids: torch.Tensor
    """Padded input ids(prompt_ids + response_ids)."""
    position_ids: torch.Tensor
    """Padded position ids."""
    response_mask: torch.Tensor
    """Padded response mask."""
    attention_mask: torch.Tensor
    """Padded attention mask."""
    response_logprobs: Optional[torch.Tensor] = None
    """Padded log probabilities for the response tokens."""
    multi_modal_inputs: Optional[dict[str, torch.Tensor]] = None
    """Multi-modal inputs for processors (e.g., pixel_values, image_grid_thw)."""
    extra_fields: dict[str, Any] = {}
    """Extra fields for dynamic addition."""


# make hydra.utils.instantiate happy
class _DummyConfig:
    def __init__(self, config: DictConfig) -> None:
        self.config = config


class AgentLoopBase(ABC):
    """An agent loop takes a input message, chat with OpenAI compatible LLM server and interact with various
    environments."""

    _class_initialized = False

    def __init__(
        self,
        trainer_config: _DummyConfig,
        server_manager: AsyncLLMServerManager,
        tokenizer: AutoTokenizer,
        processor: AutoProcessor,
        **kwargs,
    ):
        """Initialize agent loop, each sample will have its own loop instance.

        Args:
            trainer_config (_DummyConfig): trainer config.
            server_manager (AsyncLLMServerManager): OpenAI compatible LLM server manager.
            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
            processor (AutoProcessor): Processor for process messages.
        """

        self.init_class(config=trainer_config.config, tokenizer=tokenizer, processor=processor, **kwargs)
        self.config = trainer_config.config
        self.server_manager = server_manager
        self.tokenizer = tokenizer
        self.processor = processor
        self.loop = asyncio.get_running_loop()

    @classmethod
    def init_class(cls, config: DictConfig, tokenizer: AutoTokenizer, processor: AutoProcessor, **kwargs):
        """This is used to do heavy initialization work that should shared across all instances. It's only called once.

        Args:
            config (DictConfig): trainer config.
            tokenizer (AutoTokenizer): Tokenizer for tokenize messages.
            processor (AutoProcessor): Processor for process multi_modal data.
            **kwargs: extra kwargs from config file passed in by `hydra.utils.instantiate`.
        """
        if cls._class_initialized:
            return
        cls._class_initialized = True

    @abstractmethod
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        """Run agent loop to interact with LLM server and environment.

        Args:
            sampling_params (Dict[str, Any]): LLM sampling params.
            **kwargs: dataset fields from `verl.utils.dataset.RLHFDataset`.

        Returns:
            AgentLoopOutput: Agent loop output.
        """
        raise NotImplementedError


"""Agent loop registry: key is agent_name, value is a dict of agent loop config
used by hydra.utils.instantiate to initialize agent loop instance.

https://hydra.cc/docs/advanced/instantiate_objects/overview/
"""
_agent_loop_registry: dict[str, dict] = {}


def register(agent_name: str):
    """Register agent loop class."""

    def decorator(subclass: type[AgentLoopBase]) -> type[AgentLoopBase]:
        fqdn = f"{subclass.__module__}.{subclass.__qualname__}"
        _agent_loop_registry[agent_name] = {"_target_": fqdn}
        return subclass

    return decorator


@ray.remote(num_cpus=1)
class BatchExecutor:
    """Batch executor is used to collect requests into a batch execution"""

    def __init__(self, batch_func, micro_batch_size=1, max_batch_size=None):
        """

        Args:
            batch_func: batch processing function.
            micro_batch_size (int, optional): micro batch size. Defaults to 1.
            max_batch_size: batch size for batching.
        """
        self._q = queue.Queue()
        self._batch_func = batch_func
        self._max_batch = max_batch_size
        self._micro_batch_size = micro_batch_size

        self._worker = threading.Thread(target=self._worker_loop, daemon=True)
        self._worker.start()

    async def submit_task(self, item):
        """
        Blocking submission, returning Future
        Args:
            item: function input

        Returns:
            fut: function output
        """
        fut = Future()
        self._q.put((item, fut))
        async_fut = asyncio.wrap_future(fut)
        res = await async_fut
        return res

    def _worker_loop(self):
        while True:
            # 1. Fetch a full batch (block until at least one)
            first, first_fut = self._q.get()
            items = [first]
            futs = [first_fut]

            # Take the remaining tasks at once
            while True:
                try:
                    next_item, next_fut = self._q.get_nowait()
                    items.append(next_item)
                    futs.append(next_fut)
                    if self._max_batch and len(items) >= self._max_batch:
                        break
                except queue.Empty:
                    while len(items) % self._micro_batch_size != 0:
                        next_item, next_fut = self._q.get()
                        items.append(next_item)
                        futs.append(next_fut)
                        if self._max_batch and len(items) >= self._max_batch:
                            break
                    break

            try:
                results = self._batch_func(items)
            except Exception as e:
                for f in futs:
                    f.set_exception(e)
            else:
                for f, r in zip(futs, results, strict=False):
                    f.set_result(r)


@ray.remote(num_cpus=1)
class RewardManagerWorker:
    """A Ray-remote worker that computes reward scores asynchronously.

    This worker is scheduled with one CPU per instance and is intended to run
    in parallel with the agent loop to overlap reward computation with other I/O
    or model calls.
    """

    def __init__(
        self,
        config: DictConfig,
        local_path: str,
        rm_executor: BatchExecutor = None
    ) -> None:
        """Initialize the reward manager worker.

        Args:
            config (DictConfig): Hydra/OmegaConf configuration object containing
                reward model settings under `config.reward_model`.
            local_path (str): Filesystem path or model ID used to load the tokenizer.
            rm_executor (BatchExecutor, optional): Optional batching/execution
                helper for running reward model inference in batches.

        Side effects:
            - Loads a Hugging Face tokenizer (may download weights if not cached).
            - Instantiates the reward manager with parameters from `config`.

        Notes:
            - `trust_remote_code=True` allows custom tokenizer code from model repos.
              Only enable this for trusted sources.
            - `num_examine=0` disables extra cross-check passes during reward scoring.
        """
        # Build a tokenizer for the reward model. May use custom code from the repo.
        tokenizer = hf_tokenizer(local_path, trust_remote_code=True)

        # Create the reward manager with config-driven kwargs, if provided.
        # `reward_kwargs` lets you pass through advanced model/runtime options.
        self.reward_manager = load_reward_manager(
            config,
            tokenizer,
            num_examine=0,
            **config.reward_model.get("reward_kwargs", {})
        )

        # Optional batch executor to parallelize/batch reward computations.
        self.rm_executor = rm_executor

        # Capture the running asyncio event loop for scheduling async tasks.
        # In Ray actors, this allows the worker to await coroutines or create tasks.
        self.loop = asyncio.get_event_loop()

    async def compute_score(
        self,
        data: DataProto,
    ) -> dict:
        """Compute reward score for agent loop output.

        NOTE: Since `reward_manager.__call__` is blocking function, we run it in thread pool to
        compute multiple samples in parallel.

        Args:
            data: reward function input

        Returns:
            dict: Reward score and reward extra info.
        """
        result = await self.loop.run_in_executor(
            None,
            self.reward_wrapper,
            data,
            True,  # return_dict
        )

        reward_score = result["reward_tensor"].sum(dim=-1).item()
        reward_extra_info = {k: v[0] for k, v in result.get("reward_extra_info", {}).items()}
        return {"reward_score": reward_score, "reward_extra_info": reward_extra_info}

    def reward_wrapper(self, data: DataProto, return_dict=False) -> torch.Tensor:
        """Assemble reward functions and reward model into one function and expose it to the event loop
        Args:
            return_dict: whether return as dict
            data: DataProto from compute reward score
        Returns:
            torch.Tensor: Reward score tensor.
        """
        if self.rm_executor is not None:
            res = ray.get(self.rm_executor.submit_task.remote(data))
            data = data.union(res)

        return self.reward_manager(data, return_dict)


@ray.remote
class AgentLoopWorker:
    """Agent loop worker takes a batch of messages and run each message in an agent loop."""

    def __init__(
        self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None
    ):
        """Initialize agent loop manager.

        Args:
            config (DictConfig): YAML config.
            server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
        """
        self.config = config
        self.server_manager = AsyncLLMServerManager(config, server_handles)
        self.rm_executor = rm_executor

        model_path = config.actor_rollout_ref.model.path
        self.model_name = "/".join(model_path.split("/")[-2:])
        local_path = copy_to_local(config.actor_rollout_ref.model.path)
        self.tokenizer = hf_tokenizer(local_path, trust_remote_code=True)
        self.processor = hf_processor(local_path, trust_remote_code=True)

        agent_loop_config_path = config.actor_rollout_ref.rollout.agent.agent_loop_config_path
        if agent_loop_config_path:
            agent_loop_configs = OmegaConf.load(agent_loop_config_path)
            for agent_loop_config in agent_loop_configs:
                _agent_loop_registry[agent_loop_config.name] = agent_loop_config
        if self.config.actor_rollout_ref.model.get("custom_chat_template", None) is not None:
            if self.processor is not None:
                self.processor.chat_template = self.config.actor_rollout_ref.model.custom_chat_template
            self.tokenizer.chat_template = self.config.actor_rollout_ref.model.custom_chat_template

        self.reward_manager_worker = RewardManagerWorker.options(
            scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                node_id=ray.get_runtime_context().get_node_id(),
                soft=False,
            ),
        ).remote(self.config, local_path, self.rm_executor)

        trace_config = self.config.actor_rollout_ref.rollout.get("trace", {})
        RolloutTraceConfig.init(
            self.config.trainer.project_name,
            self.config.trainer.experiment_name,
            trace_config.get("backend"),
            trace_config.get("token2text", False),
        )

    async def generate_sequences(self, batch: DataProto) -> DataProto:
        """Generate sequences from agent loop.

        Args:
            batch (DataProto): Input batch.

        Returns:
            DataProto: Output batch.
            - prompts: [bsz, prompt_length], prompt token ids from dataset.
            - responses: [bsz, response_length], output token ids include response tokens
              from LLM generation and observation tokens from tool_calls.
            - response_mask: [bsz, response_length], 1 for LLM generated tokens, 0 for observation/padding tokens.
            - input_ids: [bsz, prompt_length + response_length], whole sequence token ids, including prompt tokens
              and response tokens.
            - attention_mask: [bsz, prompt_length + response_length], 0 for padding tokens, 1 for other tokens.
            - position_ids: [bsz, prompt_length + response_length], incremental position ids.

            For multi-turn conversations:
            responses:     |<- LLM generation ->|<- tool_calls ->|<- LLM generation ->|<- padding ->|
            response_mask: | 1, 1, 1, ..., 1, 1 | 0, 0, .., 0, 0 | 1, 1, 1, ..., 1, 1 | 0, 0, ..., 0|
        """
        config = self.config.actor_rollout_ref.rollout
        sampling_params = dict(
            temperature=config.temperature,
            top_p=config.top_p,
            repetition_penalty=1.0,
            logprobs=config.calculate_log_probs,
            validate=False,
        )

        if batch.meta_info.get("validate", False):
            sampling_params["top_p"] = config.val_kwargs.top_p
            sampling_params["temperature"] = batch.meta_info.get("valid_temperature", config.val_kwargs.temperature)
            sampling_params["validate"] = True

        # by default, we assume it's a single turn agent
        if "agent_name" not in batch.non_tensor_batch:
            batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)

        if "raw_response_ids" not in batch.non_tensor_batch:
            batch.non_tensor_batch["raw_response_ids"] = np.fromiter(([] for _ in range(len(batch))), dtype=object)

        if "raw_response_mask" not in batch.non_tensor_batch:
            batch.non_tensor_batch["raw_response_mask"] = np.fromiter(([] for _ in range(len(batch))), dtype=object)

        if "index" in batch.non_tensor_batch:
            index = batch.non_tensor_batch["index"]
        else:
            index = np.arange(len(batch))

        if "age" not in batch.non_tensor_batch:
            batch.non_tensor_batch["age"] = np.ones(len(batch), dtype=int)
        trajectory_info = await get_trajectory_info(
            batch.meta_info.get("global_steps", -1), index.tolist(), batch.meta_info.get("validate", False)
        )

        tasks = []
        for i in range(len(batch)):
            sampling_params_c = copy.deepcopy(sampling_params)
            kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
            tasks.append(asyncio.create_task(self._run_agent_loop(sampling_params_c, trajectory_info[i], **kwargs)))
        outputs = await asyncio.gather(*tasks)

        max_resp_len = (self.config.data.val_max_response_length if batch.meta_info.get("validate", False)
                        else self.config.actor_rollout_ref.rollout.response_length)

        output = self._postprocess(outputs, max_response_length=max_resp_len)
        return output

    async def _run_agent_loop(
        self,
        sampling_params: dict[str, Any],
        trajectory: dict[str, Any],
        *,
        agent_name: str,
        **kwargs,
    ) -> AgentLoopOutput:
        """
        Run the main loop of an agent asynchronously.

        Args:
            sampling_params (dict[str, Any]): Configuration parameters controlling
                how the agent samples or generates outputs (e.g., temperature, max tokens).
            trajectory (dict[str, Any]): The current trajectory or history of the agent’s
                interactions, including inputs, outputs, and intermediate states.
            agent_name (str): The name or identifier of the agent being executed.
            **kwargs: Additional keyword arguments passed to customize agent behavior.

        Returns:
            AgentLoopOutput: The structured output containing the agent’s response,
            updated state, and any relevant metadata.

        Notes:
            - This function represents the core execution loop for an autonomous or
              semi-autonomous agent.
            - Designed to be asynchronous to support concurrent agent execution or
              non-blocking I/O operations.
        """
        with rollout_trace_attr(
            step=trajectory["step"],
            sample_index=trajectory["sample_index"],
            rollout_n=trajectory["rollout_n"],
            validate=trajectory["validate"],
            name="agent_loop",
        ):

            assert agent_name in _agent_loop_registry, (
                f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
            )
            agent_loop_config = _agent_loop_registry[agent_name]
            agent_loop = hydra.utils.instantiate(
                config=agent_loop_config,
                trainer_config=_DummyConfig(config=self.config),
                server_manager=self.server_manager,
                tokenizer=self.tokenizer,
                processor=self.processor,
            )
            partial_split = self.config.algorithm.partial_rollout_max_split
            if partial_split > 1:
                output: AgentLoopOutput = await agent_loop.run_patial_rollout(sampling_params, **kwargs)
            else:
                output: AgentLoopOutput = await agent_loop.run(sampling_params, **kwargs)

            output.agent_name = agent_name
            return output

    def _postprocess(self, inputs: List[AgentLoopOutput], max_response_length=32768) -> DataProto:
        self.tokenizer.padding_side = "left"
        outs = self.tokenizer.pad(
            [{"input_ids": i.prompt_ids} for i in inputs],
            padding="max_length",
            max_length=self.config.actor_rollout_ref.rollout.prompt_length,
            return_tensors="pt",
            return_attention_mask=True,
        )
        prompt_ids, prompt_mask = outs["input_ids"], outs["attention_mask"]

        self.tokenizer.padding_side = "right"
        outs = self.tokenizer.pad(
            [{"input_ids": i.response_ids} for i in inputs],
            padding="max_length",
            max_length=max_response_length,
            return_tensors="pt",
            return_attention_mask=True,
        )
        resp_ids, resp_mask = outs["input_ids"], outs["attention_mask"]

        outs = self.tokenizer.pad(
            [{"input_ids": i.response_mask} for i in inputs],
            padding="max_length",
            max_length=max_response_length,
            return_tensors="pt",
            return_attention_mask=False,
        )
        resp_logic_mask = outs["input_ids"] * resp_mask

        input_ids = torch.cat([prompt_ids, resp_ids], dim=1)
        attention_mask = torch.cat([prompt_mask, resp_mask], dim=1)
        position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask

        batch = TensorDict(
            {
                "prompts": prompt_ids,
                "responses": resp_ids,
                "response_mask": resp_logic_mask,
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
            },
            batch_size=len(input_ids),
        )

        raw_prompt_ids_np = np.array([i.prompt_ids for i in inputs], dtype=object)
        raw_response_ids_np = np.array([i.response_ids for i in inputs], dtype=object)
        raw_response_mask_np = np.array([i.response_mask for i in inputs], dtype=object)
        raw_prompt_np = np.array([i.messages for i in inputs], dtype=object)
        agent_name_np = np.array([i.agent_name for i in inputs])

        num_turns = np.array([i.num_turns for i in inputs], dtype=np.int32)
        finishs = np.array([i.is_finish for i in inputs])
        metrics = [i.metrics.model_dump() for i in inputs]

        return DataProto(
            batch=batch,
            non_tensor_batch={
                "__num_turns__": num_turns,
                "raw_prompt_ids": raw_prompt_ids_np,
                "raw_response_ids": raw_response_ids_np,
                "raw_response_mask": raw_response_mask_np,
                "finished": finishs,
                "raw_prompt": raw_prompt_np,
                "agent_name": agent_name_np,
            },
            meta_info={"metrics": metrics},
        )
    def get_server_manager(self):
        """Return the server manager instance."""
        return self.server_manager

async def get_trajectory_info(step, index, validate):
    """Get trajectory info.

    Args:
        step (int): global steps in the trainer.
        index (list): form datastore extra_info.index column.
        validate (bool): whether is a validate step.

    Returns:
        list: trajectory.
    """
    trajectory_info = []
    rollout_n = 0
    for i in range(len(index)):
        if i > 0 and index[i - 1] == index[i]:
            rollout_n += 1
        else:
            rollout_n = 0
        trajectory_info.append({"step": step, "sample_index": index[i], "rollout_n": rollout_n, "validate": validate})
    return trajectory_info


class AgentLoopManager:
    """Agent loop manager that manages a group of agent loop workers."""

    def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
        """Initialize agent loop manager.

        Args:
            config (DictConfig): trainer config.
            worker_group (RayWorkerGroup): ActorRolloutRef worker group for hybrid mode; None for standalone mode.
        """
        self.config = config
        self.worker_group = worker_group
        self.rm_executor = None
        self.rm_micro_batch_size = None
        if rm_wg:

            def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
                new_data_list = []
                for data in data_list:
                    temp_non_tensor_batch = {"__num_turns__": data.non_tensor_batch["__num_turns__"]}
                    temp_data = DataProto(batch=data.batch, non_tensor_batch=temp_non_tensor_batch)
                    new_data_list.append(temp_data)

                new_batch = DataProto.concat(new_data_list)
                out_data = rm_wg.compute_rm_score(new_batch)
                return out_data.split(1)

            self.rm_executor = BatchExecutor.options(
                scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                    node_id=ray.get_runtime_context().get_node_id(),
                    soft=False,
                ),
            ).remote(batch_fn, rm_wg.world_size)

            self.rm_micro_batch_size = rm_wg.world_size

        self._initialize_llm_servers()
        self._init_agent_loop_workers()

        # Initially we're in sleep mode.
        if self.config.actor_rollout_ref.rollout.free_cache_engine:
            self.sleep()

    def _initialize_llm_servers(self):
        rollout_world_size = (
            self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
            * self.config.actor_rollout_ref.rollout.data_parallel_size
        )
        world_size = (
            self.worker_group.world_size
            if self.worker_group
            else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
        )
        num_replicas = world_size // rollout_world_size

        rollout_replica_class = Agent_get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
        rollout_config = self.config.actor_rollout_ref.rollout
        model_config = self.config.actor_rollout_ref.model
        self.rollout_replicas = [
            rollout_replica_class(
                replica_rank=replica_rank,
                config=rollout_config,
                model_config=model_config,
                gpus_per_node=self.config.trainer.n_gpus_per_node,
            )
            for replica_rank in range(num_replicas)
        ]
        if self.worker_group:
            self._run_all([server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
        else:
            self._run_all([server.init_standalone() for server in self.rollout_replicas])
        self.server_handles = [server._server_handle for server in self.rollout_replicas]
        self.server_addresses = [server._server_address for server in self.rollout_replicas]

    def _init_agent_loop_workers(self):
        self.agent_loop_workers = []
        num_workers = self.config.actor_rollout_ref.rollout.agent.num_workers

        node_ids = [node["NodeID"] for node in ray.nodes() if node["Alive"] and node["Resources"].get("CPU", 0) > 0]
        for i in range(num_workers):
            # Round-robin scheduling over the all nodes
            node_id = node_ids[i % len(node_ids)]
            self.agent_loop_workers.append(
                AgentLoopWorker.options(
                    name=f"agent_loop_worker_{i}",
                    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
                        node_id=node_id, soft=True
                    ),
                ).remote(self.config, self.server_handles, self.rm_executor)
            )

    def generate_sequences(self, prompts: DataProto) -> DataProto:
        """Split input batch and dispatch to agent loop workers.

        Args:
            prompts (DataProto): Input batch.

        Returns:
            DataProto: Output batch.
        """
        #breakpoint()
        if self.rm_micro_batch_size and len(prompts) % self.rm_micro_batch_size != 0:
            raise ValueError(
                f"The length of prompts {len(prompts)} cannot divide the world size of rm_wg {self.rm_micro_batch_size}"
            )
        if self.config.actor_rollout_ref.rollout.free_cache_engine:
            self.wake_up()

        DataProtoConfig.auto_padding = True
        chunkes = prompts.chunk(len(self.agent_loop_workers))
        outputs = ray.get(
            [
                worker.generate_sequences.remote(chunk)
                for worker, chunk in zip(self.agent_loop_workers, chunkes, strict=True)
            ]
        )
        DataProtoConfig.auto_padding = False

        output = DataProto.concat_array(outputs)

        if self.config.actor_rollout_ref.rollout.free_cache_engine and \
        self.config.actor_rollout_ref.rollout.free_cache_engine_sleep:
            self.sleep()
            self.clear_cache()

        # calculate performance metrics
        metrics = [output.meta_info.pop("metrics") for output in outputs]
        timing = self._performance_metrics(metrics, output)

        output.meta_info = {"timing": timing, **outputs[0].meta_info}
        return output

    def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: DataProto) -> dict[str, float]:
        timing = {}
        t_generate_sequences = np.array([metric["generate_sequences"] for chunk in metrics for metric in chunk])
        t_tool_calls = np.array([metric["tool_calls"] for chunk in metrics for metric in chunk])
        timing["agent_loop/generate_sequences/min"] = t_generate_sequences.min()
        timing["agent_loop/generate_sequences/max"] = t_generate_sequences.max()
        timing["agent_loop/generate_sequences/mean"] = t_generate_sequences.mean()
        timing["agent_loop/tool_calls/min"] = t_tool_calls.min()
        timing["agent_loop/tool_calls/max"] = t_tool_calls.max()
        timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean()

        # batch sequence generation is bounded by the slowest sample
        slowest = np.argmax(t_generate_sequences + t_tool_calls)
        attention_mask = output.batch["attention_mask"][slowest]
        prompt_length = output.batch["prompts"].shape[1]
        timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
        timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
        timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
        timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()

        return timing

    def wake_up(self):
        """Wake up all rollout replica instances."""
        self._run_all([replica.wake_up() for replica in self.rollout_replicas])

    def sleep(self):
        """Sleep all rollout replica instances."""
        self._run_all([replica.sleep() for replica in self.rollout_replicas])

    def _run_all(self, tasks: list[asyncio.Task]):
        async def run_all():
            await asyncio.gather(*tasks)

        asyncio.run(run_all())

    def clear_cache(self):
        server_managers = ray.get([worker.get_server_manager.remote() for worker in self.agent_loop_workers])
        for server_m in server_managers:
            server_m.clear_cache()
