import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
import dgl
from dgl.nn import EGATConv, GATConv

class GAT_layer(nn.Module):
    def __init__(self, in_vertex_dim, out_vertex_dim, num_heads):
        super().__init__()
        self.LeakyReLU = nn.LeakyReLU()
        self.gatconv = GATConv(in_feats=in_vertex_dim, out_feats=out_vertex_dim, num_heads=num_heads, bias=True)
        self.node_norm_layers = nn.LayerNorm(out_vertex_dim)
        self.edge_norm_layers = nn.LayerNorm(num_heads)
        self.edge_linear = nn.Sequential(
            nn.Linear(num_heads, 1),
            nn.ReLU()
        )

    def forward(self, graph, point_attr, edge_attr):
        point_attr, edge_attr = self.gatconv(graph, point_attr, edge_attr, get_attention=True)
        point_attr = self.node_norm_layers(point_attr)
        edge_attr = self.edge_norm_layers(edge_attr.flatten(1))
        point_attr, edge_attr = self.LeakyReLU(point_attr).flatten(1), self.edge_linear(edge_attr).squeeze(1)
        return point_attr, edge_attr

class GATConvNetwork(nn.Module):
    def __init__(self, in_vertex_dim=4, in_edge_dim=6, hidden_dim=256, num_heads=2, out_vertex_dim=1, out_edge_dim=2):
        super().__init__()
        self.in_vertex_layers = nn.ModuleList([
            nn.Linear(in_vertex_dim, 32),
            nn.Linear(32, 64),
            nn.Linear(64, hidden_dim)
        ])
        self.in_edge_layers = nn.ModuleList([
            nn.Linear(in_edge_dim, 32),
            nn.Linear(32, 64),
            nn.Linear(64, 1)
        ])
        self.out_node_layers = nn.ModuleList([
            nn.Linear(num_heads * hidden_dim, 64),
            nn.Linear(64, 32),
            nn.Linear(32, out_vertex_dim)
        ])
        self.out_edge_layers = nn.ModuleList([
            nn.Linear(1, 64),
            nn.Linear(64, 32),
            nn.Linear(32, out_edge_dim)
        ])
        self.gatconv1 = GAT_layer(hidden_dim, hidden_dim, num_heads)
        self.gatconv2 = GAT_layer(num_heads * hidden_dim, hidden_dim, num_heads)
        self.gatconv3 = GAT_layer(num_heads * hidden_dim, hidden_dim, num_heads)
        self.gatconv4 = GAT_layer(num_heads * hidden_dim, hidden_dim, num_heads)

    def forward(self, data):
        point_attr, edge_index, edge_attr = data.x, data.edge_index.t(), data.edge_attr
        graph = dgl.graph((edge_index[:, 0], edge_index[:, 1]))
        
        for layer in self.in_vertex_layers:
            point_attr = layer(point_attr)
            point_attr = F.relu(point_attr)
        
        for layer in self.in_edge_layers:
            edge_attr = layer(edge_attr)
            edge_attr = F.relu(edge_attr)
        edge_attr = edge_attr.squeeze(1)
        
        point_attr, edge_attr = self.gatconv1(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv2(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv3(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv4(graph, point_attr, edge_attr)
        
        for layer in self.out_node_layers:
            point_attr = layer(point_attr)
        
        edge_attr = edge_attr.unsqueeze(1)
        for layer in self.out_edge_layers:
            edge_attr = layer(edge_attr)
        
        return point_attr, edge_attr