import torch.nn as nn
import torch.nn.functional as F


class DeepMLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dim_pe = cfg.model.pos_encoder.emb_dim
        dim_in = cfg.model.num_eigenvecs
        hidden_dim = cfg.model.pos_encoder.hidden_dim
        self.layers = cfg.model.pos_encoder.n_hidden_layers
        assert self.layers >= 1

        # input -> hidden
        self.encoder = [
            nn.Linear(dim_in, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=cfg.model.ffn_dropout),
        ]

        # hidden -> hidden
        for _ in range(self.layers - 1):
            self.encoder.extend(
                [
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(p=cfg.model.ffn_dropout),
                ]
            )

        # hidden -> output
        self.encoder.append(nn.Linear(hidden_dim, self.dim_pe))

        self.encoder = nn.Sequential(*self.encoder)

    def forward(self, x, batch):
        return self.encoder(F.normalize(x, dim=-1))
