import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv


class GAT(nn.Module):
    def __init__(self, num_features, num_classes, convs=None, conv_args=None, **kwargs):
        super().__init__()

        layers = []
        in_features = num_features
        for layer in range(len(kwargs['hid_features'])):
            conv = GATConv(in_features,
                           kwargs['hid_features'][layer],
                           concat=kwargs['hid_concatenation'][layer],
                           heads=kwargs['hid_heads'][layer],
                           dropout=kwargs['dropout'][layer])

            layers.append(conv)
            in_features = kwargs['hid_features'][layer] * kwargs['hid_heads'][layer]

        output_layer = GATConv(in_features,
                              num_classes,
                              concat=False,
                              heads=kwargs['out_heads'],
                              dropout=0.2)

        layers.append(output_layer)

        self.layers = nn.ModuleList(layers)
        print(self.layers)

        self.dropout = kwargs['F_dropout']

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        for conv_layer in  self.layers[:-1]:
            x = conv_layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.layers[-1](x, edge_index)

        return F.log_softmax(x, dim=-1)

