import os
from typing import Annotated

from dataclasses import dataclass
from typing_extensions import TypedDict

import networkx as nx

from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.tools import tool
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.managed.is_last_step import RemainingSteps


from src.prompts import FINAL_INSTRUCT_WITH_AGENT_WITH_SAMPLE
from src.utils import logger
from src.models import select_embedding_model, log_token_usage

from sklearn.metrics.pairwise import cosine_similarity

G = None

embed_model = select_embedding_model("bge-m3")

@tool
def get_near_edges(entity: str, relations: list[str] | None = None) -> list[tuple[str, str, str]]:
    """Get a list of edge triples related to the specified entity

    Args:
        entity: The entity to query, it must be a node in the graph, the topic entity is better
        relations: Optional list of relations, if specified only returns edges with these relations.
            If not specified, it will return all edges related to the entity.
            The relations should be the relation name (or part of the relation name) in the graph

    Returns:
        list[tuple[str, str, str]]: List of triples, each triple is (head entity, relation, tail entity)
    """
    logger.info(f"[Tool] Get near edges for {entity} with relations {relations}")
    global G
    edges = []

    # 判断是否是实体
    if entity not in G.nodes:
        return {"messages": [f"Entity {entity} not found in graph"]}

    # 获取所有相邻边
    for neighbor in G.neighbors(entity):
        relation = G[entity][neighbor]['relation']
        # 如果指定了relations且当前关系不在其中，则跳过
        # 找到最相近的关系
        tt = (entity, relation, neighbor)
        if relations is None:
            edges.append(tt)
            continue

        relations_embed = embed_model.encode(relations) # [n_relations, embed_dim]
        relation_embed = embed_model.encode([relation]) # [1, embed_dim]
        # 计算相似度
        similarity = cosine_similarity(relation_embed, relations_embed) # [1, n_relations]
        if any(similarity[0] > 0.9):
            edges.append(tt)
        else:
            for r in relations:
                if r in relation:
                    edges.append(tt)
                    break

    return {"messages": edges[:100]}

@tool
def list_relations(entity: str) -> list[str]:
    """List all relations of the specified entity

    Args:
        entity: The entity to query, it must be a node in the graph

    """
    logger.info(f"[Tool] List relations for {entity}")
    global G
    # 判断是否是实体
    if entity not in G.nodes:
        return {"messages": [f"Entity {entity} not found in graph"]}

    relations = [G[entity][neighbor]['relation'] for neighbor in G.neighbors(entity)]
    # 去重
    relations = list(set(relations))
    return {"messages": relations}


class KnowState(TypedDict):
    messages: Annotated[list, add_messages]
    graph: nx.DiGraph
    question: str
    reasoning_paths: str
    samples: str
    remaining_steps: RemainingSteps

class Agent:
    def __init__(self, model_name="gpt-4o-mini", agent_model="siliconflow/THUDM/GLM-Z1-9B-0414"):
        """
        初始化Agent

        Args:
            graph_tool: 用于查询图的工具实例
            model_name: 使用的语言模型名称
            agent_model: 使用的Agent模型名称
        """
        # 导入必要的模块
        from langchain_core.messages import HumanMessage
        self.HumanMessage = HumanMessage

        self.llm = get_lc_model(model_name)
        # self.agent_llm = get_agent_model(agent_model)
        self.prompt = PromptTemplate(
            input_variables=["reasoning_paths", "question", "samples"],
            template=FINAL_INSTRUCT_WITH_AGENT_WITH_SAMPLE
        )

        self.tools = [get_near_edges, list_relations]
        self.graph = self.compile_graph()

    def compile_graph(self) -> StateGraph:
        graph_builder = StateGraph(KnowState)

        graph_builder.add_node("bot", self.run)
        graph_builder.add_node("generate_anyway", self.generate_anyway)

        tool_node = ToolNode(tools=self.tools)
        graph_builder.add_node("tools", tool_node)

        graph_builder.add_edge(START, "bot")
        graph_builder.add_edge("tools", "bot")
        graph_builder.add_edge("generate_anyway", END)
        graph_builder.add_conditional_edges(
            "bot",
            self.should_continue,
            ["generate_anyway", "tools", END]
        )
        graph = graph_builder.compile()


        # graph_png = graph.get_graph().draw_mermaid_png()
        # with open("agent.png", "wb") as f:
        #     f.write(graph_png)

        return graph


    def run(self, state: KnowState) -> KnowState:
        llm_with_tools = self.llm.bind_tools(self.tools)

        reasoning_paths = state["reasoning_paths"]
        if isinstance(reasoning_paths, list):
            reasoning_paths = "\n".join(reasoning_paths)
        reasoning_paths = reasoning_paths[:1000]

        messages = state["messages"]
        # 将PromptTemplate.invoke返回的StringPromptValue转换为HumanMessage
        prompt_content = self.prompt.invoke({
            "reasoning_paths": reasoning_paths,
            "question": state["question"],
            "samples": state["samples"]
        })

        # 转换为适当的消息对象
        human_message = self.HumanMessage(content=str(prompt_content))
        messages.append(human_message)

        response = llm_with_tools.invoke(messages)
        log_token_usage("agent", messages, response.content)
        return {"messages": response}

    def should_continue(self, state):
        # 如果剩余步骤小于等于2，则结束
        if state["remaining_steps"] <= 3:
            return "generate_anyway"

        messages = state["messages"]
        last_message = messages[-1]
        if last_message.tool_calls:
            return "tools"

        return END

    def generate_anyway(self, state: KnowState) -> KnowState:
        # 将PromptTemplate.invoke返回的StringPromptValue转换为HumanMessage
        prompt_content = self.prompt.invoke({
            "reasoning_paths": state["reasoning_paths"][:1000],
            "question": state["question"],
            "samples": state["samples"]
        })

        # 转换为适当的消息对象
        human_message = self.HumanMessage(content=str(prompt_content))

        messages = state["messages"]

        if messages[-1].tool_calls:
            messages.pop()

        messages.append(human_message)
        response = self.llm.invoke(messages)
        log_token_usage("agent", messages, response.content)
        return {"messages": response}

    def invoke(self, state: KnowState) -> KnowState:
        global G
        G = state["graph"]

        config = {
            "recursion_limit": 12
        }

        for i, event in enumerate(self.graph.stream(state, config=config)):
            event_dict = dict(event)
            stage_name = list(event_dict.keys())[0]
            msgs = event_dict[stage_name].get("messages")
            if not isinstance(msgs, list):
                msgs = [msgs]
            logger.info(f">>>>> [{i}] [Agent] {stage_name}")
            for msg in msgs:
                logger.info(f"\n{msg.pretty_repr()}")

            if stage_name in ["bot", "generate_anyway"] and len(msgs) > 0 and msgs[-1].content:
                return msgs[-1]

    def invoke_msg(self, question, graph, reasoning_paths, samples):
        inputs = {
            'messages': [self.HumanMessage(content=question)],
            'graph': graph,
            'reasoning_paths': reasoning_paths,
            "question": question,
            "samples": samples,
            "remaining_steps": RemainingSteps(10)  # 确保有足够的剩余步骤
        }
        return self.invoke(inputs)


def get_lc_model(model_name):
    if model_name in ["gpt-4o-mini", "gpt-4o"]:
        return ChatOpenAI(model_name=model_name)

    elif model_name in ["meta-llama/llama-3.1-8b-instruct:free"]:
        return ChatOpenAI(model_name=model_name,
                            base_url="https://openrouter.ai/api/v1",
                            api_key=os.getenv("OPENROUTER_API_KEY"))

    elif model_name in ["meta-llama/Meta-Llama-3.1-8B-Instruct"]:
        return ChatOpenAI(model_name=model_name,
                            base_url="https://api.siliconflow.cn/v1",
                            api_key=os.getenv("SILICONFLOW_API_KEY"))

    elif model_name in ["llama3.1:8b", "qwen2.5:32b", "qwen3:32b"]:
        return ChatOpenAI(model_name=model_name,
                            base_url=os.getenv("VLLM_API_BASE"),
                            api_key=os.getenv("VLLM_API_KEY", "empty"),
                            extra_body={"chat_template_kwargs": {"enable_thinking": False}})

    else:
        raise ValueError(f"Unknown model: {model_name}")
