import asyncio
import os
from dataclasses import dataclass
from typing import Union

import numpy as np
from sqlalchemy import create_engine, text
from tqdm import tqdm

from graphr1.base import BaseVectorStorage, BaseKVStorage
from graphr1.utils import logger


class TiDB(object):
    def __init__(self, config, **kwargs):
        self.host = config.get("host", None)
        self.port = config.get("port", None)
        self.user = config.get("user", None)
        self.password = config.get("password", None)
        self.database = config.get("database", None)
        self.workspace = config.get("workspace", None)
        connection_string = (
            f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
            f"?ssl_verify_cert=true&ssl_verify_identity=true"
        )

        try:
            self.engine = create_engine(connection_string)
            logger.info(f"Connected to TiDB database at {self.database}")
        except Exception as e:
            logger.error(f"Failed to connect to TiDB database at {self.database}")
            logger.error(f"TiDB database error: {e}")
            raise

    async def check_tables(self):
        for k, v in TABLES.items():
            try:
                await self.query(f"SELECT 1 FROM {k}".format(k=k))
            except Exception as e:
                logger.error(f"Failed to check table {k} in TiDB database")
                logger.error(f"TiDB database error: {e}")
                try:
                    # print(v["ddl"])
                    await self.execute(v["ddl"])
                    logger.info(f"Created table {k} in TiDB database")
                except Exception as e:
                    logger.error(f"Failed to create table {k} in TiDB database")
                    logger.error(f"TiDB database error: {e}")

    async def query(
        self, sql: str, params: dict = None, multirows: bool = False
    ) -> Union[dict, None]:
        if params is None:
            params = {"workspace": self.workspace}
        else:
            params.update({"workspace": self.workspace})
        with self.engine.connect() as conn, conn.begin():
            try:
                result = conn.execute(text(sql), params)
            except Exception as e:
                logger.error(f"Tidb database error: {e}")
                print(sql)
                print(params)
                raise
            if multirows:
                rows = result.all()
                if rows:
                    data = [dict(zip(result.keys(), row)) for row in rows]
                else:
                    data = []
            else:
                row = result.first()
                if row:
                    data = dict(zip(result.keys(), row))
                else:
                    data = None
            return data

    async def execute(self, sql: str, data: list | dict = None):
        # logger.info("go into TiDBDB execute method")
        try:
            with self.engine.connect() as conn, conn.begin():
                if data is None:
                    conn.execute(text(sql))
                else:
                    conn.execute(text(sql), parameters=data)
        except Exception as e:
            logger.error(f"TiDB database error: {e}")
            print(sql)
            print(data)
            raise


@dataclass
class TiDBKVStorage(BaseKVStorage):
    # should pass db object to self.db
    def __post_init__(self):
        self._data = {}
        self._max_batch_size = self.global_config["embedding_batch_num"]

    ################ QUERY METHODS ################

    async def get_by_id(self, id: str) -> Union[dict, None]:
        """根据 id 获取 doc_full 数据."""
        SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
        params = {"id": id}
        # print("get_by_id:"+SQL)
        res = await self.db.query(SQL, params)
        if res:
            data = res  # {"data":res}
            # print (data)
            return data
        else:
            return None

    # Query by id
    async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
        """根据 id 获取 doc_chunks 数据"""
        SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
            ids=",".join([f"'{id}'" for id in ids])
        )
        # print("get_by_ids:"+SQL)
        res = await self.db.query(SQL, multirows=True)
        if res:
            data = res  # [{"data":i} for i in res]
            # print(data)
            return data
        else:
            return None

    async def filter_keys(self, keys: list[str]) -> set[str]:
        """过滤掉重复内容"""
        SQL = SQL_TEMPLATES["filter_keys"].format(
            table_name=N_T[self.namespace],
            id_field=N_ID[self.namespace],
            ids=",".join([f"'{id}'" for id in keys]),
        )
        try:
            await self.db.query(SQL)
        except Exception as e:
            logger.error(f"Tidb database error: {e}")
            print(SQL)
        res = await self.db.query(SQL, multirows=True)
        if res:
            exist_keys = [key["id"] for key in res]
            data = set([s for s in keys if s not in exist_keys])
        else:
            exist_keys = []
            data = set([s for s in keys if s not in exist_keys])
        return data

    ################ INSERT full_doc AND chunks ################
    async def upsert(self, data: dict[str, dict]):
        left_data = {k: v for k, v in data.items() if k not in self._data}
        self._data.update(left_data)
        if self.namespace == "text_chunks":
            list_data = [
                {
                    "__id__": k,
                    **{k1: v1 for k1, v1 in v.items()},
                }
                for k, v in data.items()
            ]
            contents = [v["content"] for v in data.values()]
            batches = [
                contents[i : i + self._max_batch_size]
                for i in range(0, len(contents), self._max_batch_size)
            ]
            embeddings_list = await asyncio.gather(
                *[self.embedding_func(batch) for batch in batches]
            )
            embeddings = np.concatenate(embeddings_list)
            for i, d in enumerate(list_data):
                d["__vector__"] = embeddings[i]

            merge_sql = SQL_TEMPLATES["upsert_chunk"]
            data = []
            for item in list_data:
                data.append(
                    {
                        "id": item["__id__"],
                        "content": item["content"],
                        "tokens": item["tokens"],
                        "chunk_order_index": item["chunk_order_index"],
                        "full_doc_id": item["full_doc_id"],
                        "content_vector": f"{item["__vector__"].tolist()}",
                        "workspace": self.db.workspace,
                    }
                )
            await self.db.execute(merge_sql, data)

        if self.namespace == "full_docs":
            merge_sql = SQL_TEMPLATES["upsert_doc_full"]
            data = []
            for k, v in self._data.items():
                data.append(
                    {
                        "id": k,
                        "content": v["content"],
                        "workspace": self.db.workspace,
                    }
                )
            await self.db.execute(merge_sql, data)
        return left_data

    async def index_done_callback(self):
        if self.namespace in ["full_docs", "text_chunks"]:
            logger.info("full doc and chunk data had been saved into TiDB db!")


@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
    cosine_better_than_threshold: float = 0.2

    def __post_init__(self):
        self._client_file_name = os.path.join(
            self.global_config["working_dir"], f"vdb_{self.namespace}.json"
        )
        self._max_batch_size = self.global_config["embedding_batch_num"]
        self.cosine_better_than_threshold = self.global_config.get(
            "cosine_better_than_threshold", self.cosine_better_than_threshold
        )

    async def query(self, query: str, top_k: int) -> list[dict]:
        """search from tidb vector"""

        embeddings = await self.embedding_func([query])
        embedding = embeddings[0]

        embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"

        params = {
            "embedding_string": embedding_string,
            "top_k": top_k,
            "better_than_threshold": self.cosine_better_than_threshold,
        }

        results = await self.db.query(
            SQL_TEMPLATES[self.namespace], params=params, multirows=True
        )
        print("vector search result:", results)
        if not results:
            return []
        return results

    ###### INSERT entities And relationships ######
    async def upsert(self, data: dict[str, dict]):
        # ignore, upsert in TiDBKVStorage already
        if not len(data):
            logger.warning("You insert an empty data to vector DB")
            return []
        if self.namespace == "chunks":
            return []
        logger.info(f"Inserting {len(data)} vectors to {self.namespace}")

        list_data = [
            {
                "id": k,
                **{k1: v1 for k1, v1 in v.items()},
            }
            for k, v in data.items()
        ]
        contents = [v["content"] for v in data.values()]
        batches = [
            contents[i : i + self._max_batch_size]
            for i in range(0, len(contents), self._max_batch_size)
        ]
        embedding_tasks = [self.embedding_func(batch) for batch in batches]
        embeddings_list = []
        for f in tqdm(
            asyncio.as_completed(embedding_tasks),
            total=len(embedding_tasks),
            desc="Generating embeddings",
            unit="batch",
        ):
            embeddings = await f
            embeddings_list.append(embeddings)
        embeddings = np.concatenate(embeddings_list)
        for i, d in enumerate(list_data):
            d["content_vector"] = embeddings[i]

        if self.namespace == "entities":
            data = []
            for item in list_data:
                merge_sql = SQL_TEMPLATES["upsert_entity"]
                data.append(
                    {
                        "id": item["id"],
                        "name": item["entity_name"],
                        "content": item["content"],
                        "content_vector": f"{item["content_vector"].tolist()}",
                        "workspace": self.db.workspace,
                    }
                )
            await self.db.execute(merge_sql, data)

        elif self.namespace == "relationships":
            data = []
            for item in list_data:
                merge_sql = SQL_TEMPLATES["upsert_relationship"]
                data.append(
                    {
                        "id": item["id"],
                        "source_name": item["src_id"],
                        "target_name": item["tgt_id"],
                        "content": item["content"],
                        "content_vector": f"{item["content_vector"].tolist()}",
                        "workspace": self.db.workspace,
                    }
                )
            await self.db.execute(merge_sql, data)


N_T = {
    "full_docs": "HYPERGRAPHRAG_DOC_FULL",
    "text_chunks": "HYPERGRAPHRAG_DOC_CHUNKS",
    "chunks": "HYPERGRAPHRAG_DOC_CHUNKS",
    "entities": "HYPERGRAPHRAG_GRAPH_NODES",
    "relationships": "HYPERGRAPHRAG_GRAPH_EDGES",
}
N_ID = {
    "full_docs": "doc_id",
    "text_chunks": "chunk_id",
    "chunks": "chunk_id",
    "entities": "entity_id",
    "relationships": "relation_id",
}

TABLES = {
    "HYPERGRAPHRAG_DOC_FULL": {
        "ddl": """
        CREATE TABLE HYPERGRAPHRAG_DOC_FULL (
            `id` BIGINT PRIMARY KEY AUTO_RANDOM,
            `doc_id` VARCHAR(256) NOT NULL,
            `workspace` varchar(1024),
            `content` LONGTEXT,
            `meta` JSON,
            `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
            `updatetime` TIMESTAMP DEFAULT NULL,
            UNIQUE KEY (`doc_id`)
        );
        """
    },
    "HYPERGRAPHRAG_DOC_CHUNKS": {
        "ddl": """
        CREATE TABLE HYPERGRAPHRAG_DOC_CHUNKS (
            `id` BIGINT PRIMARY KEY AUTO_RANDOM,
            `chunk_id` VARCHAR(256) NOT NULL,
            `full_doc_id` VARCHAR(256) NOT NULL,
            `workspace` varchar(1024),
            `chunk_order_index` INT,
            `tokens` INT,
            `content` LONGTEXT,
            `content_vector` VECTOR,
            `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
            `updatetime` DATETIME DEFAULT NULL,
            UNIQUE KEY (`chunk_id`)
        );
        """
    },
    "HYPERGRAPHRAG_GRAPH_NODES": {
        "ddl": """
        CREATE TABLE HYPERGRAPHRAG_GRAPH_NODES (
            `id` BIGINT PRIMARY KEY AUTO_RANDOM,
            `entity_id`  VARCHAR(256) NOT NULL,
            `workspace` varchar(1024),
            `name` VARCHAR(2048),
            `content` LONGTEXT,
            `content_vector` VECTOR,
            `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
            `updatetime` DATETIME DEFAULT NULL,
            UNIQUE KEY (`entity_id`)
        );
        """
    },
    "HYPERGRAPHRAG_GRAPH_EDGES": {
        "ddl": """
        CREATE TABLE HYPERGRAPHRAG_GRAPH_EDGES (
            `id` BIGINT PRIMARY KEY AUTO_RANDOM,
            `relation_id`  VARCHAR(256) NOT NULL,
            `workspace` varchar(1024),
            `source_name` VARCHAR(2048),
            `target_name` VARCHAR(2048),
            `content` LONGTEXT,
            `content_vector` VECTOR,
            `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
            `updatetime` DATETIME DEFAULT NULL,
            UNIQUE KEY (`relation_id`)
        );
        """
    },
    "HYPERGRAPHRAG_LLM_CACHE": {
        "ddl": """
        CREATE TABLE HYPERGRAPHRAG_LLM_CACHE (
            id BIGINT PRIMARY KEY AUTO_INCREMENT,
            send TEXT,
            return TEXT,
            model VARCHAR(1024),
            createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
            updatetime DATETIME DEFAULT NULL
        );
        """
    },
}


SQL_TEMPLATES = {
    # SQL for KVStorage
    "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM HYPERGRAPHRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
    "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM HYPERGRAPHRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
    "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM HYPERGRAPHRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
    "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM HYPERGRAPHRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
    "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
    # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
    "upsert_doc_full": """
        INSERT INTO HYPERGRAPHRAG_DOC_FULL (doc_id, content, workspace)
        VALUES (:id, :content, :workspace)
        ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
        """,
    "upsert_chunk": """
        INSERT INTO HYPERGRAPHRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
        VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
        ON DUPLICATE KEY UPDATE
        content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
        full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
        """,
    # SQL for VectorStorage
    "entities": """SELECT n.name as entity_name FROM
        (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
        FROM HYPERGRAPHRAG_GRAPH_NODES WHERE workspace = :workspace) n
        WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""",
    "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
        (SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
        FROM HYPERGRAPHRAG_GRAPH_EDGES WHERE workspace = :workspace) e
        WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""",
    "chunks": """SELECT c.id FROM
        (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
        FROM HYPERGRAPHRAG_DOC_CHUNKS WHERE workspace = :workspace) c
        WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
    "upsert_entity": """
        INSERT INTO HYPERGRAPHRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
        VALUES(:id, :name, :content, :content_vector, :workspace)
        ON DUPLICATE KEY UPDATE
        name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
        workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
        """,
    "upsert_relationship": """
        INSERT INTO HYPERGRAPHRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
        VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
        ON DUPLICATE KEY UPDATE
        source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
        content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
        """,
}
