from neo4j import GraphDatabase
from typing import List, Dict, Any, Tuple

import config

import json
from neo4j import GraphDatabase
from typing import Dict, Any, List
import json
import ast
from typing import Dict, Any, List, Tuple, DefaultDict, Iterable, Optional
from collections import defaultdict



# def _merge_tasks_to_list(values):
#     seen = set()
#     out = []
#     for v in values:
#         if v is None:
#             continue
#         if isinstance(v, list):
#             it = v
#         else:
#             it = [s.strip() for s in str(v).split(",")]
#         for x in it:
#             s = str(x).strip()
#             if not s or s in seen:
#                 continue
#             seen.add(s)
#             out.append(s)
#     return out  # list[str]


def _merge_tasks_to_list(tasks: Iterable[Any]) -> List[Any]:
    """
    将若干 task 合并为“列表”：
    - None / 空跳过
    - 若本身是列表则扁平化
    - 其余视为单值加入
    - 去重但保持首次出现的顺序
    """
    out = []
    seen = set()
    def add_one(v):
        key = json.dumps(v, sort_keys=True, ensure_ascii=False) if isinstance(v, (dict, list)) else v
        if key in seen:
            return
        seen.add(key)
        out.append(v)

    for t in tasks:
        if t is None:
            continue
        if isinstance(t, list):
            for x in t:
                if x is None:
                    continue
                add_one(x)
        else:
            add_one(t)
    return out


def parse_kwargs_from_call(s: str) -> Dict[str, Any]:
    """
    将类似 'JSONAction(action_type="input_text", index=9, text="Freelance Payment")'
    解析为 dict: {"action_type": "input_text", "index": 9, "text": "Freelance Payment"}
    非字面量（非常见基本类型）会被置为 None。
    """
    try:
        node = ast.parse(s, mode='eval').body
    except SyntaxError:
        return {}
    if not isinstance(node, ast.Call):
        return {}
    out: Dict[str, Any] = {}
    for kw in node.keywords:
        try:
            out[kw.arg] = ast.literal_eval(kw.value)
        except Exception:
            out[kw.arg] = None

    return out

def get_arg(s: str, key: str, default: Any = None) -> Any:
    """
    精确按键名获取值，若不存在则返回 default。
    """
    return parse_kwargs_from_call(s).get(key, default)


class KG_filter:
    def __init__(self, uri: str, auth: tuple, database: str | None = None):
        self.driver = GraphDatabase.driver(uri, auth=auth)
        self.database = database

    def close(self):
        self.driver.close()

    def _run(self, cypher: str, params: Dict[str, Any] = None, write: bool = False):
        def work(tx):
            return list(tx.run(cypher, **(params or {})))
        with self.driver.session(database=self.database) as s:
            return s.execute_write(work) if write else s.execute_read(work)



    import json
    from typing import Dict, Any, List, Tuple

    def merge_duplicate_elements_between_pages_strict(self, src_page_id: str, dst_page_id: str) -> Dict[str, Any]:
        # 1) 读取两个 page 之间的 Element（只取节点属性 element_id）
        q_read = """
        MATCH (:Page {page_id:$src})-[:HAS_ELEMENT]->(e:Element)-[:LEADS_TO]->(:Page {page_id:$dst})
        RETURN e.element_id AS node_eid, e.task AS task, e.target_element AS te_raw
        """
        rows = self._run(q_read, {"src": src_page_id, "dst": dst_page_id})

        # 2) 按 target_element.element_id + bbox 四值完全相等分组
        groups: Dict[Tuple[str, float, float, float, float], List[Dict[str, Any]]] = {}
        for r in rows:
            node_eid = r["node_eid"]
            if not node_eid:
                continue
            task = r.get("task")
            te_raw = r.get("te_raw")

            te = te_raw if isinstance(te_raw, dict) else None
            if te is None and isinstance(te_raw, str) and te_raw.strip():
                try:
                    te = json.loads(te_raw)
                except Exception:
                    continue
            if not isinstance(te, dict):
                continue

            teid = te.get("element_id")
            bbox = te.get("bbox") if isinstance(te.get("bbox"), dict) else None
            if not teid or not bbox:
                continue

            try:
                x_min = float(bbox["x_min"]);
                x_max = float(bbox["x_max"])
                y_min = float(bbox["y_min"]);
                y_max = float(bbox["y_max"])
            except Exception:
                continue

            key = (str(teid), x_min, x_max, y_min, y_max)
            groups.setdefault(key, []).append({"node_eid": str(node_eid), "task": task})

        merge_jobs = []
        for _, es in groups.items():
            if len(es) <= 1:
                continue
            ids = sorted(e["node_eid"] for e in es)  # 用节点属性 element_id
            keep_id, dup_ids = ids[0], ids[1:]

            tokens: List[str] = []
            for e in es:
                t = e["task"]
                if t is None:
                    continue
                if isinstance(t, list):
                    tokens.extend([str(x).strip() for x in t if str(x).strip()])
                else:
                    tokens.extend([s.strip() for s in str(t).split(",") if s.strip()])
            merged_task = ",".join(sorted(set(tokens))) if tokens else None

            merge_jobs.append({"keep_id": keep_id, "dup_ids": dup_ids, "task": merged_task})

        if not merge_jobs:
            return {
                "src_page_id": src_page_id, "dst_page_id": dst_page_id,
                "groups_processed": 0, "duplicate_nodes_removed": 0, "kept_element_ids": []
            }

        # 3) 合并
        q_merge = """
        MATCH (p1:Page {page_id:$src}), (p2:Page {page_id:$dst}), (keep:Element {element_id:$keep_eid})
        FOREACH (_ IN CASE WHEN $task IS NULL THEN [] ELSE [1] END | SET keep.task = $task)

        MERGE (p1)-[:HAS_ELEMENT]->(keep)
        MERGE (keep)-[:LEADS_TO]->(p2)
        WITH keep, $dup_eids AS dup_eids   // <= 修复点：MERGE 后加 WITH

        UNWIND dup_eids AS deid
        MATCH (dup:Element {element_id:deid})

        // 入边：Page -[:HAS_ELEMENT]-> dup 迁移到 keep
        OPTIONAL MATCH (s:Page)-[:HAS_ELEMENT]->(dup)
        WITH keep, dup, collect(DISTINCT s) AS srcs
        FOREACH (s IN srcs | MERGE (s)-[:HAS_ELEMENT]->(keep))

        // 出边：dup -[:LEADS_TO]-> Page 迁移到 keep
        WITH keep, dup
        OPTIONAL MATCH (dup)-[:LEADS_TO]->(t:Page)
        WITH keep, dup, collect(DISTINCT t) AS tgts
        FOREACH (t IN tgts | MERGE (keep)-[:LEADS_TO]->(t))

        // 删除重复节点
        WITH dup
        DETACH DELETE dup
        """

        removed = 0
        kept_ids: List[str] = []
        for job in merge_jobs:
            if not job["dup_ids"]:
                continue
            self._run(
                q_merge,
                {
                    "src": src_page_id,
                    "dst": dst_page_id,
                    "keep_eid": job["keep_id"],
                    "dup_eids": job["dup_ids"],
                    "task": job["task"],
                },
                write=True
            )
            removed += len(job["dup_ids"])
            kept_ids.append(job["keep_id"])

        return {
            "src_page_id": src_page_id,
            "dst_page_id": dst_page_id,
            "groups_processed": len(merge_jobs),
            "duplicate_nodes_removed": removed,
            "kept_element_ids": kept_ids
        }

    def auto_merge_adjacent_pages_strict_single_element_edges(self) -> Dict[str, Any]:
        q_read_all = """
        MATCH (p1:Page)-[:HAS_ELEMENT]->(e:Element)-[:LEADS_TO]->(p2:Page)
        WITH p1, e, p2
        MATCH (e)<-[:HAS_ELEMENT]-(sp:Page)
        WITH p1, e, p2, count(sp) AS in_cnt
        MATCH (e)-[:LEADS_TO]->(tp:Page)
        WITH p1, e, p2, in_cnt, count(tp) AS out_cnt
        WHERE in_cnt = 1 AND out_cnt = 1
        RETURN p1.page_id AS src, p2.page_id AS dst, e.element_id AS node_eid, e.task AS task, e.target_element AS te_raw
        """
        rows = self._run(q_read_all, {})

        pairs = defaultdict(list)
        for r in rows:
            pairs[(r["src"], r["dst"])].append({
                "node_eid": r["node_eid"],
                "task": r.get("task"),
                "te_raw": r.get("te_raw"),
            })

        total_pairs_processed = 0
        total_groups = 0
        total_removed = 0
        per_pair_stats: Dict[str, Any] = {}

        for (src, dst), elems in pairs.items():
            groups = defaultdict(list)
            for e in elems:
                node_eid = e["node_eid"]
                if not node_eid:
                    continue
                te_raw = e["te_raw"]
                te = te_raw if isinstance(te_raw, dict) else None
                if te is None and isinstance(te_raw, str) and te_raw.strip():
                    try:
                        te = json.loads(te_raw)
                    except Exception:
                        continue
                if not isinstance(te, dict):
                    continue
                teid = te.get("element_id")
                bbox = te.get("bbox") if isinstance(te.get("bbox"), dict) else None
                if not teid or not bbox:
                    continue
                try:
                    x_min = float(bbox["x_min"]);
                    x_max = float(bbox["x_max"])
                    y_min = float(bbox["y_min"]);
                    y_max = float(bbox["y_max"])
                except Exception:
                    continue
                key = (str(teid), x_min, x_max, y_min, y_max)
                groups[key].append({"node_eid": str(node_eid), "task": e.get("task")})

            merge_jobs = []
            for _, es in groups.items():
                if len(es) <= 1:
                    continue
                ids = sorted(x["node_eid"] for x in es)
                keep_id, dup_ids = ids[0], ids[1:]
                merged_task_list = _merge_tasks_to_list([x["task"] for x in es])
                merge_jobs.append({
                    "keep_id": keep_id,
                    "dup_ids": dup_ids,
                    "task_list": merged_task_list  # 这里是 list
                })

            removed_here = 0
            q_merge_local = """
            MATCH (p1:Page {page_id:$src}), (p2:Page {page_id:$dst})
            MATCH (keep:Element {element_id:$keep_eid})
            FOREACH (_ IN CASE WHEN $task_list IS NULL OR size($task_list)=0 THEN [] ELSE [1] END | SET keep.task = $task_list)

            MERGE (p1)-[:HAS_ELEMENT]->(keep)
            MERGE (keep)-[:LEADS_TO]->(p2)
            WITH p1, p2, keep, $dup_eids AS dup_eids

            UNWIND dup_eids AS deid
            MATCH (dup:Element {element_id:deid})

            // 仅删除当前 pair 上 dup 的两条边
            OPTIONAL MATCH (p1)-[r1:HAS_ELEMENT]->(dup)
            WITH p1, p2, keep, dup, r1
            DELETE r1
            WITH p1, p2, keep, dup

            OPTIONAL MATCH (dup)-[r2:LEADS_TO]->(p2)
            WITH p1, p2, keep, dup, r2
            DELETE r2
            WITH dup

            // 若 dup 无任何关系则删除节点（用 degree() 避免 warning）
            WITH dup, degree(dup) AS rc
            FOREACH (_ IN CASE WHEN rc = 0 THEN [1] ELSE [] END | DETACH DELETE dup)
            """
            for job in merge_jobs:
                if not job["dup_ids"]:
                    continue
                self._run(
                    q_merge_local,
                    {
                        "src": src,
                        "dst": dst,
                        "keep_eid": job["keep_id"],
                        "dup_eids": job["dup_ids"],
                        "task_list": job["task_list"],
                    },
                    write=True
                )
                removed_here += len(job["dup_ids"])

            if merge_jobs:
                per_pair_stats[f"{src} -> {dst}"] = {
                    "groups_processed": len(merge_jobs),
                    "duplicate_nodes_removed": removed_here,
                    "kept_element_ids": [j["keep_id"] for j in merge_jobs],
                }
                total_pairs_processed += 1
                total_groups += len(merge_jobs)
                total_removed += removed_here

        return {
            "adjacent_pairs_processed": total_pairs_processed,
            "groups_processed": total_groups,
            "duplicate_nodes_removed": total_removed,
            "per_pair": per_pair_stats
        }

    def repair_element_tasks_to_list(self, batch_size: int = 1000) -> Dict[str, Any]:
        """
        把 :Element 的 task 若为字符串，转为列表（按逗号切分去空去重，保序）。
        对本来就是列表或为空的保持不变。
        """
        fixed, scanned = 0, 0
        skip = 0
        while True:
            rows = self._run(
                "MATCH (e:Element) RETURN e.element_id AS id, e.task AS task SKIP $skip LIMIT $limit",
                {"skip": skip, "limit": batch_size}
            )
            if not rows:
                break
            for r in rows:
                scanned += 1
                eid, task = r["id"], r.get("task")
                if task is None:
                    continue
                if isinstance(task, list):
                    continue
                # 标量或字符串 -> 列表
                new_list = _merge_tasks_to_list([task])
                self._run(
                    "MATCH (e:Element {element_id:$id}) SET e.task = $task_list",
                    {"id": eid, "task_list": new_list},
                    write=True
                )
                fixed += 1
            skip += batch_size
        return {"scanned": scanned, "fixed": fixed}

    def auto_merge_single_page_strict_element_duplicates(self) -> Dict[str, Any]:
        """
        对每个 Page，在其 HAS_ELEMENT 下面查找重复 Element 并合并为一个：
        - 判定重复：严格相等（target_element.element_id + bbox 4 值完全一致）
        - 仅处理 HAS_ELEMENT 入度为 1 的 Element（严格版本，避免影响其他页面）
        - 合并 task 为列表（不是字符串）
        - 保留并合并 dup 的所有 LEADS_TO 去向到 keep 上，然后删除 dup 在本页上的 HAS_ELEMENT 与其原 LEADS_TO；
          若 dup 成为孤点则删除。

        返回简单统计。
        """

        # 读取所有“单入”的 (p)-[:HAS_ELEMENT]->(e)
        q_read = """
        MATCH (p:Page)-[:HAS_ELEMENT]->(e:Element)
        WITH p, e
        MATCH (e)<-[:HAS_ELEMENT]-(sp:Page)
        WITH p, e, count(sp) AS in_cnt
        WHERE in_cnt = 1
        RETURN p.page_id AS pid, e.element_id AS node_eid, e.task AS task, e.target_element AS te_raw, e.converted_action AS conv_action
        """
        rows = self._run(q_read, {})

        # 按 page 内分组，再按严格等价键分桶
        per_page: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
        for r in rows:
            per_page[str(r["pid"])].append({
                "node_eid": r["node_eid"],
                "task": r.get("task"),
                "te_raw": r.get("te_raw"),
                "conv_action": r.get("conv_action"),
            })

        total_pages_processed = 0
        total_groups = 0
        total_removed = 0
        per_page_stats: Dict[str, Any] = {}

        def make_key(e: Dict[str, Any]) -> Optional[Tuple[str, float, float, float, float]]:
            te_raw = e.get("te_raw")
            # print(te_raw)
            converted_action = str(e.get("conv_action"))
            # print(e.get("conv_action"))
            if te_raw is None:
                action_type = get_arg(converted_action, "action_type")
                if action_type == 'scroll':
                    return (action_type, str(get_arg(converted_action, "direction")))
                elif action_type == 'input_text':
                    return (action_type, str(get_arg(converted_action, "text")))
                elif action_type == 'navigate_home' or action_type == 'navigate_back':
                    return (action_type)

            te = te_raw if isinstance(te_raw, dict) else None
            if te is None and isinstance(te_raw, str) and te_raw.strip():
                try:
                    te = json.loads(te_raw)
                except Exception:
                    return None
            if not isinstance(te, dict):
                return None
            teid = te.get("element_id")
            bbox = te.get("bbox") if isinstance(te.get("bbox"), dict) else None
            if not teid or not bbox:
                return None
            try:
                x_min = float(bbox["x_min"]);
                x_max = float(bbox["x_max"])
                y_min = float(bbox["y_min"]);
                y_max = float(bbox["y_max"])
            except Exception:
                return None
            if get_arg(converted_action, "action_type") == 'input_text':
                return (str(teid), x_min, x_max, y_min, y_max, str(get_arg(converted_action, "text")))
            elif get_arg(converted_action, "action_type") == 'scroll':
                return (str(teid), x_min, x_max, y_min, y_max, str(get_arg(converted_action, "direction")))
            return (str(teid), x_min, x_max, y_min, y_max)

        # 批量写入（按 page 一次提交，UNWIND 多个合并任务）


        q_merge_page = """
        MATCH (p:Page {page_id:$pid})
UNWIND $jobs AS job
MATCH (keep:Element {element_id: job.keep_eid})
FOREACH (_ IN CASE WHEN job.task_list IS NULL OR size(job.task_list)=0 THEN [] ELSE [1] END | 
  SET keep.task = job.task_list
)
MERGE (p)-[:HAS_ELEMENT]->(keep)
WITH p, keep, job

UNWIND job.dup_eids AS deid
MATCH (dup:Element {element_id: deid})

/* 先把与 dup 相关的数据都匹配并收集起来（还不删除） */
OPTIONAL MATCH (dup)-[r2:LEADS_TO]->(tp:Page)
OPTIONAL MATCH (p)-[r1:HAS_ELEMENT]->(dup)
OPTIONAL MATCH (dup)-[r_all]-()
WITH
  p, keep, dup,
  collect(DISTINCT tp)   AS tps,
  collect(DISTINCT r2)   AS r2s,
  collect(DISTINCT r1)   AS r1s,
  collect(DISTINCT r_all) AS all_rels

/* 先把 dup 的去向合并到 keep（过滤掉 NULL） */
FOREACH (tp IN [x IN tps WHERE x IS NOT NULL | x] | MERGE (keep)-[:LEADS_TO]->(tp))
WITH dup, r2s, r1s, [r IN all_rels WHERE NOT r IN r1s AND NOT r IN r2s] AS remain_rels

/* 删除收集到的关系（不再做新的 MATCH） */
FOREACH (r IN r2s | DELETE r)
FOREACH (r IN r1s | DELETE r)

/* 如果 dup 没有剩余关系，则删点 */
WITH dup, remain_rels
FOREACH (_ IN CASE WHEN size(remain_rels) = 0 THEN [1] ELSE [] END | DETACH DELETE dup)

RETURN 1 AS ok
        """

        for pid, elems in per_page.items():
            # 在该 page 内，按严格等价键分桶
            buckets: Dict[Tuple[str, float, float, float, float], List[Dict[str, Any]]] = defaultdict(list)
            for e in elems:
                key = make_key(e)
                if key is None:
                    continue
                eid = e.get("node_eid")
                if not eid:
                    continue
                buckets[key].append({"node_eid": str(eid), "task": e.get("task")})

            jobs = []
            removed_here = 0
            keeps_here = []

            for _, es in buckets.items():
                # print(_)
                if len(es) <= 1:
                    continue
                ids = sorted(x["node_eid"] for x in es)
                keep_id, dup_ids = ids[0], ids[1:]
                merged_task_list = _merge_tasks_to_list([x["task"] for x in es])
                jobs.append({
                    "keep_eid": keep_id,
                    "dup_eids": dup_ids,
                    "task_list": merged_task_list
                })
                removed_here += len(dup_ids)
                keeps_here.append(keep_id)

            if not jobs:
                continue

            # 执行该 page 的批量合并
            self._run(q_merge_page, {"pid": pid, "jobs": jobs}, write=True)

            per_page_stats[str(pid)] = {
                "groups_processed": len(jobs),
                "duplicate_nodes_removed": removed_here,  # 理论上这些 dup 都会被删除（若 strict 条件满足）
                "kept_element_ids": keeps_here,
            }
            total_pages_processed += 1
            total_groups += len(jobs)
            total_removed += removed_here

        return {
            "pages_processed": total_pages_processed,
            "groups_processed": total_groups,
            "duplicate_nodes_removed": total_removed,
            "per_page": per_page_stats
        }

