
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from ..utils.common import GraphData

class SimpleGATv2Layer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, edge_dim: Optional[int] = None, heads: int = 1, dropout: float = 0.0, residual: bool = True):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.edge_dim = edge_dim or 0
        self.heads = heads
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        self.W = nn.Linear(in_dim, out_dim * heads, bias=False)
        self.We = nn.Linear(self.edge_dim, out_dim * heads, bias=False) if self.edge_dim > 0 else None
        # Calculate attention dimension: 2 * out_dim (for x_i, x_j) + out_dim (for edge features if present)
        attn_dim = out_dim * 2 + (out_dim if self.We is not None else 0)
        self.a = nn.Parameter(torch.empty(heads, attn_dim))
        nn.init.xavier_uniform_(self.a)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.residual = residual
        if residual and in_dim == out_dim * heads:
            self.res_skip = nn.Identity()
        else:
            self.res_skip = nn.Linear(in_dim, out_dim * heads) if residual else None

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None):
        N = x.size(0)
        E = edge_index.size(1)
        H = self.heads
        O = self.out_dim
        
        x_proj = self.W(x)  # [N, H*O]
        x_proj = x_proj.view(N, H, O)  # [N, H, O]
        src, dst = edge_index[0].long(), edge_index[1].long()
        x_i = x_proj[src]  # [E, H, O]
        x_j = x_proj[dst]  # [E, H, O]
        
        if self.We is not None and edge_attr is not None and edge_attr.numel() > 0:
            e_proj = self.We(edge_attr)  # [E, H*O]
            e_proj = e_proj.view(E, H, O)  # [E, H, O]
            cat = torch.cat([x_i, x_j, e_proj], dim=-1)  # [E, H, 3*O]
        else:
            # If model was initialized for edge features but current graph has no edges,
            # create zero edge features to maintain consistent dimensions
            if self.We is not None:
                e_proj = torch.zeros(E, H, O, device=x.device, dtype=x.dtype)
                cat = torch.cat([x_i, x_j, e_proj], dim=-1)  # [E, H, 3*O]
            else:
                cat = torch.cat([x_i, x_j], dim=-1)  # [E, H, 2*O]
        
        a = self.a  # [H, D] where D = 2*O or 3*O depending on edge features
        attn_logits = self.leaky_relu((cat * a.unsqueeze(0)).sum(dim=-1))  # [E, H]
        order = torch.argsort(dst)
        dst_sorted = dst[order]
        attn_sorted = attn_logits[order]  # [E, H]
        uniq, counts = torch.unique_consecutive(dst_sorted, return_counts=True)
        weights_sorted = torch.zeros_like(attn_sorted)
        start = 0
        for c in counts.tolist():
            seg = attn_sorted[start:start+c]  # [c, H]
            seg = torch.softmax(seg, dim=0)
            weights_sorted[start:start+c] = seg
            start += c
        inv_order = torch.empty_like(order)
        inv_order[order] = torch.arange(E, device=order.device)
        alpha = weights_sorted[inv_order]  # [E, H]
        msg = x_i * alpha.unsqueeze(-1)  # [E, H, O]
        out = torch.zeros(N, H, x_i.size(-1), device=x.device, dtype=x.dtype)
        for e in range(E):
            out[dst[e]] += msg[e]
        out = out.reshape(N, H * x_i.size(-1))
        out = self.dropout(out)
        if self.residual and self.res_skip is not None:
            out = out + self.res_skip(x)
        return out, alpha.mean(dim=1)

class GATv2IDS(nn.Module):
    def __init__(self, in_dim_node: int, in_dim_edge: int, hidden: int = 64, layers: int = 2, heads: int = 2, dropout: float = 0.1, num_classes: int = 2):
        super().__init__()
        self.layers = nn.ModuleList()
        # First layer takes in_dim_node, subsequent layers take hidden*heads as input
        in_dims = [in_dim_node] + [hidden * heads] * (layers - 1)
        for li in range(layers):
            self.layers.append(SimpleGATv2Layer(
                in_dim=in_dims[li],
                out_dim=hidden,
                edge_dim=in_dim_edge,
                heads=heads,
                dropout=dropout,
                residual=True
            ))
        self.head = nn.Sequential(
            nn.Linear(hidden * heads, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, num_classes)
        )

    def forward(self, data: GraphData) -> Dict[str, Any]:
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr
        last_edge_attn = None
        for layer in self.layers:
            x, edge_attn = layer(x, edge_index, edge_attr)
            x = torch.relu(x)
            last_edge_attn = edge_attn
        logits = self.head(x)
        return {"node_logits": logits, "edge_attn": last_edge_attn, "node_emb": x}
