import asyncio
import ast
import json
import re
import sys
import uuid
from pathlib import Path
from typing import Any, TypedDict

from langchain_core.messages import AIMessage
from langgraph.graph import END, StateGraph

_ROOT = Path(__file__).resolve().parents[1]
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

from config_loader import load_config
from llm import load_llm
from llm_agent.coala.pm import ProceduralMemory


PROFILE_ID = "executor"
PROFILE_URL = "http://localhost:5000/profile/executor"
READ_USER_GOAL = "The agent wants to read the user goal."
PROVIDE_FEEDBACK_GOAL = "The agent wants to provide feedback to the user."
FORMALIZE_GOAL = "The agent wants to convert the natural language goal into a formal representation."
ROBOT_GOAL = "The agent wants to perform a formal operation on the robot."
FORMAL_COMMAND_RE = re.compile(r"\b(?:move|rotate)\s*\(\s*[-+]?\d+(?:\.\d+)?\s*\)", re.IGNORECASE)


class ExecutorState(TypedDict, total=False):
    current_state: str
    current_goal: str
    formalized_goal: str
    last_result: Any
    error: str
    selected_tool: str


class LangGraphExecutorAgent:
    def __init__(self) -> None:
        config = load_config(str(_ROOT / "config.json"))
        self.llm = load_llm(
            config["llm_agent"]["provider"],
            config["llm_agent"]["model"],
            temperature=config["llm_agent"].get("temperature"),
        )
        self.procedural_memory = ProceduralMemory()
        self.procedural_memory.register_mcp_server(
            name="mcp_sem",
            server_url=config["sem_mcp_endpoint"],
        )
        self.tool_timeout_seconds = 180
        self.last_run_path = _ROOT / "last_run_executor_agent.txt"
        self.last_run_path.write_text("", encoding="utf-8")
        self.graph_transitions: list[dict[str, Any]] = []
        self.graph = self._build_graph()

    def _build_graph(self):
        graph = StateGraph(ExecutorState)
        graph.add_node("start", self.start)
        graph.add_node("set_maximum_signifiers", self.set_maximum_signifiers)
        graph.add_node("set_minimum_relevance_value", self.set_minimum_relevance_value)
        graph.add_node("initialize_goal_context", self.initialize_goal_context)
        graph.add_node("read_signifier1", self.read_signifier1)
        graph.add_node("notify_goal_started", self.notify_goal_started)
        graph.add_node("update_formalize_context", self.update_formalize_context)
        graph.add_node("read_formalizer_signifier", self.read_formalizer_signifier)
        graph.add_node("use_formalizer_signifier", self.use_formalizer_signifier)
        graph.add_node("update_robot_context", self.update_robot_context)
        graph.add_node("read_robot_signifier", self.read_robot_signifier)
        graph.add_node("use_robot_signifier", self.use_robot_signifier)
        graph.add_node("final_notification", self.final_notification)
        graph.add_node("error", self.error)

        graph.set_entry_point("start")
        graph.add_edge("start", "set_maximum_signifiers")
        graph.add_edge("set_maximum_signifiers", "set_minimum_relevance_value")
        graph.add_edge("set_minimum_relevance_value", "initialize_goal_context")
        graph.add_edge("initialize_goal_context", "read_signifier1")
        graph.add_conditional_edges(
            "read_signifier1",
            self._route_after_goal_read,
            {"notify_goal_started": "notify_goal_started", "error": "error"},
        )
        graph.add_conditional_edges(
            "notify_goal_started",
            self._route_after_notify_goal_started,
            {"next": "update_formalize_context", "error": "error"},
        )
        graph.add_edge("update_formalize_context", "read_formalizer_signifier")
        graph.add_conditional_edges(
            "read_formalizer_signifier",
            self._route_after_read_formalizer_signifier,
            {"next": "use_formalizer_signifier", "error": "error"},
        )
        graph.add_conditional_edges(
            "use_formalizer_signifier",
            self._route_after_formalizer,
            {"update_robot_context": "update_robot_context", "initialize_goal_context": "initialize_goal_context", "error": "error"},
        )
        graph.add_edge("update_robot_context", "read_robot_signifier")
        graph.add_conditional_edges(
            "read_robot_signifier",
            self._route_after_read_robot_signifier,
            {"next": "use_robot_signifier", "error": "error"},
        )
        graph.add_conditional_edges(
            "use_robot_signifier",
            self._route_after_robot,
            {"final_notification": "final_notification", "error": "error"},
        )
        graph.add_conditional_edges(
            "final_notification",
            self._route_final_notification,
            {"end": END, "error": "error"},
        )
        graph.add_edge("error", END)
        return graph.compile()

    async def run(self) -> None:
        await self._sync_tools()
        self._record_graph_transition("__start__", "start")
        await self.graph.ainvoke({"current_state": "start"})

    async def start(self, state: ExecutorState) -> ExecutorState:
        await self._call_tool("register_profile", {"profile_id": PROFILE_ID})
        self._record_graph_transition("start", "set_maximum_signifiers")
        return {"current_state": "set_maximum_signifiers"}

    async def set_maximum_signifiers(self, state: ExecutorState) -> ExecutorState:
        await self._call_tool(
            "set_maximum_signifiers",
            {"profile_id": PROFILE_ID, "maximum_signifiers": 1},
        )
        self._record_graph_transition("set_maximum_signifiers", "set_minimum_relevance_value")
        return {"current_state": "set_minimum_relevance_value"}

    async def set_minimum_relevance_value(self, state: ExecutorState) -> ExecutorState:
        await self._call_tool(
            "set_minimum_relevance_value",
            {"profile_id": PROFILE_ID, "minimum_relevance_value": 0.75},
        )
        self._record_graph_transition("set_minimum_relevance_value", "initialize_goal_context")
        return {"current_state": "initialize_goal_context"}

    async def initialize_goal_context(self, state: ExecutorState) -> ExecutorState:
        await self._update_profile(READ_USER_GOAL)
        self._record_graph_transition("initialize_goal_context", "read_signifier1")
        return {"current_state": "read_signifier1"}

    async def read_signifier1(self, state: ExecutorState) -> ExecutorState:
        await self._read_signifiers()
        tool_name = self._select_tool(
            include=("goal", "current"),
            exclude=("feedback", "formal", "robot", "operation"),
            purpose="read the current natural-language user goal",
        )
        if not tool_name:
            return {"current_state": "error", "error": "No signifier tool found for reading the user goal."}
        result = await self._call_tool(tool_name, {})
        goal = self._extract_goal(result)
        if not goal:
            return {"current_state": "error", "error": f"Could not extract user goal from {result!r}."}
        return {
            "current_state": "notify_goal_started",
            "current_goal": goal,
            "selected_tool": tool_name,
        }

    async def notify_goal_started(self, state: ExecutorState) -> ExecutorState:
        goal = state.get("current_goal", "")
        await self._update_profile(PROVIDE_FEEDBACK_GOAL)
        await self._read_signifiers()
        tool_name = self._select_tool(
            include=("feedback",),
            exclude=("currentgoal", "formal", "robot", "operation"),
            purpose="provide feedback to the user",
        )
        if not tool_name:
            return {"current_state": "error", "error": "No feedback signifier tool found."}
        await self._call_tool(
            tool_name,
            {"achieved": False, "feedback": f"Started: {goal}"},
        )
        return {"current_state": "update_formalize_context"}

    async def update_formalize_context(self, state: ExecutorState) -> ExecutorState:
        await self._update_profile(FORMALIZE_GOAL)
        self._record_graph_transition("update_formalize_context", "read_formalizer_signifier")
        return {"current_state": "read_formalizer_signifier"}

    async def read_formalizer_signifier(self, state: ExecutorState) -> ExecutorState:
        await self._read_signifiers()
        tool_name = self._select_tool(
            include=("formal",),
            exclude=("feedback", "currentgoal", "robot", "operation"),
            purpose="formalize a natural-language robot goal",
        )
        if not tool_name:
            return {"current_state": "error", "error": "No formalizer signifier tool found."}
        return {"current_state": "use_formalizer_signifier", "selected_tool": tool_name}

    async def use_formalizer_signifier(self, state: ExecutorState) -> ExecutorState:
        goal = state.get("current_goal", "")
        if not goal:
            return {"current_state": "initialize_goal_context"}
        tool_name = state.get("selected_tool")
        if not tool_name:
            return {"current_state": "error", "error": "Formalizer tool was not selected."}
        result = await self._call_tool(tool_name, self._single_text_input(tool_name, goal))
        formalized_goal = self._extract_formal_command(result)
        if not formalized_goal:
            return {"current_state": "error", "error": f"Could not extract formal command from {result!r}."}
        return {
            "current_state": "update_robot_context",
            "formalized_goal": formalized_goal,
        }

    async def update_robot_context(self, state: ExecutorState) -> ExecutorState:
        await self._update_profile(ROBOT_GOAL)
        self._record_graph_transition("update_robot_context", "read_robot_signifier")
        return {"current_state": "read_robot_signifier"}

    async def read_robot_signifier(self, state: ExecutorState) -> ExecutorState:
        await self._read_signifiers()
        tool_name = self._select_tool(
            include=("operation", "robot"),
            exclude=("feedback", "currentgoal", "formal"),
            purpose="perform a formal move(...) or rotate(...) operation on the robot",
        )
        if not tool_name:
            return {"current_state": "error", "error": "No robot operation signifier tool found."}
        return {"current_state": "use_robot_signifier", "selected_tool": tool_name}

    async def use_robot_signifier(self, state: ExecutorState) -> ExecutorState:
        formalized_goal = state.get("formalized_goal", "")
        tool_name = state.get("selected_tool")
        if not formalized_goal or not tool_name:
            return {"current_state": "error", "error": "Missing robot tool or formalized goal."}
        result = await self._call_tool(tool_name, self._single_text_input(tool_name, formalized_goal))
        if self._result_has_error(result):
            return {"current_state": "error", "error": f"Robot operation failed: {result!r}"}
        return {"current_state": "final_notification", "last_result": result}

    async def final_notification(self, state: ExecutorState) -> ExecutorState:
        goal = state.get("current_goal", "")
        await self._update_profile(PROVIDE_FEEDBACK_GOAL)
        await self._read_signifiers()
        tool_name = self._select_tool(
            include=("feedback",),
            exclude=("currentgoal", "formal", "robot", "operation"),
            purpose="provide final success feedback to the user",
        )
        if not tool_name:
            return {"current_state": "error", "error": "No feedback signifier tool found for final notification."}
        await self._call_tool(
            tool_name,
            {"achieved": True, "feedback": f"Completed: {goal}"},
        )
        return {"current_state": "end"}

    async def error(self, state: ExecutorState) -> ExecutorState:
        message = state.get("error", "Unknown executor error.")
        self._write_log("error", {"current_state": state.get("current_state")}, message)
        self._record_graph_transition("error", "end", condition=message)
        return {"current_state": "end", "error": message}

    def _route_after_goal_read(self, state: ExecutorState) -> str:
        if state.get("current_goal"):
            self._record_graph_transition("read_signifier1", "notify_goal_started")
            return "notify_goal_started"
        self._record_graph_transition(
            "read_signifier1",
            "error",
            condition=state.get("error", "current_goal missing"),
        )
        return "error"

    def _route_after_notify_goal_started(self, state: ExecutorState) -> str:
        return self._route_error_or_next(
            "notify_goal_started",
            "update_formalize_context",
            state,
        )

    def _route_after_read_formalizer_signifier(self, state: ExecutorState) -> str:
        return self._route_error_or_next(
            "read_formalizer_signifier",
            "use_formalizer_signifier",
            state,
        )

    def _route_after_formalizer(self, state: ExecutorState) -> str:
        if state.get("formalized_goal"):
            self._record_graph_transition("use_formalizer_signifier", "update_robot_context")
            return "update_robot_context"
        if not state.get("current_goal"):
            self._record_graph_transition(
                "use_formalizer_signifier",
                "initialize_goal_context",
                condition="current_goal missing",
            )
            return "initialize_goal_context"
        self._record_graph_transition(
            "use_formalizer_signifier",
            "error",
            condition=state.get("error", "formalized_goal missing"),
        )
        return "error"

    def _route_after_read_robot_signifier(self, state: ExecutorState) -> str:
        return self._route_error_or_next(
            "read_robot_signifier",
            "use_robot_signifier",
            state,
        )

    def _route_after_robot(self, state: ExecutorState) -> str:
        if state.get("error"):
            self._record_graph_transition(
                "use_robot_signifier",
                "error",
                condition=state.get("error"),
            )
            return "error"
        self._record_graph_transition("use_robot_signifier", "final_notification")
        return "final_notification"

    def _route_error_or_next(
        self,
        source: str,
        next_node: str,
        state: ExecutorState,
    ) -> str:
        if state.get("error"):
            self._record_graph_transition(source, "error", condition=state.get("error"))
            return "error"
        self._record_graph_transition(source, next_node)
        return "next"

    def _route_final_notification(self, state: ExecutorState) -> str:
        if state.get("error"):
            self._record_graph_transition(
                "final_notification",
                "error",
                condition=state.get("error"),
            )
            return "error"
        self._record_graph_transition("final_notification", "end")
        return "end"

    async def _sync_tools(self) -> None:
        await asyncio.wait_for(self.procedural_memory.sync_mcp_tools(), timeout=20)

    async def _update_profile(self, nl_context: str) -> Any:
        return await self._call_tool(
            "update_profile",
            {"profile_id": PROFILE_ID, "nl_context": nl_context},
        )

    async def _read_signifiers(self) -> Any:
        result = await self._call_tool("read_signifiers", {"profile_url": PROFILE_URL})
        await self._sync_tools()
        return result

    async def _call_tool(self, tool_name: str, tool_input: dict[str, Any]) -> Any:
        tool = self.procedural_memory.get_tool(tool_name)
        if tool is None:
            await self._sync_tools()
            tool = self.procedural_memory.get_tool(tool_name)
        if tool is None:
            message = f"Could not find tool '{tool_name}'."
            self._write_log(tool_name, tool_input, error=message)
            raise RuntimeError(message)

        try:
            result = await asyncio.wait_for(
                tool.ainvoke(tool_input),
                timeout=self.tool_timeout_seconds,
            )
        except Exception as exc:
            self._write_log(tool_name, tool_input, tool=tool, error=str(exc))
            raise
        self._write_log(tool_name, tool_input, result=result, tool=tool)
        return result

    def _select_tool(
        self,
        *,
        include: tuple[str, ...],
        exclude: tuple[str, ...],
        purpose: str,
    ) -> str | None:
        tools = [
            tool
            for name, tool in self.procedural_memory.tools.items()
            if name
            not in {
                "register_profile",
                "update_profile",
                "read_signifiers",
                "set_minimum_relevance_value",
                "set_maximum_signifiers",
                "all_signifiers",
            }
        ]
        scored: list[tuple[int, str]] = []
        for tool in tools:
            text = f"{tool.name} {getattr(tool, 'description', '')}".lower()
            if any(term in text for term in exclude):
                continue
            score = sum(1 for term in include if term in text)
            if score:
                scored.append((score, tool.name))
        if scored:
            scored.sort(key=lambda item: (-item[0], item[1]))
            return scored[0][1]
        if not tools:
            return None
        return self._llm_select_tool(tools, purpose)

    def _llm_select_tool(self, tools: list[Any], purpose: str) -> str | None:
        descriptions = "\n".join(
            f"- {tool.name}: {getattr(tool, 'description', '')}" for tool in tools
        )
        response = self.llm.invoke(
            "Select exactly one tool name for this purpose. "
            "Return only the tool name, with no explanation.\n\n"
            f"Purpose: {purpose}\n\nTools:\n{descriptions}"
        )
        text = self._message_text(response).strip()
        names = {tool.name for tool in tools}
        if text in names:
            return text
        for name in names:
            if name in text:
                return name
        return None

    def _single_text_input(self, tool_name: str, value: str) -> dict[str, Any]:
        tool = self.procedural_memory.get_tool(tool_name)
        required = list(getattr(tool, "required_fields", []) or [])
        if any(field_name.startswith("message__") for field_name in required):
            return {
                "message__messageId": str(uuid.uuid4()),
                "message__role": "user",
                "message__parts": [{"kind": "text", "text": value}],
            }
        for field_name in required:
            if "parts" in field_name:
                return {field_name: [{"kind": "text", "text": value}]}
        if len(required) == 1:
            return {required[0]: value}
        schema = None
        server = getattr(tool, "server", None)
        if server is not None:
            tool_def = getattr(server, "tool_definitions", {}).get(tool_name)
            schema = getattr(tool_def, "input_schema", None) or getattr(tool_def, "inputSchema", None)
        properties = schema.get("properties", {}) if isinstance(schema, dict) else {}
        message_schema = properties.get("message") if isinstance(properties, dict) else None
        if isinstance(message_schema, dict):
            return {
                "message__messageId": str(uuid.uuid4()),
                "message__role": "user",
                "message__parts": [{"kind": "text", "text": value}],
            }
        for field_name in properties:
            if "parts" in field_name:
                return {field_name: [{"kind": "text", "text": value}]}
        for candidate in ("body", "command", "message", "text", "goal", "input"):
            if candidate in properties:
                return {candidate: value}
        if len(properties) == 1:
            return {next(iter(properties)): value}
        return {"body": value}

    def _extract_goal(self, result: Any) -> str:
        data = self._decode_result(result)
        if isinstance(data, dict):
            result_value = data.get("result")
            if isinstance(result_value, dict):
                goal = self._extract_goal(result_value)
                if goal:
                    return goal
            for key in ("content", "text", "body", "result"):
                value = data.get(key)
                if isinstance(value, str) and value.strip():
                    nested = self._decode_result(value)
                    if isinstance(nested, dict):
                        nested_goal = self._extract_goal(nested)
                        if nested_goal:
                            return nested_goal
                    return self._clean_goal(value)
            contents = data.get("contents")
            if isinstance(contents, list):
                for item in contents:
                    goal = self._extract_goal(item)
                    if goal:
                        return goal
            content_items = data.get("content")
            if isinstance(content_items, list):
                for item in content_items:
                    goal = self._extract_goal(item)
                    if goal:
                        return goal
        if isinstance(data, str):
            return self._clean_goal(data)
        return ""

    def _clean_goal(self, text: str) -> str:
        text = text.strip().strip('"')
        if text.lower() == "goal is empty.":
            return ""
        return text

    def _extract_formal_command(self, result: Any) -> str:
        text = self._result_text(result)
        match = FORMAL_COMMAND_RE.search(text)
        return match.group(0).replace(" ", "") if match else ""

    def _result_has_error(self, result: Any) -> bool:
        data = self._decode_result(result)
        if isinstance(data, dict):
            if data.get("error"):
                return True
            content = data.get("content")
            if isinstance(content, str):
                return self._result_has_error(content)
        return False

    def _decode_result(self, value: Any) -> Any:
        if isinstance(value, dict) and set(value.keys()) == {"content"}:
            return self._decode_result(value["content"])
        if isinstance(value, dict):
            decoded = {key: self._decode_result(item) for key, item in value.items()}
            body = decoded.get("body")
            if isinstance(body, (dict, list)):
                return body
            return decoded
        if isinstance(value, str):
            raw = value.strip()
            if not raw:
                return raw
            sse_data = self._decode_sse_data(raw)
            if sse_data is not None:
                return self._decode_result(sse_data)
            try:
                return json.loads(raw)
            except json.JSONDecodeError:
                try:
                    return ast.literal_eval(raw)
                except (ValueError, SyntaxError):
                    return raw
        return value

    def _decode_sse_data(self, text: str) -> Any:
        data_lines = []
        for line in text.splitlines():
            if line.startswith("data:"):
                data_lines.append(line.removeprefix("data:").strip())
        if not data_lines:
            return None
        data_text = "\n".join(data_lines).strip()
        if not data_text:
            return None
        try:
            return json.loads(data_text)
        except json.JSONDecodeError:
            return data_text

    def _result_text(self, result: Any) -> str:
        decoded = self._decode_result(result)
        if isinstance(decoded, str):
            return decoded
        if isinstance(decoded, dict):
            parts = []
            for key in ("content", "text", "body", "result", "message"):
                if key in decoded:
                    parts.append(self._result_text(decoded[key]))
            if parts:
                return "\n".join(part for part in parts if part)
        return str(decoded)

    def _message_text(self, message: Any) -> str:
        if isinstance(message, str):
            return message
        if isinstance(message, AIMessage):
            return message.text()
        return str(message)

    def _record_graph_transition(
        self,
        source: str,
        target: str,
        condition: str | None = None,
    ) -> None:
        transition: dict[str, Any] = {"source": source, "target": target}
        if condition:
            transition["condition"] = condition
        self.graph_transitions.append(transition)

        parts = [f"graph_transition: {source} -> {target}"]
        if condition:
            parts.append(f"condition: {json.dumps(condition, ensure_ascii=True)}")
        with self.last_run_path.open("a", encoding="utf-8") as handle:
            handle.write(" | ".join(parts) + "\n")

    def _write_log(
        self,
        action: str,
        tool_input: Any,
        result: Any = None,
        tool: Any = None,
        error: str | None = None,
    ) -> None:
        parts = [f"action: {action}"]
        if tool_input not in (None, {}, ""):
            parts.append(f"input: {json.dumps(self._to_jsonable(tool_input), ensure_ascii=True)}")
        if error is not None:
            parts.append(f"output: {error}")
        elif result not in (None, ""):
            parts.append(f"output: {json.dumps(self._to_jsonable(result), ensure_ascii=True)}")
        description = self._tool_description(tool)
        if description:
            parts.append(f"tool_description: {json.dumps(description, ensure_ascii=True)}")
        with self.last_run_path.open("a", encoding="utf-8") as handle:
            handle.write(" | ".join(parts) + "\n")

    def _tool_description(self, tool: Any) -> str:
        if tool is None:
            return ""
        description = getattr(tool, "description", "") or ""
        required = list(getattr(tool, "required_fields", []) or [])
        if required:
            return f"{tool.name}: {description} | required: {', '.join(required)}"
        return f"{tool.name}: {description}".strip(": ")

    def _to_jsonable(self, value: Any) -> Any:
        if isinstance(value, dict):
            return {k: self._to_jsonable(v) for k, v in value.items()}
        if isinstance(value, list):
            return [self._to_jsonable(v) for v in value]
        if isinstance(value, (str, int, float, bool)) or value is None:
            return value
        return str(value)


async def main() -> None:
    await LangGraphExecutorAgent().run()


if __name__ == "__main__":
    asyncio.run(main())
