# model.py
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl

from config import NUM_NODE_TYPES, MOTIF_VOCAB_SIZE
from layer import HeteroAttributeMasker, RelationalFusionLayer, CrossAttentionFusion


class RMGNN(nn.Module):
    """
    RMGNN model over motif-enhanced HRGs.

    Outputs:
      - logits: [B, 2]
      - student motif distributions per graph for distillation:
          list of (S, motif_emb) where:
            motif_emb: [M_i, D]
            S        : [M_i, M_i] row-normalized similarity distribution
    """
    def __init__(self, cfg: Dict):
        super().__init__()
        self.d_in = int(cfg["d_in"])
        self.d_h = int(cfg["d_h"])
        self.n_layers = int(cfg.get("n_layers", 3))
        self.dropout = float(cfg.get("dropout", 0.2))
        self.tau = float(cfg.get("tau", 0.5))

        # Attribute masking for base nodes (type-conditioned)
        self.masker = HeteroAttributeMasker(NUM_NODE_TYPES, self.d_in)
        self.input_proj = nn.Sequential(
            nn.Linear(self.d_in, self.d_h),
            nn.ReLU(),
            nn.Dropout(self.dropout),
        )

        # Motif token embedding for motif super-nodes
        self.motif_embed = nn.Embedding(MOTIF_VOCAB_SIZE, self.d_h)

        # Heterogeneous propagation layers
        self.layers = nn.ModuleList([
            RelationalFusionLayer(self.d_h, self.d_h, dropout=self.dropout)
            for _ in range(self.n_layers)
        ])

        # Motif-level readout (node set already covered via super-node connections)
        self.motif_readout = nn.Linear(self.d_h, self.d_h)

        # Cross-attention fusion
        self.cross_attn = CrossAttentionFusion(self.d_h, n_heads=int(cfg.get("n_heads", 4)), dropout=self.dropout)

        # Classifier
        self.fuse = nn.Sequential(
            nn.Linear(self.d_h * 2, self.d_h),
            nn.ReLU(),
            nn.Dropout(self.dropout),
        )
        self.cls = nn.Linear(self.d_h, 2)

    def _split_by_graph(self, g: dgl.DGLGraph, x: torch.Tensor) -> List[torch.Tensor]:
        """
        Splits node tensor x [N, D] into a list of [N_i, D] by graph components.
        """
        sizes = g.batch_num_nodes().tolist()
        out = []
        off = 0
        for n in sizes:
            out.append(x[off:off + n])
            off += n
        return out

    def _student_distribution(self, motif_emb: torch.Tensor) -> Optional[torch.Tensor]:
        """
        Student similarity distribution S from motif embeddings:
          S_ij = softmax( (z_i^T z_j) / tau ) row-wise
        """
        if motif_emb is None or motif_emb.size(0) == 0:
            return None
        z = F.normalize(motif_emb, p=2, dim=1)
        sim = torch.mm(z, z.t()) / self.tau
        return F.softmax(sim, dim=1)

    def forward(self, g: dgl.DGLGraph) -> Tuple[torch.Tensor, List[Tuple[Optional[torch.Tensor], torch.Tensor]]]:
        """
        Input graph fields expected:
          - g.ndata['feat']       : [N, d_in]
          - g.ndata['ntype_id']   : [N] base node type id or -1 for motif nodes
          - g.ndata['is_motif']   : [N] 1 if motif node else 0
          - g.ndata['motif_token']: [N] motif token for motif nodes else -1
          - g.edata['etype']      : [E] edge types in {1..5}
        """
        x = g.ndata["feat"]
        ntype_id = g.ndata["ntype_id"]
        is_motif = g.ndata["is_motif"].bool()
        motif_token = g.ndata["motif_token"]
        etype = g.edata["etype"]

        # Base nodes: type-conditioned masking + projection
        # Motif nodes: initialized from motif token embedding
        x_base = self.masker(x, torch.clamp(ntype_id, min=0))
        h = self.input_proj(x_base)

        if torch.count_nonzero(is_motif).item() > 0:
            tok = torch.clamp(motif_token[is_motif], min=0)
            h_m = self.motif_embed(tok)
            h[is_motif] = h_m

        # Heterogeneous propagation
        for layer in self.layers:
            h = layer(g, h, etype)

        # Graph-level motif vector: sum over motif nodes per graph
        # Motif nodes are explicit super-nodes.
        sizes = g.batch_num_nodes().tolist()
        h_split = self._split_by_graph(g, h)
        motif_out: List[torch.Tensor] = []
        student_pack: List[Tuple[Optional[torch.Tensor], torch.Tensor]] = []

        off = 0
        for gi, n in enumerate(sizes):
            h_i = h_split[gi]
            is_motif_i = is_motif[off:off + n]
            off += n

            motif_emb_i = h_i[is_motif_i]
            motif_emb_i = self.motif_readout(motif_emb_i) if motif_emb_i.size(0) > 0 else motif_emb_i

            S_i = self._student_distribution(motif_emb_i)
            student_pack.append((S_i, motif_emb_i))

            if motif_emb_i.size(0) > 0:
                g_m = torch.sum(motif_emb_i, dim=0, keepdim=True)  # [1, D]
            else:
                g_m = torch.mean(h_i, dim=0, keepdim=True)         # fallback
            motif_out.append(g_m)

        g_motif = torch.cat(motif_out, dim=0)  # [B, D]

        # Cross-attention fusion with node embeddings (local)
        c = self.cross_attn(g, h, g_motif)     # [B, D]

        # Final graph representation
        z = self.fuse(torch.cat([g_motif, c], dim=1)) + g_motif
        logits = self.cls(z)

        return logits, student_pack


def rmgnn_loss(
    logits: torch.Tensor,
    y: torch.Tensor,
    student_pack: List[Tuple[Optional[torch.Tensor], torch.Tensor]],
    teacher_pack: List[Optional[torch.Tensor]],
    alpha: float = 0.1
) -> torch.Tensor:
    """
    Classification + kernel-guided motif distillation.
    teacher_pack[i] is T_i for graph i, shaped [M_i, M_i], or None if no motifs.
    """
    cls = F.cross_entropy(logits, y)

    dist = torch.tensor(0.0, device=logits.device)
    cnt = 0

    for i, (S_i, motif_emb_i) in enumerate(student_pack):
        T_i = teacher_pack[i] if i < len(teacher_pack) else None
        if S_i is None or T_i is None:
            continue
        if S_i.numel() == 0 or T_i.numel() == 0:
            continue

        # KL(T || S): both are row-normalized distributions
        # Add epsilon for stability
        eps = 1e-12
        S_log = torch.log(torch.clamp(S_i, min=eps))
        dist = dist + F.kl_div(S_log, T_i, reduction="batchmean")
        cnt += 1

    if cnt > 0:
        dist = dist / float(cnt)

    return cls + alpha * dist
