import torch
import torch.nn as nn
from torch.nn import Parameter

from script.models.utils import get_activation
from script.manifolds.hyperboloid import Hyperboloid
from script.models.Hypformer.layers import TransConv, GraphConv


class Hypformer(nn.Module):
    
    def __init__(self, args, gnn_num_layers=3, gnn_dropout=0.4, gnn_use_weight=False, gnn_use_init=False, gnn_use_bn=False,
                 gnn_use_residual=True, gnn_use_act=True, use_graph=True, graph_weight=0.8, aggregate='add'):
        """
        Initializes a HypFormer object.

        Args:
            in_channels (int): The number of input channels.
            hidden_channels (int): The number of hidden channels.
            out_channels (int): The number of output channels.
            trans_num_layers (int, optional): The number of layers in the TransConv module. Defaults to 1.
            trans_num_heads (int, optional): The number of attention heads in the TransConv module. Defaults to 1.
            trans_dropout (float, optional): The dropout rate in the TransConv module. Defaults to 0.5.
            trans_use_bn (bool, optional): Whether to use batch normalization in the TransConv module. Defaults to True.
            trans_use_residual (bool, optional): Whether to use residual connections in the TransConv module. Defaults to True.
            trans_use_weight (bool, optional): Whether to use learnable weights in the TransConv module. Defaults to True.
            trans_use_act (bool, optional): Whether to use activation functions in the TransConv module. Defaults to True.
            args (optional): Additional arguments.

        Raises:
            NotImplementedError: If the decoder_type is not 'euc' or 'hyp'.

        """
        super(Hypformer, self).__init__()
        self.manifold = Hyperboloid()
        self.c_in = nn.Parameter(torch.tensor(args.curvature), requires_grad=False)
        self.c_hidden = nn.Parameter(torch.tensor(args.curvature), requires_grad=False)
        self.c_out = nn.Parameter(torch.tensor(args.curvature_out), requires_grad=False)

        self.in_channels = args.nfeat
        self.hidden_channels = args.nhid
        self.out_channels = args.nout

        self.trans_conv = TransConv(self.manifold, self.c_in, self.c_hidden, self.c_out,
                                    self.in_channels, self.hidden_channels, get_activation(args.act), 
                                    args)

        self.graph_conv = GraphConv(self.in_channels, self.hidden_channels, gnn_dropout, gnn_num_layers, gnn_use_bn, 
                                    gnn_use_residual, gnn_use_weight, gnn_use_init, gnn_use_act)

        self.aggregate = aggregate
        self.use_graph = use_graph
        self.graph_weight = graph_weight
        self.decode_trans = nn.Linear(self.hidden_channels, self.out_channels)
        self.decode_graph = nn.Linear(self.hidden_channels, self.out_channels)
        self.feat = Parameter((torch.ones(args.num_nodes, self.in_channels)), requires_grad=True)
        

    def forward(self, edge_index, x=None):
        if x is None:
            x = self.feat
        x1 = self.trans_conv(x)
        if self.use_graph:
            x2 = self.graph_conv(x, edge_index) 
            a = self.manifold.logmap0(x1, self.c_out)[..., 1:]
            output = (1 - self.graph_weight) * self.decode_trans(a) + self.graph_weight * self.decode_graph(x2)
        else:
            output = self.decode_trans(self.manifold.logmap0(x1, self.c_out)[..., 1:])
        return output

    def decoding_lp(self, z, edge_index):
        # map z to hyperboloid
        edge_i = edge_index[0]
        edge_j = edge_index[1]
        z_ = torch.cat([torch.ones_like(z)[..., 0:1], z], dim=-1)
        z_ = self.manifold.expmap0(z_, self.c_out)
        z_i = torch.nn.functional.embedding(edge_i, z_)
        z_j = torch.nn.functional.embedding(edge_j, z_)
        dist = self.manifold.sqdist(z_i, z_j, self.c_out)
        return dist
