import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


class GraphSAGE(nn.Module):
    def __init__(self, in_channels, out_channels, convs=None, conv_args=None, **kwargs):
        super().__init__()

        layers = []
        in_features = in_channels
        for layer in range(kwargs['layers']):
            conv = SAGEConv(in_features,
                            kwargs['hid_features'][layer])

            layers.append(conv)
            in_features = kwargs['hid_features'][layer]

        output_layer = SAGEConv(in_features,
                                out_channels)
        layers.append(output_layer)

        self.layers = nn.ModuleList(layers)
        print(self.layers)

        self.dropout = kwargs['F_dropout']

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        for conv_layer in self.layers[:-1]:
            x = conv_layer(x, edge_index)
            x = F.sigmoid(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.layers[-1](x, edge_index)

        return F.log_softmax(x, dim=-1)

