# motif_miner.py
import collections
import hashlib
import logging
from dataclasses import dataclass
from typing import Dict, List, Tuple, Set, Optional, Any

import numpy as np
import networkx as nx

from config import (
    CNT_OPS, UNI_OPS, EXT_OPS,
    MOTIF_VOCAB_SIZE, NUM_NODE_TYPES
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass(frozen=True)
class MotifInstance:
    """
    One extracted motif instance.

    center : counting node id (local id in per-sample HRG)
    token  : hashed motif ID in [0, K-1]
    nodes  : node set V_M (tuple sorted) that induces this instance
    u      : the matched target node paired with center (for debugging / analysis)
    mtype  : motif class among {"M1","M2","M3"} (nested/sib cnt, union child, ext desc)
    """
    center: int
    token: int
    nodes: Tuple[int, ...]
    u: int
    mtype: str


class MotifMiner:
    """
    Counting-centered motif extraction and hashing-based indexing.

    Implements Appendix Algorithm (Counting-Centered Motif Extraction and Encoding):
      - For each counting node c, define:
          M1 = {u in V_cnt | u in Desc(c) U Sib(c)}
          M2 = {u in Child(c) | Type(u) in UnionOps}
          M3 = {u in Desc(c) | Type(u) in {Look, Ref}}
      - For each u in M1 U M2 U M3:
          M = induced connected subgraph around {c,u}
          k = Hash(Canon(M)) mod K
          record (k, V(M)) and attach token to nodes in V(M)

    Canon(M):
      - deterministic BFS from center c
      - neighbor ordering by (edge_type, node_type, node_id)
      - sequence records visited structural attributes as strings
    """
    def __init__(self, vocab_size: int = MOTIF_VOCAB_SIZE, max_depth: int = 6):
        self.K = int(vocab_size)
        self.max_depth = int(max_depth)

    # -------------------------------------------------------------------------
    # Basic helpers
    # -------------------------------------------------------------------------
    def _decode_node_type(self, node_data: Dict[str, Any]) -> int:
        """
        Decodes integer node type from one-hot prefix in node_data['attr'].
        Falls back to 0 if not decodable.
        """
        feat = node_data.get("attr", [])
        if isinstance(feat, list) and len(feat) >= NUM_NODE_TYPES:
            prefix = feat[:NUM_NODE_TYPES]
            if float(sum(prefix)) > 0:
                return int(np.argmax(np.asarray(prefix, dtype=np.float32)))
        return 0

    def _edge_type(self, g: nx.Graph, u: int, v: int) -> int:
        """
        Returns edge relation/type for ordering in Canon(·).
        Defaults to 1 if missing.
        """
        d = g.edges[u, v]
        return int(d.get("type", 1))

    def _hash(self, s: str) -> int:
        """
        Stable hashing then modulo K.
        """
        h = hashlib.md5(s.encode("utf-8")).hexdigest()
        return int(h, 16) % self.K

    # -------------------------------------------------------------------------
    # Graph navigation w.r.t. hierarchical relations
    # -------------------------------------------------------------------------
    def _neighbors_by_edge_type(self, g: nx.Graph, x: int, et: int) -> List[int]:
        out = []
        for y in g.neighbors(x):
            if self._edge_type(g, x, y) == et:
                out.append(y)
        return out

    def _children(self, g: nx.Graph, c: int) -> List[int]:
        """
        Child(c): neighbors connected by hierarchical relation.
        Here we treat edge_type==1 as the primary syntactic/hierarchical relation.
        """
        return self._neighbors_by_edge_type(g, c, et=1)

    def _descendants(self, g: nx.Graph, c: int) -> List[int]:
        """
        Desc(c): BFS within max_depth using only hierarchical edges (etype==1).
        """
        visited: Set[int] = {c}
        q = collections.deque([(c, 0)])
        out: List[int] = []

        while q:
            u, d = q.popleft()
            if d >= self.max_depth:
                continue
            for v in self._neighbors_by_edge_type(g, u, et=1):
                if v in visited:
                    continue
                visited.add(v)
                out.append(v)
                q.append((v, d + 1))
        return out

    def _siblings(self, g: nx.Graph, c: int) -> List[int]:
        """
        Sib(c): nodes that share a common parent-like neighbor with c through hierarchical edges.

        Since the HRG edges are stored as undirected, "parent" direction is not explicit.
        This implementation follows the structural definition "share the same parent node":
          - find pivot nodes p in N(c) via etype==1
          - siblings are other nodes connected to p via etype==1 (excluding c)
        """
        sibs: Set[int] = set()
        pivots = self._neighbors_by_edge_type(g, c, et=1)
        for p in pivots:
            for u in self._neighbors_by_edge_type(g, p, et=1):
                if u != c:
                    sibs.add(u)
        return sorted(sibs)

    # -------------------------------------------------------------------------
    # Motif induced connected subgraph around {c,u}
    # -------------------------------------------------------------------------
    def _connected_induced_nodes(self, g: nx.Graph, c: int, u: int) -> List[int]:
        """
        Produces a connected node set V_M used to induce the motif instance subgraph.

        Minimal connected construction:
          - take shortest path nodes between c and u (in the full HRG)
          - if no path exists, use {c,u}
        """
        if c == u:
            return [c]
        try:
            path = nx.shortest_path(g, source=c, target=u)
            return list(dict.fromkeys(path))  # preserve order, remove duplicates
        except nx.NetworkXNoPath:
            return [c, u]

    def _induced_subgraph(self, g: nx.Graph, nodes: List[int]) -> nx.Graph:
        """
        Returns induced subgraph on given node set (copy with attributes).
        """
        return g.subgraph(nodes).copy()

    # -------------------------------------------------------------------------
    # Canonicalization: deterministic BFS from center
    # -------------------------------------------------------------------------
    def canon(self, subg: nx.Graph, center: int, node_types: Dict[int, int]) -> str:
        """
        Canon(subg):
          - BFS starting from center
          - neighbor ordering by (edge_type, node_type, node_id)
          - emit a canonical sequence as a single string

        The sequence encodes structural attributes visited in BFS order:
          - node type at each visited node
          - edge relation + neighbor type for each expansion edge
        """
        if center not in subg:
            # fallback: choose the smallest node id in subgraph
            center = min(subg.nodes())

        visited: Set[int] = {center}
        q = collections.deque([center])

        seq: List[str] = []
        seq.append(f"C{node_types.get(center, 0)}")  # center type

        while q:
            u = q.popleft()
            tu = node_types.get(u, 0)
            seq.append(f"V{tu}")

            neigh = []
            for v in subg.neighbors(u):
                et = self._edge_type(subg, u, v)
                tv = node_types.get(v, 0)
                neigh.append((et, tv, v))

            neigh.sort(key=lambda x: (x[0], x[1], x[2]))

            for et, tv, v in neigh:
                # record edge relation and neighbor type
                seq.append(f"E{et}T{tv}")
                if v not in visited:
                    visited.add(v)
                    q.append(v)

        return "|".join(seq)

    # -------------------------------------------------------------------------
    # Public APIs
    # -------------------------------------------------------------------------
    def extract(self, g: nx.Graph) -> Dict[int, List[int]]:
        """
        Backward-compatible API:
          Returns node -> list of motif tokens.
        """
        motif_map, _ = self.extract_with_instances(g)
        return motif_map

    def extract_with_instances(self, g: nx.Graph) -> Tuple[Dict[int, List[int]], List[MotifInstance]]:
        """
        Returns:
          motif_map : {v: [token1, token2, ...]} for node-level motif aggregation
          instances : list of MotifInstance (token, V_M) tuples for each motif
        """
        # Precompute node types
        node_types: Dict[int, int] = {}
        for n, data in g.nodes(data=True):
            node_types[n] = self._decode_node_type(data)

        # Counting nodes
        V_cnt = [v for v, t in node_types.items() if t in CNT_OPS]

        motif_map: Dict[int, List[int]] = collections.defaultdict(list)
        instances: List[MotifInstance] = []

        for c in V_cnt:
            # M1: nested/sibling counting nodes
            desc = self._descendants(g, c)
            sib = self._siblings(g, c)
            M1 = [u for u in V_cnt if (u in desc) or (u in sib)]
            M1 = [u for u in M1 if u != c]

            # M2: union among immediate children
            ch = self._children(g, c)
            M2 = [u for u in ch if node_types.get(u, 0) in UNI_OPS]

            # M3: extended features (lookaround/backreference) among descendants
            M3 = [u for u in desc if node_types.get(u, 0) in EXT_OPS]

            # S_c = { induced connected subgraph around {c,u} | u in M1 U M2 U M3 }
            # Each u creates one motif instance; token is computed from Canon(M).
            for u in (M1 + M2 + M3):
                V_M = self._connected_induced_nodes(g, c, u)
                subg = self._induced_subgraph(g, V_M)

                # Canonical sequence and hashing
                S = self.canon(subg, center=c, node_types=node_types)
                token = self._hash(S)

                # Determine motif class tag
                if u in M1:
                    mtype = "M1"
                elif u in M2:
                    mtype = "M2"
                else:
                    mtype = "M3"

                V_tuple = tuple(sorted(set(V_M)))
                instances.append(MotifInstance(center=c, token=token, nodes=V_tuple, u=u, mtype=mtype))

                # Node-level accumulation: record that all nodes in V_M belong to this motif
                for v in V_tuple:
                    motif_map[v].append(token)

        return dict(motif_map), instances

    def debug_stats(self, graphs: List[nx.Graph]) -> Dict[str, float]:
        """
        Produces summary statistics for motif extraction.
        """
        total_nodes = 0
        total_cnt = 0
        total_inst = 0

        for g in graphs:
            node_types = {n: self._decode_node_type(d) for n, d in g.nodes(data=True)}
            total_nodes += g.number_of_nodes()
            total_cnt += sum(1 for _, t in node_types.items() if t in CNT_OPS)

            _, inst = self.extract_with_instances(g)
            total_inst += len(inst)

        nG = max(1, len(graphs))
        return {
            "avg_nodes": float(total_nodes) / nG,
            "avg_cnt_centers": float(total_cnt) / nG,
            "avg_instances": float(total_inst) / nG,
            "inst_per_1k_nodes": 1000.0 * float(total_inst) / max(1, total_nodes),
        }
