import json
import copy
from typing import Any, List, Dict
from collections import defaultdict
from uuid import uuid4
import traceback
from skyrl_agent.functional.utils import (
    Transition,
    record_transition,
    StepResult,
    StepException,
    ContextWindowExceeded,
    ParseError,
    NoToolCall,
    ToolExecutionFailed,
)

from skyrl_agent.functional.history import (
    MessageHistory,
    MessageEncoder,
    parse_tool_call,
    extract_tool_info,
    check_truncated_tool_call,
    format_output_preview,
)
from skyrl_agent.config.configuration_utils import TrajectoryConfig
from skyrl_agent.integrations.base import AsyncInferBackend
from skyrl_agent.tools.base import TOOL_REGISTRY
from skyrl_agent.dispatcher.async_utils import call_sync_from_async
from .messages import (
    TOOL_CALL_PARSE_ERROR_GUIDANCE,
    NO_TOOL_CALL_DETECTED_GUIDANCE,
    TOOL_INVOCATION_ERROR_GUIDANCE,
    get_turn_reminder_text,
)
from skyrl_agent.functional.function_calling import convert_fncall_messages_to_non_fncall_messages


class ReActAgent:
    def __init__(
        self,
        traj_config: TrajectoryConfig,
        infer_engine: AsyncInferBackend,
        tokenizer: Any,
    ) -> None:
        self.tokenizer = tokenizer
        self.infer_engine = infer_engine
        self.sampling_params = traj_config.sampling_params

        self.max_prompt_length = traj_config.max_prompt_length
        self.qwen3_enable_thinking = traj_config.qwen3_enable_thinking
        self.qwen3_acc_thinking = traj_config.qwen3_acc_thinking

        self.instance_id = traj_config.instance_id
        self.trajectory_id = traj_config.trajectory_id
        self.max_iterations = traj_config.max_iterations
        self.early_step_threshold = traj_config.early_step_threshold
        self.enable_turn_reminder = getattr(traj_config, "enable_turn_reminder", False)

        self.step_count = 0
        self.history = MessageHistory()
        self.tools = {}
        self.tool_params = []
        self.transitions: List[Transition] = []  # Record transitions per LLM call

        self.agent_id = uuid4().hex
        self._register_tools(traj_config.tools)

        # Message encoder
        self.message_encoder = MessageEncoder(
            tokenizer, qwen3_enable_thinking=self.qwen3_enable_thinking, qwen3_acc_thinking=self.qwen3_acc_thinking
        )

        self.prompt_token_len = 0
        self.response_token_len = 0

        # Debug and profiling flags/counters
        self._debug = bool(traj_config.debug_log)
        self._profile_enabled = bool(traj_config.profile_tools)
        self._tool_calls_total: int = 0
        self._tool_calls_by_name: Dict[str, int] = defaultdict(int)

    def _register_tools(self, tools: List[str]) -> None:
        """Register a list of tool instances."""
        print(f"[Register Tools] {tools}")
        for name in tools:
            if name not in TOOL_REGISTRY:
                raise ValueError(f"Unknown tool '{name}'. Must be one of: {list(TOOL_REGISTRY)}")
            tool = TOOL_REGISTRY[name]()
            # Enforce unique function names per agent for ALL tools
            if tool.name in self.tools:
                raise ValueError(
                    f"Duplicate tool function name '{tool.name}' for this agent. "
                    f"Tool function names must be unique per agent."
                )
            self.tools[tool.name] = tool
            self.tool_params.append(tool.get_tool_param())

    @record_transition
    async def _generate_with_recording(self, input_ids, sampling_params, request_id):
        """LLM generation wrapper that records transitions.

        This method is decorated to automatically capture:
        - input_ids: tokens fed to the LLM
        - output_tokens: tokens generated by the LLM
        - logprobs: log probabilities of generated tokens
        """
        return await self.infer_engine.async_generate_ids(
            input_ids=input_ids,
            sampling_params=sampling_params,
            request_id=request_id,
        )

    def _prepare_llm_input(self) -> tuple[List[int], Dict]:
        """Prepare input_ids and sampling params for LLM using incremental encoding.

        When history is reset, retokenizes everything. Otherwise, performs incremental
        encoding by appending new messages to existing tokens.

        Returns:
            Tuple of (input_ids, sampling_params)
        """

        # Check if history was reset - store flag before clearing it
        history_was_reset = self.history.was_reset()
        if history_was_reset:
            self.prompt_token_len = 0
            self.history.clear_reset_flag()

        # Add turn reminder to history (optional via config)
        if self.enable_turn_reminder:
            remaining_steps = self.max_iterations - self.step_count + 1
            reminder_text = get_turn_reminder_text(
                self.step_count,
                remaining_steps,
                early_step_threshold=self.early_step_threshold,
            )
            self.history.add_turn_reminder(reminder_text)

        # Determine if we should retokenize everything or use incremental encoding
        is_prompt = len(self.history) == 2  # system + user (with reminder appended)
        should_retokenize = is_prompt or history_was_reset or not self.transitions

        if should_retokenize:
            # Retokenize everything (first message or after history reset)
            input_ids = self.message_encoder.encode_messages(
                self.history.messages,
                self.tool_params,
                is_first_message=True,
            )
            self.prompt_token_len = len(input_ids)
        else:
            # Incremental encoding: append new messages to existing tokens
            # TODO(@csy): Handle nested agent scenarios
            # When tools spawn subagents that also use LLM generation, self.transitions[-1]
            # may belong to a subagent rather than this agent, breaking incremental encoding.
            # Proposed solution: Filter transitions by agent_id:
            #   own_transitions = [t for t in self.transitions if getattr(t, 'agent_id', None) == self.agent_id]
            #   last_transition = own_transitions[-1] if own_transitions else None
            #   If last_transition is None, fallback to retokenizing everything.
            last_transition = self.transitions[-1]
            if not last_transition.ac.token_ids:
                # Retokenize the action response_str if serving endpoints do not return token_ids
                message = [{"role": "assistant", "content": last_transition.ac.text}]
                last_transition.ac.token_ids = self.message_encoder.encode_messages(
                    message, self.tool_params, add_generation=False
                )

            # Encode only the new observation message(s)
            # Simple default behavior: assuming one env observation message per step
            # Can be overridden by subclasses
            new_obs_ids = self.message_encoder.encode_messages(
                [self.history.messages[-1]], self.tool_params, add_generation=True
            )

            # Build input_ids incrementally: previous observation + previous action + new observation
            input_ids = last_transition.ob.input_ids + last_transition.ac.token_ids + new_obs_ids

        self.response_token_len = len(input_ids) - self.prompt_token_len

        # Prepare sampling params
        sampling_params = copy.deepcopy(self.sampling_params)
        sampling_params["max_tokens"] = self.max_prompt_length - self.response_token_len

        return input_ids, sampling_params

    def _prepare_llm_input_deprecated(self) -> tuple[List[int], Dict]:
        """[DEPRECATED] Prepare input_ids and sampling params for LLM.

        This is an old version that performs retokenization at every turn.
        This method is deprecated and will be removed in a future version.
        Use `_prepare_llm_input()` instead.

        Returns:
            Tuple of (input_ids, sampling_params)
        """

        # Check if history was reset - store flag before clearing it
        history_was_reset = self.history.was_reset()
        if history_was_reset:
            self.prompt_token_len = 0
            self.history.clear_reset_flag()

        # Track token lengths
        # Encode messages to input_ids
        if self.enable_turn_reminder:
            remaining_steps = self.max_iterations - self.step_count + 1
            reminder_text = get_turn_reminder_text(
                self.step_count,
                remaining_steps,
                early_step_threshold=self.early_step_threshold,
            )
            self.history.add_turn_reminder(reminder_text)

        input_ids = self.message_encoder.encode_messages(
            self.history.messages,
            self.tool_params,
            is_first_message=True,
        )
        # Set prompt_token_len on first message (initial setup) or after history reset
        is_prompt = len(self.history) == 2  # system + user (possibly with reminder appended)
        if is_prompt or history_was_reset:
            self.prompt_token_len = len(input_ids)

        self.response_token_len = len(input_ids) - self.prompt_token_len

        # Prepare sampling params
        sampling_params = copy.deepcopy(self.sampling_params)
        sampling_params["max_tokens"] = self.max_prompt_length - self.response_token_len

        return input_ids, sampling_params

    def _handle_parse_error(self, error: str) -> None:
        """Handle tool call parsing error and raise ParseError."""
        print(f"[Agent Step Error] Converter failed to parse tool call: {error}")
        guidance = TOOL_CALL_PARSE_ERROR_GUIDANCE.format(error=error)

        self.history.add_tool_error(error)
        self.history.add_user_guidance(guidance)

        raise ParseError()

    def _handle_no_tool_call(self, response_str: str) -> None:
        """Handle case when no tool call is detected and raise NoToolCall."""
        print(f"[Agent Step {self.step_count}] No tool call found in response")

        # Check if response was likely truncated during a tool call
        if check_truncated_tool_call(response_str):
            print("[ERROR] Tool call appears incomplete - likely truncated!")
            print(f"[ERROR] Last 500 chars: {response_str[-500:]}")

        self.history.add_user_guidance(NO_TOOL_CALL_DETECTED_GUIDANCE)
        raise NoToolCall()

    async def _execute_tool(self, tool_name: str, tool_args: Dict, tool_call_id: str) -> Any:
        """Execute a tool and return output.

        Raises:
            ToolExecutionFailed: If tool execution fails
        """
        tool = self.tools[tool_name]

        try:
            output = await call_sync_from_async(
                tool.call,
                tool_args,
                agent=self,
                trajectory_id=self.trajectory_id,
            )

            # Record profiling stats if enabled
            if self._profile_enabled:
                try:
                    self._tool_calls_total += 1
                    if tool_name:
                        self._tool_calls_by_name[tool_name] += 1
                except Exception:
                    pass

            return output

        except Exception as e:
            # Tool invocation failed
            error_str = str(e)
            try:
                self.history.add_tool_error(error_str, tool_call_id)
            except Exception:
                self.history.add_tool_error("Tool failed with an exception.", tool_call_id)

            self.history.add_user_guidance(TOOL_INVOCATION_ERROR_GUIDANCE)
            raise ToolExecutionFailed()

    def _append_tool_output(self, output: Any, tool_call_id: str) -> None:
        """Append tool output to message history.

        Args:
            output: Tool output to append
            tool_call_id: ID of the tool call
        """
        try:
            self.history.add_tool_response(output, tool_call_id)

            if self._debug:
                preview = format_output_preview(output)
                print(f"[Tool Output Preview] {preview}")

        except Exception as e:
            print(f"[Agent Step Error] Error appending tool output to messages: {str(e)}")
            self.history.add_tool_error(str(e), tool_call_id)

    async def step(self):
        """Execute one agent step: LLM generation -> tool call -> tool execution.

        Returns:
            Tuple of (done, finish_reason, result)
        """
        self.step_count += 1
        print(f"[Agent Step {self.step_count}] instance={self.instance_id} traj={self.trajectory_id}")

        result = None

        try:
            # 1. Prepare LLM input
            input_ids, sampling_params = self._prepare_llm_input()

            # Check context window
            if self.response_token_len >= self.max_prompt_length:
                print("[Agent Step] Stopping reason: context_window_exceeded. Stopping agent.")
                raise ContextWindowExceeded()

            # 2. Generate LLM response
            response_str, meta_info = await self._generate_with_recording(
                input_ids=input_ids,
                sampling_params=sampling_params,
                request_id=self.agent_id,
            )
            stop_reason = meta_info["finish_reason"]
            print(f"[Agent Step {self.step_count}] LLM response: {response_str}. Stop reason: {stop_reason}")

            # Add assistant message to history
            self.history.add_assistant(response_str)

            # Check if generation stopped due to length
            if stop_reason == "length":
                print(f"[Agent Step] Stopping reason: {stop_reason}. Stopping agent.")
                raise ContextWindowExceeded()

            # 3. Parse tool call from response
            tool_call, parse_error = parse_tool_call(response_str, self.tool_params)

            # Handle parse error
            if parse_error:
                self._handle_parse_error(parse_error)

            # Handle no tools scenario
            if not self.tools:
                print(f"[Agent Step {self.step_count}] No tools provided, returning response.")
                result = StepResult.finished("FINISH", response_str)

            # Handle no tool call detected
            elif tool_call is None:
                self._handle_no_tool_call(response_str)

            else:
                # 4. Extract tool information
                tool_name, tool_args = extract_tool_info(tool_call)
                tool_call_id = tool_call.get("id")

                # Validate tool exists
                if tool_name not in self.tools:
                    self.history.add_user_guidance(json.dumps({"error": f"Tool '{tool_name}' not found."}))
                    result = StepResult.continuing(response_str)
                else:
                    # 5. Execute tool
                    output = await self._execute_tool(tool_name, tool_args, tool_call_id)

                    # 6. Check if finish tool was called
                    if tool_name == "finish":
                        print(f"[Agent Step {self.step_count}] Finish tool called. Stopping agent.")
                        result = StepResult.finished("FINISH_TOOL", output)
                    else:
                        # Continue agent loop
                        result = StepResult.continuing(response_str)

                        # 7. Append tool output to history only if output is not None
                        # Some tools (like next_with_summary) embed feedback in user message
                        # and return None to skip adding tool output
                        if output is not None:
                            print(f"[Tool Output step {self.step_count}] {output}")
                            self._append_tool_output(output, tool_call_id)
                        else:
                            print(f"[Tool Output step {self.step_count}] No output (feedback embedded in user message)")

        except StepException as e:
            # Handle expected control flow exceptions
            result = e.step_result

        except Exception as e:
            # Handle unexpected errors
            print(f"[Agent Step Error] Error during step: {str(e)}")
            result = StepResult.finished(f"error: {str(e)}", None)

        # Single exit point
        return result.to_tuple()

    async def run(self, instruction: List[Dict], instance: Dict | None = None) -> List[str]:
        """Run the agent till the end with the provided user input.
        Optionally accepts an instance payload for tools (stored on self.instance).
        """
        self.instance = instance
        self._init_message(instruction)
        result = None
        finish_reason = None
        while self.step_count < self.max_iterations:
            try:
                done, finish_reason, result = await self.step()
                if done:
                    break
            except Exception as e:
                finish_reason = f"error: {str(e)}"
                print(f"[Agent Run Error] Exception during step: {str(e)}")
                # traceback
                print(traceback.format_exc())
                break
        else:  # If we exit the loop without hitting a break, it means we reached max iterations
            finish_reason = "max_iterations_reached"

        return finish_reason, result

    def get_messages(self) -> List[dict]:
        return convert_fncall_messages_to_non_fncall_messages(self.history.messages, self.tool_params)

    def get_transitions(self) -> List[Transition]:
        """Return the list of transitions recorded during agent execution.

        Each transition contains:
        - ob: Observation with input_ids (tokens fed to LLM)
        - ac: TokensWithLogprobs with output_tokens, logprobs, and generated text
        - reward: Float reward value (default 0.0, can be updated based on outcomes)
        - episode_done: Boolean indicating if episode finished
        - metrics: Dict with finish_reason, response_length, and other metadata
        """
        return self.transitions

    def _init_message(self, instruction: List[Dict]) -> None:
        """Initialize the agent's message history with the provided instruction.

        Automatically collects system prompt prefixes from registered tools and prepends them
        to the system message if present.
        """
        if not isinstance(instruction, list):
            raise ValueError("Instruction must be a list of messages.")

        for msg in instruction:
            if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
                raise ValueError("Each message must be a dictionary with 'role' and 'content'.")

        # Collect system prompt prefixes from registered tools
        tool_prefixes = []
        for tool_name, tool in self.tools.items():
            prefix = tool.get_system_prompt_prefix()
            if prefix:
                tool_prefixes.append(prefix)

        # Prepend tool prefixes to system message if any exist
        if tool_prefixes:
            processed_instruction = copy.deepcopy(instruction)
            # Combine all tool prefixes
            combined_prefix = "\n\n---\n\n".join(tool_prefixes)

            # Find the first system message
            system_msg_found = False
            for msg in processed_instruction:
                if msg.get("role") == "system":
                    # Prepend tool prefixes to existing system message
                    msg["content"] = combined_prefix + "\n\n---\n\n" + msg["content"]
                    system_msg_found = True
                    break

            # If no system message exists, create one at the beginning
            if not system_msg_found:
                processed_instruction.insert(0, {"role": "system", "content": combined_prefix})

            self.history.initialize(processed_instruction)
        else:
            self.history.initialize(instruction)

    # Expose profiling snapshot for upstream aggregation
    def get_tool_profile(self) -> Dict[str, Any]:
        if not self._profile_enabled:
            return None
        try:
            return {
                "tool_calls_total": int(self._tool_calls_total),
                "tool_calls_by_name": dict(self._tool_calls_by_name),
            }
        except Exception:
            return None


if __name__ == "__main__":
    # Example usage for testing
    from skyrl_agent.config.configuration_utils import TrajectoryConfig
    from skyrl_agent.integrations.openai import OpenAIBackend, OpenAIBackendConfig
    from transformers import AutoTokenizer
    import asyncio

    # Load tokenizer and model
    model_name = "Qwen/Qwen2.5-1.5B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Define trajectory configuration
    traj_config = TrajectoryConfig(
        instance_id="test_instance",
        trajectory_id="test_trajectory",
        sampling_params={
            "temperature": 0.7,
            "top_p": 0.95,
            "max_tokens": 2048,
        },
        max_prompt_length=12048,
        qwen3_enable_thinking=True,
        tools=["finish", "code_interpreter"],
        max_iterations=5,
        agent_cls="skyrl_agent.agents.react.ReActAgent",  # Use ReActAgent for testing
    )

    backend_config = OpenAIBackendConfig(
        model_name=model_name,
        # change this to your desired url and port
        api_url="http://localhost:8000",
    )
    # TODO: model_name need not be in config
    infer_engine = OpenAIBackend(infer_engine=None, cfg=backend_config)

    # Create the ReAct agent
    # Test for with tools
    agent = ReActAgent(
        traj_config=traj_config,
        infer_engine=infer_engine,
        tokenizer=tokenizer,
    )

    # Define a sample instruction
    instruction = [
        {"content": "Please reason step by step, and put your final answer within \\boxed{}.", "role": "system"},
        {
            "content": "Points $A,B,C,D,E$ and $F$ lie, in that order, on $\\overline{AF}$, dividing it into five segments, each of length 1. Point $G$ is not on line $AF$. Point $H$ lies on $\\overline{GD}$, and point $J$ lies on $\\overline{GF}$. The line segments $\\overline{HC}, \\overline{JE},$ and $\\overline{AG}$ are parallel. Find $HC/JE$.",
            "role": "user",
        },
    ]

    # Run the agent
    finish_reason, result = asyncio.run(agent.run(instruction))

    print(agent.get_messages())
    print(f"Finish Reason: {finish_reason}")
    print(f"Result: {result}")
