# -*- coding: utf-8 -*-
from __future__ import annotations

"""
sqldb_tools.py
- SQLite-based tools for querying CMP_info / Scene_info tables.
- All column names are ENGLISH (see COLUMNS below).
- Tools:
  1) search_by_character  -> fuzzy match over character; return full rows.
  2) search_by_scene      -> filter by scene_title / subscene_title (fuzzy optional).
  3) chunk_to_scene       -> map chunk_id(s) to scene_title / subscene_title.
  4) scene_to_chunks      -> list chunk_id(s) under a (sub)scene.
  5) nlp2sql_query        -> natural language to SQL (via LangChain SQL chain).
"""

from typing import List, Dict, Any, Optional, Tuple, Union
import json
import sqlite3
import ast  # used to safely parse list-like string results

from qwen_agent.tools.base import BaseTool, register_tool
from qwen_agent.utils.utils import logger
from langchain_community.utilities import SQLDatabase
from langchain_experimental.sql.base import SQLDatabaseSequentialChain

# ---- Column list (must match your SQLite schema) ----
COLUMNS = [
    "name", "category", "subcategory", "appearance", "status", "character",
    "evidence", "notes", "chunk_id", "scene_id", "scene_title", "subscene_title"
]


def build_cols_sql(conn, table: str, candidate_cols: list[str]) -> str:
    cur = conn.cursor()
    cur.execute(f"PRAGMA table_info('{table}')")
    cols_in_db = {row[1] for row in cur.fetchall()}
    valid_cols = [c for c in candidate_cols if c in cols_in_db]
    return ", ".join([f'"{c}"' for c in valid_cols])


def _parse_sql_from_steps(steps):
    if not steps:
        return None
    for st in steps:
        if isinstance(st, dict) and "sql_cmd" in st:
            return st["sql_cmd"].strip()
        if isinstance(st, str) and st.strip().lower().startswith("select"):
            return st.strip()
    return None


def _format_sql_rows_text(raw):
    """
    Format raw chain output (which can be list/str) into readable text.
    """
    if raw is None:
        return "No result."
    if isinstance(raw, str):
        try:
            data = ast.literal_eval(raw)
            if isinstance(data, (list, tuple)) and data and isinstance(data[0], (list, tuple)):
                lines = []
                for row in data:
                    items = [str(x) for x in row]
                    lines.append("- " + " | ".join(items))
                return "Query results:\n" + "\n".join(lines)
            return raw
        except Exception:
            return raw
    if isinstance(raw, (list, tuple)):
        if not raw:
            return "No result."
        lines = []
        for row in raw:
            if isinstance(row, (list, tuple)):
                items = [str(x) for x in row]
                lines.append("- " + " | ".join(items))
            else:
                lines.append("- " + str(row))
        return "Query results:\n" + "\n".join(lines)
    return str(raw)


# ---- Generic formatting helpers: dict rows -> natural language ----
def _fmt_kv_line(row: Dict[str, Any], columns: List[str]) -> str:
    items = []
    for c in columns:
        if c in row:
            v = row[c]
            v = "" if v is None else str(v).strip()
            if v != "":
                items.append(f"{c}: {v}")
    return " | ".join(items)

def format_rows_dicts_to_nl(rows: List[Dict[str, Any]],
                            columns: Optional[List[str]] = None,
                            dedup: bool = True,
                            header: Optional[str] = "Query results:") -> str:
    if not rows:
        return "No result."
    if columns is None:
        # respect COLUMNS order first, then append other keys observed
        appeared = set()
        ordered = []
        for c in COLUMNS:
            if any(c in r for r in rows):
                ordered.append(c); appeared.add(c)
        for k in rows[0].keys():
            if k not in appeared:
                ordered.append(k); appeared.add(k)
        columns = ordered

    seen: set[Tuple] = set()
    lines: List[str] = []
    for r in rows:
        line = _fmt_kv_line(r, columns)
        if not line:
            continue
        key = tuple((k, r.get(k, None)) for k in columns)
        if dedup and key in seen:
            continue
        seen.add(key)
        lines.append(f"- {line}")

    if not lines:
        return "No result."
    if header:
        return header + "\n" + "\n".join(lines)
    return "\n".join(lines)

def format_query_result(
    data: Union[List[Dict[str, Any]], None],
    columns: Optional[List[str]] = None,
    header: Optional[str] = "Query results:",
) -> str:
    if data is None:
        return "No result."
    if isinstance(data, list) and (not data or isinstance(data[0], dict)):
        return format_rows_dicts_to_nl(data, columns=columns, header=header)
    return "(Unrecognized result type; cannot format.)"

def format_mapping_chunks_to_scene(mappings: List[Dict[str, Any]]) -> str:
    """
    For chunk_to_scene: handle multiple rows.
    """
    if not mappings:
        return "No scene info found for the given chunk_id(s)."

    lines = []
    for m in mappings:
        chunk = m.get("chunk_id", "")
        scene = m.get("scene_title", "")
        sub = m.get("subscene_title", "")
        tail = f" | subscene_title: {sub}" if sub else ""
        lines.append(f"- chunk_id: {chunk} → scene_title: {scene}{tail}")

    return "Query results:\n" + "\n".join(lines)

def format_scene_to_chunks(scene_title: str,
                           chunks: List[str],
                           subscene_title: Optional[str] = None) -> str:
    """
    For scene_to_chunks output.
    """
    if not chunks:
        if subscene_title:
            return f"No chunk_id under scene '{scene_title}' - '{subscene_title}'."
        return f"No chunk_id under scene '{scene_title}'."
    head = f"scene '{scene_title}'" + (f" - '{subscene_title}'" if subscene_title else "")
    lines = "\n".join(f"- {cid}" for cid in chunks)
    return f"{head} has {len(chunks)} chunk_id(s):\n{lines}"


# ---- SQLite basics ----
def get_conn(db_path: str) -> sqlite3.Connection:
    conn = sqlite3.connect(db_path)
    conn.row_factory = sqlite3.Row
    return conn

def rows_to_dicts(rows: List[sqlite3.Row]) -> List[Dict[str, Any]]:
    return [dict(row) for row in rows]

def ensure_indices(conn: sqlite3.Connection, table: str) -> None:
    cur = conn.cursor()
    # Create indices using ENGLISH column names
    if table == "CMP_info":
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_character ON "{table}"("character");')
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_scene_title ON "{table}"("scene_title");')
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_subscene_title ON "{table}"("subscene_title");')
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_chunk ON "{table}"("chunk_id");')
    elif table == "Scene_info":
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_scene_title ON "{table}"("scene_title");')
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_subscene_title ON "{table}"("subscene_title");')
        cur.execute(f'CREATE INDEX IF NOT EXISTS idx_{table}_chunk ON "{table}"("chunk_id");')
    conn.commit()


# ---- 1) fuzzy search by character ----
@register_tool("search_by_character")
class Search_By_Character(BaseTool):
    """
    Case-insensitive fuzzy match on 'character' column, returning full rows.
    """
    name = "search_by_character"
    description = "Fuzzy-search by character name in CMP_info and return full matched records."
    parameters = [
        {"name": "query", "type": "string", "description": "Character keyword (case-insensitive, fuzzy).", "required": True},
        {"name": "limit", "type": "integer", "description": "Max number of rows to return (optional).", "required": False}
    ]

    def __init__(self, db_path: str,
                 default_table: str = "CMP_info",
                 build_indices: bool = False):
        self.db_path = db_path
        self.default_table = default_table
        if build_indices:
            conn = get_conn(self.db_path)
            try:
                ensure_indices(conn, self.default_table)
            finally:
                conn.close()
        self.COLS_SQL = build_cols_sql(get_conn(db_path), self.default_table, COLUMNS)

    def _search_by_character(self, character_keyword: str,
                             limit: Optional[int] = None) -> List[Dict[str, Any]]:
        sql = f'''
            SELECT {self.COLS_SQL}
            FROM "{self.default_table}"
            WHERE "character" LIKE ? COLLATE NOCASE
            ORDER BY "scene_title","subscene_title","chunk_id"
        '''
        params: List[Any] = [f"%{character_keyword}%"]
        if limit is not None:
            sql += " LIMIT ?"
            params.append(int(limit))

        conn = get_conn(self.db_path)
        try:
            cur = conn.cursor()
            cur.execute(sql, tuple(params))
            return rows_to_dicts(cur.fetchall())
        finally:
            conn.close()

    def call(self, params: str, **kwargs) -> str:
        logger.info("🔎 search_by_character: start fuzzy SQL search over 'character'")
        p: Dict[str, Any] = json.loads(params)
        query = str(p.get("query", "")).strip()
        limit = p.get("limit", None)
        if isinstance(limit, str) and limit.isdigit():
            limit = int(limit)
        elif not isinstance(limit, (int, type(None))):
            limit = None

        rows = self._search_by_character(query, limit=limit)
        return format_query_result(rows, header=f"Records related to '{query}':")


# ---- 2) search by scene/subscene ----
@register_tool("search_by_scene")
class Search_By_Scene(BaseTool):
    """
    Return all rows matching a scene (optional subscene). Fuzzy or exact.
    """
    name = "search_by_scene"
    description = "Search by 'scene_title' (optional 'subscene_title') in CMP_info. Supports fuzzy matching."
    parameters = [
        {"name": "scene_name", "type": "string", "description": "Scene title keyword (fuzzy or exact).", "required": True},
        {"name": "subscene_name", "type": "string", "description": "Subscene title keyword (optional).", "required": False},
        {"name": "fuzzy", "type": "boolean", "description": "Use LIKE fuzzy match (default true).", "required": False},
        {"name": "limit", "type": "integer", "description": "Max number of rows to return (optional).", "required": False},
    ]

    def __init__(self, db_path: str,
                 default_table: str = "CMP_info",
                 build_indices: bool = False):
        self.db_path = db_path
        self.default_table = default_table
        if build_indices:
            conn = get_conn(self.db_path)
            try:
                ensure_indices(conn, self.default_table)
            finally:
                conn.close()
        self.COLS_SQL = build_cols_sql(get_conn(db_path), self.default_table, COLUMNS)

    @staticmethod
    def _build_where(scene_name: str,
                     subscene_name: Optional[str],
                     fuzzy: bool) -> tuple[str, list[Any]]:
        where = []
        params: list[Any] = []
        if fuzzy:
            where.append('"scene_title" LIKE ? COLLATE NOCASE')
            params.append(f"%{scene_name}%")
        else:
            where.append('"scene_title" = ?')
            params.append(scene_name)

        if subscene_name:
            if fuzzy:
                where.append('"subscene_title" LIKE ? COLLATE NOCASE')
                params.append(f"%{subscene_name}%")
            else:
                where.append('"subscene_title" = ?')
                params.append(subscene_name)

        return " AND ".join(where) if where else "1=1", params

    def _search_by_scene(self, scene_name: str,
                         subscene_name: Optional[str] = None,
                         fuzzy: bool = True,
                         limit: Optional[int] = None) -> List[Dict[str, Any]]:
        where_sql, params = self._build_where(scene_name, subscene_name, fuzzy)
        sql = f'''
            SELECT {self.COLS_SQL}
            FROM "{self.default_table}"
            WHERE {where_sql}
            ORDER BY "scene_title","subscene_title","chunk_id"
        '''
        if limit is not None:
            sql += " LIMIT ?"
            params.append(int(limit))

        conn = get_conn(self.db_path)
        try:
            cur = conn.cursor()
            cur.execute(sql, tuple(params))
            return rows_to_dicts(cur.fetchall())
        finally:
            conn.close()

    def call(self, params: str, **kwargs) -> str:
        logger.info("🎬 search_by_scene: query by scene/subscene")
        p: Dict[str, Any] = json.loads(params)
        scene_name = str(p.get("scene_name", "")).strip()
        subscene_name = (str(p["subscene_name"]).strip()
                         if p.get("subscene_name") not in (None, "") else None)
        fuzzy = True if p.get("fuzzy", True) else False
        limit = p.get("limit", None)
        if isinstance(limit, str) and limit.isdigit():
            limit = int(limit)
        elif not isinstance(limit, (int, type(None))):
            limit = None

        rows = self._search_by_scene(scene_name, subscene_name=subscene_name,
                                     fuzzy=fuzzy, limit=limit)
        head = f"scene '{scene_name}'" + (f" - '{subscene_name}'" if subscene_name else "")
        return format_query_result(rows, header=f"Records under {head}:")


# ---- 3) chunk_id list -> scene mapping ----
@register_tool("chunk_to_scene")
class Chunk_To_Scene(BaseTool):
    """
    Map chunk_id(s) to (scene_title, subscene_title) from Scene_info.
    """
    name = "chunk_to_scene"
    description = "Given one or multiple chunk_id values, map them to scene_title / subscene_title."
    parameters = [
        {"name": "chunk_ids", "type": "array", "items": {"type": "string"},
         "description": "List of chunk_id (exact match).", "required": True}
    ]

    def __init__(self, db_path: str,
                 default_table: str = "Scene_info",
                 build_indices: bool = False):
        self.db_path = db_path
        self.default_table = default_table
        if build_indices:
            conn = get_conn(self.db_path)
            try:
                ensure_indices(conn, self.default_table)
            finally:
                conn.close()

    def _chunk_to_scene(self, chunk_ids: List[str]) -> List[Dict[str, Any]]:
        if not chunk_ids:
            return []
        placeholders = ",".join("?" for _ in chunk_ids)
        sql = f'''
            SELECT "chunk_id","scene_title","subscene_title"
            FROM "{self.default_table}"
            WHERE "chunk_id" IN ({placeholders})
            ORDER BY "scene_title","subscene_title","chunk_id"
        '''
        conn = get_conn(self.db_path)
        try:
            cur = conn.cursor()
            cur.execute(sql, tuple(chunk_ids))
            return rows_to_dicts(cur.fetchall())
        finally:
            conn.close()

    def call(self, params: str, **kwargs) -> str:
        logger.info("🧭 chunk_to_scene: map chunk_id(s) to scene info")
        p: Dict[str, Any] = json.loads(params)
        chunk_ids = p.get("chunk_ids", [])
        if not isinstance(chunk_ids, list) or not chunk_ids:
            return "Please provide a non-empty list of chunk_id."
        chunk_ids = [str(cid).strip() for cid in chunk_ids if str(cid).strip()]

        mappings = self._chunk_to_scene(chunk_ids)
        return format_mapping_chunks_to_scene(mappings)


# ---- 4) scene/subscene -> chunk list ----
@register_tool("scene_to_chunks")
class Scene_To_Chunks(BaseTool):
    """
    List chunk_id(s) under a given scene (optional subscene) from Scene_info.
    """
    name = "scene_to_chunks"
    description = "Given a scene_title (optional subscene_title), list all chunk_id(s). Supports fuzzy matching."
    parameters = [
        {"name": "scene_name", "type": "string", "description": "Scene title keyword (fuzzy or exact).", "required": True},
        {"name": "subscene_name", "type": "string", "description": "Subscene title keyword (optional).", "required": False},
        {"name": "fuzzy", "type": "boolean", "description": "Use LIKE fuzzy match (default true).", "required": False},
    ]

    def __init__(self, db_path: str,
                 default_table: str = "Scene_info",
                 build_indices: bool = False):
        self.db_path = db_path
        self.default_table = default_table
        if build_indices:
            conn = get_conn(self.db_path)
            try:
                ensure_indices(conn, self.default_table)
            finally:
                conn.close()

    @staticmethod
    def _build_where(scene_name: str,
                     subscene_name: Optional[str],
                     fuzzy: bool) -> tuple[str, list[Any]]:
        where = []
        params: list[Any] = []
        if fuzzy:
            where.append('"scene_title" LIKE ? COLLATE NOCASE')
            params.append(f"%{scene_name}%")
        else:
            where.append('"scene_title" = ?')
            params.append(scene_name)

        if subscene_name:
            if fuzzy:
                where.append('"subscene_title" LIKE ? COLLATE NOCASE')
                params.append(f"%{subscene_name}%")
            else:
                where.append('"subscene_title" = ?')
                params.append(subscene_name)

        return " AND ".join(where) if where else "1=1", params

    def _scene_to_chunks(self, scene_name: str,
                         subscene_name: Optional[str] = None,
                         fuzzy: bool = True) -> List[str]:
        where_sql, params = self._build_where(scene_name, subscene_name, fuzzy)
        sql = f'''
            SELECT DISTINCT "chunk_id"
            FROM "{self.default_table}"
            WHERE {where_sql}
            ORDER BY "chunk_id"
        '''
        conn = get_conn(self.db_path)
        try:
            cur = conn.cursor()
            cur.execute(sql, tuple(params))
            return [r["chunk_id"] for r in cur.fetchall()]
        finally:
            conn.close()

    def call(self, params: str, **kwargs) -> str:
        logger.info("🧩 scene_to_chunks: list chunk ids by (sub)scene")
        p: Dict[str, Any] = json.loads(params)
        scene_name = str(p.get("scene_name", "")).strip()
        subscene_name = (str(p["subscene_name"]).strip()
                         if p.get("subscene_name") not in (None, "") else None)
        fuzzy = True if p.get("fuzzy", True) else False

        chunks = self._scene_to_chunks(scene_name, subscene_name=subscene_name, fuzzy=fuzzy)
        return format_scene_to_chunks(scene_name, chunks, subscene_title=subscene_name)


@register_tool("nlp2sql_query")
class NLP2SQL_Query(BaseTool):
    """
    Natural-language query over the SQLite DB (CMP.db): auto-generate SQL and execute.
    """

    name = "nlp2sql_query"
    description = f"Ask in natural language about the wardrobe/makeup/prop database. The column names are: {COLUMNS}"
    parameters = [
        {"name": "query", "type": "string", "description": "Query text.", "required": True},
        {"name": "return_sql", "type": "boolean", "description": "Whether to include the generated SQL in the output (default false).", "required": False},
    ]

    def __init__(self, db_path, llm,
                 instruction="When handling the query, do not only perform exact matches. Also consider substring matching, synonym expansion, or semantic similarity to cover related results."):
        self.db_path = db_path
        self.db = SQLDatabase.from_uri(f"sqlite:///{self.db_path}")
        self.llm = llm
        self.instruction = instruction

        # Fixed chain parameters
        self.chain = SQLDatabaseSequentialChain.from_llm(
            llm=self.llm,
            db=self.db,
            verbose=False,
            return_direct=True,
            return_intermediate_steps=True,
            top_k=10000,
            use_query_checker=True
        )

    def call(self, params: str, **kwargs) -> str:
        logger.info("🧠 nlp2sql_query: NL → SQL → execution")
        p = json.loads(params)
        query = str(p.get("query", "")).strip()

        if not query:
            return "Please provide 'query'."
        if self.instruction:
            query += "\n" + self.instruction

        return_sql = bool(p.get("return_sql", False))

        try:
            out = self.chain.invoke(query)
            result = out.get("result", "")
            steps = out.get("intermediate_steps", None)
            sql_cmd = _parse_sql_from_steps(steps)

            nl = _format_sql_rows_text(result)
            parts = [f"◼︎ Natural-language result:\n{nl}"]
            if return_sql and sql_cmd:
                parts.append(f"◼︎ Generated SQL:\n```sql\n{sql_cmd}\n```")
            return "\n\n".join(parts)
        except Exception as e:
            logger.exception("nlp2sql_query execution failed")
            return f"Query failed: {type(e).__name__}: {e}"
