import torch
import torch.nn as nn



class Attention(nn.Module):


    def __init__(self, embed_dim: int, hidden_dim: int = 16):

        super(Attention, self).__init__()

        self.embed_dims = embed_dim
        self.bias = nn.Parameter(torch.tensor([1.0])) # init

        self.emb_linear_node = nn.Sequential(nn.Linear(self.embed_dims, hidden_dim), nn.ReLU())
        self.emb_linear_hedge = nn.Sequential(nn.Linear(self.embed_dims, hidden_dim), nn.ReLU())

        
        self._xavier_unif_init()



    def _xavier_unif_init(self):

        def init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        self.emb_linear_node.apply(init_weights)
        self.emb_linear_hedge.apply(init_weights)

    
    
    def forward(self, n_x, e_x, is_source):

        w = 1
        if is_source:
            w = 1
        else:
            w = self.bias
        out_n = self.emb_linear_node(n_x) 
        # print(f"{out_n=}")
        out_e = w * self.emb_linear_hedge(e_x)
        # print(f"{out_e=}")
        out_n = torch.reshape(out_n, (-1, 16))
        out_e = torch.reshape(out_e, (-1, 16))

        out_e = torch.permute(out_e,(1,0))

        out = torch.matmul(out_n, out_e)
        return out