from torch import nn
from torch_geometric import nn as nng

from cfd.models.base import BaseModel

class GraphSAGE(BaseModel):
    def __init__(self, hparams, encoder, decoder):
        super(GraphSAGE, self).__init__(hparams, encoder, decoder)

    def _in_layer(self, hparams):
        return nng.SAGEConv(
            in_channels = self.enc_dim,
            out_channels = self.size_hidden_layers
        )

    def _hidden_layers(self, hparams):
        hidden_layers = nn.ModuleList()
        for n in range(self.nb_hidden_layers - 1):
            hidden_layers.append(nng.SAGEConv(
                in_channels = self.size_hidden_layers,
                out_channels = self.size_hidden_layers
            ))
        return hidden_layers

    def _out_layer(self, hparams): 
        return nng.SAGEConv(
                in_channels = self.size_hidden_layers,
                out_channels = self.dec_dim
            )
