""" GraphGPS """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter

from script.models.GraphGPS.layers import GPSLayer

class GraphGPS(nn.Module):
    def __init__(self, args):
        super(GraphGPS, self).__init__()
        self.in_channels = args.nfeat
        self.hidden_channels = args.nhid
        self.out_channels = args.nout
        self.num_layers = args.trans_num_layers
        self.num_heads = args.trans_num_heads

        self.pre_mp=nn.Linear(self.in_channels,self.hidden_channels)
        self.dropout=args.dropout
        self.layers = nn.ModuleList()
        for _ in range(self.num_layers):
            self.layers.append(GPSLayer(
                self.hidden_channels,
                self.num_heads,
                dropout=self.dropout,
                attn_dropout=self.dropout,
                use_bn=True,
            ))

        self.post_mp = nn.Linear(self.hidden_channels,self.out_channels)
        self.feat = Parameter((torch.ones(args.num_nodes, self.in_channels)), requires_grad=True)

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat
        x=self.pre_mp(x)
        x=F.relu(x)
        x=F.dropout(x,self.dropout,training=self.training)
        for layer in self.layers:
            x=layer(x,edge_index)
        x=self.post_mp(x)
        return x
    
    def reset_parameters(self):
        self.pre_mp.reset_parameters()
        self.post_mp.reset_parameters()
        for layer in self.layers:
            layer.reset_parameters()

    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = (z_i * z_j).sum(dim=1)
        return dist