import torch
import torch.nn as nn
from torch.nn import Parameter

from script.models.utils import get_activation
from script.manifolds.product import Product
from script.manifolds.euclidean import Euclidean
from script.manifolds.spherical import Spherical
from script.manifolds.hyperboloid import Hyperboloid
from script.models.kGCN.layers import kGraphConvolution



class kGCN(nn.Module):
    def __init__(self, args):
        super(kGCN, self).__init__()
        self.device = args.device
        self.input_dim = args.nfeat
        self.hidden_dim = args.nhid
        self.output_dim = args.nout # nout = n_classes
        self.curvature = args.curvature
        self.task = args.task

        # load product manifold
        self.manifolds_array = []

        assert args.prod_manifold_e + args.prod_manifold_s + args.prod_manifold_h == args.nhid

        if args.prod_manifold_e > 0:
            self.manifolds_array.append((Euclidean(), args.prod_manifold_e))

        if args.prod_manifold_s > 0:
            self.manifolds_array.append((Spherical(), args.prod_manifold_s))

        if args.prod_manifold_h > 0:
            self.manifolds_array.append((Hyperboloid(), args.prod_manifold_h))

        self.manifold = Product(self.manifolds_array)

        # load a list of encoders
        encs = []
        self.total_dim = 0
        if len(self.manifolds_array) > 0:
            self.manifold.calc_indices(self.manifolds_array, self.input_dim)
            for i, m in enumerate(self.manifolds_array):
                split = self.manifold.indices[i][0]
                hid_dim = split[1] - split[0]
                enc = nn.ModuleList()
                enc.append(kGraphConvolution(m[0], hid_dim, hid_dim, c_in = self.curvature, c_out = self.curvature,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias))
                enc.append(kGraphConvolution(m[0], hid_dim, hid_dim, c_in = self.curvature, c_out = self.curvature,
                                          act=get_activation(args.act), dropout=args.dropout, use_bias=args.bias))
                self.total_dim += hid_dim

                encs.append(enc)
            self.encoders = nn.ModuleList(encs)
        else:
            print("Error when initializing kGCN")
                    
        self.feat = Parameter((torch.ones(args.num_nodes, self.input_dim)), requires_grad=True)
        self.layer3 = nn.Linear(self.total_dim, self.output_dim, False)
        self.pre_layer = nn.Linear(self.input_dim, self.hidden_dim)

    #def initHyperX(self, x, c=1.0):
    #    o = torch.zeros_like(x)
    #    x = torch.cat([o[:, 0:1], x], dim=1)
    #    return self.toHyperX(x, c)

    #def toHyperX(self, x, c=1.0):
    #    x_tan = self.manifold.proj_tan0(x, c)
    #    x_hyp = self.manifold.expmap0(x_tan, c)
    #    x_hyp = self.manifold.proj(x_hyp, c)
    #    return x_hyp

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat
        x = self.pre_layer(x)

        self.manifold.calc_indices(self.manifolds_array, self.total_dim)
        embs = []
        embs_no_tangent = []        
        for j, split in enumerate(self.manifold.indices):
            man = self.manifolds_array[j][0]
            s = self.manifold.indices[j][0]
            encoder = self.encoders[j]
            start, end = s
            x_ = x[:, start:end]
            #x_ = man.proj_tan0(x_, self.curvature)
            x_ = man.expmap0(x_, self.curvature)
            x_ = man.proj(x_, self.curvature)
            for layer in encoder:
                 x_ = layer(x_, edge_index)
            embs_no_tangent.append(x_)
            x_ =man.logmap0(x_, self.curvature)
            embs.append(x_)
        
        if self.task == 'lp':
            h = torch.cat(embs_no_tangent, dim=1)
            return h

        h = torch.cat(embs, dim=1)
        output = self.layer3(h)

        return output
    
    def decoding_lp(self, z, edge_index):
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        d = None
        for j, split in enumerate(self.manifold.indices):
            man = self.manifolds_array[j][0]
            s = self.manifold.indices[j][0]
            start, end = s
            x_ = z[:, start:end]
            z_i = torch.nn.functional.embedding(edge_i, x_)
            z_j = torch.nn.functional.embedding(edge_j, x_)
            dist = man.sqdist(z_i, z_j, self.curvature).squeeze()
            if d is None:
                d = dist
            else:
                d = d + dist # could change to min/max
        return d