import torch
import torch.nn as nn
from torch.nn import Parameter

from script.models.utils import get_activation
from script.manifolds.hyperboloid import Hyperboloid
from script.models.HGCN.layers import HypGraphConvolution

class HGCN(nn.Module):
    def __init__(self, args):
        super(HGCN, self).__init__()
        self.device = args.device
        self.input_dim = args.nfeat
        self.hidden_dim = args.nhid
        self.output_dim = args.nout # nout = n_classes
        self.manifold = Hyperboloid()
        self.curvature = args.curvature
        self.task = args.task

        self.layer1 = HypGraphConvolution(self.manifold, self.input_dim + 1, self.hidden_dim + 1, 
                                          c_in = self.curvature, c_out = self.curvature,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
        self.layer2 = HypGraphConvolution(self.manifold, self.hidden_dim + 1, self.hidden_dim + 1, 
                                          c_in = self.curvature, c_out = self.curvature,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
        self.layer3 = nn.Linear(self.hidden_dim + 1, self.output_dim, False)
        self.feat = Parameter((torch.ones(args.num_nodes, self.input_dim)), requires_grad=True)

    def initHyperX(self, x, c=1.0):
        o = torch.zeros_like(x)
        x = torch.cat([o[:, 0:1], x], dim=1)
        return self.toHyperX(x, c)

    def toHyperX(self, x, c=1.0):
        x_tan = self.manifold.proj_tan0(x, c)
        x_hyp = self.manifold.expmap0(x_tan, c)
        x_hyp = self.manifold.proj(x_hyp, c)
        return x_hyp

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat

        h = self.initHyperX(x, self.curvature)

        h = self.layer1(h, edge_index)
        h = self.layer2(h, edge_index)
        if self.task == 'lp':
            return h
        h = self.manifold.proj_tan0(self.manifold.logmap0(h, self.curvature), self.curvature)
        output = self.layer3(h)

        return output

    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = self.manifold.sqdist(z_i, z_j, self.curvature)
        return dist
