# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import json
import logging
import os
import re
from typing import Any
from uuid import uuid4
from verl.tools.utils.tool_registry import initialize_tools_from_config
from verl.utils.profiler import simple_timer
from verl.utils.rollout_trace import rollout_trace_op
from .tool_parser import FunctionCall, ToolParser
from .agent_math_loop import AgentLoopBase, AgentLoopOutput, register


logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

@register("ag_math_agent_patial_rollout")
class AgentMathAgentLoopPatialRollout(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, processor, **kwargs):
        # One-time, class-level initialization (idempotent).
        # If we've already initialized static members, exit early.
        if cls._class_initialized:
            return
        cls._class_initialized = True
        print("Performing class-level AgentLoop initialization")

        # ----- Core references from config -----
        # Keep handles to shared components on the class for reuse by all instances.
        cls.tokenizer = tokenizer
        cls.config = config
        cls.processor = processor

        # For "partial rollout", split the total allowed conversational turns
        # evenly across `partial_rollout_max_split` segments. Each segment only
        # gets a fraction of the overall budget.
        cls.max_user_turns = (
            config.actor_rollout_ref.rollout.multi_turn.max_user_turns
            // config.actor_rollout_ref.rollout.partial_rollout_max_split
        )
        cls.max_assistant_turns = (
            config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
            // config.actor_rollout_ref.rollout.partial_rollout_max_split
        )

        # Concurrency and tool-response constraints for multi-turn rollouts.
        cls.max_parallel_calls = config.actor_rollout_ref.rollout.multi_turn.max_parallel_calls
        cls.max_tool_response_length = config.actor_rollout_ref.rollout.multi_turn.max_tool_response_length
        cls.tool_response_truncate_side = config.actor_rollout_ref.rollout.multi_turn.tool_response_truncate_side

        # ----- Tooling setup -----
        # Load tool definitions from a YAML/JSON config (if provided), then build:
        # - `cls.tools`: runtime tool objects keyed by name
        # - `cls.tool_schemas`: JSON schemas for tool-callable functions
        tool_config_path = config.actor_rollout_ref.rollout.multi_turn.tool_config_path
        tool_list = initialize_tools_from_config(tool_config_path) if tool_config_path else []
        cls.tools = {tool.name: tool for tool in tool_list}
        cls.tool_schemas = [
            tool.tool_schema.model_dump(exclude_unset=True, exclude_none=True)
            for tool in tool_list
        ]

        # Create a parser that maps model outputs to structured tool calls
        # according to the configured formatting strategy.
        cls.tool_parser = ToolParser.get_tool_parser(
            config.actor_rollout_ref.rollout.multi_turn.format_real,
            cls.tokenizer,
        )
        cls.function_name = config.actor_rollout_ref.rollout.multi_turn.function_name
        print(f"Initialized tools: {cls.tools}")

        # ----- Prompt / length budgeting -----
        # Global token budgets for prompt + response; partial rollout gets a slice
        # of the full response length for each segment.
        cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
        cls.max_response_length = config.actor_rollout_ref.rollout.response_length
        cls.response_length = (
            config.actor_rollout_ref.rollout.response_length
            // config.actor_rollout_ref.rollout.partial_rollout_max_split
        )

        # Precompute a system prompt template using the tokenizer’s chat format.
        # `apply_chat_template([{}])` creates a minimal system message scaffold.
        # Using `tokenize=True` returns token IDs suitable for length checks.
        cls.system_prompt = tokenizer.apply_chat_template(
            [{}], add_generation_prompt=False, tokenize=True
        )

        # Tokens that should stop generation early (e.g., EOS, role separators).
        cls.stop_tokens_list = list(config.actor_rollout_ref.rollout.stop_tokens)

        # Upper bound on tokens seen by the model during a turn:
        # prompt budget + (full) response budget.
        cls.max_model_len = (
            config.actor_rollout_ref.rollout.prompt_length
            + config.actor_rollout_ref.rollout.response_length
        )


    @rollout_trace_op
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        messages = list(kwargs["raw_prompt"])
        age = kwargs["age"]
        tools_kwargs = kwargs.get("tools_kwargs", {})
        metrics = {}
        request_id = uuid4().hex
        sampling_params["stop"] = self.stop_tokens_list

        if "validate" in sampling_params:
            validate = sampling_params.pop("validate", False)
            if validate == True:
                max_user_turns = self.config.actor_rollout_ref.rollout.multi_turn.val_max_user_turns
                max_assistant_turns = self.config.actor_rollout_ref.rollout.multi_turn.val_max_assistant_turns
                response_length = self.config.data.val_max_response_length
                max_response_length = response_length
                max_model_len = self.config.actor_rollout_ref.rollout.prompt_length + max_response_length

            else:
                max_user_turns = self.max_user_turns
                max_assistant_turns = self.max_assistant_turns
                response_length = self.response_length
                max_model_len = self.max_model_len
                max_response_length = self.max_response_length
        else:
            validate = False
            max_user_turns = self.max_user_turns
            max_assistant_turns = self.max_assistant_turns
            response_length = self.response_length
            max_model_len = self.max_model_len
            max_response_length = self.max_response_length

        print(
            f"validate ==== {validate}, "
            f"sampling_params === {sampling_params},"
            f"max_user_turns==={max_user_turns}, "
            f"max_assistant_turns==={max_assistant_turns},"
            )
        prompt_ids = await self.loop.run_in_executor(
            None,
            lambda: self.tokenizer.apply_chat_template(
                messages, tools=None, add_generation_prompt=True, tokenize=True
            ),
        )
        response_mask = []
        user_turns, assistant_turns = 0, 0
        while True:
            with simple_timer("generate_sequences", metrics):
                sampling_params_copy = copy.deepcopy(sampling_params)
                vllm_output= await self.server_manager.generate(
                    request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params_copy
                )
                response_ids = vllm_output.token_ids
                response_logprobs = vllm_output.log_probs
                finish_reason_list = vllm_output.finish_reson

            prompt_ids += response_ids
            response_mask += [1] * len(response_ids)
            assistant_turns += 1

            if len(response_mask) >= response_length:
                break

            if max_assistant_turns and assistant_turns >= max_assistant_turns:
                break

            # reach max user turns
            if max_user_turns and user_turns >= max_user_turns:
                break

            # no tool calls
            _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids, function_name=self.function_name)
            if not tool_calls:
                break

            # call tools
            tasks = []
            for tool_call in tool_calls[: self.max_parallel_calls]:
                tasks.append(self._call_tool(tool_call, tools_kwargs))
            with simple_timer("tool_calls", metrics):
                tool_responses = await asyncio.gather(*tasks)
            if any(isinstance(item, Exception) for item in tool_responses):
                break

            # append tool_response_ids
            temp_tool_response = tool_responses[0]["content"]
            interpreter_results = "\n<interpreter>\n" + temp_tool_response + "\n</interpreter>\n"


            tool_response_ids = await self.loop.run_in_executor(
                None,
                lambda tool_results=interpreter_results: self.tokenizer.encode(tool_results)
            )

            # NOTE: last turn should not be user turn, or the EOS token reward
            # can't be propagated to previous token in GAE.
            if len(response_mask) + len(tool_response_ids) >= response_length:
                break

            prompt_ids += tool_response_ids
            response_mask += [0] * len(tool_response_ids)
            user_turns += 1
        response_ids = prompt_ids[-len(response_mask) :]
        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
        output = AgentLoopOutput(
            prompt_ids=prompt_ids,
            response_ids=response_ids[: response_length],
            response_mask=response_mask[: response_length],
            num_turns=user_turns + assistant_turns + 1,
            metrics=metrics,
        )
        return output
    @rollout_trace_op
    async def run_patial_rollout(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        """
        Execute a *partial* rollout segment of the agent loop asynchronously.

        This method runs one partition of a multi-turn rollout (see class-level
        budgets such as `response_length`, `max_user_turns`, and
        `max_assistant_turns`). It is intended to interleave with other partial
        segments to build a full trajectory while staying within token/turn limits.
        (Function name kept as provided; consider renaming to `run_partial_rollout`.)

        Args:
            sampling_params (dict[str, Any]):
                Inference/generation controls for the model (e.g., temperature,
                top_p, max_new_tokens). These are applied to this segment only.
            **kwargs:
                Optional controls and runtime hooks. Common patterns include:
                - `trajectory` (dict): Mutable conversation state to extend.
                - `tools_enabled` (bool): Whether tool calling is allowed.
                - `parallel` (bool | int): Whether/how many tool calls to run in parallel.
                - `stop_tokens` (Iterable[int|str]): Overrides for early stopping.
                - `metadata` (dict): Extra info to thread through to the output.
        """
        messages = list(kwargs["raw_prompt"])
        pre_response_id = list(kwargs["raw_response_ids"])
        pre_response_mask = list(kwargs["raw_response_mask"])
        age = kwargs["age"]
        tools_kwargs = kwargs.get("tools_kwargs", {})

        metrics = {}
        request_id = uuid4().hex
        print(f"request_id ===== {request_id}")
        sampling_params["stop"] = self.stop_tokens_list
        sampling_params["return_fininsh_reason"] = True
        pre_code_numbers = 0
        pre_response_length = 0
        complement_code_numbers = 0
        complement_token_numbers = 0
        if "validate" in sampling_params:
            validate = sampling_params.pop("validate", False)
            if validate == True:
                max_user_turns = self.config.actor_rollout_ref.rollout.multi_turn.val_max_user_turns
                max_assistant_turns = self.config.actor_rollout_ref.rollout.multi_turn.val_max_assistant_turns
                response_length = self.config.data.val_max_response_length
                max_response_length = response_length
                max_model_len = self.config.actor_rollout_ref.rollout.prompt_length + max_response_length
                sampling_params['max_tokens'] = max(response_length, max_response_length)
            else:
                pre_response_length = len(pre_response_id)
                pattern = r"<code>(.*?)</code>"
                pre_response_text = self.tokenizer.decode(pre_response_id, skip_special_tokens=False)
                pre_code_matches = re.findall(pattern, pre_response_text, flags=re.S | re.I)
                pre_code_numbers = len(pre_code_matches)
                complement_code_numbers = int((age - 1) * self.max_user_turns - pre_code_numbers)
                complement_token_numbers = int((age - 1) * self.response_length - pre_response_length)

                max_user_turns = self.max_user_turns + complement_code_numbers
                max_assistant_turns = self.max_assistant_turns + complement_code_numbers
                response_length = self.response_length + complement_token_numbers
                max_model_len = self.max_model_len
                max_response_length = self.max_response_length

                sampling_params['max_tokens'] = max(response_length, 1024)
        else:
            validate = False
            pre_response_length = len(pre_response_id)
            pattern = r"<code>(.*?)</code>"
            pre_response_text = self.tokenizer.decode(pre_response_id, skip_special_tokens=False)
            pre_code_matches = re.findall(pattern, pre_response_text, flags=re.S | re.I)
            pre_code_numbers = len(pre_code_matches)
            complement_code_numbers = int((age - 1) * self.max_user_turns - pre_code_numbers)
            complement_token_numbers = int((age - 1) * self.response_length - pre_response_length)
            max_user_turns = self.max_user_turns + complement_code_numbers
            max_assistant_turns = self.max_assistant_turns + complement_code_numbers
            response_length = self.response_length + complement_token_numbers
            max_model_len = self.max_model_len
            max_response_length = self.max_response_length

            sampling_params['max_tokens'] = max(response_length, 1024)

        print(
            f"validate === {validate}, "
            f"sampling_params === {sampling_params}, "
            f"age === {age}, "
            f"max_user_turns === {max_user_turns}, "
            f"max_assistant_turns === {max_assistant_turns}, "
            f"complement_code_numbers === {complement_code_numbers}, "
            f"pre_code_numbers === {pre_code_numbers}, "
            f"response_length === {response_length}, "
            f"complement_token_numbers === {complement_token_numbers}, "
            f"pre_response_length === {pre_response_length}, "
            f"max_model_len === {max_model_len}, "
            f"max_response_length === {max_response_length}"
        )

        prompt_ids = await self.loop.run_in_executor(
            None,
            lambda: self.tokenizer.apply_chat_template(
                messages, tools=None, add_generation_prompt=True, tokenize=True
            ),
        )

        prompt_ids += pre_response_id
        response_mask = []

        user_turns, assistant_turns = 0, 0
        all_finish_reason_lists = []
        break_true = False

        while True:
            with simple_timer("generate_sequences", metrics):
                sampling_params_copy = copy.deepcopy(sampling_params)
                vllm_output= await self.server_manager.generate(
                    request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params_copy
                )
                response_ids = vllm_output.token_ids
                response_logprobs = vllm_output.log_probs
                finish_reason_list = vllm_output.finish_reason

            if break_true == True:
                pass
            all_finish_reason_lists.append(finish_reason_list)
            prompt_ids += response_ids
            response_mask += [1] * len(response_ids)
            assistant_turns += 1

            _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids, function_name=self.function_name)
            if not tool_calls:
                break

            # call tools
            tasks = []
            for tool_call in tool_calls[: self.max_parallel_calls]:
                tasks.append(self._call_tool(tool_call, tools_kwargs))
            with simple_timer("tool_calls", metrics):
                tool_responses = await asyncio.gather(*tasks)
            if any(isinstance(item, Exception) for item in tool_responses):
                break

            temp_tool_response = tool_responses[0]["content"]
            if (
                "/tmp" in temp_tool_response
                or "/root" in temp_tool_response
                or "Unknown error" in temp_tool_response
                or "/lib" in temp_tool_response
                or "/python" in temp_tool_response
                or "Traceback" in temp_tool_response 
                ):
                
                interpreter_results = (
                "\n<interpreter>\n"
                + temp_tool_response
                + "\n</interpreter>\n"
                + "Okay, the above is the result of using Python code.\n"
                + "Oops, the code above appears to be throwing an error. "
                + "I need to fix this to ensure it runs successfully."
                )

            else:
                interpreter_results = "\n<interpreter>\n" + temp_tool_response + "\n</interpreter>\n"

            tool_response_ids = await self.loop.run_in_executor(
                None,
                lambda tool_results=interpreter_results: self.tokenizer.encode(tool_results)
            )

            # NOTE: last turn should not be user turn, or the EOS token reward
            # can't be propagated to previous token in GAE.
            prompt_ids += tool_response_ids
            response_mask += [0] * len(tool_response_ids)

            # reach max response length
            if len(response_mask) >= response_length:
                break

            # reach max assistant turns
            if max_assistant_turns and assistant_turns >= max_assistant_turns:
                break

            # reach max user turns
            if max_user_turns and user_turns >= max_user_turns:
                break

            user_turns += 1
            if len(prompt_ids) >= max_model_len:
                break


        all_response_mask = pre_response_mask + response_mask
        response_ids = prompt_ids[-len(all_response_mask) :]
        prompt_ids = prompt_ids[: len(prompt_ids) - len(all_response_mask)]

        last_finish_reason_ls = all_finish_reason_lists[-1]
        finish_reason_text = last_finish_reason_ls[0]
        stop_resaon_text = last_finish_reason_ls[1]
        if finish_reason_text == "stop" and stop_resaon_text == None:
            is_finish = True
        else:
            is_finish = False

        output = AgentLoopOutput(
            prompt_ids=prompt_ids,
            response_ids=response_ids[: max_response_length],
            response_mask=all_response_mask[: max_response_length],
            num_turns=user_turns + assistant_turns + 1,
            metrics=metrics,
            is_finish=is_finish,
            messages=messages,
        )
        return output

    async def _call_tool(self, tool_call: FunctionCall, tools_kwargs: dict[str, Any]) -> dict[str, str]:
        """
        Execute a single tool call and return its (possibly-truncated) textual outputs.

        Args:
            tool_call (FunctionCall):
                Parsed function/tool call emitted by the model. Expected fields:
                - `name` (str): Tool identifier (must exist in `self.tools`).
                - `arguments` (dict | str): JSON-like args to the tool. If str, will
                  be parsed as JSON when possible.
            tools_kwargs (dict[str, Any]):
                Runtime overrides to pass to tools (e.g., auth tokens, timeouts,
                environment flags). These are shallow-merged with `tool_call.arguments`;
                explicit `tool_call.arguments` take precedence unless documented otherwise.
        """
        tool, instance_id = None, None
        try:
            # TODO: append malformed tool_call to the prompt: invalid function name or arguments
            tool_name = tool_call.name
            tool_args = json.loads(tool_call.arguments)
            tool = self.tools[tool_name]
            kwargs = tools_kwargs.get(tool_name, {})
            instance_id, _ = await tool.create(create_kwargs=kwargs.get("create_kwargs", {}))
            tool_execution_response, tool_reward, res= await tool.execute(instance_id, tool_args)
        except Exception as e:
            logger.exception(f"Error when executing tool: {e}")
            return {"role": "tool", "content": f"Error when executing tool: {e}"}
        finally:
            if tool and instance_id:
                await tool.release(instance_id)

        tool_response_text = tool_execution_response.text
        tool_response = tool_response_text.replace("/data/miniconda3/envs/sandbox-runtime", '/root')
        tool_response = tool_response.replace("/miniconda3/envs/sandbox-runtime", '')
        tool_response = tool_response.replace("/data/miniconda3", '/root')

        if len(tool_response) > self.max_tool_response_length:
            if self.tool_response_truncate_side == "left":
                tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
            elif self.tool_response_truncate_side == "right":
                tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
            else:
                length = self.max_tool_response_length // 2
                tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]

        return {
            "role": "tool",
            "content": tool_response,
        }
