import torch.nn as nn
import torch.nn.functional as F


ACT = {
    "relu": F.relu,
    "sigmoid": F.sigmoid
}


class EFG(nn.Module):
    def __init__(self, in_channels, out_channels, layers, layer_config, act="relu", F_dropout=0.2):
        super().__init__()

        self.layers = nn.ModuleList()
        for i, conv in enumerate(layers):
            kwargs = layer_config[i]['kwargs']
            if i != len(layers) - 1:
                layer = conv(in_channels,
                             layer_config[i]['out_channels'],
                             layer_config[i]['function'],
                             layer_config[i]['fun_kwargs'],
                             **kwargs)

                in_channels = layer_config[i]['out_channels']
            else:
                layer = conv(in_channels,
                             out_channels,
                             layer_config[i]['function'],
                             layer_config[i]['fun_kwargs'],
                             **kwargs)

            self.layers.append(layer)

        print(self.layers)

        self.act = ACT[act]
        self.dropout = F_dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        for i, conv in enumerate(self.layers):
            x = conv(x, edge_index)
            if i != len(self.layers) - 1:
                x = self.act(x)
                x = F.dropout(x, p=self.dropout, training=self.training)

        return F.log_softmax(x, dim=-1)


