import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric.nn.inits import glorot


class SimplePrompt(nn.Module):
    def __init__(self, in_channels: int):
        super(SimplePrompt, self).__init__()
        self.global_emb = nn.Parameter(torch.Tensor(1, in_channels))
        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.global_emb)

    def add(self, x: Tensor):
        return x + self.global_emb

    def mul(self, x: Tensor):
        return x * self.global_emb