import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import EGATConv
from transformer import SimpleTransformer

class EGAT_layer(nn.Module):
    def __init__(self, in_vertex_dim, in_edge_dim, out_vertex_dim, out_edge_dim, num_heads, hidden_dim):
        super().__init__()
        self.LeakyReLU = nn.LeakyReLU()
        self.gatconv = EGATConv(
            in_node_feats=in_vertex_dim,
            in_edge_feats=in_edge_dim,
            out_node_feats=out_vertex_dim,
            out_edge_feats=out_edge_dim,
            num_heads=num_heads,
            bias=True
        )
        self.node_norm_layers = nn.LayerNorm(hidden_dim)
        self.edge_norm_layers = nn.LayerNorm(hidden_dim)
        self.vertex_linear = nn.Sequential(
            nn.Linear(num_heads * out_vertex_dim, out_vertex_dim),
            nn.ReLU()
        )
        self.edge_linear = nn.Sequential(
            nn.Linear(num_heads * out_edge_dim, out_edge_dim),
            nn.ReLU()
        )

    def forward(self, graph, point_attr, edge_attr):
        point_attr, edge_attr = self.gatconv(graph, point_attr, edge_attr)
        point_attr = self.node_norm_layers(point_attr)
        edge_attr = self.edge_norm_layers(edge_attr)
        point_attr, edge_attr = self.LeakyReLU(point_attr).flatten(1), self.LeakyReLU(edge_attr).flatten(1)
        return point_attr, edge_attr


class TransformerModel(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, n_layers, dropout):
        super().__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, n_layers)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        output = self.dropout(src)
        output = self.encoder(output)
        output = self.dropout(output)
        return output


class EGATTransformerNetwork(nn.Module):
    def __init__(self, in_vertex_dim=4, in_edge_dim=6, num_heads=2, hidden_dim=256, out_vertex_dim=1, out_edge_dim=2):
        super().__init__()
        self.in_vertex_layers = nn.ModuleList([
            nn.Linear(in_vertex_dim, 32),
            nn.Linear(32, 64),
            nn.Linear(64, hidden_dim)
        ])
        self.in_edge_layers = nn.ModuleList([
            nn.Linear(in_edge_dim, 32),
            nn.Linear(32, 64),
            nn.Linear(64, hidden_dim)
        ])

        self.node_transformer = SimpleTransformer(
            dim=hidden_dim * num_heads,
            n_layers=2,
            n_heads=8,
            head_dim=32,
            hidden_dim=256
        )
        self.edge_transformer = SimpleTransformer(
            dim=hidden_dim * num_heads,
            n_layers=2,
            n_heads=8,
            head_dim=32,
            hidden_dim=256
        )

        self.gatconv1 = EGAT_layer(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_heads, hidden_dim)
        self.gatconv2 = EGAT_layer(hidden_dim * num_heads, hidden_dim * num_heads, hidden_dim, hidden_dim, num_heads, hidden_dim)
        self.gatconv3 = EGAT_layer(hidden_dim * num_heads, hidden_dim * num_heads, hidden_dim, hidden_dim, num_heads, hidden_dim)
        self.gatconv4 = EGAT_layer(hidden_dim * num_heads, hidden_dim * num_heads, hidden_dim, hidden_dim, num_heads, hidden_dim)

        self.node_out_layers = nn.ModuleList([
            nn.Linear(hidden_dim * num_heads, hidden_dim),
            nn.Linear(hidden_dim, 64),
            nn.Linear(64, 32),
            nn.Linear(32, out_vertex_dim)
        ])
        self.edge_out_layers = nn.ModuleList([
            nn.Linear(hidden_dim * num_heads, hidden_dim),
            nn.Linear(hidden_dim, 64),
            nn.Linear(64, 32),
            nn.Linear(32, out_edge_dim)
        ])

    def forward(self, data):
        point_attr, edge_index, edge_attr = data.x, data.edge_index.t(), data.edge_attr
        graph = dgl.graph((edge_index[:, 0], edge_index[:, 1]))
        graph.ndata['feat'] = point_attr
        graph.edata['feat'] = edge_attr


        for layer in self.in_vertex_layers:
            point_attr = layer(point_attr)
            point_attr = F.relu(point_attr)

        for layer in self.in_edge_layers:
            edge_attr = layer(edge_attr)
            edge_attr = F.relu(edge_attr)

        point_attr, edge_attr = self.gatconv1(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv2(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv3(graph, point_attr, edge_attr)
        point_attr, edge_attr = self.gatconv4(graph, point_attr, edge_attr)

        point_attr = self.node_transformer(point_attr.unsqueeze(0)).squeeze(0)
        edge_attr = self.edge_transformer(edge_attr.unsqueeze(0)).squeeze(0)

        for layer in self.node_out_layers:
            point_attr = layer(point_attr)

        for layer in self.edge_out_layers:
            edge_attr = layer(edge_attr)

        return point_attr, edge_attr