
import json5
import re
import asyncio
import random

from typing import Callable, Awaitable, Dict, Any
from datetime import datetime, timedelta

# =================================================================
#  Agent Tools and Prompts
# =================================================================
from .agent_tools import (
    Search, Scholar, PythonInterpreter, Visit,
    initial_instruction_prompt, instruction_prompt, observation_prompt, last_instruction_prompt,
    AGENT_TOOLS_SCHEMA,
)

search_engine = Search()
scholar_engine = Scholar()
visit_tool = Visit()
python_executor = PythonInterpreter()

TOOL_DESC = """# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>\n{"type": "function", "function": {"name": "google_search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}}\n{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will alse return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}}\n{"type": "function", "function": {"name": "Visit", "description": "Visit webpage(s) or paper(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "minItems": 1, "description": "The URL(s) of the webpage(s) or paper(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The goal of the visit for webpage(s) or paper(s)."}, "parse_type": {"type": "string", "enum": ["html", "pdf"], "default": "html", "description": "Specify whether to visit a HTML webpage or a PDF paper. Must be either \'html\' or \'pdf\'. Defaults to \'html\' if not specified."}}, "required": ["url", "goal"]}}}\n{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes arbitrary Python code in a secure, sandboxed environment. This tool is designed for performing complex calculations, data manipulations, string processing, logical operations, and general programming tasks that require programmatic logic. It can define variables, use built-in functions, and print results to standard output. It operates with a limited set of pre-installed libraries and cannot access local files, external networks, or system resources.", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": "The Python code to execute. This should be a complete and runnable Python script or series of statements. All output should be explicitly printed using `print()` functions."}}, "required": ["code"]}}}\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>"""

def random_date(start=datetime(2023, 1, 1), end=datetime(2025, 12, 31)):
    delta = end - start
    random_days = random.randint(0, delta.days)
    final_random_date = start + timedelta(days=random_days)
    return final_random_date.strftime("%Y-%m-%d")

def clean_text_for_tokenizer(text: str) -> str:
    """
    Cleans a string by resolving UTF-16 surrogate pairs into
    standard Unicode characters that Fast Tokenizers can handle.
    """
    if not isinstance(text, str):
        return text # Or handle non-string input as you see fit
    return text.encode('utf-16', 'surrogatepass').decode('utf-16')

def parse_text_output(text: str):
    parse_output = {
        # "think": None,
        "report": None,
        "action": None,
        "answer": None,
        # "text_without_think": text
    }


    report_match = re.search(r"<report>(.*?)</report>", text, re.DOTALL)
    if report_match:
        parse_output["report"] = report_match.group(1).strip()
    else:
        return False, 'Report tag not found.'

    action_match = re.search(r"<action>(.*?)</action>", text, re.DOTALL)
    if action_match:
        parse_output["action"] = action_match.group(1).strip()

    answer_match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
    if answer_match:
        parse_output["answer"] = answer_match.group(1).strip()

    return True, parse_output

# =================================================================
#  Agent Logic Class
# =================================================================
MAX_AGENT_TURNS = 20

class ContextAgent:
    def __init__(
        self,
        question: str,
        sampling_params: Dict[str, Any],
        llm_caller: Callable[[Dict, Dict], Awaitable[Dict]],
        max_agent_turns: int = MAX_AGENT_TURNS,
    ):
        """
        初始化 Agent。

        Args:
            question: 用户的初始问题。
            sampling_params: LLM 的采样参数。
            llm_caller: 一个异步函数，接收 (messages_list, sampling_params) 并返回 LLM 输出。
            max_agent_turns: 此 Agent 的最大思考轮次。
        """
        self.question = clean_text_for_tokenizer(question)
        self.sampling_params = sampling_params
        self.llm_caller = llm_caller
        self.max_agent_turns = max_agent_turns
        self.date = random_date()

        
    async def _execute_tool_call(self, tool_call_dict: Dict[str, Any]):
        """Executes a single tool call and returns the observation."""
        tool_name = tool_call_dict.get('name')
        try:
            arguments_str = tool_call_dict.get('parameters', '{}')
            arguments = json5.loads(arguments_str)
            observation = f"Tool {tool_name} failed or returned no output."
            summary_message_list = []

            if tool_name == 'google_search':
                observation, tool_status = await asyncio.to_thread(search_engine.call, {"query": arguments['query']})
            elif tool_name == 'google_scholar':
                scholar_task = asyncio.to_thread(scholar_engine.call, {"query": arguments['query']})
                search_task = asyncio.to_thread(search_engine.call, {"query": arguments['query']})
                (scholar_res, tool_status_scholar), (search_res, tool_status_search) = await asyncio.gather(scholar_task, search_task)
                observation = scholar_res + "\n\n" + search_res
                tool_status = tool_status_scholar and tool_status_search
            elif tool_name == 'PythonInterpreter':
                tool_output = await asyncio.to_thread(python_executor.call, {"code": arguments['code']})
                tool_status = tool_output['success']
                observation = tool_output['results']
            elif tool_name == 'Visit':
                observation, summary_message_list, tool_status = await asyncio.to_thread(visit_tool.call, arguments_str)
            else:
                observation = f"Error: Unknown tool '{tool_name}'."
                tool_status = False
            
            return clean_text_for_tokenizer(str(observation)), summary_message_list, tool_status
        except Exception as e:
            print(f"Error executing tool {tool_name}: {e}")
            return f"Error parsing arguments or executing tool {tool_name}.", [], False

    def _format_turn_prompt(self, messages, is_last: bool=False):
        user, bot, tool = messages[-3], messages[-2], messages[-1]
        assert user['role'] == 'user'
        assert bot['role'] == 'assistant'
        assert tool['role'] == 'tool'

        parse_status, parse_output = parse_text_output(bot['no_think_text'])
        assert parse_status == True, f"format error when preparing next turn. {parse_output}"
        report, action = parse_output['report'], parse_output['action']
        
        tool_name = tool.get('name', 'N/A')
        tool_arguments = tool.get('arguments', '{}')
        tool_response = tool['content']
        
        observation = observation_prompt.format(tool_name=tool_name, tool_arguments=tool_arguments, tool_response=tool_response)
        
        prompt_template = last_instruction_prompt if is_last else instruction_prompt
        return clean_text_for_tokenizer(prompt_template.format(question=self.question, report=report, action=action, observation=observation, date_to_use=self.date))

    async def rollout(self) -> tuple[str, dict]:
        traj_status = {"status": "health", "message": ""}
        messages = [{"role": "user", "content": initial_instruction_prompt.replace('{question}', self.question).replace('{date_to_use}', self.date)}]
        conversations = []
        last_llm_meta_info = {}

        for turn in range(self.max_agent_turns):
            print(f"Executing agent turn {turn + 1}/{self.max_agent_turns}...")
            is_last_turn = (turn == self.max_agent_turns - 1)

            response = await self.llm_caller(messages, self.sampling_params.copy(), use_tools=not is_last_turn)
            messages[0]['tokens'] = response['input_tokens']

            last_llm_meta_info = response['meta_info']
            current_full_text = response['text']
            
            if last_llm_meta_info["finish_reason"]["type"] in ["exceeded_max_length", "abort"]:
                break

            elif last_llm_meta_info["finish_reason"]["type"] in ["length", "format_error"]:
                assert current_full_text is not None, "LLM returned no text, but status is {last_llm_meta_info['finish_reason']['type']}"
                messages.append({"role": "assistant", "content": current_full_text, "tokens": response['new_response_tokens'], 'rollout_log_probs': response.get('rollout_log_probs', [])})
                conversations.append(messages)
                break
            
            assert current_full_text is not None, f"LLM returned no text, but status is {last_llm_meta_info['finish_reason']['type']}."
            assert "parse_info" in response, "LLM returned no parse_info, but status is healthy."

            messages.append({"role": "assistant", "content": current_full_text, "tokens": response['new_response_tokens'], "no_think_text": response.get("parse_info", {}).get("no_think_text", ""), 'rollout_log_probs': response.get('rollout_log_probs', [])})

            tool_calls = response.get("parse_info", {}).get("tool_calls", None)
            if not tool_calls:
                conversations.append(messages)
                break

            # --- Tool Execution Step ---
            tool_call_obj = tool_calls[0]
            observation, _, tool_status = await self._execute_tool_call(tool_call_obj)
            tool_message = {
                "role": "tool",
                "tool_call_id": -1,
                "name": tool_call_obj['name'],
                "arguments": tool_call_obj['parameters'],
                "content": observation,
                'tool_status': tool_status,
            }
            messages.append(tool_message)

            # --- Prepare for next turn ---
            conversations.append(messages)

            if turn < self.max_agent_turns - 1:
                next_turn_prompt = self._format_turn_prompt(messages, is_last=(turn == self.max_agent_turns - 2))
                messages = [{"role": "user", "content": next_turn_prompt}]

        traj_status['status'] = last_llm_meta_info["finish_reason"]["type"]
        traj_status['message'] = f"LLM failed due to {last_llm_meta_info["finish_reason"]["type"]}."

        return conversations, traj_status
