import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.inits import glorot
import math

from script.models.utils import zeros, get_activation
from script.models.GCN.layers import GraphConvolution

class GCN(nn.Module):
    def __init__(self, args):
        super(GCN, self).__init__()
        self.device = args.device
        self.input_dim = args.nfeat
        self.hidden_dim = args.nhid
        self.output_dim = args.nout # nout = n_classes

        self.layer1 = GraphConvolution(self.input_dim, self.hidden_dim, get_activation(args.act), args.dropout, args.bias)
        self.layer2 = GraphConvolution(self.hidden_dim, self.output_dim, get_activation(args.act), args.dropout, args.bias)
        #self.layer3 = GraphConvolution(self.hidden_dim, self.hidden_dim, get_activation(args.act), args.dropout, args.bias)

        self.linear = nn.Linear(self.hidden_dim, self.output_dim, args.bias)

        self.feat = Parameter((torch.ones(args.num_nodes, self.input_dim)), requires_grad=True)

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat

        h = self.layer1(x, edge_index)
        output = self.layer2(h, edge_index)
        
        return output

    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = (z_i * z_j).sum(dim=1)
        return dist