{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b9162f2f",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0def110c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture --no-stderr\n",
    "# %pip install --upgrade --quiet langchain-community langgraph langchain-openai langchain-chroma graphrag"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7f0705e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import getpass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6281bc0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Comment out the below to opt-out of using LangSmith in this notebook..\n",
    "# if not os.environ.get(\"LANGSMITH_API_KEY\"):\n",
    "#     os.environ[\"LANGSMITH_API_KEY\"] = getpass.getpass()\n",
    "#     os.environ[\"LANGSMITH_TRACING\"] = \"true\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e93d0a4",
   "metadata": {},
   "source": [
    "# Importing SQL Database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43ca7d3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "#add .db file\n",
    "db = SQLDatabase.from_uri(\"sqlite:///license.db\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe01d4d3",
   "metadata": {},
   "source": [
    "Data Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a1d64a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(db.dialect)\n",
    "print(db.get_usable_table_names()) #need these table names for the vectorization\n",
    "print(db.get_table_info(['TV_Intercity_Relay']))\n",
    "db.run(\"SELECT * FROM TV_Intercity_Relay LIMIT 10;\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "554d672d",
   "metadata": {},
   "source": [
    "Setting up OpenAI and specifying model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bf1df93",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import getpass\n",
    "# import os\n",
    "\n",
    "# os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter API key for OpenAI: \")\n",
    "\n",
    "# from langchain.chat_models import init_chat_model\n",
    "\n",
    "# llm = init_chat_model(\"gpt-4o\", model_provider=\"openai\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e312f089",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# api key setup\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
    "\n",
    "from langchain.chat_models import init_chat_model\n",
    "llm = init_chat_model(\"gpt-5\", model_provider=\"openai\",temperature=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ea85dc0",
   "metadata": {},
   "source": [
    "# Graph"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1da60682",
   "metadata": {},
   "source": [
    "Importing SQL toolkit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8b71d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_community.agent_toolkits import SQLDatabaseToolkit\n",
    "\n",
    "toolkit = SQLDatabaseToolkit(db=db, llm=llm)\n",
    "\n",
    "sqltools = toolkit.get_tools()\n",
    "\n",
    "sqltools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0782414b",
   "metadata": {},
   "source": [
    "Creating System Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b72dec2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_system = \"\"\"\n",
    "You are an agent designed to interact with a SQL database containing information on Spectrum Usage.\n",
    "\n",
    "## Database Structure\n",
    "- The database may contain **many tables**.\n",
    "- **Each table corresponds to one service type**, and the **table name = service type**.\n",
    "- Each table contains rows describing which **companies (licensees)** use this service, and on which **frequency bands**.\n",
    "- A single service type may have multiple companies and multiple frequency ranges per company.\n",
    "\n",
    "## Your Task\n",
    "Given a user question, you must:\n",
    "1. **List all tables** in the database first (do not skip).\n",
    "2. Decide which table(s) are relevant:\n",
    "   - If the question mentions explicit **service type(s)**, select only the corresponding table(s).\n",
    "   - If the question does not mention explicit service types, assume all tables are possible and you MUST select all tables.\n",
    "3. **Inspect the schema** of any relevant table(s) before writing queries (e.g., column names, data types).\n",
    "   - Pay special attention to columns containing **'licensee'** (or variants), as they represent company names.\n",
    "4. Construct a syntactically correct {dialect} query to answer the question.\n",
    "   - Never query all columns; select only the necessary ones.\n",
    "   - Unless the user specifies otherwise, always apply **LIMIT {top_k}**.\n",
    "   - You may order results by a relevant column (e.g., frequency span, license count) to surface the most interesting answers.\n",
    "5. If your first query fails, you must **rewrite and retry** using the actual schema.\n",
    "6. DO NOT use any DML (INSERT, UPDATE, DELETE, DROP, etc.); **SELECT only**.\n",
    "\n",
    "## Output\n",
    "- Always explain briefly which tables and columns you selected.\n",
    "- Provide the SQL you executed.\n",
    "- Summarize the results clearly in natural language.\n",
    "\"\"\".format(\n",
    "    dialect=\"SQLite\",\n",
    "    top_k=30,\n",
    ")\n",
    "\n",
    "\n",
    "router_system = (\n",
    "    \"You are a query router. This project has two knowledge sources:\\n\"\n",
    "    \"- SQL database (DB): multiple tables where each table corresponds to a spectrum service type; rows contain companies (licensees), locations, frequency bands, usage modes, etc.\\n\"\n",
    "    \"- Policy files: opinions, comments, and positions from people and organizations regarding spectrum band use.\\n\\n\"\n",
    "    \"Decide which source(s) are required and return exactly one routing token:\\n\"\n",
    "    \"- If the question can be answered purely with structured data/metrics from the SQL DB: need_sql\\n\"\n",
    "    \"- If the question is only about viewpoints/comments/stances/positions of people or organizations: need_policy\\n\"\n",
    "    \"- If the question requires combining DB facts with policy viewpoints: need_mix\\n\\n\"\n",
    "    \"Return only one of: need_sql | need_policy | need_mix | final_answer.\"\n",
    ")\n",
    "\n",
    "\n",
    "policy_system = (\n",
    "    \"You are a policy RAG agent operating over proceedings. Each user question belongs to exactly one proceeding. A proceeding has its own topic and contains many documents. A document is the set of viewpoints or comments expressed by an organization or an individual on the proceeding's topic.\\n\"\n",
    "    \"We provide three corpora at different abstraction levels derived from the same proceeding documents: basic, local, and global.\\n\"\n",
    "    \"- basic: no condensation; raw paragraph-level content.\\n\"\n",
    "    \"- local: per-paragraph condensation of the basic corpus (good for single-document summaries).\\n\"\n",
    "    \"- global: further condensation over the local corpus (good for multi-document synthesis).\\n\\n\"\n",
    "    \"TOOLS:\\n\"\n",
    "    \"- local_graphrag_tool(question): run RAG over the local corpus with --method local; you select exactly one proceeding root and the tool returns an answer candidate.\\n\"\n",
    "    \"- global_graphrag_tool(question): run RAG over the global corpus with --method global; you select exactly one proceeding root and the tool returns an answer candidate.\\n\"\n",
    "    \"- basic_graphrag_tool(question): run RAG over the basic corpus with --method basic; you select exactly one proceeding root and the tool returns an answer candidate.\\n\\n\"\n",
    "    \"DECISION RULES:\\n\"\n",
    "\"1) Prefer local or global as the first call:\\n\"\n",
    "\"   - Choose local_graphrag_tool when the answer scope involves a single document only — e.g., asking for a specific organization's or person's viewpoint, or a concrete fact.\\n\"\n",
    "\"   - Choose global_graphrag_tool when the answer scope may span multiple documents — e.g., opposition-type questions (many organizations/people may oppose), or 'who holds ... view' questions (many organizations/people may share the view).\\n\"\n",
    "\"   Unless you are very certain basic is required, do NOT call basic_graphrag_tool first.\\n\" \n",
    "    \"2) Fallback: If the chosen local/global tool returns no supporting information and you cannot produce a valid answer, then call basic_graphrag_tool and use its output as the FINAL ANSWER regardless of quality.\\n\\n\"\n",
    "    \"OUTPUT:\\n\"\n",
    "    \"- If local/global produced a valid answer, output it as the final answer (cite the tool used).\\n\"\n",
    "    \"- If you had to fallback, output the basic tool's result as the final answer (cite fallback).\\n\"\n",
    ")\n",
    "\n",
    "mix_system = \"\"\"\n",
    "    You are a MIX agent with two tool families. Read carefully and follow the structured procedure.\n",
    "\n",
    "KNOWLEDGE SOURCES\n",
    "1) SQL database (Spectrum Usage)\n",
    "   - Tables: each table corresponds to ONE spectrum service type; the table name equals the service type.\n",
    "   - Rows: describe which licensees/companies use that service and on which frequency ranges/bands. Typical columns include license identifiers, licensee_name, service_type, frequency_ranges, and related counts/locations.\n",
    "   - Tools: use {SQL_LIST_TABLES} to enumerate tables; {SQL_GET_SCHEMA} to inspect schema; {SQL_QUERY} to run read-only SELECT queries.\n",
    "   - Querying rules:\n",
    "     - Start by listing tables, then inspect schema for relevant tables before writing queries.\n",
    "     - Select only necessary columns; avoid SELECT *.\n",
    "     - Unless the user specifies otherwise, apply LIMIT 15.\n",
    "     - If a query fails due to schema mismatch, rewrite using the actual schema and retry.\n",
    "\n",
    "2) Policy/commentary corpora (GraphRAG over proceedings)\n",
    "   - A proceeding is a topical container of many documents (comments, opinions, positions by organizations/people).\n",
    "   - We provide three abstraction levels derived from the same proceeding documents:\n",
    "     - basic: raw paragraph-level content (no condensation).\n",
    "     - local: per-paragraph condensation of basic (good for single-document/single-entity summaries).\n",
    "     - global: condensation over many local snippets (good for multi-document, cross-entity synthesis).\n",
    "   - Tools:\n",
    "     - {LOCAL_GR_TOOL}(question): run with --method local over the selected proceeding (you must pick exactly one proceeding root).\n",
    "     - {GLOBAL_GR_TOOL}(question): run with --method global over the selected proceeding (exactly one proceeding root).\n",
    "     - {BASIC_GR_TOOL}(question): run with --method basic over the selected proceeding (exactly one proceeding root).\n",
    "\n",
    "PROCEDURE\n",
    "1) Use SQL first to extract factual anchors that will help the policy answer (e.g., which companies, what services, which frequency ranges, counts, locations).\n",
    "2) Integrate those facts into a refined natural-language question that targets the appropriate proceeding/topic.\n",
    "3) Choose EXACTLY ONE GraphRAG tool according to the decision rules below. If the chosen tool gives no usable support, fall back once as specified.\n",
    "\n",
    "DECISION RULES FOR LOCAL VS GLOBAL \n",
    "- Choose {LOCAL_GR_TOOL} when the information scope is narrow and likely contained within a single document or a single commenter/org.— e.g., asking for a specific organization's or person's viewpoint, or a concrete fact.\\n\n",
    "- Choose {GLOBAL_GR_TOOL} when the answer requires synthesis across multiple documents, organizations, viewpoints, or time periods.e.g., opposition-type questions (many organizations/people may oppose), or 'who holds ... view' questions (many organizations/people may share the view).\\n\n",
    "- Unless you are very certain basic is required, do NOT call {BASIC_GR_TOOL} first.\\n\n",
    "- Fallback to {BASIC_GR_TOOL} ONLY IF your chosen local/global call yields no sufficient evidence to answer. In that case, call {BASIC_GR_TOOL} and treat its output as the FINAL answer.\n",
    "\n",
    "PRACTICAL HEURISTICS TO REDUCE LEVEL MISJUDGMENT\n",
    "- If your SQL step yields a SINGLE prominent company/entity, prefer {LOCAL_GR_TOOL}.\n",
    "- If your SQL step yields MULTIPLE relevant companies/entities or the user asks for broader consensus/controversy/trends, prefer {GLOBAL_GR_TOOL}.\n",
    "- If still uncertain after drafting the refined question, default to {LOCAL_GR_TOOL} for coverage.\n",
    "\n",
    "OUTPUT REQUIREMENTS\n",
    "- Briefly state which SQL tables/columns you used and show the SQL you executed.\n",
    "- Show the refined policy question you created.\n",
    "- Cite the ONE GraphRAG tool chosen and present the final answer aligned with the tool’s output.\n",
    "\n",
    "END-TO-END EXAMPLE\n",
    "User asks:\n",
    "“Which companies use Common Carrier Fixed Point-to-Point Microwave services in the 12000 MHz to 12500 MHz frequency band? What’s people’s comments about it?”\n",
    "\n",
    "Steps:\n",
    "1) SQL: find companies using service “Common Carrier Fixed Point-to-Point Microwave” in 12–13 GHz.\n",
    "2) Refine: “What are commenters’ views on the frequency usage by {COMPANY_LIST} in the 12–13 GHz band (Common Carrier Fixed Point-to-Point Microwave)?”\n",
    "3) Tool choice: if a few specific companies are returned, choose {LOCAL_GR_TOOL}; if many companies span diverse viewpoints, choose {GLOBAL_GR_TOOL}. If neither yields support, call {BASIC_GR_TOOL} and treat its output as final.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "\n",
    "PROCEEDING_DESCRIPTIONS = {\n",
    "    \"NTIA_NSS\": \"The name of this proceeding is Comments for National Spectrum Strategy. The NTIA solicited public comments in 2023 for the development and implementation of the National Spectrum Strategy, an initiative to improve management and access to U.S. spectrum resources to foster innovation and economic growth.\",\n",
    "    \"FCC_19_38\": \"The name of this proceeding is Partitioning, Disaggregation, and Leasing of Spectrum. This proceeding explores ways to expand spectrum access for small and rural carriers. Potential methods include relaxing or extending buildout/performance requirements, setting conditions on spectrum transfers, offering incentives such as longer license terms or modified obligations, and allowing reaggregation of previously partitioned or disaggregated licenses to reduce regulatory burdens and promote secondary market transactions.\",\n",
    "    \"FCC_22_352\": \"The name of this proceeding is Expanding Use of the 12.7-13.25 GHz Band for Mobile Broadband or Other Expanded Use. This proceeding is exploring whether and how the 12.7–13.25 GHz band can be opened for mobile broadband or other expanded uses. It seeks input on current operations, potential sharing methods (static, dynamic, database-driven, or licensed-light), possible relocation or repacking of incumbents, new licensing frameworks, and protections for adjacent bands, with the goal of unlocking mid-band spectrum for future 5G/6G and beyond.\",\n",
    "    \"FCC_23_158\": \"The name of this proceeding is Shared Use of the 42-42.5 GHz Band. This proceeding (FCC 23-51) explores adopting a shared licensing model for the underused 42–42.5 GHz band to support future 5G/6G. It seeks comment on different sharing approaches, license rules, and measures to protect adjacent radio astronomy, while emphasizing opportunities to advance digital equity by expanding access for smaller and community providers.\",\n",
    "    \"FCC_23_232\": \"The name of this proceeding is Advancing Understanding of Non-Federal Spectrum Usage. FCC seeks public input on how to better understand non-Federal spectrum usage—by defining usage, exploring new data sources and collection methods, and leveraging modern technologies like AI/ML—to improve spectrum management, efficiency, and sharing.\",\n",
    "    \"FCC_24_72\": \"The name of this proceeding is WTB Seeks Comment on Options for Facilitating Access to Unassigned Auction Inventory Spectrum. FCC seeks comment on how the FCC can make unused auctioned spectrum (“Inventory Spectrum”) available while its auction authority is lapsed. It explores options like dynamic sharing, non-exclusive site-based licensing, and leasing, along with other temporary or experimental measures.\",\n",
    "    \"FCC_25_59\": \"The name of this proceeding is Upper C-band (3.98 GHz to 4.2 GHz). This proceeding is an FCC Notice of Inquiry on the Upper C-band (3.98–4.2 GHz). It explores whether, and how, to repurpose part or all of this spectrum for more intensive use—such as 5G/6G—while addressing the needs of existing satellite operators, considering transition mechanisms, and protecting adjacent aviation altimeter systems.\",\n",
    "}\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7419402",
   "metadata": {},
   "source": [
    "Printing tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b042f684",
   "metadata": {},
   "outputs": [],
   "source": [
    "for tool in sqltools:\n",
    "    print(f\"Name: {tool.name}\")\n",
    "    print(f\"Description: {tool.description}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82c8da61",
   "metadata": {},
   "source": [
    "# Multi-agent system"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "686c4c77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Multi-Agent \n",
    "from typing import TypedDict, List, Literal\n",
    "from langgraph.graph import START, END, StateGraph\n",
    "from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage\n",
    "import subprocess, shlex, os, re, json\n",
    "from langchain.tools import Tool\n",
    "from langgraph.prebuilt import create_react_agent\n",
    "\n",
    "\n",
    "# 1）define GraphState\n",
    "class GraphState(TypedDict):\n",
    "    messages: List[AnyMessage]\n",
    "    route: Literal[\"need_sql\", \"need_policy\", \"need_mix\", \"final_answer\", \"error\"]\n",
    "\n",
    "# 2) Router Agent：\n",
    "router = create_react_agent(llm, tools=[], prompt=router_system)\n",
    "\n",
    "# Simple hard match classification (enabled when LLM classification fails)\n",
    "def _classify_policy_like(text: str) -> bool:\n",
    "    policy_markers = [\n",
    "        \"view\", \"views\", \"stance\", \"opinion\",\n",
    "        \"support\", \"oppose\", \"similar\", \"opposite\",\"comment\", \n",
    "    ]\n",
    "    t = text.lower()\n",
    "    return any(k in t for k in policy_markers)\n",
    "\n",
    "\n",
    "def _classify_db_like(text: str) -> bool:\n",
    "    db_markers = [\n",
    "        \"sql\", \"query\", \"records\", \"table\", \"tables\", \"schema\", \"mhz\", \"ghz\", \"khz\", \"longitude\", \"latitude\", \"经度\", \"纬度\",\n",
    "        \"licensee\", \"count\",  \"top\", \"list\", \"where\", \"filter\", \n",
    "    ]\n",
    "    t = text.lower()\n",
    "    return any(k in t for k in db_markers)\n",
    "\n",
    "\n",
    "def router_node(state: GraphState) -> GraphState:\n",
    "    last = state[\"messages\"][-1]\n",
    "    routed_state = router.invoke({\"messages\": [last]})\n",
    "    router_msg = routed_state[\"messages\"][-1]\n",
    "    llm_text = str(router_msg.content).strip().lower()\n",
    "    user_text = last.content.lower() if isinstance(last, HumanMessage) else str(last.content).lower()\n",
    "\n",
    "    if \"need_mix\" in llm_text:\n",
    "        route = \"need_mix\"\n",
    "    elif \"need_policy\" in llm_text or _classify_policy_like(llm_text):\n",
    "        route = \"need_policy\"\n",
    "    elif \"need_sql\" in llm_text:\n",
    "        route = \"need_sql\"\n",
    "    elif \"final_answer\" in llm_text:\n",
    "        route = \"final_answer\"\n",
    "    else:\n",
    "        is_db = _classify_db_like(user_text)\n",
    "        is_policy = _classify_policy_like(user_text)\n",
    "        if is_db and is_policy:\n",
    "            route = \"need_mix\"\n",
    "        elif is_policy:\n",
    "            route = \"need_policy\"\n",
    "        else:\n",
    "            route = \"need_sql\"\n",
    "    return {\"messages\": state[\"messages\"] + [router_msg], \"route\": route}\n",
    "\n",
    "\n",
    "# 3) SQL Agent：Use SQL Tools\n",
    "sql_agent = create_react_agent(llm, tools=sqltools, prompt=sql_system)\n",
    "\n",
    "def sql_node(state: GraphState) -> GraphState:\n",
    "    result_state = sql_agent.invoke({\"messages\": state[\"messages\"]})\n",
    "    all_msgs = result_state[\"messages\"]\n",
    "    new_msgs = all_msgs[len(state[\"messages\"]):] if len(all_msgs) >= len(state[\"messages\"]) else all_msgs\n",
    "    return {\"messages\": state[\"messages\"] + new_msgs, \"route\": \"final_answer\"}\n",
    "\n",
    "# 4) Policy Agent：Encapsulate graphrag as Tool and call it in the node (LLM determines the root)\n",
    "\n",
    "def _extract_json(text: str) -> dict:\n",
    "    try:\n",
    "        return json.loads(text)\n",
    "    except Exception:\n",
    "        # strip code fences\n",
    "        m = re.search(r\"\\{[\\s\\S]*\\}\", text)\n",
    "        if m:\n",
    "            try:\n",
    "                return json.loads(m.group(0))\n",
    "            except Exception:\n",
    "                pass\n",
    "    return {}\n",
    "\n",
    "\n",
    "def _select_root_token_via_llm(user_query: str, descriptions: dict) -> str:\n",
    "    keys = list(descriptions.keys())\n",
    "    sys = (\n",
    "        \"You are a classifier that MUST map any user question to EXACTLY ONE proceeding root from the provided candidates.\\n\"\n",
    "        \"Every question belongs to one of these proceedings; you MUST choose exactly one.\\n\"\n",
    "        \"Never output 'ALL', 'None', multiple roots, or an empty value.\\n\"\n",
    "        \"Base your choice strictly on the candidate descriptions below, focusing on topical fit over superficial keyword overlap.\\n\"\n",
    "        \"Prefer the proceeding whose scope most directly and specifically covers the question.\\n\"\n",
    "        \"Tie-breakers when multiple seem plausible:\\n\"\n",
    "        \"  1) Narrower/more specific match to the question,\\n\"\n",
    "        \"  2) Presence of the question's core entities/terms in the description,\\n\"\n",
    "        \"  3) The proceeding that would be the primary/most appropriate venue for that issue.\\n\"\n",
    "        \"If uncertain, choose the closest match and reflect uncertainty in the confidence score.\\n\"\n",
    "        f\"Candidates (name -> description): {descriptions}\\n\"\n",
    "        f\"Valid KEYS: {keys}\\n\"\n",
    "        f\"Return STRICT JSON only, no extra text: {{\\\"root\\\": one of {keys}, \\\"confidence\\\": number between 0 and 1, \\\"rationale\\\": \\\"short reason\\\"}}.\"\n",
    "    )\n",
    "    msg = (\n",
    "        f\"User question: {user_query}\\n\"\n",
    "        \"Return only the JSON object. Do not include explanations, code fences, or any extra text.\"\n",
    "    )\n",
    "    resp = llm.invoke([SystemMessage(content=sys), HumanMessage(content=msg)])\n",
    "    data = _extract_json(str(resp.content))\n",
    "    root = data.get(\"root\") if isinstance(data, dict) else None\n",
    "    if root in descriptions:\n",
    "        return root\n",
    "    # Fallback: Find candidate tokens directly from the text\n",
    "    for k in keys:\n",
    "        if k.lower() in str(resp.content).lower():\n",
    "            return k\n",
    "    return keys[0]\n",
    "\n",
    "\n",
    "def _resolve_root_dir(root_token: str) -> str:\n",
    "    candidates = [f\"./{root_token}\", f\"../{root_token}\", f\"../../{root_token}\"]\n",
    "    for c in candidates:\n",
    "        if os.path.isdir(c):\n",
    "            return os.path.abspath(c)\n",
    "    return os.path.abspath(candidates[0])\n",
    "\n",
    "\n",
    "\n",
    "def _run_graphrag_with_method(root_abs: str, method: str, query: str) -> str:\n",
    "    method_l = (method or \"\").strip().lower()\n",
    "    method_safe = method_l if method_l in (\"local\", \"global\", \"basic\") else \"basic\"\n",
    "    cmd = f\"graphrag query --root {shlex.quote(root_abs)} --method {method_safe} --query {shlex.quote(query)}\"\n",
    "    print(cmd)\n",
    "    try:\n",
    "        completed = subprocess.run(cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)\n",
    "        return completed.stdout.strip()[-4000:]\n",
    "    except subprocess.CalledProcessError as e:\n",
    "        return f\"graphrag error:\\n{e.stdout[-4000:] if e.stdout else str(e)}\"\n",
    "\n",
    "\n",
    "\n",
    "def _graphrag_local_tool_func(query: str) -> str:\n",
    "    token = _select_root_token_via_llm(query, PROCEEDING_DESCRIPTIONS)\n",
    "    root_abs = _resolve_root_dir(token)\n",
    "    return _run_graphrag_with_method(root_abs, \"local\", query)\n",
    "\n",
    "\n",
    "def _graphrag_global_tool_func(query: str) -> str:\n",
    "    token = _select_root_token_via_llm(query, PROCEEDING_DESCRIPTIONS)\n",
    "    root_abs = _resolve_root_dir(token)\n",
    "    return _run_graphrag_with_method(root_abs, \"global\", query)\n",
    "\n",
    "\n",
    "local_graphrag_tool = Tool(name=\"local_graphrag_tool\", func=_graphrag_local_tool_func, description=\"Run graphrag with --method local over the selected proceeding. Input: user's question.\")\n",
    "\n",
    "global_graphrag_tool = Tool(name=\"global_graphrag_tool\", func=_graphrag_global_tool_func, description=\"Run graphrag with --method global over the selected proceeding. Input: user's question.\")\n",
    "\n",
    "\n",
    "def _graphrag_basic_tool_func(query: str) -> str:\n",
    "    token = _select_root_token_via_llm(query, PROCEEDING_DESCRIPTIONS)\n",
    "    root_abs = _resolve_root_dir(token)\n",
    "    return _run_graphrag_with_method(root_abs, \"basic\", query)\n",
    "\n",
    "\n",
    "basic_graphrag_tool = Tool(name=\"basic_graphrag_tool\", func=_graphrag_basic_tool_func, description=\"Run graphrag with --method basic over the selected proceeding. Input: user's question.\")\n",
    "\n",
    "\n",
    "policy_agent = create_react_agent(llm, tools=[local_graphrag_tool, global_graphrag_tool, basic_graphrag_tool], prompt=policy_system)\n",
    "\n",
    "def policy_node(state: GraphState) -> GraphState:\n",
    "    result_state = policy_agent.invoke({\"messages\": state[\"messages\"]})\n",
    "    all_msgs = result_state[\"messages\"]\n",
    "    new_msgs = all_msgs[len(state[\"messages\"]):] if len(all_msgs) >= len(state[\"messages\"]) else all_msgs\n",
    "    return {\"messages\": state[\"messages\"] + new_msgs, \"route\": \"final_answer\"}\n",
    "\n",
    "\n",
    "\n",
    "# 5）Mix Agent: prompt uses mix_system; tools = SQL tool + graphrag tool\n",
    "mix_agent = create_react_agent(llm, tools=sqltools + [local_graphrag_tool, global_graphrag_tool, basic_graphrag_tool], prompt=mix_system)\n",
    "\n",
    "def mix_node(state: GraphState) -> GraphState:\n",
    "    result_state = mix_agent.invoke({\"messages\": state[\"messages\"]})\n",
    "    all_msgs = result_state[\"messages\"]\n",
    "    new_msgs = all_msgs[len(state[\"messages\"]):] if len(all_msgs) >= len(state[\"messages\"]) else all_msgs\n",
    "    return {\"messages\": state[\"messages\"] + new_msgs, \"route\": \"final_answer\"}\n",
    "\n",
    "# workflow\n",
    "workflow = StateGraph(GraphState)\n",
    "workflow.add_node(\"router\", router_node)\n",
    "workflow.add_node(\"sql\", sql_node)\n",
    "workflow.add_node(\"policy\", policy_node)\n",
    "workflow.add_node(\"mix\", mix_node)\n",
    "\n",
    "workflow.add_edge(START, \"router\")\n",
    "workflow.add_conditional_edges(\n",
    "    \"router\",\n",
    "    lambda s: s[\"route\"],\n",
    "    {\n",
    "        \"need_policy\": \"policy\",\n",
    "        \"need_sql\": \"sql\",\n",
    "        \"need_mix\": \"mix\",\n",
    "        \"final_answer\": \"sql\",\n",
    "        \"error\": \"sql\",\n",
    "    },\n",
    ")\n",
    "workflow.add_edge(\"sql\", END)\n",
    "workflow.add_edge(\"policy\", END)\n",
    "workflow.add_edge(\"mix\", END)\n",
    "\n",
    "app = workflow.compile()\n",
    "\n",
    "# # 6) Demo ( router agent)\n",
    "# question = \"Is INTELSAT an incumbent user in 12.7-13.25 GHz band? What are their opinions about sharing this band with terrestrial service?\"\n",
    "# inputs = {\"messages\": [HumanMessage(content=question)]}\n",
    "\n",
    "# for step in app.stream(inputs, stream_mode=\"values\"):\n",
    "#     for msg in step.get(\"messages\", []):\n",
    "#         msg.pretty_print()\n",
    "\n",
    "\n",
    "# # 6) Demo ( policy agent)\n",
    "# question = \"Does SpectrumX regularly organize employees to get massages?\"\n",
    "# for step in policy_agent.stream(\n",
    "#     {\"messages\": [{\"role\": \"user\", \"content\": question}]},\n",
    "#     stream_mode=\"values\",\n",
    "# ):\n",
    "#     step[\"messages\"][-1].pretty_print()\n",
    "\n",
    "# # 6) Demo ( sql agent)\n",
    "# question = \"How many paths does the license with order number 8 have in 92-94 GHz?\"\n",
    "# for step in sql_agent.stream(\n",
    "#     {\"messages\": [{\"role\": \"user\", \"content\": question}]},\n",
    "#     stream_mode=\"values\",\n",
    "# ):\n",
    "#     step[\"messages\"][-1].pretty_print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langgraph",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
