from torch import nn
from torch.nn import functional as F

class BaseModel(nn.Module):
    def __init__(self, hparams, encoder, decoder):
        super(BaseModel, self).__init__()

        self.nb_hidden_layers = hparams['nb_hidden_layers']
        self.size_hidden_layers = hparams['size_hidden_layers']
        self.enc_dim = hparams['encoder'][-1]
        self.dec_dim = hparams['decoder'][0]
        self.bn_bool = hparams['bn_bool']
        self.res_bool = hparams['res_bool']
        self.activation = F.gelu

        self.encoder = encoder
        self.decoder = decoder
        
        self.in_layer = self._in_layer(hparams)
        self.hidden_layers = self._hidden_layers(hparams)
        self.out_layer = self._out_layer(hparams)
        self.bn = self._bn(hparams)
    
    def _in_layer(self, hparams):
        raise NotImplementedError

    def _hidden_layers(self, hparams):
        raise NotImplementedError

    def _out_layer(self, hparams):
        raise NotImplementedError
    
    def _bn(self, hparams):
        bn = None
        if self.bn_bool:
            bn = nn.ModuleList()
            for n in range(self.nb_hidden_layers):
                bn.append(nn.BatchNorm1d(self.size_hidden_layers, track_running_stats = False))
        return bn

    def forward(self, data):
        z, edge_index = data.x, data.edge_index

        if hasattr(self, 'get_edge_attr'):
            edge_attr = self.get_edge_attr(z, edge_index)

        z = self.encoder(z)

        if self.enc_dim == self.dec_dim:
            z_in = z
        
        if hasattr(self, 'get_edge_attr'):
            z = self.in_layer(z, edge_index, edge_attr)
        else:
            z = self.in_layer(z, edge_index)

        if self.bn_bool:
            z = self.bn[0](z)
        z = self.activation(z)

        for n in range(self.nb_hidden_layers - 1):
            if hasattr(self, 'res_bool') and self.res_bool:
                z_res = z
            
            if hasattr(self, 'get_edge_attr'):
                z = self.hidden_layers[n](z, edge_index, edge_attr)
            else:
                z = self.hidden_layers[n](z, edge_index)

            if self.bn_bool:
                z = self.bn[n + 1](z)
            z = self.activation(z)
            if hasattr(self, 'res_bool') and self.res_bool:
                z = z + z_res

        if hasattr(self, 'get_edge_attr'):
            z = self.out_layer(z, edge_index, edge_attr)
        else:
            z = self.out_layer(z, edge_index)

        if self.enc_dim == self.dec_dim:
            z = z + z_in

        z = self.decoder(z)

        return z