import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


class ClusterGCN(nn.Module):
    def __init__(self, num_features, num_classes, convs=None, conv_args=None, **kwargs):
        super().__init__()

        layers = []
        in_features = num_features
        for layer in range(kwargs['layers']):
            conv = SAGEConv(in_features,
                            kwargs['hid_features'][layer])
            layers.append(conv)
            in_features = kwargs['hid_features'][layer]

        output_layer = SAGEConv(in_features,
                                num_classes)
        layers.append(output_layer)

        self.layers = nn.ModuleList(layers)
        print(self.layers)

        self.dropout = kwargs['F_dropout']

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        for i, conv in enumerate(self.layers):
            x = conv(x, edge_index)

            if i != len(self.layers) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        return F.log_softmax(x, dim=-1)

