from torch import nn
from torch.nn import functional as F

from cfd.models.base import BaseModel
from cfd.models.utils import MLP

class NN(BaseModel):
    def __init__(self, hparams, encoder, decoder):
        self.enc_dim = hparams['encoder'][-1]
        self.dec_dim = hparams['decoder'][0]
        super(NN, self).__init__(hparams, encoder, decoder)

    def _in_layer(self, hparams):
        return MLP(
            [self.enc_dim, self.size_hidden_layers],
        )

    def _hidden_layers(self, hparams):
        hidden_layers = nn.ModuleList()
        for n in range(self.nb_hidden_layers - 1):
            hidden_layers.append(MLP(
                [self.size_hidden_layers, self.size_hidden_layers],
            ))
        return hidden_layers
        
    def _out_layer(self, hparams):
        return MLP(
            [self.size_hidden_layers, self.dec_dim],
        )

        return z