import torch
from torch_geometric.nn.inits import glorot
import torch.nn.functional as F
import ipdb


class GPF(torch.nn.Module):
    def __init__(self, in_channels: int, device: str):
        super(GPF, self).__init__()
        self.global_emb = torch.nn.Sequential(torch.nn.Linear(1, in_channels, bias=False))
        self.device = device

    def reset_parameters(self):
        glorot(self.global_emb)

    def forward(self, x: torch.Tensor):
        out = self.global_emb(torch.ones(1).to(self.device))
        x = x + out
        
        return x

class GPF_plus(torch.nn.Module):
    def __init__(self, in_channels: int, p_num: int):
        super(GPF_plus, self).__init__()
        self.p_list = torch.nn.Linear(p_num, in_channels, bias=False)

        self.a = torch.nn.Linear(in_channels, p_num)
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.p_list)
        self.a.reset_parameters()
    
    def forward(self, x: torch.Tensor):
        score = self.a(x)
        weight = F.softmax(score, dim=1)
        p = self.p_list(weight)
        x = x+p
        return x

