import torch.nn as nn
import torch.nn.functional as F


class Linear(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dim_pe = cfg.model.pos_encoder.emb_dim
        self.in_dim = cfg.model.num_eigenvecs
        self.net = nn.Linear(self.in_dim, self.dim_pe)

    def forward(self, eigvecs, batch):
        return self.net(F.normalize(eigvecs, dim=-1))
