# kag/builder/database_builder.py

import os
import json
import sqlite3
import asyncio
from typing import List, Dict, Any, Tuple
import pandas as pd
from tqdm import tqdm

from core.utils.config import KAGConfig
from core.model_providers.openai_llm import OpenAILLM
from core.utils.prompt_loader import PromptLoader
from core.agent.cmp_extraction_agent import CMPExtractionAgent


def clean_df(df: pd.DataFrame) -> pd.DataFrame:
    """
    Normalize and deduplicate rows before writing to SQLite.
    This version assumes ENGLISH column names only.
    """
    df = df.copy()

    # Normalize spaces for the columns we actually use
    norm_cols = [
        "name","category","character","chunk_id","scene_id",
        "subcategory","appearance","status","evidence","notes",
        "scene_title","subscene_title"
    ]
    for c in norm_cols:
        if c in df.columns:
            df[c] = df[c].astype("string").str.strip()

    # 1) Drop rows with missing/empty name
    if "name" in df.columns:
        df = df[df["name"].notna() & (df["name"] != "")]

    # 2) Normalize category: "propitem" -> "prop" (case-insensitive)
    if "category" in df.columns:
        df.loc[df["category"].str.lower() == "propitem", "category"] = "prop"

    # 3) Deduplicate on key columns (keep first)
    keys = [c for c in ["name","category","character","chunk_id","scene_id"] if c in df.columns]
    if keys:
        df = df.drop_duplicates(subset=keys, keep="first").reset_index(drop=True)

    return df


class RelationalDatabaseBuilder:
    """
    Pipeline:
      - Read all_document_chunks.json
      - Run CMPExtractionAgent over chunks (async with retries)
      - Build SQLite with ENGLISH column names (CMP_info and Scene_info)
      - Also export CSVs

    All messages/logs are English to avoid mixed-language issues.
    """

    def __init__(self, config: KAGConfig, max_retries: int = 2):
        self.config = config
        self.llm = OpenAILLM(config)
        self.max_retries = max_retries  # total attempts including the first run
        self.prompt_loader = PromptLoader(self.config.knowledge_graph_builder.prompt_dir)
        self.system_prompt = self._init_system_prompt()
        self.agent = CMPExtractionAgent(config, self.llm, self.system_prompt)

    # ---------------- system prompt ----------------
    def _init_system_prompt(self) -> str:
        """
        Build the system prompt with background/abbreviations if provided.
        """
        base = self.config.storage.graph_schema_path
        settings_path = os.path.join(base, "settings.json")
        if os.path.exists(settings_path):
            settings = json.load(open(settings_path, "r", encoding="utf-8"))
        elif os.path.exists(self.config.probing.default_background_path):
            settings = json.load(open(self.config.probing.default_background_path, "r", encoding="utf-8"))
        else:
            settings = {"background": "", "abbreviations": []}

        background_info = self.get_background_info(
            background=settings.get("background", ""),
            abbreviations=settings.get("abbreviations", []),
        )
        system_prompt_id = "agent_prompt_cmp"
        return self.prompt_loader.render_prompt(system_prompt_id, {"background_info": background_info})

    def get_background_info(self, background: str, abbreviations: List[dict]) -> str:
        """
        Render a concise English background block for the prompt.
        The content itself may come in any language from upstream config,
        but we format the wrapper in English.
        """
        bg_block = f"**Background**: {background}\n" if background else ""

        def fmt(item: dict) -> str:
            if not isinstance(item, dict):
                return ""
            abbr = (
                item.get("abbr")
                or item.get("full")
                or next((v for k, v in item.items() if isinstance(v, str) and v.strip()), "N/A")
            )
            parts = [v.strip() for k, v in item.items() if k not in ("abbr", "full") and isinstance(v, str) and v.strip()]
            return f"- **{abbr}**: " + " - ".join(parts) if parts else f"- **{abbr}**"

        abbr_block = "\n".join(fmt(x) for x in abbreviations if isinstance(x, dict))
        return f"{bg_block}\n{abbr_block}" if (background and abbr_block) else (bg_block or abbr_block)

    # ---------------- extraction helpers ----------------
    def _rows_from_result(self, chunk: Dict[str, Any], merged_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Convert agent extractions into table rows with ENGLISH keys only.
        """
        md = chunk.get("metadata", {}) or {}
        rows = []
        for r in merged_results:
            rows.append({
                "name": r.get("name", ""),
                "category": r.get("category", ""),
                "subcategory": r.get("subcategory", ""),
                "appearance": r.get("appearance", ""),
                "status": r.get("status", ""),
                "character": r.get("character", ""),
                "evidence": r.get("evidence", ""),
                "notes": r.get("notes", ""),
                "chunk_id": chunk.get("id", ""),
                "scene_id": md.get("scene_id", ""),
                "scene_title": md.get("title", ""),
                "subscene_title": md.get("subtitle", "")
            })
        return rows

    async def _extract_chunk(self, chunk: Dict[str, Any]) -> Dict[str, Any]:
        content = (chunk.get("content") or "").strip()
        if not content:
            return {"chunk": chunk, "rows": [], "error": None}
        try:
            result = await self.agent.arun(content, timeout=self.config.agent.async_timeout)
            merged = result.get("results", []) if isinstance(result, dict) else []
            return {"chunk": chunk, "rows": self._rows_from_result(chunk, merged), "error": None}
        except Exception as e:
            return {"chunk": chunk, "rows": [], "error": f"{e.__class__.__name__}: {e}"}

    async def _gather_once(self, chunks: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        sem = asyncio.Semaphore(self.config.knowledge_graph_builder.max_workers)

        async def _guarded(ch):
            async with sem:
                return await self._extract_chunk(ch)

        rows, failures = [], []
        for coro in tqdm(asyncio.as_completed([_guarded(ch) for ch in chunks]), total=len(chunks), desc="Extracting CMP items"):
            res = await coro
            if res["error"]:
                failures.append({"chunk": res["chunk"], "error": res["error"]})
            else:
                rows.extend(res["rows"])
        return rows, failures

    async def _gather_with_retries(self, chunks: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        all_rows, failures = [], []
        current_chunks = chunks

        for attempt in range(1, self.max_retries + 1):
            rows, failures = await self._gather_once(current_chunks)
            all_rows.extend(rows)
            if not failures:
                break
            if attempt < self.max_retries:
                backoff = getattr(self.config.agent, "async_backoff_seconds", 2)
                print(f"🔄 Retry round {attempt} still has {len(failures)} failures, sleep {backoff}s and retry...")
                await asyncio.sleep(backoff)
                current_chunks = [f["chunk"] for f in failures]

        return all_rows, failures

    # ---------------- main steps ----------------
    def extract_cmp_information(self):
        base = self.config.storage.knowledge_graph_path
        input_json_path = os.path.join(base, "all_document_chunks.json")

        with open(input_json_path, "r", encoding="utf-8") as fr:
            chunks = json.load(fr)

        all_rows, still_failed = asyncio.run(self._gather_with_retries(chunks))
        os.makedirs(self.config.storage.sql_database_path, exist_ok=True)
        with open(os.path.join(self.config.storage.sql_database_path, "extraction_results.json"), "w", encoding="utf-8") as f:
            json.dump(all_rows, f, ensure_ascii=False, indent=2)

        print(f"✅ CMP extraction done: {len(all_rows)} rows succeeded, {len(still_failed)} failed (max attempts = {self.max_retries})")

    def build_relational_database(self):
        """
        Build CMP.db with ENGLISH column names; write CMP_info and export CSV.
        """
        with open(os.path.join(self.config.storage.sql_database_path, "extraction_results.json"), "r", encoding="utf-8") as f:
            all_rows = json.load(f)
        df_cmp = pd.DataFrame(all_rows)

        db_path = os.path.join(self.config.storage.sql_database_path, "CMP.db")
        os.makedirs(self.config.storage.sql_database_path, exist_ok=True)
        if os.path.exists(db_path):
            os.remove(db_path)
        conn = sqlite3.connect(db_path)

        df_cmp = clean_df(df_cmp)
        df_cmp.to_sql("CMP_info", conn, if_exists="replace", index=False)
        conn.commit()
        print(f"✅ Built SQLite successfully: {db_path}")
        df_cmp.to_csv(os.path.join(self.config.storage.sql_database_path, "CMP_info.csv"), index=False)
        conn.close()

    def build_scene_info(self):
        """
        Build Scene_info table (ENGLISH column names).
        """
        db_path = os.path.join(self.config.storage.sql_database_path, "CMP.db")
        os.makedirs(self.config.storage.sql_database_path, exist_ok=True)
        conn = sqlite3.connect(db_path)

        rows = []
        base = self.config.storage.knowledge_graph_path
        input_json_path = os.path.join(base, "all_document_chunks.json")
        with open(input_json_path, "r", encoding="utf-8") as fr:
            chunks = json.load(fr)

        for chunk in chunks:
            metadata = chunk.get("metadata", {})
            rows.append({
                "chunk_id": chunk.get("id"),
                "scene_id": metadata.get("scene_id", (metadata.get("title") or "").split("、")[0]),
                "scene_title": metadata.get("title"),
                "subscene_title": metadata.get("subtitle")
            })

        df_scene = pd.DataFrame(rows).drop_duplicates()
        df_scene.to_sql("Scene_info", conn, if_exists="replace", index=False)
        conn.commit()
        df_scene.to_csv(os.path.join(self.config.storage.sql_database_path, "Scene_info.csv"), index=False)
        conn.close()
