import torch
import torch.nn as nn
from torch.nn import Parameter

from script.models.utils import get_activation
from script.manifolds.pseudohyperboloid_sr import PseudoHyperboloidSR
from script.models.QGCN.layers import PseudoGraphConvolution
from torch_geometric.nn import GCNConv

import torch
import torch.nn as nn
from torch.nn import Parameter

from script.models.utils import get_activation
from script.manifolds.pseudohyperboloid_sr import PseudoHyperboloidSR
from script.models.QGCN.layers import PseudoGraphConvolution
from torch_geometric.nn import GCNConv

class QGCN(nn.Module):
    def __init__(self, args):
        super(QGCN, self).__init__()
        self.device = args.device
        self.input_dim = args.nfeat
        self.time_dim = args.time_dim
        self.space_dim = args.space_dim
        self.beta = nn.Parameter(torch.tensor(args.beta), requires_grad=False)
        args.nhid = self.time_dim + self.space_dim + 1
        self.hidden_dim = args.nhid
        self.output_dim = args.nout # nout = n_classes
        self.manifold = PseudoHyperboloidSR(self.time_dim, self.space_dim, self.beta)
        self.act = get_activation(args.act)
        self.task = args.task

        self.layer1 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.input_dim, self.hidden_dim,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
        self.layer2 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
        self.layer4 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
        self.layer5 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
                                          
        self.layer3 = nn.Linear(self.hidden_dim, self.output_dim, True)
        self.gcn2 = GCNConv(self.hidden_dim, self.output_dim, normalize=False, bias=args.bias)
        self.feat = Parameter((torch.ones(args.num_nodes, self.input_dim)), requires_grad=True)

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat

        # feature map
        x_t = self.manifold.proj_tan0(x, self.beta, time_dim=self.time_dim)
        x_t = self.manifold.expmap_sr_0(x_t, self.beta, time_dim=self.time_dim)

        # forward
        h = self.layer1(x_t, edge_index)
        h = self.layer2(h, edge_index)
        #h = self.layer4(h, edge_index)
        if self.task == 'lp':
            return h
        h = self.manifold.proj_tan0(self.manifold.logmap_sr_0(h, self.beta, time_dim=self.time_dim), self.beta, time_dim=self.time_dim)
        #output = self.gcn2(h, edge_index.long())
        output = self.layer3(h)

        return output

    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_i = torch.nn.functional.embedding(edge_i, z)
        z_j = torch.nn.functional.embedding(edge_j, z)
        dist = self.manifold.sqdist(z_i, z_j, self.beta, self.time_dim)
        return dist



# class QGCN(nn.Module):
#     def __init__(self, args):
#         super(QGCN, self).__init__()
#         self.device = args.device
#         self.input_dim = args.nfeat
#         self.time_dim = args.time_dim
#         self.space_dim = args.space_dim
#         self.beta = nn.Parameter(torch.tensor(args.beta), requires_grad=False)
#         args.nhid = self.time_dim + self.space_dim + 1
#         self.hidden_dim = args.nhid
#         self.output_dim = args.nout # nout = n_classes
#         self.manifold = PseudoHyperboloidSR(self.time_dim, self.space_dim, self.beta)
#         self.act = get_activation(args.act)
#         #self.gcn = GCNConv(self.input_dim, self.hidden_dim, normalize=False, bias=args.bias)
#         #self.linear = nn.Linear(self.input_dim, self.hidden_dim, True)

#         self.layer1 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.input_dim, self.hidden_dim,
#                                           act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
#         self.layer2 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
#                                           act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
#         self.layer4 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
#                                           act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
#         self.layer5 = PseudoGraphConvolution(self.time_dim, self.space_dim, self.beta, self.hidden_dim, self.hidden_dim,
#                                           act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias)
                                          
#         self.layer3 = nn.Linear(self.hidden_dim, self.output_dim, True)
#         self.gcn2 = GCNConv(self.hidden_dim, self.output_dim, normalize=False, bias=args.bias)
#         #self.feat = Parameter((torch.ones(args.num_nodes, self.input_dim)), requires_grad=True)

#     def forward(self, edge_index, x=None):
#         if x is None:
#             x = self.feat
#         #x = self.linear(x)
#         #x_ = self.act(self.gcn(x, edge_index.long()))

#         # feature map
#         x_t = self.manifold.proj_tan0(x, self.beta, time_dim=self.time_dim)
#         x_t = self.manifold.expmap_sr_0(x_t, self.beta, time_dim=self.time_dim)

#         # forward
#         h = self.layer1(x_t, edge_index)
#         h = self.layer2(h, edge_index)
#         h = self.layer4(h, edge_index)
#         h = self.layer5(h, edge_index)
#         h = self.manifold.proj_tan0(self.manifold.logmap_sr_0(h, self.beta, time_dim=self.time_dim), self.beta, time_dim=self.time_dim)
#         #output = self.gcn2(h, edge_index.long())
#         output = self.layer3(h)

#         return output
