import dgl
import dgl.nn as dglnn

import torch
import torch.nn as nn
import torch.nn.functional as F


class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads=4, n_transformer=5):
        super().__init__()
        self.n_transformer = n_transformer
        self.gat_layers = nn.ModuleList()
        # two-layer GAT
        self.gat_layers.append(
            dglnn.GATConv(
                in_size,
                hid_size,
                heads,
                feat_drop=0.1,
                attn_drop=0.1,
                activation=F.elu,
            )
        )
        for i in range(n_transformer-2):
            self.gat_layers.append(
                dglnn.GATConv(
                    hid_size * heads,
                    hid_size,
                    heads,
                    feat_drop=0.1,
                    attn_drop=0.1,
                    activation=F.elu,
                )
            )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads,
                out_size,
                heads,
                feat_drop=0.1,
                attn_drop=0.1,
                activation=None,
            )
        )

    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h, edge_weight=g.edata['weight'])
            # h = layer(g, h, edge_weight=None)
            if i == self.n_transformer - 1:  # last layer
                h = h.mean(1)
            else:  # other layer(s)
                h = h.flatten(1)
        return h


if __name__ == '__main__':
    # Example Model Initialization
    in_features = 5   # Number of input node features
    hidden_features = 64
    out_features = 2   # Output classes

    model = GAT(in_features, hidden_features, out_features, heads=8)

    # Define edges (source and destination nodes)
    edges_src = torch.tensor([0, 1, 2, 3, 4])
    edges_dst = torch.tensor([1, 2, 3, 4, 0])


    # Create DGL Graph
    g = dgl.graph((edges_src, edges_dst))
    g.edata['weight'] = torch.rand(5)
    g = dgl.add_self_loop(g)  # Adding self-loops improves stability

    print("Graph:", g)

    num_nodes = g.num_nodes()
    node_features = torch.randn(num_nodes, in_features)  # Random input features
    print("Node features shape:", node_features.shape)

    output = model(g, node_features)
    print("Output shape:", output.shape)
    print("Output:", output)
