import logging
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import torch
import dgl

from motif_miner import MotifInstance

logger = logging.getLogger(__name__)


class GraphEngine:
    """
    Builds DGL graphs for RMGNN training.

    Design:
      - Each sample HRG is a disconnected component.
      - Base HRG nodes keep original node_attributes as 'feat'.
      - Each MotifInstance becomes a motif super-node:
          * It is connected bidirectionally with all covered nodes.
          * Edge types:
              1: syntactic edges (from A.txt)
              2: scope edges (lookaround scope)
              3: backreference edges
              4: motif -> node edges
              5: node -> motif edges
      - Additional node fields:
          * ntype_id: decoded node type id for base nodes; motif nodes use ntype_id = -1
          * is_motif: boolean mask
          * motif_token: for motif nodes only (hashed token id); else -1
    """
    def __init__(
        self,
        nx_graphs: List[Any],
        motif_instances_list: List[List[MotifInstance]],
    ):
        self.graphs = nx_graphs
        self.instances = motif_instances_list

        self.n_base = int(sum(g.number_of_nodes() for g in nx_graphs))
        self.n_motif = int(sum(len(insts) for insts in motif_instances_list))

        logger.info("GraphEngine: base_nodes=%d, motif_nodes=%d", self.n_base, self.n_motif)

    @staticmethod
    def _pad_features(feat_list: List[List[float]], dim: int) -> np.ndarray:
        mat = np.zeros((len(feat_list), dim), dtype=np.float32)
        for i, vec in enumerate(feat_list):
            if vec is None:
                continue
            if len(vec) == 0:
                continue
            m = min(dim, len(vec))
            mat[i, :m] = np.asarray(vec[:m], dtype=np.float32)
        return mat

    def build_batched_graph(self) -> Tuple[dgl.DGLGraph, np.ndarray]:
        """
        Returns:
          - dgl_graph: batched disconnected graph
          - labels   : per-sample graph labels (numpy int64)
        """
        u_all: List[int] = []
        v_all: List[int] = []
        et_all: List[int] = []

        labels: List[int] = []

        # First pass to infer feature dim from base nodes
        max_dim = 0
        for g in self.graphs:
            n = g.number_of_nodes()
            for i in range(n):
                vec = g.nodes[i].get("attr", [])
                if isinstance(vec, list):
                    max_dim = max(max_dim, len(vec))
        if max_dim == 0:
            max_dim = 1

        base_feats_blocks: List[np.ndarray] = []
        base_ntype_blocks: List[np.ndarray] = []

        base_offset = 0
        motif_global_id = self.n_base  # motif nodes are appended after all base nodes

        for g, insts in zip(self.graphs, self.instances):
            n_local = g.number_of_nodes()
            labels.append(int(g.graph.get("label", 0)))

            # Base node features
            local_feat = [g.nodes[i].get("attr", []) for i in range(n_local)]
            base_feats_blocks.append(self._pad_features(local_feat, max_dim))

            # Base node types (argmax over one-hot prefix); fallback 0
            ntype = np.zeros((n_local,), dtype=np.int64)
            for i in range(n_local):
                vec = local_feat[i]
                if isinstance(vec, list) and len(vec) > 0:
                    # decode from one-hot prefix if it looks like one-hot
                    k = int(np.argmax(np.asarray(vec, dtype=np.float32))) if len(vec) > 0 else 0
                    ntype[i] = k
            base_ntype_blocks.append(ntype)

            # Base edges (from A.txt reconstruction)
            for u, v, d in g.edges(data=True):
                gu, gv = int(u + base_offset), int(v + base_offset)
                et = int(d.get("type", 1))
                # Keep edges as given in intermediate (already includes both directions in many dumps),
                # but enforce bidirectionality to be safe.
                u_all.extend([gu, gv])
                v_all.extend([gv, gu])
                et_all.extend([et, et])

            # Motif super-nodes: one per instance
            for inst in insts:
                m_id = motif_global_id
                motif_global_id += 1

                # Connect motif <-> each covered node (bidirectional with two edge types)
                for nid in inst.nodes:
                    gn = int(nid + base_offset)
                    u_all.append(m_id)
                    v_all.append(gn)
                    et_all.append(4)  # motif -> node

                    u_all.append(gn)
                    v_all.append(m_id)
                    et_all.append(5)  # node -> motif

            base_offset += n_local

        # Assemble node feature tensor
        base_feat = np.concatenate(base_feats_blocks, axis=0) if base_feats_blocks else np.zeros((0, max_dim), np.float32)
        motif_feat = np.zeros((self.n_motif, max_dim), dtype=np.float32)
        feat = np.concatenate([base_feat, motif_feat], axis=0)

        # Assemble node type ids and motif masks
        base_ntype = np.concatenate(base_ntype_blocks, axis=0) if base_ntype_blocks else np.zeros((0,), np.int64)
        motif_ntype = -1 * np.ones((self.n_motif,), dtype=np.int64)
        ntype_id = np.concatenate([base_ntype, motif_ntype], axis=0)

        is_motif = np.zeros((self.n_base + self.n_motif,), dtype=np.int64)
        is_motif[self.n_base:] = 1

        # Motif token id per motif node (aligned with motif node ordering)
        motif_token = -1 * np.ones((self.n_base + self.n_motif,), dtype=np.int64)
        cursor = self.n_base
        for insts in self.instances:
            for inst in insts:
                motif_token[cursor] = int(inst.token)
                cursor += 1

        # Build DGL graph
        g_dgl = dgl.graph(
            (torch.tensor(u_all, dtype=torch.long), torch.tensor(v_all, dtype=torch.long)),
            num_nodes=(self.n_base + self.n_motif)
        )
        g_dgl.ndata["feat"] = torch.tensor(feat, dtype=torch.float32)
        g_dgl.ndata["ntype_id"] = torch.tensor(ntype_id, dtype=torch.long)
        g_dgl.ndata["is_motif"] = torch.tensor(is_motif, dtype=torch.long)
        g_dgl.ndata["motif_token"] = torch.tensor(motif_token, dtype=torch.long)
        g_dgl.edata["etype"] = torch.tensor(et_all, dtype=torch.long)

        return g_dgl, np.asarray(labels, dtype=np.int64)
