# layer.py
import math
from typing import Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn


class HeteroAttributeMasker(nn.Module):
    """
    Node-level semantic augmentation (type-conditioned attribute masking).
    """
    def __init__(self, num_types: int, feat_dim: int):
        super().__init__()
        self.num_types = int(num_types)
        self.feat_dim = int(feat_dim)
        self.S = nn.Parameter(torch.empty(self.num_types, self.feat_dim))
        self.bias = nn.Parameter(torch.empty(self.num_types, self.feat_dim))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.S)
        nn.init.zeros_(self.bias)

    def forward(self, x: torch.Tensor, type_ids: torch.Tensor) -> torch.Tensor:
        """
        x       : [N, D]
        type_ids: [N] integer ids in [0..num_types-1]; motif nodes are mapped to 0 outside.
        """
        type_ids = torch.clamp(type_ids, min=0, max=self.num_types - 1)
        mask = torch.sigmoid(self.S[type_ids])  # [N, D]
        return x * mask + self.bias[type_ids] * (1.0 - mask)


class RelationalFusionLayer(nn.Module):
    """
    Heterogeneous information propagation over typed edges.

    Edge types used in this project:
      1: syntactic
      2: scope
      3: backreference
      4: motif->node
      5: node->motif
    """
    def __init__(self, d_in: int, d_out: int, dropout: float = 0.2):
        super().__init__()
        self.d_in = int(d_in)
        self.d_out = int(d_out)

        # Relation-specific transforms
        self.W = nn.ModuleDict({
            "1": nn.Linear(self.d_in, self.d_out),
            "2": nn.Linear(self.d_in, self.d_out),
            "3": nn.Linear(self.d_in, self.d_out),
            "4": nn.Linear(self.d_in, self.d_out),
            "5": nn.Linear(self.d_in, self.d_out),
        })

        # Relation-specific message fusion (src || dst)
        self.psi = nn.ModuleDict({
            "1": nn.Sequential(nn.Linear(self.d_in * 2, self.d_out), nn.LeakyReLU(0.2), nn.Dropout(dropout)),
            "2": nn.Sequential(nn.Linear(self.d_in * 2, self.d_out), nn.LeakyReLU(0.2), nn.Dropout(dropout)),
            "3": nn.Sequential(nn.Linear(self.d_in * 2, self.d_out), nn.LeakyReLU(0.2), nn.Dropout(dropout)),
            "4": nn.Sequential(nn.Linear(self.d_in * 2, self.d_out), nn.LeakyReLU(0.2), nn.Dropout(dropout)),
            "5": nn.Sequential(nn.Linear(self.d_in * 2, self.d_out), nn.LeakyReLU(0.2), nn.Dropout(dropout)),
        })

        self.norm = nn.LayerNorm(self.d_out)
        self.act = nn.ELU()
        self.drop = nn.Dropout(dropout)
        self.res = nn.Linear(self.d_in, self.d_out) if self.d_in != self.d_out else None

    def forward(self, g: dgl.DGLGraph, h: torch.Tensor, etype: torch.Tensor) -> torch.Tensor:
        """
        g    : DGLGraph
        h    : [N, d_in]
        etype: [E] edge types in {1..5}
        """
        with g.local_scope():
            g.ndata["h"] = h
            g.edata["t"] = etype

            # Aggregate messages per relation type using masking on edges.
            out = torch.zeros((g.num_nodes(), self.d_out), device=h.device, dtype=h.dtype)

            src, dst = g.edges()
            h_src = h[src]
            h_dst = h[dst]
            pair = torch.cat([h_src, h_dst], dim=1)  # [E, 2*d_in]

            for r in ["1", "2", "3", "4", "5"]:
                r_id = int(r)
                mask = (etype == r_id)
                if torch.count_nonzero(mask).item() == 0:
                    continue

                msg = self.psi[r](pair[mask])              # [E_r, d_out]
                msg = self.W[r](h_src[mask]) + msg         # inject src transform
                # scatter-add to destination nodes
                idx = dst[mask]
                out.index_add_(0, idx, msg)

            # Residual + norm + activation
            res = self.res(h) if self.res is not None else h
            out = self.norm(self.drop(out) + res)
            out = self.act(out)
            return out


class CrossAttentionFusion(nn.Module):
    """
    Cross-attention fusion:
      Query  : global motif vector (per graph)
      Key/Val: node embeddings (per graph, padded)
    """
    def __init__(self, d_model: int, n_heads: int = 4, dropout: float = 0.2):
        super().__init__()
        self.d_model = int(d_model)
        self.attn = nn.MultiheadAttention(self.d_model, n_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(self.d_model)

    def forward(self, g: dgl.DGLGraph, node_h: torch.Tensor, g_motif: torch.Tensor) -> torch.Tensor:
        """
        node_h : [N, D]
        g_motif: [B, D]
        returns: [B, D] attended local summary
        """
        dev = node_h.device
        sizes = g.batch_num_nodes()
        B = int(len(sizes))
        max_n = int(torch.max(sizes).item()) if B > 0 else 0

        if B == 0 or max_n == 0:
            return g_motif

        # Pack node embeddings into [B, max_n, D]
        K = torch.zeros((B, max_n, self.d_model), device=dev, dtype=node_h.dtype)
        pad_mask = torch.ones((B, max_n), device=dev, dtype=torch.bool)

        off = 0
        for i, n in enumerate(sizes.tolist()):
            n = int(n)
            K[i, :n] = node_h[off:off + n]
            pad_mask[i, :n] = False
            off += n

        Q = g_motif.unsqueeze(1)  # [B, 1, D]
        attn_out, _ = self.attn(Q, K, K, key_padding_mask=pad_mask)
        c = attn_out.squeeze(1)  # [B, D]
        return self.norm(c + g_motif)
