import torch.nn as nn
import torch.nn.functional as F
from .layers import layers


ACT = {
    "relu": F.relu,
    "sigmoid": F.sigmoid
}


class Net(nn.Module):
    def __init__(
            self,
            config,
            num_classes,
            num_features,
    ):
        super().__init__()

        self.layers = nn.ModuleList()
        in_features = num_features
        convs = config['layers']
        for i, conv in enumerate(convs):
            module = layers[conv['name']]
            if i != len(convs) - 1:
                layer = module(in_features,
                               conv['hid_features'],
                               **conv['kwargs'])
                in_features = conv['out_features']
            else:
                layer = module(in_features,
                               num_classes,
                               **conv['kwargs'])
            self.layers.append(layer)

        print(self.layers)

        self.act = config['act']
        self.dropout = config['dropout']

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        for i, conv in enumerate(self.layers):
            x = conv(x, edge_index)
            if i != len(self.layers) - 1:
                x = ACT[self.act](x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        return F.log_softmax(x, dim=-1)


