"""
SignNet https://arxiv.org/abs/2202.13013
based on https://github.com/cptq/SignNet-BasisNet
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.nn import GINConv
from torch_scatter import scatter
from custom_modules.network.attention import ElementwiseVectorMul


@register_node_encoder("RotPE")
class RotatePENodeEncoder(torch.nn.Module):
    def __init__(self, dim_pos_emb=None):
        super().__init__()
        pecfg = cfg.posenc_RotPE
        self.encoder = ElementwiseVectorMul(pecfg.eigen.max_freqs, 1, bias=False)

    def forward(self, batch, return_pos_only=False):
        if return_pos_only:
            batch = self.encoder(batch).squeeze(1)
            return batch
        batch.eigvecs_sn = self.encoder(batch.eigvecs_sn).squeeze(1)
        batch.pos_latents = self.encoder(batch.pos_latents).squeeze(1)

        return batch
